Skip to content

feat(kv): fused quantized decode SDPA — make KV codecs live at decode (read quant store directly) #45

@Pushkinist

Description

@Pushkinist

Problem

18 of 25 KV-cache quant codecs are inert at decode. exit_prefill materialises a bf16 K/V decode seed (crates/rmlx-kv-quant/src/kvcache/update.rs:2245-2250); every per-codec update_* early-returns to update_decode_fp16 whenever the bf16 seed is present. Per the cache contract (update.rs:2920-2925), the quantized store is frozen at the prefill length and never read back at decode. Decode SDPA reads bf16.

Consequences, confirmed by the Gemma4 e2b/e4b dual-axis sweep:

  • KV quant moves decode TPS within ±1% of none at every context (the quantized store is not on the decode read path).
  • KV quant inflates resident KV (k8v4 1.20×, k8v8 1.25×, planar 1.47× vs none at 64k) because the full bf16 seed mirror coexists with the frozen quantized prefix + scales — a quantized global layer can be larger than bf16.

Only Mixed / RotK (fused quantized_matmul) genuinely read the quant store at decode; rot_k_tq4v reads it but the slow way (full-prefix dequant every step, O(seq)). The rest are decode-cosmetic.

Goal

A decode-time fused quantized SDPA path that reads the quantized KV store directly via MLX quantized_matmul (dequant inside the attention matmul), never materialising the full bf16 prefix. This is the only architecture in which a KV codec reduces decode-resident KV and decode KV-read bandwidth.

Scope

  • Target the growing attention layers (full-attention / "global" layers; on Gemma4 these are the 7 global layers that grow KV O(ctx)). SWA windowed layers stay bf16 — their rotating ring is already window-bounded and tiny.
  • Start with one codec (q8 or q4 global K/V) as the proof of mechanism, behind the existing codec plumbing — not all 25 at once.
  • Per-step append = GPU-quantize the single new token's K/V into the quant ring; attention matmul runs over (quantized prefix + recent tokens).

Expected gain

  • Memory: global KV stored q4/q8 instead of bf16 → ~2–4× smaller resident KV on the decode window (e2b 64k: 780 MB → ~200–390 MB).
  • Bandwidth: decode reads q4/q8 KV instead of bf16 on the global layers → up to ~4× less KV-read → faster long-context decode. The only path to beat bf16-KV backends (mlx-lm) on long context rather than tie.
  • Makes the "widest weight × KV quant matrix" a real decode feature instead of a prefill-only / cosmetic one.

References

  • Working fused path to model after: Mixed / RotK routing at crates/rmlx-kv-quant/src/kvcache/update.rs:583, fused SDPA at crates/rmlx-kv-quant/src/kvcache/sdpa.rs:250.
  • Anti-pattern (do not copy): rot_k_tq4v full-prefix dequant per step (sdpa.rs:269,295) — O(seq), materialises the bf16 prefix.
  • Complementary, separate work: casting the none-path f32 global decode seed to bf16 (~74% of the current Gemma4 decode gap, codec-independent).

Acceptance criteria

  • At least one codec stores global-layer KV quantized and runs fused quantized decode SDPA (no full-prefix bf16 materialisation).
  • Coherence (temp=0 greedy) and a long-context retrieval check (NIAH-style) pass.
  • Measured resident-KV reduction AND decode TPS ≥ bf16 at long context on a global-layer-heavy model (Gemma4 e2b/e4b).
  • No regression on the SWA windowed path or on none.

Out of scope

  • Per-step CPU dequant (the slow rot_k_tq4v pattern).
  • SWA windowed layers (already bounded).
  • The f32→bf16 global-seed change (tracked separately).

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions