Skip to content

cuda: replace K4 cuBLAS-cache gate with strict-mode gate (stacked on #6)#7

Closed
TrevorS wants to merge 1 commit into
mtp-beats-plain-kernels-v5from
mtp-beats-plain-kernels-v6
Closed

cuda: replace K4 cuBLAS-cache gate with strict-mode gate (stacked on #6)#7
TrevorS wants to merge 1 commit into
mtp-beats-plain-kernels-v5from
mtp-beats-plain-kernels-v6

Conversation

@TrevorS
Copy link
Copy Markdown
Owner

@TrevorS TrevorS commented May 24, 2026

PR7: cuda: replace K4 cuBLAS-cache gate with strict-mode gate (stacked on #6)

Summary

PR #3's K4 commit (584de5e on the v2 branch) added a cuBLAS-cache-availability gate to the share-warp Q8 matmul dispatcher. The intent was correctness — prevent share-warp from displacing cuBLAS where cuBLAS would have handled the weight. On DGX Spark this gate is empirically always-false: every Q8 weight has a cached F16 copy by the time the dispatcher runs, so share-warp NEVER fires at n_tok=2..4. This silently routes the Q_B and output_proj_b matmuls (the only V4-Flash matmuls with blocks=32) through cuBLAS' small-M tensor-core path which pads M=2..4 → M=16 and wastes ~7/8 of the inner-product work.

The actual correctness concern is narrower: under DS4_MTP_STRICT (or --quality), users require byte-equality with plain decode. Share-warp is not bit-identical to cuBLAS Gemm at small M (different reduction order), so strict-mode must fall through to cuBLAS. In non-strict mode this drift is acceptable — same Option-B pattern as the combined-forward gate in PR #6.

Replaces the cuBLAS-cache-availability check with !strict_mtp_env. The blocks <= 32u constraint is preserved — share-warp is only bit-equal to the N=1 reference at blocks ≤ 32 (verified empirically during bisect: dropping that constraint causes ds4_test --all to fail long_memory_archive greedy-equivalence).

Bench impact (DGX Spark, ds4flash, n=256, "knight")

Mode PR #6 baseline This PR Δ
Default --mtp (combined-forward fires) 15.56 16.20 +3.8% (+0.64 t/s)
DS4_MTP_STRICT=1 --mtp (canonical) 13.83 13.83 unchanged (strict refuses share-warp → cuBLAS)
Plain decode (no --mtp) 16.60 16.60 unchanged

How this PR was discovered

Bisect investigation triggered by noticing mtp-beats-plain @ 45ba7613 (downstream source) hits 19.7 t/s under combined-forward, while PR #6 (faithful cherry-pick port) only hits 15.6. Bisect localized 73% of the regression to this gate. The remaining 5% (16.2 → 19.7) is base-tree drift not localized to any single commit — likely from the captured-graph subsystem absent on this stack (which isn't a behavior we want to chase without proper investigation).

Why not drop blocks <= 32u too?

The bisect agent's full revert of K4 hit 18.6 t/s but didn't run ds4_test --all. Verified that dropping blocks <= 32u here causes share-warp to fire for matmuls where it's NOT bit-equal to the warp8 reference (different block-loop reduction order), failing ds4_test --all long_memory_archive with greedy_fail=4 top1_mismatch=1. The +2.5 t/s from dropping that constraint requires a kernel-level rewrite to make share-warp bit-equal at large block counts — separate PR.

Tested against

AGENT.md compliance

  • "Preserve correctness before speed" — strict-mode preserves byte-equality with PR mtp: combined-forward default + Option-B strict fallback (stacked on #5) #6 (canonical fallback). Non-strict mode change is documented; the drift between share-warp and cuBLAS at small M is bounded and surfaces only as different argmaxes under temp=0 in non-strict configurations
  • "Do not add permanent semantic variants behind flags" — no new flag. Uses the pre-existing DS4_MTP_STRICT env knob to gate the fast path
  • "Diagnostic switches are fine when they validate the one release path"DS4_CUDA_NO_Q8_SHARE_BATCH=1 opt-out is preserved as the diagnostic kill switch

Out of scope / follow-ups

  • The +2.5 t/s further win (16.2 → 19.7-ish) is locked behind kernel-level work to make share-warp bit-equal at large block counts
  • The captured-graph subsystem (held back from PR stack) may also be part of the residual gap

…rp dispatch

PR #3's K4 commit (584de5e on the v2 branch) added a cuBLAS-cache-
availability gate to the share-warp Q8 matmul dispatch in
cuda_matmul_q8_0_tensor_labeled. The intent was correctness: prevent
share-warp from displacing cuBLAS where cuBLAS would have handled the
weight. On DGX Spark this gate is empirically always-false: every Q8
weight has a cached F16 copy by the time the dispatcher runs, so
share-warp NEVER fires at n_tok=2..4 with blocks<=32 -- which means
the Q_B and output_proj_b matmuls (the only V4-Flash matmuls with
blocks=32) silently route through cuBLAS' small-M tensor-core path
which pads M=2..4 to M=16 and wastes ~7/8 of the inner-product work.

The actual correctness concern is narrower: under DS4_MTP_STRICT (or
--quality), users require byte-equality with plain decode. Share-warp
is not bit-identical to cuBLAS Gemm at small M (different reduction
order), so strict-mode must fall through to cuBLAS. In non-strict
mode this drift is acceptable -- it matches PR #6's combined-forward
Option-B pattern (same env knob selects strict vs perf).

Replace the cuBLAS-cache-availability check with `!strict_mtp_env`.
Same opt-out shape as the combined-forward gate in
ds4_session_eval_speculative_argmax. The `blocks <= 32u` constraint
is preserved (share-warp is bit-equal to N=1 warp8 only at blocks<=32;
larger block counts drift from the batch_warp8 reference and would
fail ds4_test --all long-context tensor equivalence -- verified
empirically during bisect).

Bench impact (DGX Spark, ds4flash, n=256, "knight" prompt):
  Default `--mtp` (combined-forward fires)   15.6 -> 16.20 t/s  (+3.8%)
  `DS4_MTP_STRICT=1 --mtp` (canonical)       13.83 -> 13.83     unchanged
  Plain decode                               16.60 -> 16.60     unchanged

Strict-mode byte-equality vs PR #6 baseline confirmed (diff empty).
`./ds4_test --all` shows the same 1 pre-existing failure as PR #6
(`logprob-vectors short_code_completion`, also fails on upstream/main).
@TrevorS TrevorS force-pushed the mtp-beats-plain-kernels-v5 branch from bb85073 to dcb5cc3 Compare May 24, 2026 17:14
@TrevorS TrevorS force-pushed the mtp-beats-plain-kernels-v6 branch from bb27595 to 28ff7ce Compare May 24, 2026 17:14
@TrevorS
Copy link
Copy Markdown
Owner Author

TrevorS commented May 24, 2026

Superseded by the reframed 2-PR stack (#11 + #12), which tells the same Spark/GB10 + MTP combined-forward story more concisely, rebased on current upstream/main, with the exploratory paths dropped.

@TrevorS TrevorS closed this May 24, 2026
@TrevorS TrevorS deleted the mtp-beats-plain-kernels-v6 branch May 24, 2026 22:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant