Quantized SDPA#3026
Conversation
b64b7dc to
11b24f5
Compare
|
The numbers seem quite good.. a little too good to be true 😅 What's the difference between SDPA and Attention in the benchmark? Also what's the query sequence length used for the benchmark? |
11b24f5 to
640ec94
Compare
Totally agree, must be missing something 🤔
Attention is a simple reference implementation built from The query sequence length here is 1 (q.shape = (1, 32, 1, 128)), so this benchmark is measuring the single-token decode case, where one new token attends to a long KV cache (L = 32768). |
|
@awni |
|
So if I’m understanding correctly the fused implementation is slower in the quantized case than the unfused ops-based one? |
|
Fused SDPA is faster: |
|
Very nice!! |
| if (qmode == QuantizationMode::Nvfp4) { | ||
| throw std::invalid_argument( | ||
| "[quantized_scaled_dot_product_attention] Mode 'nvfp4' is not supported for fast attention."); | ||
| } |
There was a problem hiding this comment.
It’s on the way! I just wanted to make sure the PR structure was okay first.
| if (qmode == QuantizationMode::Affine) { | ||
| throw std::invalid_argument( | ||
| "[quantized_scaled_dot_product_attention] Only fp quantization modes are supported."); | ||
| } |
There was a problem hiding this comment.
Btw not suggesting we necessarily do it. Maybe it's better to be more limited in the quants we support here. Maybe fp8, fp4 are fine to start?
For example I don't think it's necessary to support every bit width because in practice no-one will ever use 2, 3 for KV cache quantization.
There was a problem hiding this comment.
Added initial support, still has more room for tuning bit 2/3/5/6
|
@CC-Yeh I'm interested in this PR moving forward. Let me know if you have questions. Also no need to support everything on a first pass. I think doing one 8-bit (fp8 / int8) quant well for Metal / CUDA is already probably good enough to start. |
f3dc49d to
5af4060
Compare
What group sizes did you do for that? I"m not convinced we need broad support for bitwidth X group size. I expect bits < 4 to be used rarely if ever. |
3bc3e28 to
c72fad9
Compare
What group sizes do you think we should support for affine? Currently it's templated so it can handle various template <typename T, int D, QuantMode mode, int group_size, int bits>
[[kernel]] void quant_sdpa_vector_2pass_1( |
|
Yes totally. I think it's good to keep it generic. But probably better to limit initial support and grow than vice versa. I would maybe start with bits = {4, 6, 8} and just group_size = 32. I think 32 is most flexible for the head dimension right? |
Limited the affine support. Yeah, 32 is most flexible for head dim. |
75231ee to
d692162
Compare
|
Hey @awni Just fine-tuned the block sizes and GQA factors, and switched from template kernels to a function_constant approach to trade some cold-start latency for reduced binary size. Ready for review!
|
1e8bfc1 to
4a25689
Compare
| kname += "_"; | ||
| kname += std::to_string(q.shape(-1)); | ||
| kname += "_"; | ||
| kname += std::to_string(q.shape(-1)); |
There was a problem hiding this comment.
Yeah, in 2 pass kernels both values are the same.
0046b95 to
48650a4
Compare
|
Just wanted to gently bump this PR in case it got buried. Thanks! |
|
Looping in @angeloskath @jagrit06 for visibility. Happy to split it up or adjust anything if that helps with review. Thanks! |
|
This is exciting — fusing KV dequantization into SDPA is exactly what large-context quantized inference needs. I run Qwen3.5-122B (5-bit, 10B active MoE) daily on M2 Ultra 128GB with 192K enforced context. Long-context decode is the primary bottleneck once SpecPrefill handles TTFT. The 1.6–1.8x decode speedup at 32K KV cache would be transformative for my use case. Happy to benchmark this on M2 Ultra once it's ready for testing — I can provide before/after numbers at 16K, 32K, 64K, and 128K KV cache lengths on the 122B model. Let me know if there's anything specific you'd like measured. +1 for merge — this fills a significant performance gap vs llama.cpp for quantized long-context inference. |
…(-1) Update submodule mlx vers branche quantized-sdpa contenant le kernel Metal quantized SDPA de la PR ml-explore/mlx#3026 avec fix du bug de dispatch kernel name (v.shape(-1) → q.shape(-1)). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
@Nicolas-nwb, @Thump604 thanks for the suggestions! Improvements pushed. |
Add Affine dispatch entries for group_size=64 at bits={4,6,8} and
relax the validation in quantized_scaled_dot_product_attention.
This matches the default produced by mx.quantize(mode="affine") and
the kv_group_size=64 default used by mlx-lm, so users following the
MLX/mlx-lm conventions no longer hit an error when using fused
quantized attention.
Benchmarks (M4, B=1 H=32 D=128 Lq=1, affine 4-bit):
Context gs=32 fused gs=64 fused speedup
32K 50 us 41 us +22%
64K 95 us 82 us +16%
128K 176 us 152 us +16%
gs=64 is faster at long context because it has half the scale/bias
memory traffic.
Costs:
mlx.metallib: 128,161,428 -> 128,233,236 bytes (+0.056%)
libmlx.dylib: unchanged
Existing 10 test_quantized_sdpa* tests continue to pass (54 subtests).
Support group_size=64 for affine quantized SDPA
Builds on ml-explore/mlx#3026 (Dan Yeh) — the generic quantized_scaled_dot_product_attention API with pluggable modes. Companion PR: ml-explore/mlx#XXXX (adds turbo3/turbo4 to mlx core). New files: - mlx_lm/models/turbo_cache.py: TurboQuantKVCache(_BaseCache). Two-phase: prefill stores float16, generation returns (packed_uint32, float16_scales) for the fused kernel. On first generation step, prefill tokens are batch-compressed. K is WHT-rotated before quantization; V is not. - mlx_lm/models/turbo_metal.py: two fused Metal kernels via mx.fast.metal_kernel, one thread per token, float32 registers: turbo_encode_metal (WHT + norm + codebook + pack) and wht_rotate_metal (used to pre-rotate Q before SDPA). Modified files: - mlx_lm/models/base.py: detects TurboQuantKVCache via isinstance, routes to _turbo_scaled_dot_product_attention. Q rotation in float32 (bfloat16 butterfly shifts softmax peaks on models with large key scales). - mlx_lm/models/cache.py: make_turbo_cache() replaces KVCache with TurboQuantKVCache; leaves ArraysCache/DeltaNet layers untouched. fp16_layers= keeps first/last N attention layers in float16. - mlx_lm/generate.py: generate_step gains kv_cache_type= and turbo_fp16_layers=; --kv-cache-type and --turbo-fp16-layers in CLI. - tests/test_turbo_cache.py: 22 sub-cases covering WHT isometry, encode shapes, two-phase cache, D in {64,128,256}, gqa_factor=6, B=2, bfloat16, fp16_layers boundary. Supported head dims: 64, 128, 256. Requires Metal GPU. Tested on Qwen3.6-27B (head_dim=256, 24Q/4KV, GQA=6), 24 GB unified memory Mac. Enables longer context generation on memory-constrained hardware by compressing the KV cache ~5x.
|
Hi @CC-Yeh and @awni — we've been building on top of this PR for Multi-row dispatch (qsl up to 32) — commit f0ac19b9 on our fork:
New tests added:
Debug spy: The changes live on the
Happy to split these into a follow-up PR if that makes review easier, or |
|
Benchmarks: Quantized SDPA + TurboQuant modes on M3 Pro (18 GB) Built from the Setup: Apple M3 Pro, 18 GB, macOS 26, MLX Llama-3-8B config (D=128, GQA 4:1, 32 query / 8 KV heads)
Gemma-4 config (D=256, GQA 6:1, 24 query / 4 KV heads)
Memory compression (32K context)
Key takeaways
Happy to run additional configs (D=512 for Gemma-4 31B, different batch sizes, M-series comparisons). The benchmark script is at joelnishanth/mlx-swift-turboquant and I'll push the Python version to a branch. This PR + @dedalien's turbo modes together provide a compelling quantized attention stack. Would love to see it land. |


Proposed changes
Add Metal quantized SDPA vector kernels based on #1515
Speedup vs fp16
TODO:
AffineandNVFP4What improve performance:
k/vclangloop optimizerChecklist
pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes