Skip to content

[training, perf] feat: Add eval-time context parallelism via decentralized PG rebinding#3755

Open
cuichenx wants to merge 4 commits intomainfrom
feat/eval-time-context-parallelism
Open

[training, perf] feat: Add eval-time context parallelism via decentralized PG rebinding#3755
cuichenx wants to merge 4 commits intomainfrom
feat/eval-time-context-parallelism

Conversation

@cuichenx
Copy link
Copy Markdown
Contributor

@cuichenx cuichenx commented May 8, 2026

Summary

Enables validation to run with a higher CP degree than training, addressing the wasted-work problem when DP_train is large and the validation set is small. The mechanism swaps cached process-group references on every module from train_pgs to eval_pgs for the duration of evaluate(), then restores them on exit. Works only with the decentralized process-group path (use_decentralized_pg=True); no Megatron-Core changes required.

Motivation

For an 8192-GPU run with TP=2, PP=4, CP_train=1 (so DP_train=1024) and a 1024-sample validation set, you must set GA=4 to keep the PP pipeline full, processing 4096 sample-steps — 4× redundant work vs evaluating each sample once. With CP_eval=4, DP_eval=256 covers all 1024 samples in one iter with no redundancy and ~CP_eval× lower activation memory.

What's added

  • src/megatron/bridge/training/eval_cp.pyinstall_pg_collection() walks the module tree and rebinds:

    • pg_collection (used by TransformerLayer, DotProductAttention, …)
    • Named CP-bearing group attributes (cp_group, tp_cp_group, tp_dp_cp_group, dp_cp_group)
    • TEDotProductAttention internal CP comm state via set_context_parallel_group()

    RotaryEmbedding / MultimodalRotaryEmbedding instances are intentionally skipped — see the in-code comment for the double-shard bug this avoids. The @lru_cache on forward is cleared defensively.

    eval_cp_context() is a context manager wrapping install/restore.

  • src/megatron/bridge/training/config.pyDistributedInitConfig.eval_context_parallel_size: int | None = None

  • src/megatron/bridge/training/state.pyGlobalState.train_pgs / .eval_pgs (ProcessGroupCollection | None)

  • examples/decentralized_pg/pretrain_qwen3_eval_cp.py — 4-GPU end-to-end demo on Qwen3-4B (TP=2, CP_train=1, CP_eval=2): training → baseline eval at CP=1 → eval-CP at CP=2 → verify-restore at CP=1.

Constraints

  • Requires dist.use_decentralized_pg=True
  • No FSDP / CUDA graphs / hierarchical CP today
  • seq_length % (2 * cp_eval) == 0

Test plan

  • 4-GPU manual run (TP=2 CP_train=1 CP_eval=2, Qwen3-4B): train 3 iters, baseline eval CP=1 succeeds (loss 7.870), eval-CP CP=2 succeeds (loss 8.836), verify-restore eval CP=1 succeeds (loss 7.978)
  • uv run pre-commit run passes (ruff, ruff-format)
  • Add a unit test for install_pg_collection rebinding semantics
  • Add a 2-GPU functional test (within CI budget) covering the eval-CP path
  • Multi-node benchmark vs the original 8K-GPU motivating scenario

Notes for reviewers

  • RotaryEmbedding.cp_group deliberately not updated. With pre_process SP scatter at eval, the embedding is sharded across tp*cp_eval so decoder_input.size(0) = seq/(tp*cp_eval). After QKV's TP AllGather, query is [seq/cp_eval, …]. get_rotary_seq_len multiplies by tp_size and the config's cp_size (= cp_train), so freqs land at seq/cp_eval before CP slicing. If we then slice freqs by cp_eval we'd double-shard them — concrete repro shows query (512,…) vs freqs (256,…) mismatch in _apply_rotary_pos_emb_bshd. Skipping the rotary cp_group update keeps freqs at seq/cp_eval, matching query.
  • The example uses Qwen3 with YaRN RoPE end-to-end, which exercises the unfused _apply_rotary_pos_emb_bshd path; the same logic also applies to standard RoPE.

…lized PG rebinding

Enables validation to run with a higher CP degree than training, addressing
the wasted-work problem when DP_train >> validation set size. The mechanism
swaps cached process-group references on every module from train_pgs to
eval_pgs for the duration of evaluate(), then restores them on exit.

