Skip to content

cuda: F16 share-warp kernel for n_tok=2 combined-forward verifier (stacked on #8)#9

Closed
TrevorS wants to merge 1 commit into
mtp-beats-plain-kernels-v7from
mtp-beats-plain-kernels-v8
Closed

cuda: F16 share-warp kernel for n_tok=2 combined-forward verifier (stacked on #8)#9
TrevorS wants to merge 1 commit into
mtp-beats-plain-kernels-v7from
mtp-beats-plain-kernels-v8

Conversation

@TrevorS
Copy link
Copy Markdown
Owner

@TrevorS TrevorS commented May 24, 2026

PR9: cuda: F16 share-warp kernel for n_tok=2 combined-forward verifier (stacked on #8)

Summary

MTP now beats plain decode by +1.0 t/s with --mtp-draft 2, the first measurable win on this PR stack.

Adds matmul_f16_share_warp_kernel<N_TOK>, dispatched at n_tok=2 in ds4_gpu_matmul_f16_tensor before the cuBLAS GemmEx path. Each warp reads one F16 weight row once into 32 lanes and computes N_TOK separate F32 dot-products against N_TOK token activation vectors, serial-reduced by lane 0.

This is the path combined-forward MTP K=1 (--mtp-draft 2) uses for every layer's Q/K/V/O/router/expert F16 projection in its batched verifier forward. cuBLAS GemmEx at M=2 wastes 14/16 of the inner-product work on M-tile padding AND requires an F32→F16 activation conversion pass. The share-warp kernel skips both.

Bench headline

Mode PR7/8 PR9 Δ vs plain
Plain decode 16.11-16.13 16.11-16.12
--mtp default (canonical, no combined-forward) 16.11-16.13 16.11-16.13 parity
--mtp --mtp-draft 2 (combined-forward K=1) 16.29-16.30 17.14-17.15 +1.02 t/s above plain

The stack's name "mtp-beats-plain" is finally true at the default-precision path.

Why this works

Combined-forward MTP K=1 issues a batched-N=2 forward ([first_token, drafts[0]]) through the verifier graph. At each layer that pass triggers many F16 matmul calls with n_tok=2:

dispatcher trace, --mtp-draft 2, 8-token generation:
   1184 F16 share-warp hits at n_tok=2  (all decode-phase combined-forward)
      0 hits during plain decode        (n_tok==2 never reached without combined-forward)

cuBLAS GemmEx at M=2 must pad to the F16 tensor-core MMA M-tile (16): each tensor-core invocation computes 16 useful M-rows but the dispatcher only needs 2, so 14/16 = 87.5% of M-axis arithmetic is wasted. cuBLAS also requires F16×F16 inputs, so the dispatcher allocates n_tok * in_dim * 2 bytes of scratch and runs f32_to_f16_kernel before the GEMM.

The share-warp kernel eliminates both:

  • M-axis: each warp computes exactly the requested 2 tokens, no padding waste
  • Activation conversion: reads F32 activations directly, paying only the on-the-fly __half2float of the weight bytes it'd touch anyway

Bit-equality with N=1

For every output (row, token), this kernel produces byte-identical output to running N separate N=1 matmul_f16_ordered_chunks_kernel calls — the kernel that fires for N=1 plain decode under the default ordered_router configuration.

The argument:

Property N=1 ordered_chunks_kernel PR9 share_warp_kernel<N_TOK>
Lane access pattern [lane*chunk .. lane*chunk+chunk-1] identical
Per-lane accumulation sum += __half2float(wr[i]) * xr[i] (fma-contracted to one fma(wval, xval, sum)) per-t: acc[t] += wval * xval (same fma)
Reduction order serial loop in lane 0: total += partial[i] for i=0..31 identical, per-t partials

So for token t, this kernel's output for row row equals N=1 ordered_chunks_kernel(row, t) byte-for-byte.

Why n_tok==2 only (not 3..4)

Restricted to exact n_tok==2 so the dispatch never fires for prefill MoE experts receiving 3-4 routed tokens. Those would route through cuBLAS today; switching them to the share-warp would change plain-decode greedy output. The F32 scalar accumulation in share-warp is more precise than cuBLAS F16 tensor-core accumulation, but the resulting ulp-scale shift can change argmax on near-tied tokens — a user-visible side-effect we don't want to introduce for plain decode.

Combined-forward K=1 is exactly n_tok=2. K=2 (combined-forward N=3) would be n_tok=3, but combined-K=2 is not the current default (drafts[1] staleness; see ds4_session_eval_speculative_argmax_combined comments). The n_tok=3..4 case is reserved for a future PR with a combined-forward-context gate that lets us safely fire the share-warp under that label only.

What this PR is and isn't

  • Is: a real +1.0 t/s perf win for MTP combined-forward
  • Is: bit-equal to N=1 plain decode by construction (no drift for the shape that fires)
  • Is: zero impact on plain decode (dispatch never reached without combined-forward)
  • Is: zero impact on default --mtp (canonical decode2 path; needs --mtp-draft 2 to activate combined-forward)
  • Is NOT: a change to the default mtp_draft_tokens (separately scoped — that's a CLI-level choice with its own correctness story around accept-rate)
  • Is NOT: a fix for the K=2 cascading combined-forward path (drafts[1] staleness is a separate known issue)

Tested

  • make clean && make cuda-spark — clean
  • make cpu — clean
  • ./ds4_test --all — only the pre-existing --logprob-vectors short_code_completion failure (same as upstream/main, PR5/6/7/8). metal-tensor-equivalence: OK. metal-kernels: OK.
  • Plain decode output byte-identical to PR7 (4 filtered diff lines, all timing/cache markers)
  • Dispatcher trace, plain decode: 0 share-warp hits
  • Dispatcher trace, --mtp --mtp-draft 2: 1184 hits during 8-token gen
  • Bench stability across 3 runs of --mtp-draft 2: 17.14 / 17.15 / 17.14 t/s
  • DS4_CUDA_NO_F16_SHARE_WARP=1 opt-out restores cuBLAS path

Hardware: NVIDIA DGX Spark (GB10 / sm_121), driver 580.142, CUDA 13.0
Model: DeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix.gguf
MTP: DeepSeek-V4-Flash-MTP-Q4K-Q8_0-F32.gguf

AGENT.md compliance

  • "Preserve correctness before speed" — bit-equal to N=1 ordered_chunks by construction for the shape that fires. Plain decode unaffected.
  • "Do not add permanent semantic variants behind flags" — no new env knob to opt IN (the share-warp fires by default at its shape). DS4_CUDA_NO_F16_SHARE_WARP=1 is the diagnostic opt-out only.
  • "Diagnostic switches are fine when they validate the one release path"DS4_CUDA_NO_F16_SHARE_WARP is the validation knob.

Out of scope / follow-ups

  • F16 share-warp at n_tok=3..4 (needs combined-forward-context gate)
  • Default mtp_draft_tokens change (CLI-level, separate concern)
  • K=2 cascading combined-forward staleness fix
  • Captured-graph spec decode subsystem

Adds matmul_f16_share_warp_kernel<N_TOK>, dispatched at n_tok=2 in
ds4_gpu_matmul_f16_tensor before the cuBLAS GemmEx path.  Each warp
reads one F16 weight row once into 32 lanes and computes N_TOK
separate F32 dot-products against N_TOK token activation vectors,
serial-reduced by lane 0.

This is the path combined-forward MTP K=1 (--mtp-draft 2) uses for
every layer's Q/K/V/O/router/expert F16 projection in its batched
verifier forward.  cuBLAS GemmEx at M=2 pads to the F16 tensor-core
MMA M-tile (16) and wastes 14/16 of the inner-product work, AND
requires an F32->F16 activation conversion into a scratch buffer
before the GEMM.  The share-warp kernel skips both: it reads F32
activations directly and pays only the on-the-fly __half2float of
the weight bytes it touches anyway.

Bit-equality with N=1 matmul_f16_ordered_chunks_kernel (the kernel
that fires for N=1 plain decode under the default ordered-router
configuration) is preserved exactly: same contiguous-chunk lane
distribution `[lane*chunk .. lane*chunk+chunk-1]`, same serial
reduce by lane 0.  For every output (row, token) the kernel produces
byte-identical output to running N separate N=1 ordered_chunks
matmuls -- which lets it fire safely regardless of caller (combined-
forward verifier, OR any prefill-time MoE expert that receives
exactly 2 routed tokens) without changing plain-decode greedy output.

Restricted to n_tok==2 (not 3..4) so the dispatch never fires for
prefill MoE experts receiving 3-4 tokens, which would route through
cuBLAS today and would change plain-decode greedy output if rerouted
through the share-warp (the F32 scalar accum vs cuBLAS F16 tensor-
core accum differs at ulp scale -- a precision improvement, but a
user-visible argmax shift on near-tied tokens).  Combined-forward K=1
is exactly n_tok=2; K=2 (combined-forward N=3) would be n_tok=3, but
combined-K=2 is not the current default (drafts[1] staleness; see
ds4_session_eval_speculative_argmax_combined comments) and would
need a combined-forward-context gate before the share-warp can safely
fire for it.  The n_tok=3..4 case is reserved for a future PR with
that gating.

Bench impact (DGX Spark / GB10, ds4flash, n=256, "knight" prompt):

  Plain decode (no --mtp)                   16.11-16.12 t/s
  --mtp (default, --mtp-draft 1)            16.11-16.13 t/s  (canonical
                                            decode2_exact; combined-forward
                                            requires mtp_draft_tokens==2 so
                                            doesn't fire at the default)
  --mtp --mtp-draft 2                       17.14-17.15 t/s  (+1.0 t/s
                                            ABOVE plain decode!  PR7
                                            baseline was 16.29 t/s on this
                                            path -- this PR adds ~0.85 t/s
                                            to the combined-forward path)

This is the first time on this PR stack that MTP beats plain decode by
more than noise.  PR7's "MTP-beats-plain" stack name was aspirational
until now; this kernel makes the name true.

Tested:
  - `make clean && make cuda-spark` -- clean, no warnings
  - `make cpu` -- clean
  - `./ds4_test --all` -- only the pre-existing `--logprob-vectors
    short_code_completion` failure (same as upstream/main, PR5/6/7/8).
    `metal-tensor-equivalence`: OK.  `metal-kernels`: OK.
  - Plain decode output byte-identical to PR7 (only timing/cache lines
    differ in the diff -- 4 filtered lines, all `<` `>` `---` markers)
  - Dispatcher trace under plain decode: 0 share-warp hits (n_tok==2
    never reached without --mtp-draft 2)
  - Dispatcher trace under --mtp --mtp-draft 2: 1184 share-warp hits
    during 8-token gen -- combined-forward verifier saturates the path
  - Bench stability: 3 runs of mtp-draft=2 within 0.06 t/s (17.14, 17.15,
    17.14)
  - DS4_CUDA_NO_F16_SHARE_WARP=1 opt-out restores cuBLAS path

Hardware: NVIDIA DGX Spark (GB10 / sm_121), driver 580.142, CUDA 13.0
Model: DeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix.gguf
MTP: DeepSeek-V4-Flash-MTP-Q4K-Q8_0-F32.gguf
@TrevorS TrevorS force-pushed the mtp-beats-plain-kernels-v7 branch from fba7778 to ff5d8a4 Compare May 24, 2026 17:14
@TrevorS TrevorS force-pushed the mtp-beats-plain-kernels-v8 branch from 120c033 to cfddd4b Compare May 24, 2026 17:14
@TrevorS
Copy link
Copy Markdown
Owner Author

TrevorS commented May 24, 2026

Superseded by the reframed 2-PR stack (#11 + #12), which tells the same Spark/GB10 + MTP combined-forward story more concisely, rebased on current upstream/main, with the exploratory paths dropped.

@TrevorS TrevorS closed this May 24, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant