From 047658e5949aa4250723d7c1477522e54de8d305 Mon Sep 17 00:00:00 2001 From: Gao Deng Date: Mon, 20 Apr 2026 16:32:19 -0700 Subject: [PATCH] Add high-priority A2A stream and HybridEP preprocessing SMs Port the main-applicable parts of PR #4401. - Add high_priority_a2a_comm_stream for combined 1F1B A2A overlap stream creation. - Add moe_hybridep_num_sms_preprocessing and plumb it through HybridEP dispatch initialization. The dev PR's overflow_flag handle-index change targets moe_expert_rank_capacity_factor code that is not present on main, so it is intentionally omitted here. Original dev commit: 7ae10f2931e93242da5a38840749ad414b604dcc. --- megatron/core/pipeline_parallel/combined_1f1b.py | 4 ++-- megatron/core/pipeline_parallel/utils.py | 8 ++++++-- megatron/core/transformer/moe/fused_a2a.py | 12 ++++++++++++ megatron/core/transformer/moe/token_dispatcher.py | 1 + megatron/core/transformer/transformer_config.py | 7 +++++++ tests/unit_tests/models/test_hybrid_moe_model.py | 2 ++ 6 files changed, 30 insertions(+), 4 deletions(-) diff --git a/megatron/core/pipeline_parallel/combined_1f1b.py b/megatron/core/pipeline_parallel/combined_1f1b.py index f4f222ad2a1..ec689e8fe7f 100644 --- a/megatron/core/pipeline_parallel/combined_1f1b.py +++ b/megatron/core/pipeline_parallel/combined_1f1b.py @@ -52,7 +52,7 @@ def combined_1f1b_schedule_for_no_pipelining( Phases 4: 4th microbatch backward """ - set_streams() + set_streams(high_priority=config.high_priority_a2a_comm_stream) # The forward step for the first microbatch is executed alone, no a2a overlapping output_tensor, num_tokens, _ = combined_forward_backward_step( forward_step_func, @@ -178,7 +178,7 @@ def combined_1f1b_schedule_for_interleaved_pipelining(): # backward_step_helper_postprocess() """ - set_streams() + set_streams(high_priority=config.high_priority_a2a_comm_stream) # forward prepare f_model_chunk_id = None f_microbatch_id = None diff --git a/megatron/core/pipeline_parallel/utils.py b/megatron/core/pipeline_parallel/utils.py index f316806ead7..48ce3d34a3c 100644 --- a/megatron/core/pipeline_parallel/utils.py +++ b/megatron/core/pipeline_parallel/utils.py @@ -341,14 +341,18 @@ def run( _COMM_STREAM = None -def set_streams(comm_stream=None): +def set_streams(comm_stream=None, high_priority=False): """Set the stream for communication operations.""" global _COMM_STREAM # Set communication stream if _COMM_STREAM is None: if comm_stream is None: - comm_stream = torch.cuda.Stream(device="cuda") + if high_priority: + _, high = torch.cuda.Stream.priority_range() + comm_stream = torch.cuda.Stream(device="cuda", priority=high) + else: + comm_stream = torch.cuda.Stream(device="cuda") _COMM_STREAM = comm_stream diff --git a/megatron/core/transformer/moe/fused_a2a.py b/megatron/core/transformer/moe/fused_a2a.py index 586d70c400f..07f33deca6c 100644 --- a/megatron/core/transformer/moe/fused_a2a.py +++ b/megatron/core/transformer/moe/fused_a2a.py @@ -287,6 +287,7 @@ def init_hybrid_ep_buffer( num_blocks_permute: Optional[int] = None, num_blocks_unpermute: Optional[int] = None, fp8_dispatch: bool = False, + num_sms_preprocessing_api: Optional[int] = None, ) -> None: ''' Initialize the HybridEP buffer, including buffer allocation and metadata @@ -315,6 +316,8 @@ def init_hybrid_ep_buffer( Number of blocks used by the unpermute part. fp8_dispatch (bool): Whether to use FP8 communication during the dispatch phase. + num_sms_preprocessing_api (Optional[int]): + Number of SMs used by the preprocessing (metadata scan) kernel. ''' assert not fp8_dispatch, "HybridEP dispatcher does not support fp8 dispatch now" global _hybrid_ep_buffer @@ -327,6 +330,8 @@ def init_hybrid_ep_buffer( kwargs['num_blocks_permute'] = num_blocks_permute if num_blocks_unpermute is not None: kwargs['num_blocks_unpermute'] = num_blocks_unpermute + if num_sms_preprocessing_api is not None: + kwargs['num_sms_preprocessing_api'] = num_sms_preprocessing_api _hybrid_ep_buffer = HybridEPBuffer( group=group, hidden_dim=hidden_dim, @@ -365,6 +370,7 @@ def forward( fused=False, num_permuted_tokens=None, pad_multiple=None, + num_sms_preprocessing_api=108, ): ''' Forward pass of fused dispatch of the HybridEP backend @@ -399,6 +405,7 @@ def forward( num_blocks_permute, num_blocks_unpermute, fp8_dispatch, + num_sms_preprocessing_api, ) # If we provide the num_permuted_tokens, we do not need to use sync to # wait for the data in pinned memory ready @@ -459,6 +466,7 @@ def backward(ctx, grad_x, grad_probs, grad_scaling_factor, grad_tokens_per_exper None, None, None, + None, ) @@ -518,6 +526,7 @@ def hybrid_ep_dispatch( fused=False, num_permuted_tokens=None, pad_multiple=None, + num_sms_preprocessing_api=108, ): ''' Perform fused dispatch for "permute + dispatch a2a + permute" using the @@ -549,6 +558,8 @@ def hybrid_ep_dispatch( pad_multiple (int): Alignment multiple required for FP8 GEMM. If not provided, no padding is performed. + num_sms_preprocessing_api (int): + Number of SMs used by the preprocessing (metadata scan) kernel. ''' return HybridEPDispatch.apply( x, @@ -563,6 +574,7 @@ def hybrid_ep_dispatch( fused, num_permuted_tokens, pad_multiple, + num_sms_preprocessing_api, ) @internal_api diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index 717e285a249..11d1eda9260 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -1078,6 +1078,7 @@ def dispatch( num_permuted_tokens=self.num_permuted_tokens, pad_multiple=self.pad_multiple, fused=self.config.moe_permute_fusion_into_hybridep, + num_sms_preprocessing_api=self.config.moe_hybridep_num_sms_preprocessing, ) ) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index bb044787b9c..88f1d518cab 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -601,6 +601,10 @@ class TransformerConfig(ModelParallelConfig): """Python import path to a callable quantizer factory, e.g., package.module.quantizer_factory. Required when fp4_recipe is custom.""" + high_priority_a2a_comm_stream: bool = False + """If True, the communication stream created by set_streams for combined 1f1b + a2a overlap is created with CUDA high priority.""" + #################### # MoE related #################### @@ -821,6 +825,9 @@ class TransformerConfig(ModelParallelConfig): When permute_fusion_into_hybridep is True, this sets the number of SMs for the unpermute part (only 1 block per SM).""" + moe_hybridep_num_sms_preprocessing: int = 108 + """Number of SMs to use for HybridEP preprocessing (metadata scan kernel).""" + ################## # Context Parallel ################## diff --git a/tests/unit_tests/models/test_hybrid_moe_model.py b/tests/unit_tests/models/test_hybrid_moe_model.py index 3935964c975..eea933c7322 100644 --- a/tests/unit_tests/models/test_hybrid_moe_model.py +++ b/tests/unit_tests/models/test_hybrid_moe_model.py @@ -127,6 +127,7 @@ "hidden_dropout": 0.0, "hidden_size": 2688, "hierarchical_context_parallel_sizes": None, + "high_priority_a2a_comm_stream": False, "inference_fuse_tp_communication": False, "inference_rng_tracker": False, "inference_sampling_seed": 42, @@ -163,6 +164,7 @@ "moe_flex_dispatcher_backend": "deepep", "moe_grouped_gemm": True, "moe_hybridep_num_sms": None, + "moe_hybridep_num_sms_preprocessing": 108, "moe_hybridep_num_blocks_permute": None, "moe_hybridep_num_blocks_unpermute": None, "moe_input_jitter_eps": None,