Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions megatron/core/transformer/moe/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,13 +586,15 @@ def group_limited_topk(
num_experts: int,
num_groups: int,
group_topk: int,
group_scoring_topk: int | None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Perform top-k routing on a subset of expert groups.

When using group-limited routing:
1. Experts are divided into 'moe_router_num_groups' equal-sized groups
2. For each token, 'moe_router_group_topk' groups are selected based on routing scores
(specifically, the sum of top-2 expert scores within each group)
(specifically, the sum of top-`group_scoring_topk` expert scores within each group
(`topk // group_topk` if `group_scoring_topk`‌ is None))
3. From these selected groups, 'moe_router_topk' individual experts are chosen

Two common use cases:
Expand All @@ -616,9 +618,10 @@ def group_limited_topk(
Tuple[torch.Tensor, torch.Tensor]: Probs and indices tensor.
"""
# Organize the experts into groups
# Select groups based on sum of top-(topk/group_topk) routing scores within each group
# Select groups based on sum of top-k routing scores within each group
group_scoring_k = group_scoring_topk if group_scoring_topk is not None else topk // group_topk
group_scores = (
scores.view(num_tokens, num_groups, -1).topk(topk // group_topk, dim=-1)[0].sum(dim=-1)
scores.view(num_tokens, num_groups, -1).topk(group_scoring_k, dim=-1)[0].sum(dim=-1)
)
group_idx = torch.topk(group_scores, k=group_topk, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
Expand Down Expand Up @@ -678,6 +681,7 @@ def topk_routing_with_score_function(
use_pre_softmax: bool = False,
num_groups: Optional[int] = None,
group_topk: Optional[int] = None,
group_scoring_topk: int | None = None,
scaling_factor: Optional[float] = None,
score_function: str = "softmax",
expert_bias: Optional[torch.Tensor] = None,
Expand All @@ -694,6 +698,8 @@ def topk_routing_with_score_function(
selection. Defaults to False.
num_groups (int, optional): Number of groups for routed experts. Defaults to None.
group_topk (int, optional): Number of selected groups for each token. Defaults to None.
group_scoring_topk (int, optional): Number of top expert scores per group used to rank
groups. Defaults to None, meaning `topk // group_topk`.
scaling_factor (float, optional): Scaling factor of routing score in top-k selection.
Defaults to None.
score_function (str, optional): The score function to use. Can be "softmax", "sigmoid"
Expand Down Expand Up @@ -774,6 +780,7 @@ def _compute_topk(
num_experts=num_experts,
num_groups=num_groups,
group_topk=group_topk,
group_scoring_topk=group_scoring_topk,
)
else:
# Sorting top-k turned off during inference
Expand Down
4 changes: 4 additions & 0 deletions megatron/core/transformer/moe/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,7 @@ def routing(self, logits: torch.Tensor, padding_mask: Optional[torch.Tensor] = N
use_pre_softmax=self.config.moe_router_pre_softmax,
num_groups=self.config.moe_router_num_groups,
group_topk=self.config.moe_router_group_topk,
group_scoring_topk=self.config.moe_router_group_scoring_topk,
scaling_factor=self.config.moe_router_topk_scaling_factor,
score_function=self.score_function,
expert_bias=self.expert_bias,
Expand Down Expand Up @@ -763,6 +764,7 @@ def _compiled_topk_routing(
use_pre_softmax,
num_groups,
group_topk,
group_scoring_topk,
scaling_factor,
score_function,
expert_bias,
Expand All @@ -776,6 +778,7 @@ def _compiled_topk_routing(
use_pre_softmax=use_pre_softmax,
num_groups=num_groups,
group_topk=group_topk,
group_scoring_topk=group_scoring_topk,
scaling_factor=scaling_factor,
score_function=score_function,
expert_bias=expert_bias,
Expand All @@ -793,6 +796,7 @@ def _forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = N
use_pre_softmax=self.config.moe_router_pre_softmax,
num_groups=self.config.moe_router_num_groups,
group_topk=self.config.moe_router_group_topk,
group_scoring_topk=self.config.moe_router_group_scoring_topk,
scaling_factor=self.config.moe_router_topk_scaling_factor,
score_function=self.score_function,
expert_bias=self.expert_bias,
Expand Down
5 changes: 5 additions & 0 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,11 @@ class TransformerConfig(ModelParallelConfig):
moe_router_group_topk: Optional[int] = None
"""Number of selected groups for group-limited routing."""

moe_router_group_scoring_topk: int | None = None
"""Number of top expert scores per group used to rank groups in group-limited routing.
If None, defaults to `moe_router_topk // moe_router_group_topk`.
"""

moe_router_pre_softmax: bool = False
"""Enable pre-softmax(pre-sigmoid) routing for MoE, which means softmax is before the
top-k selection.
Expand Down
3 changes: 3 additions & 0 deletions megatron/elastification/flextron_elasticity_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,7 @@ def topk_softmax_with_capacity(
use_pre_softmax: bool = False,
num_groups: Optional[int] = None,
group_topk: Optional[int] = None,
group_scoring_topk: int | None = None,
scaling_factor: Optional[float] = None,
deterministic_mode: bool = False,
score_function: str = "softmax",
Expand Down Expand Up @@ -702,6 +703,7 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None):
num_experts=num_experts,
num_groups=num_groups,
group_topk=group_topk,
group_scoring_topk=group_scoring_topk,
)
else:
return torch.topk(scores, k=topk, dim=1)
Expand Down Expand Up @@ -835,6 +837,7 @@ def wrapped_routing(logits, **kwargs):
use_pre_softmax=self.config.moe_router_pre_softmax,
num_groups=self.config.moe_router_num_groups,
group_topk=self.config.moe_router_group_topk,
group_scoring_topk=self.config.moe_router_group_scoring_topk,
scaling_factor=self.config.moe_router_topk_scaling_factor,
deterministic_mode=self.config.deterministic_mode,
score_function=self.config.moe_router_score_function,
Expand Down
Loading