Skip to content

perf(moe): triton biased grouped topk for deepseek-v3 routing#171

Open
roycho96 wants to merge 4 commits into
lightseekorg:mainfrom
roycho96:perf/dsv3-grouped-topk-triton
Open

perf(moe): triton biased grouped topk for deepseek-v3 routing#171
roycho96 wants to merge 4 commits into
lightseekorg:mainfrom
roycho96:perf/dsv3-grouped-topk-triton

Conversation

@roycho96
Copy link
Copy Markdown
Contributor

Summary

DeepSeek-V3/V4 expert routing (biased_grouped_topk_gpu with num_expert_group=8, topk_group=4, topk=8, num_fused_shared_experts ∈ {0,1}) runs as a @torch.compile reference: 8-10 launches per layer per decode step. The MiniMax sibling kernel at tokenspeed-kernel/ops/moe/triton.py:227 is trait-gated to num_expert_group=1, topk_group=1 and falls through for DSv3 shapes.

This PR adds a sibling Triton kernel _grouped_topk_biased_kernel plus triton_biased_grouped_topk wrapper registered on the DSv3 trait. The kernel folds sigmoid, bias, group-top-2-sum, group selection, per-token top-K, renorm, and the shared-expert weight into one program per token. Dispatch is handled by the existing trait filter; no runtime changes.

Microbench (fp32 gating, 30 warmup + 400 timed iters, cudaEvent + cudaSync):

shared batch ref µs trt µs speedup
0 1 100.51 18.41 5.46x
0 32 103.53 17.94 5.77x
0 256 100.52 17.76 5.66x
1 1 130.49 17.74 7.36x
1 32 129.69 18.12 7.16x
1 256 129.54 17.74 7.30x

bf16 gating falls back to reference (bitwise-identical, verified by test).

Test Plan

pytest tokenspeed-kernel/test/ops/test_moe_routing.py

roycho96 added 2 commits May 17, 2026 19:44
The reference biased_grouped_topk_gpu runs 8-10 launches per layer per
decode step (sigmoid, bias, group-top-2-sum, group selection, per-token
top-K, renorm, shared-expert weight). The existing minimax sibling kernel
is trait-gated to num_expert_group=1 and falls through for DSv3 shapes.

Add a sibling Triton kernel that folds the full routing pipeline into one
program per token, plus the in-kernel shared-expert weight (saves a
randint + sum + divide + multiply chain on the host).

Dispatch is handled by the existing trait filter; no runtime changes.
fp32 only; bf16 falls back to the reference (sigmoid precision diverges
at the top-K boundary, verified bitwise-identical on the fallback path).

Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
Parametrize over batch, shared-expert slot, renormalize, output scaling,
padding, and gating dtype; assert routed id set-equality and per-id
weight equality. Adds explicit num_tokens=0 and bf16 bit-exact fallback
coverage.

Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
@roycho96 roycho96 requested a review from a team as a code owner May 17, 2026 11:24
@lightseek-bot lightseek-bot requested a review from borontion May 17, 2026 16:37
static_logical_to_physical_map_ptr,
topk_weights_ptr,
topk_ids_ptr,
stride_gm,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should stride be constexpr?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output strides here are always (TOPK, 1) since we allocate the outputs inside the wrapper, so I just dropped them instead of constexpr. left the input strides.

roycho96 added 2 commits May 18, 2026 11:53
topk_weights and topk_ids are allocated inside the wrapper as
torch.empty((num_tokens, topk), ...), so their strides are always
(TOPK, 1). Drop the four output stride args and index with
* TOPK + k directly. Input strides stay runtime args.

Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
@lightseek-bot lightseek-bot requested a review from borontion May 18, 2026 20:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants