diff --git a/megatron/core/transformer/moe/fused_a2a.py b/megatron/core/transformer/moe/fused_a2a.py index 586d70c400f..91a3bf0cffb 100644 --- a/megatron/core/transformer/moe/fused_a2a.py +++ b/megatron/core/transformer/moe/fused_a2a.py @@ -3,6 +3,7 @@ # Copyright (c) 2025 DeepSeek # Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE +import os from typing import Optional from megatron.core.utils import internal_api @@ -268,6 +269,7 @@ def set_deepep_num_sms(num_sms): try: + import hybrid_ep_cpp from deep_ep import HybridEPBuffer HAVE_HYBRIDEP = True @@ -275,12 +277,68 @@ def set_deepep_num_sms(num_sms): HAVE_HYBRIDEP = False _hybrid_ep_buffer = None +_HYBRID_EP_TOKEN_ALIGNMENT = 16 +_HYBRID_EP_MIN_BUFFER_TOKENS = 512 +_HYBRID_EP_IB_QP_MAX_DEPTH = 65535 +_HYBRID_EP_IB_DISPATCH_DEPTH_PER_TOKEN = 3 + + +def _round_up_to_multiple(value: int, multiple: int) -> int: + return ((value + multiple - 1) // multiple) * multiple + + +def _hybrid_ep_num_nodes(group: torch.distributed.ProcessGroup) -> int: + """Mirror HybridEP's NVLink-domain detection without constructing the full buffer.""" + ranks_per_nvlink_domain_env = os.getenv("NUM_OF_HYBRID_EP_RANKS_PER_NVLINK_DOMAIN") + if ranks_per_nvlink_domain_env is not None: + ranks_per_nvlink_domain = int(ranks_per_nvlink_domain_env) + else: + allocator = hybrid_ep_cpp.ExtendedMemoryAllocator() + ranks_per_nvlink_domain = allocator.detect_accessible_ranks(group) + + assert group.size() % ranks_per_nvlink_domain == 0, ( + f"The number of ranks {group.size()} should be divisible by the number of ranks per " + f"NVLink domain {ranks_per_nvlink_domain}." + ) + return group.size() // ranks_per_nvlink_domain + + +def _hybrid_ep_uses_internode_rdma(group: torch.distributed.ProcessGroup) -> bool: + if _hybrid_ep_buffer is not None and hasattr(_hybrid_ep_buffer, "num_of_nodes"): + return _hybrid_ep_buffer.num_of_nodes > 1 + return _hybrid_ep_num_nodes(group) > 1 + + +def _validate_hybrid_ep_ib_tx_depth(num_tokens: int, group: torch.distributed.ProcessGroup) -> None: + buffer_tokens = max( + _round_up_to_multiple(num_tokens, _HYBRID_EP_TOKEN_ALIGNMENT), _HYBRID_EP_MIN_BUFFER_TOKENS + ) + tx_depth = _HYBRID_EP_IB_DISPATCH_DEPTH_PER_TOKEN * buffer_tokens + 1 + if tx_depth <= _HYBRID_EP_IB_QP_MAX_DEPTH: + return + + if not _hybrid_ep_uses_internode_rdma(group): + return + + max_supported_tokens = ( + ((_HYBRID_EP_IB_QP_MAX_DEPTH - 1) // _HYBRID_EP_IB_DISPATCH_DEPTH_PER_TOKEN) + // _HYBRID_EP_TOKEN_ALIGNMENT + * _HYBRID_EP_TOKEN_ALIGNMENT + ) + raise ValueError( + f"HybridEP InfiniBand dispatch queue pair depth ({tx_depth}) exceeds the hardware " + f"limit of {_HYBRID_EP_IB_QP_MAX_DEPTH}. DeepEP computes this depth from the " + f"tokens per rank rounded up to a {_HYBRID_EP_TOKEN_ALIGNMENT}-token buffer " + f"alignment ({buffer_tokens}). Reduce sequence length or micro-batch size, or " + f"increase Tensor Parallelism (TP) / Context Parallelism (CP), so tokens per rank " + f"are at most {max_supported_tokens} for multi-node HybridEP." + ) def init_hybrid_ep_buffer( group: torch.distributed.ProcessGroup, hidden_dim: int, - seq_len: int, + num_tokens: int, num_local_experts: int, num_sms_dispatch_api: Optional[int] = None, num_sms_combine_api: Optional[int] = None, @@ -301,8 +359,8 @@ def init_hybrid_ep_buffer( Process group for HybridEP all-to-all communication. hidden_dim (int): Hidden dimension of the input tensor. - seq_len (int): - Maximum sequence length of the input tensor. + num_tokens (int): + Maximum token count of the input tensor. num_local_experts (int): Number of local experts. num_sms_dispatch_api (Optional[int]): @@ -330,7 +388,7 @@ def init_hybrid_ep_buffer( _hybrid_ep_buffer = HybridEPBuffer( group=group, hidden_dim=hidden_dim, - max_num_of_tokens_per_rank=seq_len, + max_num_of_tokens_per_rank=num_tokens, num_local_experts=num_local_experts, use_fp8=fp8_dispatch, **kwargs, @@ -386,13 +444,14 @@ def forward( num_blocks_permute = None num_blocks_unpermute = None + num_tokens, hidden_dim = x.shape[-2:] + _validate_hybrid_ep_ib_tx_depth(num_tokens, group) if _hybrid_ep_buffer is None: - seq_len, hidden_dim = x.shape[-2:] fp8_dispatch = False # Currently, we do not support fp8 dispatch init_hybrid_ep_buffer( group, hidden_dim, - seq_len, + num_tokens, num_local_experts, num_sms_dispatch_api, num_sms_combine_api,