[training, perf] feat: Add eval-time context parallelism via decentralized PG rebinding#3755
[training, perf] feat: Add eval-time context parallelism via decentralized PG rebinding#3755
Conversation
…lized PG rebinding Enables validation to run with a higher CP degree than training, addressing the wasted-work problem when DP_train >> validation set size. The mechanism swaps cached process-group references on every module from train_pgs to eval_pgs for the duration of evaluate(), then restores them on exit. - src/megatron/bridge/training/eval_cp.py: install_pg_collection() walks the module tree and rebinds pg_collection, named CP-bearing group attrs, and TEDotProductAttention CP comm state via set_context_parallel_group(). RotaryEmbedding instances are intentionally skipped: their cp_group must stay at the train CP group because the embedding-level sequence scatter already produces a per-rank input of seq/(tp*cp_eval) tokens. Pre-slicing RoPE freqs by eval cp here would double-shard them, causing a shape mismatch in _apply_rotary_pos_emb_bshd. The @lru_cache on RotaryEmbedding forward is cleared defensively. - eval_cp_context() context manager wrapping install/restore. - DistributedInitConfig.eval_context_parallel_size and GlobalState .train_pgs / .eval_pgs fields wire the feature through the existing config / state surfaces. - examples/decentralized_pg/pretrain_qwen3_eval_cp.py: 4-GPU end-to-end demo on Qwen3-4B (TP=2, CP_train=1, CP_eval=2) covering training, baseline eval at CP=1, eval-CP at CP=2, and verify-restore at CP=1. Constraints: requires use_decentralized_pg=True; no FSDP / CUDA graphs / hierarchical CP; seq_length % (2 * cp_eval) == 0. Signed-off-by: Chen Cui <chcui@nvidia.com>
…ch correctness check + live-switch demo
Changes
-------
src/megatron/bridge/training/eval_cp.py
- On install_pg_collection, mutate the shared TransformerConfig.context_parallel_size
to the live cp_size. This is the only runtime read of config.cp on the standard
GPT eval path (RotaryEmbedding.get_rotary_seq_len). Without it, train CP=N → eval
CP=M with M != N would compute rotary_seq_len from the static training cp, then
apply RoPE freqs at the wrong positions — the forward would not crash but the
eval loss would be silently wrong.
- Drop the RotaryEmbedding skip; with config.cp now live, the cp_group rebinding
+ lru_cache clear produces freqs that match the query shape in both directions.
- Restoration on context exit is implicit: eval_cp_context calls install_pg_collection
a second time with train_pgs, which re-mutates config.cp back.
examples/decentralized_pg/pretrain_qwen3_eval_cp.py
- Step 7a (new): correctness check feeding the SAME deterministic batch through
eval at CP=cp_train and CP=cp_eval, asserting |Δ lm_loss| < 5e-3 (bf16 ring-attention
tolerance). Sidesteps the data iterator since DP_train != DP_eval would otherwise
feed different shards.
- Step 7 (replaces former Steps 7-10): manual interleaved train + eval loop that
calls eval_cp_context EVERY iteration, exercising live CP-group switching across
a tight train_step → evaluate cycle.
- Relax cp_eval > cp_train to cp_eval != cp_train (both directions now supported).
Verified end-to-end on 4 GPUs (TP=2, PP=1):
- cp_train=1 / cp_eval=2: Step 7a |Δ| = 2.07e-04, live-switch loop 3/3 OK
- cp_train=2 / cp_eval=1: Step 7a |Δ| = 2.07e-04, live-switch loop 3/3 OK
Signed-off-by: Chen Cui <chcui@nvidia.com>
…eval-CP example Adds Step 7b that times one full evaluate() call on the SAME fixed batch under each CP layout (warmup=3, timed=10) and reports mean/median/std + observed ratio. Sidesteps the data iterator the same way the verification step does so the two runs see identical input. The docstring is explicit that the bench only captures per-iter compute scaling vs CP comm overhead. It does NOT capture the dominant term in the 2.3x estimate in eval_cp_plan.md, which is the pipeline-bubble shrinkage at PP > 1. With PP=1 in this 4-GPU smoke test, expect CP=cp_eval to be modestly slower or modestly faster than CP=cp_train depending on whether comm hides behind compute at this model size — not a 2x win. Observed on 4xH100 (TP=2, PP=1, Qwen3-4B downsized): seq=1024, layers=2: CP=1 14.9 ms vs CP=2 20.7 ms -> 0.72x seq=8192, layers=4: CP=1 25.3 ms vs CP=2 33.7 ms -> 0.75x Both correctness checks PASS to ~2e-4 (bf16 ring-attention tolerance). Signed-off-by: Chen Cui <chcui@nvidia.com>
… 7b numbers with PP=2/PP=4 wins Step 7a indexed loss_dict["lm loss"] unconditionally, which only works when every rank is on the last PP stage. With PP > 1, first-stage ranks get an empty dict and the assertion crashed with KeyError. Broadcast the scalar from the last PP stage across the PP group so the assertion runs uniformly. With PP > 1 unblocked, Step 7b actually demonstrates the headline win: seq=1024, layers=4, TP=2 PP=1: CP=1 8.9 ms vs CP=2 12.4 ms -> 0.72x (regression) seq=8192, layers=8, TP=2 PP=1: CP=1 26.3 ms vs CP=2 21.3 ms -> 1.23x seq=8192, layers=8, TP=2 PP=2: CP=1 39.5 ms vs CP=2 32.2 ms -> 1.23x seq=8192, layers=8, TP=1 PP=4: CP=1 53.1 ms vs CP=2 32.3 ms -> 1.64x seq=16384, layers=8, TP=2 PP=2: CP=1 82.5 ms vs CP=2 46.5 ms -> 1.78x seq=16384, layers=8, TP=1 PP=4: CP=1 117.7 ms vs CP=2 62.8 ms -> 1.88x Refresh the Step 7b docstring + trailing Note to drop the PP=1-only caveat and walk through both compounding effects (O(N^2) attention + bubble shrink). Add PP=2/PP=4 invocations to the "How to Run" block; flag short-seq + PP=1 as the regression case. Verified on 8xH100; all step-7a correctness checks PASS at the listed shapes. Signed-off-by: Chen Cui <chcui@nvidia.com>
| f"world_size ({world_size}) must be divisible by TP*PP*CP " | ||
| f"({tp_size}*{pp_size}*{cp_size}={model_parallel_size})" | ||
| ) | ||
| dp_size = world_size // model_parallel_size |
There was a problem hiding this comment.
Nit: bare print() — repo convention requires print_rank_0() or logging.getLogger. The rank-0 guard makes it functionally equivalent, but switching to print_rank_0() would be consistent with the rest of this file.
| dp_size = world_size // model_parallel_size | |
| print_rank_0(f"\nCreating ProcessGroupCollection{tag}: TP={tp_size} CP={cp_size} DP={dp_size} PP={pp_size}") |
Light Review — eval-time context parallelism via decentralized PG rebindingThe core eval_cp.py module looks solid. The context-manager pattern with finally-based restore is the right approach, and the RotaryEmbedding skip is well-justified with the double-shard reasoning. The _GROUP_ATTRS single-source-of-truth dict is a clean design. Missing test coverage (blocking) — The PR body itself flags two unchecked items: (1) unit test for install_pg_collection rebinding semantics, and (2) 2-GPU functional test covering the eval-CP path. The install_pg_collection function is the heart of this feature — it mutates module state across the entire model tree, including TE internals and the shared TransformerConfig. A unit test that verifies rebinding semantics (correct attributes set, rotary cache cleared, config.context_parallel_size updated, restore on exception) would catch regressions from MCore changes without needing GPUs. This should land with or shortly after this PR. Minor: bare print() in example (inline comment posted) — Line 184 of the example uses print() instead of print_rank_0(). Functionally equivalent due to the rank guard, but inconsistent with the rest of the file. Suggested test cases: No perf tests impacted. |
Summary
Enables validation to run with a higher CP degree than training, addressing the wasted-work problem when DP_train is large and the validation set is small. The mechanism swaps cached process-group references on every module from
train_pgstoeval_pgsfor the duration ofevaluate(), then restores them on exit. Works only with the decentralized process-group path (use_decentralized_pg=True); no Megatron-Core changes required.Motivation
For an 8192-GPU run with TP=2, PP=4, CP_train=1 (so DP_train=1024) and a 1024-sample validation set, you must set GA=4 to keep the PP pipeline full, processing 4096 sample-steps — 4× redundant work vs evaluating each sample once. With CP_eval=4, DP_eval=256 covers all 1024 samples in one iter with no redundancy and ~CP_eval× lower activation memory.
What's added
src/megatron/bridge/training/eval_cp.py—install_pg_collection()walks the module tree and rebinds:pg_collection(used byTransformerLayer,DotProductAttention, …)cp_group,tp_cp_group,tp_dp_cp_group,dp_cp_group)TEDotProductAttentioninternal CP comm state viaset_context_parallel_group()RotaryEmbedding/MultimodalRotaryEmbeddinginstances are intentionally skipped — see the in-code comment for the double-shard bug this avoids. The@lru_cacheonforwardis cleared defensively.eval_cp_context()is a context manager wrapping install/restore.src/megatron/bridge/training/config.py—DistributedInitConfig.eval_context_parallel_size: int | None = Nonesrc/megatron/bridge/training/state.py—GlobalState.train_pgs/.eval_pgs(ProcessGroupCollection | None)examples/decentralized_pg/pretrain_qwen3_eval_cp.py— 4-GPU end-to-end demo on Qwen3-4B (TP=2, CP_train=1, CP_eval=2): training → baseline eval at CP=1 → eval-CP at CP=2 → verify-restore at CP=1.Constraints
dist.use_decentralized_pg=Trueseq_length % (2 * cp_eval) == 0Test plan
TP=2 CP_train=1 CP_eval=2, Qwen3-4B): train 3 iters, baseline eval CP=1 succeeds (loss 7.870), eval-CP CP=2 succeeds (loss 8.836), verify-restore eval CP=1 succeeds (loss 7.978)uv run pre-commit runpasses (ruff, ruff-format)install_pg_collectionrebinding semanticsNotes for reviewers
RotaryEmbedding.cp_groupdeliberately not updated. Withpre_processSP scatter at eval, the embedding is sharded acrosstp*cp_evalsodecoder_input.size(0) = seq/(tp*cp_eval). After QKV's TP AllGather, query is[seq/cp_eval, …].get_rotary_seq_lenmultiplies by tp_size and the config's cp_size (= cp_train), so freqs land atseq/cp_evalbefore CP slicing. If we then slice freqs bycp_evalwe'd double-shard them — concrete repro shows query(512,…)vs freqs(256,…)mismatch in_apply_rotary_pos_emb_bshd. Skipping the rotary cp_group update keeps freqs atseq/cp_eval, matching query._apply_rotary_pos_emb_bshdpath; the same logic also applies to standard RoPE.