cuda: F16 share-warp kernel for n_tok=2 combined-forward verifier (stacked on #8)#9
Closed
TrevorS wants to merge 1 commit into
Closed
cuda: F16 share-warp kernel for n_tok=2 combined-forward verifier (stacked on #8)#9TrevorS wants to merge 1 commit into
TrevorS wants to merge 1 commit into
Conversation
Adds matmul_f16_share_warp_kernel<N_TOK>, dispatched at n_tok=2 in
ds4_gpu_matmul_f16_tensor before the cuBLAS GemmEx path. Each warp
reads one F16 weight row once into 32 lanes and computes N_TOK
separate F32 dot-products against N_TOK token activation vectors,
serial-reduced by lane 0.
This is the path combined-forward MTP K=1 (--mtp-draft 2) uses for
every layer's Q/K/V/O/router/expert F16 projection in its batched
verifier forward. cuBLAS GemmEx at M=2 pads to the F16 tensor-core
MMA M-tile (16) and wastes 14/16 of the inner-product work, AND
requires an F32->F16 activation conversion into a scratch buffer
before the GEMM. The share-warp kernel skips both: it reads F32
activations directly and pays only the on-the-fly __half2float of
the weight bytes it touches anyway.
Bit-equality with N=1 matmul_f16_ordered_chunks_kernel (the kernel
that fires for N=1 plain decode under the default ordered-router
configuration) is preserved exactly: same contiguous-chunk lane
distribution `[lane*chunk .. lane*chunk+chunk-1]`, same serial
reduce by lane 0. For every output (row, token) the kernel produces
byte-identical output to running N separate N=1 ordered_chunks
matmuls -- which lets it fire safely regardless of caller (combined-
forward verifier, OR any prefill-time MoE expert that receives
exactly 2 routed tokens) without changing plain-decode greedy output.
Restricted to n_tok==2 (not 3..4) so the dispatch never fires for
prefill MoE experts receiving 3-4 tokens, which would route through
cuBLAS today and would change plain-decode greedy output if rerouted
through the share-warp (the F32 scalar accum vs cuBLAS F16 tensor-
core accum differs at ulp scale -- a precision improvement, but a
user-visible argmax shift on near-tied tokens). Combined-forward K=1
is exactly n_tok=2; K=2 (combined-forward N=3) would be n_tok=3, but
combined-K=2 is not the current default (drafts[1] staleness; see
ds4_session_eval_speculative_argmax_combined comments) and would
need a combined-forward-context gate before the share-warp can safely
fire for it. The n_tok=3..4 case is reserved for a future PR with
that gating.
Bench impact (DGX Spark / GB10, ds4flash, n=256, "knight" prompt):
Plain decode (no --mtp) 16.11-16.12 t/s
--mtp (default, --mtp-draft 1) 16.11-16.13 t/s (canonical
decode2_exact; combined-forward
requires mtp_draft_tokens==2 so
doesn't fire at the default)
--mtp --mtp-draft 2 17.14-17.15 t/s (+1.0 t/s
ABOVE plain decode! PR7
baseline was 16.29 t/s on this
path -- this PR adds ~0.85 t/s
to the combined-forward path)
This is the first time on this PR stack that MTP beats plain decode by
more than noise. PR7's "MTP-beats-plain" stack name was aspirational
until now; this kernel makes the name true.
Tested:
- `make clean && make cuda-spark` -- clean, no warnings
- `make cpu` -- clean
- `./ds4_test --all` -- only the pre-existing `--logprob-vectors
short_code_completion` failure (same as upstream/main, PR5/6/7/8).
`metal-tensor-equivalence`: OK. `metal-kernels`: OK.
- Plain decode output byte-identical to PR7 (only timing/cache lines
differ in the diff -- 4 filtered lines, all `<` `>` `---` markers)
- Dispatcher trace under plain decode: 0 share-warp hits (n_tok==2
never reached without --mtp-draft 2)
- Dispatcher trace under --mtp --mtp-draft 2: 1184 share-warp hits
during 8-token gen -- combined-forward verifier saturates the path
- Bench stability: 3 runs of mtp-draft=2 within 0.06 t/s (17.14, 17.15,
17.14)
- DS4_CUDA_NO_F16_SHARE_WARP=1 opt-out restores cuBLAS path
Hardware: NVIDIA DGX Spark (GB10 / sm_121), driver 580.142, CUDA 13.0
Model: DeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix.gguf
MTP: DeepSeek-V4-Flash-MTP-Q4K-Q8_0-F32.gguf
fba7778 to
ff5d8a4
Compare
120c033 to
cfddd4b
Compare
Owner
Author
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
PR9: cuda: F16 share-warp kernel for n_tok=2 combined-forward verifier (stacked on #8)
Summary
MTP now beats plain decode by +1.0 t/s with
--mtp-draft 2, the first measurable win on this PR stack.Adds
matmul_f16_share_warp_kernel<N_TOK>, dispatched atn_tok=2inds4_gpu_matmul_f16_tensorbefore the cuBLAS GemmEx path. Each warp reads one F16 weight row once into 32 lanes and computesN_TOKseparate F32 dot-products againstN_TOKtoken activation vectors, serial-reduced by lane 0.This is the path combined-forward MTP K=1 (
--mtp-draft 2) uses for every layer's Q/K/V/O/router/expert F16 projection in its batched verifier forward. cuBLAS GemmEx at M=2 wastes 14/16 of the inner-product work on M-tile padding AND requires an F32→F16 activation conversion pass. The share-warp kernel skips both.Bench headline
--mtpdefault (canonical, no combined-forward)--mtp --mtp-draft 2(combined-forward K=1)The stack's name "mtp-beats-plain" is finally true at the default-precision path.
Why this works
Combined-forward MTP K=1 issues a batched-N=2 forward (
[first_token, drafts[0]]) through the verifier graph. At each layer that pass triggers many F16 matmul calls withn_tok=2:cuBLAS GemmEx at M=2 must pad to the F16 tensor-core MMA M-tile (16): each tensor-core invocation computes 16 useful M-rows but the dispatcher only needs 2, so 14/16 = 87.5% of M-axis arithmetic is wasted. cuBLAS also requires F16×F16 inputs, so the dispatcher allocates
n_tok * in_dim * 2bytes of scratch and runsf32_to_f16_kernelbefore the GEMM.The share-warp kernel eliminates both:
__half2floatof the weight bytes it'd touch anywayBit-equality with N=1
For every output
(row, token), this kernel produces byte-identical output to running N separate N=1matmul_f16_ordered_chunks_kernelcalls — the kernel that fires for N=1 plain decode under the defaultordered_routerconfiguration.The argument:
ordered_chunks_kernelshare_warp_kernel<N_TOK>[lane*chunk .. lane*chunk+chunk-1]sum += __half2float(wr[i]) * xr[i](fma-contracted to onefma(wval, xval, sum))acc[t] += wval * xval(same fma)total += partial[i]fori=0..31So for token
t, this kernel's output for rowrowequals N=1ordered_chunks_kernel(row, t)byte-for-byte.Why n_tok==2 only (not 3..4)
Restricted to exact
n_tok==2so the dispatch never fires for prefill MoE experts receiving 3-4 routed tokens. Those would route through cuBLAS today; switching them to the share-warp would change plain-decode greedy output. The F32 scalar accumulation in share-warp is more precise than cuBLAS F16 tensor-core accumulation, but the resulting ulp-scale shift can change argmax on near-tied tokens — a user-visible side-effect we don't want to introduce for plain decode.Combined-forward K=1 is exactly
n_tok=2. K=2 (combined-forward N=3) would ben_tok=3, but combined-K=2 is not the current default (drafts[1] staleness; seeds4_session_eval_speculative_argmax_combinedcomments). Then_tok=3..4case is reserved for a future PR with a combined-forward-context gate that lets us safely fire the share-warp under that label only.What this PR is and isn't
--mtp(canonical decode2 path; needs--mtp-draft 2to activate combined-forward)mtp_draft_tokens(separately scoped — that's a CLI-level choice with its own correctness story around accept-rate)Tested
make clean && make cuda-spark— cleanmake cpu— clean./ds4_test --all— only the pre-existing--logprob-vectors short_code_completionfailure (same asupstream/main, PR5/6/7/8).metal-tensor-equivalence: OK.metal-kernels: OK.--mtp --mtp-draft 2: 1184 hits during 8-token gen--mtp-draft 2: 17.14 / 17.15 / 17.14 t/sDS4_CUDA_NO_F16_SHARE_WARP=1opt-out restores cuBLAS pathHardware: NVIDIA DGX Spark (GB10 / sm_121), driver 580.142, CUDA 13.0
Model:
DeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix.ggufMTP:
DeepSeek-V4-Flash-MTP-Q4K-Q8_0-F32.ggufAGENT.md compliance
DS4_CUDA_NO_F16_SHARE_WARP=1is the diagnostic opt-out only.DS4_CUDA_NO_F16_SHARE_WARPis the validation knob.Out of scope / follow-ups
mtp_draft_tokenschange (CLI-level, separate concern)