- src/megatron/bridge/training/eval_cp.py: install_pg_collection() walks the
  module tree and rebinds pg_collection, named CP-bearing group attrs, and
  TEDotProductAttention CP comm state via set_context_parallel_group().
  RotaryEmbedding instances are intentionally skipped: their cp_group must
  stay at the train CP group because the embedding-level sequence scatter
  already produces a per-rank input of seq/(tp*cp_eval) tokens. Pre-slicing
  RoPE freqs by eval cp here would double-shard them, causing a shape
  mismatch in _apply_rotary_pos_emb_bshd. The @lru_cache on RotaryEmbedding
  forward is cleared defensively.
- eval_cp_context() context manager wrapping install/restore.
- DistributedInitConfig.eval_context_parallel_size and GlobalState
  .train_pgs / .eval_pgs fields wire the feature through the existing
  config / state surfaces.
- examples/decentralized_pg/pretrain_qwen3_eval_cp.py: 4-GPU end-to-end demo
  on Qwen3-4B (TP=2, CP_train=1, CP_eval=2) covering training, baseline
  eval at CP=1, eval-CP at CP=2, and verify-restore at CP=1.

Constraints: requires use_decentralized_pg=True; no FSDP / CUDA graphs /
hierarchical CP; seq_length % (2 * cp_eval) == 0.

Signed-off-by: Chen Cui <chcui@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 8, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@cuichenx cuichenx added area:perf Performance optimizations and benchmarking area:training Training loop, callbacks, and runtime integration feature New capabilities, enhancements, or enablement work labels May 8, 2026
cuichenx added 2 commits May 8, 2026 13:04
…ch correctness check + live-switch demo

Changes
-------
src/megatron/bridge/training/eval_cp.py
  - On install_pg_collection, mutate the shared TransformerConfig.context_parallel_size
    to the live cp_size. This is the only runtime read of config.cp on the standard
    GPT eval path (RotaryEmbedding.get_rotary_seq_len). Without it, train CP=N → eval
    CP=M with M != N would compute rotary_seq_len from the static training cp, then
    apply RoPE freqs at the wrong positions — the forward would not crash but the
    eval loss would be silently wrong.
  - Drop the RotaryEmbedding skip; with config.cp now live, the cp_group rebinding
    + lru_cache clear produces freqs that match the query shape in both directions.
  - Restoration on context exit is implicit: eval_cp_context calls install_pg_collection
    a second time with train_pgs, which re-mutates config.cp back.

examples/decentralized_pg/pretrain_qwen3_eval_cp.py
  - Step 7a (new): correctness check feeding the SAME deterministic batch through
    eval at CP=cp_train and CP=cp_eval, asserting |Δ lm_loss| < 5e-3 (bf16 ring-attention
    tolerance). Sidesteps the data iterator since DP_train != DP_eval would otherwise
    feed different shards.
  - Step 7 (replaces former Steps 7-10): manual interleaved train + eval loop that
    calls eval_cp_context EVERY iteration, exercising live CP-group switching across
    a tight train_step → evaluate cycle.
  - Relax cp_eval > cp_train to cp_eval != cp_train (both directions now supported).

Verified end-to-end on 4 GPUs (TP=2, PP=1):
  - cp_train=1 / cp_eval=2: Step 7a |Δ| = 2.07e-04, live-switch loop 3/3 OK
  - cp_train=2 / cp_eval=1: Step 7a |Δ| = 2.07e-04, live-switch loop 3/3 OK

Signed-off-by: Chen Cui <chcui@nvidia.com>
…eval-CP example

Adds Step 7b that times one full evaluate() call on the SAME fixed batch under
each CP layout (warmup=3, timed=10) and reports mean/median/std + observed
ratio. Sidesteps the data iterator the same way the verification step does so
the two runs see identical input.

The docstring is explicit that the bench only captures per-iter compute scaling
vs CP comm overhead. It does NOT capture the dominant term in the 2.3x estimate
in eval_cp_plan.md, which is the pipeline-bubble shrinkage at PP > 1. With PP=1
in this 4-GPU smoke test, expect CP=cp_eval to be modestly slower or modestly
faster than CP=cp_train depending on whether comm hides behind compute at this
model size — not a 2x win.

Observed on 4xH100 (TP=2, PP=1, Qwen3-4B downsized):
  seq=1024, layers=2:   CP=1 14.9 ms vs CP=2 20.7 ms  -> 0.72x
  seq=8192, layers=4:   CP=1 25.3 ms vs CP=2 33.7 ms  -> 0.75x
Both correctness checks PASS to ~2e-4 (bf16 ring-attention tolerance).

Signed-off-by: Chen Cui <chcui@nvidia.com>
@cuichenx cuichenx marked this pull request as ready for review May 8, 2026 21:30
… 7b numbers with PP=2/PP=4 wins

Step 7a indexed loss_dict["lm loss"] unconditionally, which only works when
every rank is on the last PP stage. With PP > 1, first-stage ranks get an
empty dict and the assertion crashed with KeyError. Broadcast the scalar from
the last PP stage across the PP group so the assertion runs uniformly.

With PP > 1 unblocked, Step 7b actually demonstrates the headline win:

  seq=1024,  layers=4, TP=2 PP=1:  CP=1  8.9 ms vs CP=2 12.4 ms -> 0.72x (regression)
  seq=8192,  layers=8, TP=2 PP=1:  CP=1 26.3 ms vs CP=2 21.3 ms -> 1.23x
  seq=8192,  layers=8, TP=2 PP=2:  CP=1 39.5 ms vs CP=2 32.2 ms -> 1.23x
  seq=8192,  layers=8, TP=1 PP=4:  CP=1 53.1 ms vs CP=2 32.3 ms -> 1.64x
  seq=16384, layers=8, TP=2 PP=2:  CP=1 82.5 ms vs CP=2 46.5 ms -> 1.78x
  seq=16384, layers=8, TP=1 PP=4:  CP=1 117.7 ms vs CP=2 62.8 ms -> 1.88x

Refresh the Step 7b docstring + trailing Note to drop the PP=1-only caveat
and walk through both compounding effects (O(N^2) attention + bubble shrink).
Add PP=2/PP=4 invocations to the "How to Run" block; flag short-seq + PP=1
as the regression case.

Verified on 8xH100; all step-7a correctness checks PASS at the listed shapes.

Signed-off-by: Chen Cui <chcui@nvidia.com>
f"world_size ({world_size}) must be divisible by TP*PP*CP "
f"({tp_size}*{pp_size}*{cp_size}={model_parallel_size})"
)
dp_size = world_size // model_parallel_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: bare print() — repo convention requires print_rank_0() or logging.getLogger. The rank-0 guard makes it functionally equivalent, but switching to print_rank_0() would be consistent with the rest of this file.

Suggested change
dp_size = world_size // model_parallel_size
print_rank_0(f"\nCreating ProcessGroupCollection{tag}: TP={tp_size} CP={cp_size} DP={dp_size} PP={pp_size}")

@claude
Copy link
Copy Markdown
Contributor

claude Bot commented May 8, 2026

Light Review — eval-time context parallelism via decentralized PG rebinding

The core eval_cp.py module looks solid. The context-manager pattern with finally-based restore is the right approach, and the RotaryEmbedding skip is well-justified with the double-shard reasoning. The _GROUP_ATTRS single-source-of-truth dict is a clean design.

Missing test coverage (blocking) — The PR body itself flags two unchecked items: (1) unit test for install_pg_collection rebinding semantics, and (2) 2-GPU functional test covering the eval-CP path. The install_pg_collection function is the heart of this feature — it mutates module state across the entire model tree, including TE internals and the shared TransformerConfig. A unit test that verifies rebinding semantics (correct attributes set, rotary cache cleared, config.context_parallel_size updated, restore on exception) would catch regressions from MCore changes without needing GPUs. This should land with or shortly after this PR.

Minor: bare print() in example (inline comment posted) — Line 184 of the example uses print() instead of print_rank_0(). Functionally equivalent due to the rank guard, but inconsistent with the rest of the file.

Suggested test cases: No perf tests impacted.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:perf Performance optimizations and benchmarking area:training Training loop, callbacks, and runtime integration feature New capabilities, enhancements, or enablement work

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant