Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions megatron/core/pipeline_parallel/combined_1f1b.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def combined_1f1b_schedule_for_no_pipelining(
Phases 4: 4th microbatch backward
"""

set_streams()
set_streams(high_priority=config.high_priority_a2a_comm_stream)
# The forward step for the first microbatch is executed alone, no a2a overlapping
output_tensor, num_tokens, _ = combined_forward_backward_step(
forward_step_func,
Expand Down Expand Up @@ -178,7 +178,7 @@ def combined_1f1b_schedule_for_interleaved_pipelining():
# backward_step_helper_postprocess()
"""

set_streams()
set_streams(high_priority=config.high_priority_a2a_comm_stream)
# forward prepare
f_model_chunk_id = None
f_microbatch_id = None
Expand Down
8 changes: 6 additions & 2 deletions megatron/core/pipeline_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,14 +341,18 @@ def run(
_COMM_STREAM = None


def set_streams(comm_stream=None):
def set_streams(comm_stream=None, high_priority=False):
"""Set the stream for communication operations."""
global _COMM_STREAM

# Set communication stream
if _COMM_STREAM is None:
if comm_stream is None:
comm_stream = torch.cuda.Stream(device="cuda")
if high_priority:
_, high = torch.cuda.Stream.priority_range()
comm_stream = torch.cuda.Stream(device="cuda", priority=high)
else:
comm_stream = torch.cuda.Stream(device="cuda")
_COMM_STREAM = comm_stream


Expand Down
12 changes: 12 additions & 0 deletions megatron/core/transformer/moe/fused_a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def init_hybrid_ep_buffer(
num_blocks_permute: Optional[int] = None,
num_blocks_unpermute: Optional[int] = None,
fp8_dispatch: bool = False,
num_sms_preprocessing_api: Optional[int] = None,
) -> None:
'''
Initialize the HybridEP buffer, including buffer allocation and metadata
Expand Down Expand Up @@ -373,6 +374,8 @@ def init_hybrid_ep_buffer(
Number of blocks used by the unpermute part.
fp8_dispatch (bool):
Whether to use FP8 communication during the dispatch phase.
num_sms_preprocessing_api (Optional[int]):
Number of SMs used by the preprocessing (metadata scan) kernel.
'''
assert not fp8_dispatch, "HybridEP dispatcher does not support fp8 dispatch now"
global _hybrid_ep_buffer
Expand All @@ -385,6 +388,8 @@ def init_hybrid_ep_buffer(
kwargs['num_blocks_permute'] = num_blocks_permute
if num_blocks_unpermute is not None:
kwargs['num_blocks_unpermute'] = num_blocks_unpermute
if num_sms_preprocessing_api is not None:
kwargs['num_sms_preprocessing_api'] = num_sms_preprocessing_api
_hybrid_ep_buffer = HybridEPBuffer(
group=group,
hidden_dim=hidden_dim,
Expand Down Expand Up @@ -423,6 +428,7 @@ def forward(
fused=False,
num_permuted_tokens=None,
pad_multiple=None,
num_sms_preprocessing_api=108,
):
'''
Forward pass of fused dispatch of the HybridEP backend
Expand Down Expand Up @@ -458,6 +464,7 @@ def forward(
num_blocks_permute,
num_blocks_unpermute,
fp8_dispatch,
num_sms_preprocessing_api,
)
# If we provide the num_permuted_tokens, we do not need to use sync to
# wait for the data in pinned memory ready
Expand Down Expand Up @@ -518,6 +525,7 @@ def backward(ctx, grad_x, grad_probs, grad_scaling_factor, grad_tokens_per_exper
None,
None,
None,
None,
)


Expand Down Expand Up @@ -577,6 +585,7 @@ def hybrid_ep_dispatch(
fused=False,
num_permuted_tokens=None,
pad_multiple=None,
num_sms_preprocessing_api=108,
):
'''
Perform fused dispatch for "permute + dispatch a2a + permute" using the
Expand Down Expand Up @@ -608,6 +617,8 @@ def hybrid_ep_dispatch(
pad_multiple (int):
Alignment multiple required for FP8 GEMM. If not provided, no padding
is performed.
num_sms_preprocessing_api (int):
Number of SMs used by the preprocessing (metadata scan) kernel.
'''
return HybridEPDispatch.apply(
x,
Expand All @@ -622,6 +633,7 @@ def hybrid_ep_dispatch(
fused,
num_permuted_tokens,
pad_multiple,
num_sms_preprocessing_api,
)

@internal_api
Expand Down
1 change: 1 addition & 0 deletions megatron/core/transformer/moe/token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,7 @@ def dispatch(
num_permuted_tokens=self.num_permuted_tokens,
pad_multiple=self.pad_multiple,
fused=self.config.moe_permute_fusion_into_hybridep,
num_sms_preprocessing_api=self.config.moe_hybridep_num_sms_preprocessing,
)
)

Expand Down
7 changes: 7 additions & 0 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,10 @@ class TransformerConfig(ModelParallelConfig):
"""Python import path to a callable quantizer factory, e.g., package.module.quantizer_factory.
Required when fp4_recipe is custom."""

high_priority_a2a_comm_stream: bool = False
"""If True, the communication stream created by set_streams for combined 1f1b
a2a overlap is created with CUDA high priority."""

####################
# MoE related
####################
Expand Down Expand Up @@ -821,6 +825,9 @@ class TransformerConfig(ModelParallelConfig):
When permute_fusion_into_hybridep is True, this sets the number
of SMs for the unpermute part (only 1 block per SM)."""

moe_hybridep_num_sms_preprocessing: int = 108
"""Number of SMs to use for HybridEP preprocessing (metadata scan kernel)."""

##################
# Context Parallel
##################
Expand Down
2 changes: 2 additions & 0 deletions tests/unit_tests/models/test_hybrid_moe_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
"hidden_dropout": 0.0,
"hidden_size": 2688,
"hierarchical_context_parallel_sizes": None,
"high_priority_a2a_comm_stream": False,
"inference_fuse_tp_communication": False,
"inference_rng_tracker": False,
"inference_sampling_seed": 42,
Expand Down Expand Up @@ -163,6 +164,7 @@
"moe_flex_dispatcher_backend": "deepep",
"moe_grouped_gemm": True,
"moe_hybridep_num_sms": None,
"moe_hybridep_num_sms_preprocessing": 108,
"moe_hybridep_num_blocks_permute": None,
"moe_hybridep_num_blocks_unpermute": None,
"moe_input_jitter_eps": None,
Expand Down