diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index f258f3474ae..e134cca1d6a 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -586,13 +586,15 @@ def group_limited_topk( num_experts: int, num_groups: int, group_topk: int, + group_scoring_topk: int | None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Perform top-k routing on a subset of expert groups. When using group-limited routing: 1. Experts are divided into 'moe_router_num_groups' equal-sized groups 2. For each token, 'moe_router_group_topk' groups are selected based on routing scores - (specifically, the sum of top-2 expert scores within each group) + (specifically, the sum of top-`group_scoring_topk` expert scores within each group + (`topk // group_topk` if `group_scoring_topk`‌ is None)) 3. From these selected groups, 'moe_router_topk' individual experts are chosen Two common use cases: @@ -616,9 +618,10 @@ def group_limited_topk( Tuple[torch.Tensor, torch.Tensor]: Probs and indices tensor. """ # Organize the experts into groups - # Select groups based on sum of top-(topk/group_topk) routing scores within each group + # Select groups based on sum of top-k routing scores within each group + group_scoring_k = group_scoring_topk if group_scoring_topk is not None else topk // group_topk group_scores = ( - scores.view(num_tokens, num_groups, -1).topk(topk // group_topk, dim=-1)[0].sum(dim=-1) + scores.view(num_tokens, num_groups, -1).topk(group_scoring_k, dim=-1)[0].sum(dim=-1) ) group_idx = torch.topk(group_scores, k=group_topk, dim=-1, sorted=False)[1] group_mask = torch.zeros_like(group_scores) @@ -678,6 +681,7 @@ def topk_routing_with_score_function( use_pre_softmax: bool = False, num_groups: Optional[int] = None, group_topk: Optional[int] = None, + group_scoring_topk: int | None = None, scaling_factor: Optional[float] = None, score_function: str = "softmax", expert_bias: Optional[torch.Tensor] = None, @@ -694,6 +698,8 @@ def topk_routing_with_score_function( selection. Defaults to False. num_groups (int, optional): Number of groups for routed experts. Defaults to None. group_topk (int, optional): Number of selected groups for each token. Defaults to None. + group_scoring_topk (int, optional): Number of top expert scores per group used to rank + groups. Defaults to None, meaning `topk // group_topk`. scaling_factor (float, optional): Scaling factor of routing score in top-k selection. Defaults to None. score_function (str, optional): The score function to use. Can be "softmax", "sigmoid" @@ -774,6 +780,7 @@ def _compute_topk( num_experts=num_experts, num_groups=num_groups, group_topk=group_topk, + group_scoring_topk=group_scoring_topk, ) else: # Sorting top-k turned off during inference diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index 3e1fb01eb43..5484e612036 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -618,6 +618,7 @@ def routing(self, logits: torch.Tensor, padding_mask: Optional[torch.Tensor] = N use_pre_softmax=self.config.moe_router_pre_softmax, num_groups=self.config.moe_router_num_groups, group_topk=self.config.moe_router_group_topk, + group_scoring_topk=self.config.moe_router_group_scoring_topk, scaling_factor=self.config.moe_router_topk_scaling_factor, score_function=self.score_function, expert_bias=self.expert_bias, @@ -763,6 +764,7 @@ def _compiled_topk_routing( use_pre_softmax, num_groups, group_topk, + group_scoring_topk, scaling_factor, score_function, expert_bias, @@ -776,6 +778,7 @@ def _compiled_topk_routing( use_pre_softmax=use_pre_softmax, num_groups=num_groups, group_topk=group_topk, + group_scoring_topk=group_scoring_topk, scaling_factor=scaling_factor, score_function=score_function, expert_bias=expert_bias, @@ -793,6 +796,7 @@ def _forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = N use_pre_softmax=self.config.moe_router_pre_softmax, num_groups=self.config.moe_router_num_groups, group_topk=self.config.moe_router_group_topk, + group_scoring_topk=self.config.moe_router_group_scoring_topk, scaling_factor=self.config.moe_router_topk_scaling_factor, score_function=self.score_function, expert_bias=self.expert_bias, diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index bb044787b9c..c85e53918fc 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -687,6 +687,11 @@ class TransformerConfig(ModelParallelConfig): moe_router_group_topk: Optional[int] = None """Number of selected groups for group-limited routing.""" + moe_router_group_scoring_topk: int | None = None + """Number of top expert scores per group used to rank groups in group-limited routing. + If None, defaults to `moe_router_topk // moe_router_group_topk`. + """ + moe_router_pre_softmax: bool = False """Enable pre-softmax(pre-sigmoid) routing for MoE, which means softmax is before the top-k selection. diff --git a/megatron/elastification/flextron_elasticity_hooks.py b/megatron/elastification/flextron_elasticity_hooks.py index da815043865..ebedc6ede5e 100644 --- a/megatron/elastification/flextron_elasticity_hooks.py +++ b/megatron/elastification/flextron_elasticity_hooks.py @@ -652,6 +652,7 @@ def topk_softmax_with_capacity( use_pre_softmax: bool = False, num_groups: Optional[int] = None, group_topk: Optional[int] = None, + group_scoring_topk: int | None = None, scaling_factor: Optional[float] = None, deterministic_mode: bool = False, score_function: str = "softmax", @@ -702,6 +703,7 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): num_experts=num_experts, num_groups=num_groups, group_topk=group_topk, + group_scoring_topk=group_scoring_topk, ) else: return torch.topk(scores, k=topk, dim=1) @@ -835,6 +837,7 @@ def wrapped_routing(logits, **kwargs): use_pre_softmax=self.config.moe_router_pre_softmax, num_groups=self.config.moe_router_num_groups, group_topk=self.config.moe_router_group_topk, + group_scoring_topk=self.config.moe_router_group_scoring_topk, scaling_factor=self.config.moe_router_topk_scaling_factor, deterministic_mode=self.config.deterministic_mode, score_function=self.config.moe_router_score_function,