perf(moe): triton biased grouped topk for deepseek-v3 routing#171
Open
roycho96 wants to merge 4 commits into
Open
perf(moe): triton biased grouped topk for deepseek-v3 routing#171roycho96 wants to merge 4 commits into
roycho96 wants to merge 4 commits into
Conversation
The reference biased_grouped_topk_gpu runs 8-10 launches per layer per decode step (sigmoid, bias, group-top-2-sum, group selection, per-token top-K, renorm, shared-expert weight). The existing minimax sibling kernel is trait-gated to num_expert_group=1 and falls through for DSv3 shapes. Add a sibling Triton kernel that folds the full routing pipeline into one program per token, plus the in-kernel shared-expert weight (saves a randint + sum + divide + multiply chain on the host). Dispatch is handled by the existing trait filter; no runtime changes. fp32 only; bf16 falls back to the reference (sigmoid precision diverges at the top-K boundary, verified bitwise-identical on the fallback path). Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
Parametrize over batch, shared-expert slot, renormalize, output scaling, padding, and gating dtype; assert routed id set-equality and per-id weight equality. Adds explicit num_tokens=0 and bf16 bit-exact fallback coverage. Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
LorrinWWW
approved these changes
May 17, 2026
borontion
reviewed
May 17, 2026
| static_logical_to_physical_map_ptr, | ||
| topk_weights_ptr, | ||
| topk_ids_ptr, | ||
| stride_gm, |
Contributor
There was a problem hiding this comment.
should stride be constexpr?
Contributor
Author
There was a problem hiding this comment.
output strides here are always (TOPK, 1) since we allocate the outputs inside the wrapper, so I just dropped them instead of constexpr. left the input strides.
topk_weights and topk_ids are allocated inside the wrapper as torch.empty((num_tokens, topk), ...), so their strides are always (TOPK, 1). Drop the four output stride args and index with * TOPK + k directly. Input strides stay runtime args. Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
DeepSeek-V3/V4 expert routing (
biased_grouped_topk_gpuwithnum_expert_group=8, topk_group=4, topk=8, num_fused_shared_experts ∈ {0,1}) runs as a@torch.compilereference: 8-10 launches per layer per decode step. The MiniMax sibling kernel attokenspeed-kernel/ops/moe/triton.py:227is trait-gated tonum_expert_group=1, topk_group=1and falls through for DSv3 shapes.This PR adds a sibling Triton kernel
_grouped_topk_biased_kernelplustriton_biased_grouped_topkwrapper registered on the DSv3 trait. The kernel folds sigmoid, bias, group-top-2-sum, group selection, per-token top-K, renorm, and the shared-expert weight into one program per token. Dispatch is handled by the existing trait filter; no runtime changes.Microbench (fp32 gating, 30 warmup + 400 timed iters, cudaEvent + cudaSync):
bf16 gating falls back to reference (bitwise-identical, verified by test).
Test Plan