Skip to content

Quantized SDPA#3026

Open
CC-Yeh wants to merge 22 commits into
ml-explore:mainfrom
CC-Yeh:quantized_sdpa
Open

Quantized SDPA#3026
CC-Yeh wants to merge 22 commits into
ml-explore:mainfrom
CC-Yeh:quantized_sdpa

Conversation

@CC-Yeh
Copy link
Copy Markdown
Contributor

@CC-Yeh CC-Yeh commented Jan 20, 2026

Proposed changes

Add Metal quantized SDPA vector kernels based on #1515

Speedup vs fp16

speedup_main

TODO:

What improve performance:

  • Removed thread storage k, v to reduce register pressure (was waiting on synchronization).
  • Fused computation with dequantization
  • Tuned reading size ('uint16_t'/'uin32_t') for quantized k/v
  • Manual unroll better than clang loop optimizer

Checklist

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@awni
Copy link
Copy Markdown
Member

awni commented Jan 21, 2026

The numbers seem quite good.. a little too good to be true 😅

What's the difference between SDPA and Attention in the benchmark? Also what's the query sequence length used for the benchmark?

@CC-Yeh
Copy link
Copy Markdown
Contributor Author

CC-Yeh commented Jan 21, 2026

The numbers seem quite good.. a little too good to be true 😅

Totally agree, must be missing something 🤔

What's the difference between SDPA and Attention in the benchmark? Also what's the query sequence length used for the benchmark?

Attention is a simple reference implementation built from matmul + softmax + matmul (Maybe too naive?).
SDPA uses mx.fast.scaled_dot_product_attention, which hits the sdpa_vector_2pass kernels when Lq ≤ 8 (this case).

The query sequence length here is 1 (q.shape = (1, 32, 1, 128)), so this benchmark is measuring the single-token decode case, where one new token attends to a long KV cache (L = 32768).

@CC-Yeh
Copy link
Copy Markdown
Contributor Author

CC-Yeh commented Jan 21, 2026

@awni
Fixed some bugs in dequantizing 8bit and benchmark(unneccessary dequantization steps).
Now the numbers make more sense 😃

@awni
Copy link
Copy Markdown
Member

awni commented Jan 21, 2026

So if I’m understanding correctly the fused implementation is slower in the quantized case than the unfused ops-based one?

@CC-Yeh
Copy link
Copy Markdown
Contributor Author

CC-Yeh commented Jan 21, 2026

Fused SDPA is faster: MXFP4 15.33 ms vs 24.71 ms, and MXFP8 26.09 ms vs 46.48 ms to decode a single query.

@awni
Copy link
Copy Markdown
Member

awni commented Jan 21, 2026

Very nice!!

Comment thread mlx/fast.cpp Outdated
Comment on lines +875 to +878
if (qmode == QuantizationMode::Nvfp4) {
throw std::invalid_argument(
"[quantized_scaled_dot_product_attention] Mode 'nvfp4' is not supported for fast attention.");
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not nvfp4?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It’s on the way! I just wanted to make sure the PR structure was okay first.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added support

Comment thread mlx/fast.cpp Outdated
Comment on lines +871 to +874
if (qmode == QuantizationMode::Affine) {
throw std::invalid_argument(
"[quantized_scaled_dot_product_attention] Only fp quantization modes are supported.");
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not affine?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw not suggesting we necessarily do it. Maybe it's better to be more limited in the quants we support here. Maybe fp8, fp4 are fine to start?

For example I don't think it's necessary to support every bit width because in practice no-one will ever use 2, 3 for KV cache quantization.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added initial support, still has more room for tuning bit 2/3/5/6

@awni
Copy link
Copy Markdown
Member

awni commented Jan 27, 2026

@CC-Yeh I'm interested in this PR moving forward. Let me know if you have questions. Also no need to support everything on a first pass. I think doing one 8-bit (fp8 / int8) quant well for Metal / CUDA is already probably good enough to start.

@CC-Yeh CC-Yeh changed the title [WIP] Quantized SDPA Quantized SDPA Jan 29, 2026
@CC-Yeh CC-Yeh marked this pull request as ready for review January 29, 2026 22:04
@CC-Yeh
Copy link
Copy Markdown
Contributor Author

CC-Yeh commented Jan 29, 2026

@awni

I’ve added the Metal paths for mxfp4/8, nvfp4, and affine(2/3/4/5/6/8) (affine is not optimized).
Further tuning likely needs validation on other machines.

For the CUDA path (maybe next PR), Colab doesn’t support NVFP4, so would need help for that.

quant_sdpa_speedup_vs_seqlen

@CC-Yeh CC-Yeh requested a review from awni January 29, 2026 22:22
@awni
Copy link
Copy Markdown
Member

awni commented Jan 29, 2026

affine(2/3/4/5/6/8)

What group sizes did you do for that? I"m not convinced we need broad support for bitwidth X group size. I expect bits < 4 to be used rarely if ever.

@CC-Yeh
Copy link
Copy Markdown
Contributor Author

CC-Yeh commented Jan 29, 2026

affine(2/3/4/5/6/8)

What group sizes did you do for that? I"m not convinced we need broad support there. I expect < 4 to be used rarely.

What group sizes do you think we should support for affine? Currently it's templated so it can handle various
sizes, but I can limit the instantiations if there's a specific set that's practical.

template <typename T, int D, QuantMode mode, int group_size, int bits>
[[kernel]] void quant_sdpa_vector_2pass_1(

@awni
Copy link
Copy Markdown
Member

awni commented Jan 29, 2026

Yes totally. I think it's good to keep it generic. But probably better to limit initial support and grow than vice versa.

I would maybe start with bits = {4, 6, 8} and just group_size = 32. I think 32 is most flexible for the head dimension right?

@CC-Yeh
Copy link
Copy Markdown
Contributor Author

CC-Yeh commented Jan 30, 2026

Yes totally. I think it's good to keep it generic. But probably better to limit initial support and grow than vice versa.

I would maybe start with bits = {4, 6, 8} and just group_size = 32. I think 32 is most flexible for the head dimension right?

Limited the affine support.

Yeah, 32 is most flexible for head dim.

@CC-Yeh
Copy link
Copy Markdown
Contributor Author

CC-Yeh commented Feb 7, 2026

Hey @awni

Just fine-tuned the block sizes and GQA factors, and switched from template kernels to a function_constant approach to trade some cold-start latency for reduced binary size. Ready for review!

speedup_main

kname += "_";
kname += std::to_string(q.shape(-1));
kname += "_";
kname += std::to_string(q.shape(-1));
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you meant v.shape(-1) here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, in 2 pass kernels both values are the same.

@CC-Yeh
Copy link
Copy Markdown
Contributor Author

CC-Yeh commented Feb 25, 2026

@awni

Just wanted to gently bump this PR in case it got buried.
Happy to split it up or adjust anything if that helps with review.

Thanks!

@CC-Yeh
Copy link
Copy Markdown
Contributor Author

CC-Yeh commented Mar 2, 2026

Looping in @angeloskath @jagrit06 for visibility.

Happy to split it up or adjust anything if that helps with review.

Thanks!

@Thump604
Copy link
Copy Markdown

This is exciting — fusing KV dequantization into SDPA is exactly what large-context quantized inference needs.

I run Qwen3.5-122B (5-bit, 10B active MoE) daily on M2 Ultra 128GB with 192K enforced context. Long-context decode is the primary bottleneck once SpecPrefill handles TTFT. The 1.6–1.8x decode speedup at 32K KV cache would be transformative for my use case.

Happy to benchmark this on M2 Ultra once it's ready for testing — I can provide before/after numbers at 16K, 32K, 64K, and 128K KV cache lengths on the 122B model. Let me know if there's anything specific you'd like measured.

+1 for merge — this fills a significant performance gap vs llama.cpp for quantized long-context inference.

Nicolas-nwb added a commit to Nicolas-nwb/mlx-swift that referenced this pull request Apr 16, 2026
…(-1)

Update submodule mlx vers branche quantized-sdpa contenant le kernel
Metal quantized SDPA de la PR ml-explore/mlx#3026 avec fix du bug
de dispatch kernel name (v.shape(-1) → q.shape(-1)).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@CC-Yeh
Copy link
Copy Markdown
Contributor Author

CC-Yeh commented Apr 16, 2026

@Nicolas-nwb, @Thump604 thanks for the suggestions!

Improvements pushed.

dogukanveziroglu and others added 2 commits April 18, 2026 19:43
Add Affine dispatch entries for group_size=64 at bits={4,6,8} and
relax the validation in quantized_scaled_dot_product_attention.

This matches the default produced by mx.quantize(mode="affine") and
the kv_group_size=64 default used by mlx-lm, so users following the
MLX/mlx-lm conventions no longer hit an error when using fused
quantized attention.

Benchmarks (M4, B=1 H=32 D=128 Lq=1, affine 4-bit):
  Context  gs=32 fused   gs=64 fused   speedup
  32K      50 us         41 us         +22%
  64K      95 us         82 us         +16%
  128K     176 us        152 us        +16%

gs=64 is faster at long context because it has half the scale/bias
memory traffic.

Costs:
  mlx.metallib: 128,161,428 -> 128,233,236 bytes (+0.056%)
  libmlx.dylib: unchanged

Existing 10 test_quantized_sdpa* tests continue to pass (54 subtests).
Support group_size=64 for affine quantized SDPA
dedalien added a commit to dedalien/mlx-lm that referenced this pull request Apr 26, 2026
Builds on ml-explore/mlx#3026 (Dan Yeh) — the generic
quantized_scaled_dot_product_attention API with pluggable modes.
Companion PR: ml-explore/mlx#XXXX (adds turbo3/turbo4 to mlx core).

New files:
- mlx_lm/models/turbo_cache.py: TurboQuantKVCache(_BaseCache).
  Two-phase: prefill stores float16, generation returns
  (packed_uint32, float16_scales) for the fused kernel.
  On first generation step, prefill tokens are batch-compressed.
  K is WHT-rotated before quantization; V is not.
- mlx_lm/models/turbo_metal.py: two fused Metal kernels via
  mx.fast.metal_kernel, one thread per token, float32 registers:
  turbo_encode_metal (WHT + norm + codebook + pack) and
  wht_rotate_metal (used to pre-rotate Q before SDPA).

Modified files:
- mlx_lm/models/base.py: detects TurboQuantKVCache via isinstance,
  routes to _turbo_scaled_dot_product_attention. Q rotation in
  float32 (bfloat16 butterfly shifts softmax peaks on models with
  large key scales).
- mlx_lm/models/cache.py: make_turbo_cache() replaces KVCache with
  TurboQuantKVCache; leaves ArraysCache/DeltaNet layers untouched.
  fp16_layers= keeps first/last N attention layers in float16.
- mlx_lm/generate.py: generate_step gains kv_cache_type= and
  turbo_fp16_layers=; --kv-cache-type and --turbo-fp16-layers in CLI.
- tests/test_turbo_cache.py: 22 sub-cases covering WHT isometry,
  encode shapes, two-phase cache, D in {64,128,256}, gqa_factor=6,
  B=2, bfloat16, fp16_layers boundary.

Supported head dims: 64, 128, 256. Requires Metal GPU.

Tested on Qwen3.6-27B (head_dim=256, 24Q/4KV, GQA=6), 24 GB unified
memory Mac. Enables longer context generation on memory-constrained
hardware by compressing the KV cache ~5x.
@Nicolas-nwb
Copy link
Copy Markdown

Hi @CC-Yeh and @awni — we've been building on top of this PR for
transformer inference on Apple Silicon and wanted to share a few
additions from our fork that might be useful upstream.

Multi-row dispatch (qsl up to 32) — commit f0ac19b9 on our fork:

  • Moves qsl from the threadgroup z-dim to the grid z-dim. Threadgroup
    becomes (32, gqa_factor, 1) = 256 threads (well under the 1024 hw
    limit); grid z becomes blocks * qsl. The kernel decodes
    block_idx = tid.z / qsl and q_seq_idx = tid.z % qsl via a new
    function_constant(31) for q_seq_len.
  • Removes the qsl * gqa <= 32 constraint from the use_fallback gate,
    allowing qsl up to 32. This unblocks speculative decoding verifier
    batches (e.g. GQA 8:1 with K=8 → qsl=9) without falling back to FP16.
  • Adds head_dim == 512 to the fast-path gate (instantiations were
    already present in scaled_dot_product_attention.metal).
  • Backports dynamic BD=8 head_dim selection to sdpa_vector.h.

New tests added:

  • TEST 4: qsl ∈ {2, 4, 8, 16, 32} × D ∈ {128, 256, 512} × {fp16, bf16}
    with CHECK_FALSE(use_fallback) + max_err < 0.06 vs FP32 ref.
  • TEST 5: pure gate-logic contract for qsl=1..8, GQA 8:1, post-patch.

Debug spy: MLX_DEBUG_QUANT_SDPA=1 env var logs q/k shapes and
KERNEL/FALLBACK decisions to stderr — useful for Swift integration tests.

The changes live on the quantized-sdpa branch of our forks:

  • mlx (submodule): f0ac19b9
  • mlx-swift: 509a729

Happy to split these into a follow-up PR if that makes review easier, or
to rebase on top of the final merged state of this PR. Let us know what
works best.

@joelnishanth
Copy link
Copy Markdown

Benchmarks: Quantized SDPA + TurboQuant modes on M3 Pro (18 GB)

Built from the quantized_sdpa branch with @dedalien's TurboQuant3/4 modes (CC-Yeh#3) cherry-picked on top. All 12 quantized SDPA tests pass. Benchmark script measures single-token decode (qsl=1) with realistic GQA configs.

Setup: Apple M3 Pro, 18 GB, macOS 26, MLX 0.31.2.dev+f983ccad

Llama-3-8B config (D=128, GQA 4:1, 32 query / 8 KV heads)

Context FP16 affine-4b affine-8b turbo3 turbo4
1K 0.224ms 0.307ms (0.7x) 0.322ms (0.7x) 0.313ms (0.7x) 0.332ms (0.7x)
2K 0.246ms 0.193ms (1.3x) 0.197ms (1.3x) 0.403ms (0.6x) 0.442ms (0.6x)
4K 0.301ms 0.266ms (1.1x) 0.270ms (1.1x) 0.609ms (0.5x) 0.378ms (0.8x)
8K 0.493ms 0.359ms (1.4x) 0.395ms (1.2x) 0.487ms (1.0x) 0.503ms (1.0x)
16K 0.881ms 0.550ms (1.6x) 0.603ms (1.5x) 0.544ms (1.6x) 0.864ms (1.0x)
32K 1.740ms 0.946ms (1.8x) 1.084ms (1.6x) 0.943ms (1.8x) 1.267ms (1.4x)

Gemma-4 config (D=256, GQA 6:1, 24 query / 4 KV heads)

Context FP16 affine-4b affine-8b turbo3 turbo4
1K 0.170ms 0.247ms (0.7x) 0.267ms (0.6x) 0.426ms (0.4x) 0.657ms (0.3x)
4K 0.340ms 0.487ms (0.7x) 0.501ms (0.7x) 0.975ms (0.3x) 0.646ms (0.5x)
8K 0.598ms 0.731ms (0.8x) 0.734ms (0.8x) 1.068ms (0.6x) 0.794ms (0.8x)
32K 1.429ms 2.272ms (0.6x) 2.385ms (0.6x) 3.798ms (0.4x) 2.710ms (0.5x)

Memory compression (32K context)

Mode Llama (128MB fp16) Gemma (128MB fp16)
affine-4b 40.0MB (3.2x) 40.0MB (3.2x)
affine-8b 72.0MB (1.8x) 72.0MB (1.8x)
turbo3 25.0MB (5.1x) 24.5MB (5.2x)
turbo4 33.0MB (3.9x) 32.5MB (3.9x)

Key takeaways

  1. Affine-4b is the speed winner for decode. 1.8x speedup at 32K on Llama config — the fused dequant kernel clearly benefits from reduced memory bandwidth.

  2. TurboQuant turbo3 matches affine-4b speed at long context (32K) while delivering 60% more compression (5.1x vs 3.2x). The crossover happens around 16K tokens.

  3. Short-context overhead is real for TurboQuant — the codebook lookup path has higher per-dispatch cost. This is consistent with @Landon-Molt's findings on mlx#3404 about codebook vs scalar dequant.

  4. Gemma-4 (D=256) shows weaker speedups across the board — affine is slower than fp16 at all context lengths. This may be a tuning opportunity for the D=256 kernel specialization.

  5. TurboQuant's value proposition is memory, not speed. For memory-constrained inference (long context on 8-16 GB devices), turbo3's 5x compression enables contexts that wouldn't fit in fp16. The speed cost is manageable above 8K tokens.

Happy to run additional configs (D=512 for Gemma-4 31B, different batch sizes, M-series comparisons). The benchmark script is at joelnishanth/mlx-swift-turboquant and I'll push the Python version to a branch.

This PR + @dedalien's turbo modes together provide a compelling quantized attention stack. Would love to see it land.

Joel Nishanth · offlyn.AI

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants