Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 65 additions & 6 deletions megatron/core/transformer/moe/fused_a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Copyright (c) 2025 DeepSeek
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE

import os
from typing import Optional

from megatron.core.utils import internal_api
Expand Down Expand Up @@ -268,19 +269,76 @@ def set_deepep_num_sms(num_sms):


try:
import hybrid_ep_cpp
from deep_ep import HybridEPBuffer

HAVE_HYBRIDEP = True
except ImportError:
HAVE_HYBRIDEP = False

_hybrid_ep_buffer = None
_HYBRID_EP_TOKEN_ALIGNMENT = 16
_HYBRID_EP_MIN_BUFFER_TOKENS = 512
_HYBRID_EP_IB_QP_MAX_DEPTH = 65535
_HYBRID_EP_IB_DISPATCH_DEPTH_PER_TOKEN = 3


def _round_up_to_multiple(value: int, multiple: int) -> int:
return ((value + multiple - 1) // multiple) * multiple


def _hybrid_ep_num_nodes(group: torch.distributed.ProcessGroup) -> int:
"""Mirror HybridEP's NVLink-domain detection without constructing the full buffer."""
ranks_per_nvlink_domain_env = os.getenv("NUM_OF_HYBRID_EP_RANKS_PER_NVLINK_DOMAIN")
if ranks_per_nvlink_domain_env is not None:
ranks_per_nvlink_domain = int(ranks_per_nvlink_domain_env)
else:
allocator = hybrid_ep_cpp.ExtendedMemoryAllocator()
ranks_per_nvlink_domain = allocator.detect_accessible_ranks(group)

assert group.size() % ranks_per_nvlink_domain == 0, (
f"The number of ranks {group.size()} should be divisible by the number of ranks per "
f"NVLink domain {ranks_per_nvlink_domain}."
)
return group.size() // ranks_per_nvlink_domain


def _hybrid_ep_uses_internode_rdma(group: torch.distributed.ProcessGroup) -> bool:
if _hybrid_ep_buffer is not None and hasattr(_hybrid_ep_buffer, "num_of_nodes"):
return _hybrid_ep_buffer.num_of_nodes > 1
return _hybrid_ep_num_nodes(group) > 1


def _validate_hybrid_ep_ib_tx_depth(num_tokens: int, group: torch.distributed.ProcessGroup) -> None:
buffer_tokens = max(
_round_up_to_multiple(num_tokens, _HYBRID_EP_TOKEN_ALIGNMENT), _HYBRID_EP_MIN_BUFFER_TOKENS
)
tx_depth = _HYBRID_EP_IB_DISPATCH_DEPTH_PER_TOKEN * buffer_tokens + 1
if tx_depth <= _HYBRID_EP_IB_QP_MAX_DEPTH:
return

if not _hybrid_ep_uses_internode_rdma(group):
return

max_supported_tokens = (
((_HYBRID_EP_IB_QP_MAX_DEPTH - 1) // _HYBRID_EP_IB_DISPATCH_DEPTH_PER_TOKEN)
// _HYBRID_EP_TOKEN_ALIGNMENT
* _HYBRID_EP_TOKEN_ALIGNMENT
)
raise ValueError(
f"HybridEP InfiniBand dispatch queue pair depth ({tx_depth}) exceeds the hardware "
f"limit of {_HYBRID_EP_IB_QP_MAX_DEPTH}. DeepEP computes this depth from the "
f"tokens per rank rounded up to a {_HYBRID_EP_TOKEN_ALIGNMENT}-token buffer "
f"alignment ({buffer_tokens}). Reduce sequence length or micro-batch size, or "
f"increase Tensor Parallelism (TP) / Context Parallelism (CP), so tokens per rank "
f"are at most {max_supported_tokens} for multi-node HybridEP."
)


def init_hybrid_ep_buffer(
group: torch.distributed.ProcessGroup,
hidden_dim: int,
seq_len: int,
num_tokens: int,
num_local_experts: int,
num_sms_dispatch_api: Optional[int] = None,
num_sms_combine_api: Optional[int] = None,
Expand All @@ -301,8 +359,8 @@ def init_hybrid_ep_buffer(
Process group for HybridEP all-to-all communication.
hidden_dim (int):
Hidden dimension of the input tensor.
seq_len (int):
Maximum sequence length of the input tensor.
num_tokens (int):
Maximum token count of the input tensor.
num_local_experts (int):
Number of local experts.
num_sms_dispatch_api (Optional[int]):
Expand Down Expand Up @@ -330,7 +388,7 @@ def init_hybrid_ep_buffer(
_hybrid_ep_buffer = HybridEPBuffer(
group=group,
hidden_dim=hidden_dim,
max_num_of_tokens_per_rank=seq_len,
max_num_of_tokens_per_rank=num_tokens,
num_local_experts=num_local_experts,
use_fp8=fp8_dispatch,
**kwargs,
Expand Down Expand Up @@ -386,13 +444,14 @@ def forward(
num_blocks_permute = None
num_blocks_unpermute = None

num_tokens, hidden_dim = x.shape[-2:]
_validate_hybrid_ep_ib_tx_depth(num_tokens, group)
if _hybrid_ep_buffer is None:
seq_len, hidden_dim = x.shape[-2:]
fp8_dispatch = False # Currently, we do not support fp8 dispatch
init_hybrid_ep_buffer(
group,
hidden_dim,
seq_len,
num_tokens,
num_local_experts,
num_sms_dispatch_api,
num_sms_combine_api,
Expand Down
Loading