From 8bda2e20386ed37497f03947d03303a45726b56a Mon Sep 17 00:00:00 2001 From: Dan Yeh Date: Tue, 20 Jan 2026 21:30:06 +0100 Subject: [PATCH 01/21] first attempt --- .gitignore | 1 + benchmarks/python/sdpa_vector_bench.py | 135 +++++---- mlx/backend/metal/kernels/fp_quantized_nax.h | 2 + .../metal/kernels/fp_quantized_nax.metal | 2 - .../scaled_dot_product_attention.metal | 22 ++ mlx/backend/metal/kernels/sdpa_vector.h | 239 +++++++++++++++ .../metal/scaled_dot_product_attention.cpp | 222 ++++++++++++++ mlx/fast.cpp | 277 ++++++++++++++++++ mlx/fast.h | 14 + mlx/fast_primitives.h | 41 +++ python/src/fast.cpp | 37 +++ python/tests/test_quantized.py | 34 +++ 12 files changed, 956 insertions(+), 70 deletions(-) diff --git a/.gitignore b/.gitignore index 1daaa46d12..3a1ff6a3a0 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,7 @@ parts/ sdist/ var/ venv/ +.venv/ wheels/ share/python-wheels/ *.egg-info/ diff --git a/benchmarks/python/sdpa_vector_bench.py b/benchmarks/python/sdpa_vector_bench.py index 546bff84c2..c98b2e35d9 100644 --- a/benchmarks/python/sdpa_vector_bench.py +++ b/benchmarks/python/sdpa_vector_bench.py @@ -1,95 +1,94 @@ -import argparse -import math - import mlx.core as mx +from mlx.utils import tree_map from time_utils import time_fn -L = 16384 +L = 32768 H = 32 H_k = H // 4 D = 128 -V = 128 dtype = mx.float16 -loops = 10 - +bits = 4 +mode = "mxfp8" if bits == 8 else "mxfp4" -def upproject(x, w): - if w is None: - return x - else: - return x @ w.T +loops = 20 -def attention(q, k, v, mask=None, w=None): - def _sdpa(q, k, v): - B, Hq, L, D = q.shape +def attention(q, k, v): + for _ in range(loops): + B, Hq, Lq, Dq = q.shape _, Hk, S, _ = k.shape - _, _, _, V = v.shape - q = q.reshape(B, Hk, Hq // Hk, L, D) - k = k[:, :, None, :, :] - v = v[:, :, None, :, :] - s = q @ k.transpose(0, 1, 2, 4, 3) - if mask is not None: - m = mx.broadcast_to(mask, (B, Hq, L, S)).reshape(B, Hk, Hq // Hk, L, S) - s = mx.where(m, s, mx.finfo(s.dtype).min) - p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype) - o = p @ v - return o.reshape(B, Hq, L, V) - - for i in range(loops): - q = _sdpa(q, k, v) - q = upproject(q, w) + q = q.reshape(B, Hk, Hq // Hk, Lq, Dq) + ke = k[:, :, None, :, :] + ve = v[:, :, None, :, :] + scores = q @ ke.transpose(0, 1, 2, 4, 3) + probs = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) + q = probs @ ve + q = q.reshape(B, Hq, Lq, Dq) return q -def sdpa(q, k, v, mask=None, w=None): - for i in range(loops): - q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask) - q = upproject(q, w) +def sdpa(q, k, v): + for _ in range(loops): + q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None) return q -def time_self_attention_primitives(): - mx.random.seed(3) - q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) - k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) - v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype) - w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None - mx.eval(q, k, v, w) - time_fn(attention, q, k, v, w=w) +def quant_sdpa(q, k, v, bits=4, mode="mxfp4"): + for _ in range(loops): + q = mx.fast.quantized_scaled_dot_product_attention( + q, *k, *v, scale=1.0, mask=None, bits=bits, mode=mode + ) + return q -def time_self_attention_sdpa(): - mx.random.seed(3) - q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) - k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) - v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype) - w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None - mx.eval(q, k, v, w) - time_fn(sdpa, q, k, v, w=w) +def quant_attention(q, k, v, bits=4, mode="mxfp4"): + for _ in range(loops): + B, Hq, Lq, Dq = q.shape + Hk = k[0].shape[1] + q = q.reshape((B, Hk, Hq // Hk, Lq, Dq)) + ke = tree_map(lambda x: mx.expand_dims(x, axis=2), k) + ve = tree_map(lambda x: mx.expand_dims(x, axis=2), v) + + scores = mx.quantized_matmul(q, *ke, transpose=True, bits=bits, mode=mode) + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) + + q = mx.quantized_matmul(scores, *ve, transpose=False, bits=bits, mode=mode) + q = q.reshape((B, Hq, Lq, Dq)) + return q + + +def time_self_attention_primitives(q, k, v): + time_fn(attention, q, k, v) -def time_self_attention_sdpa_with_mask(): - mx.random.seed(3) - q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) - k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) - v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype) - w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None - mask = mx.full((L,), True) - mask[L // 2 :] = False - mx.eval(q, k, v, mask, w) - def sdpa_mask(*args): - return sdpa(*args, mask=mask, w=w) +def time_self_attention_sdpa(q, k, v): + time_fn(sdpa, q, k, v) - def attention_mask(*args): - return attention(*args, mask=mask, w=w) - time_fn(attention_mask, q, k, v) - time_fn(sdpa_mask, q, k, v) +def time_self_attention_quant_sdpa(q, k, v, bits, mode): + time_fn(quant_sdpa, q, k, v, bits, mode) + + +def time_self_attention_quant_primitives(q, k, v, bits, mode): + time_fn(quant_attention, q, k, v, bits, mode) if __name__ == "__main__": - time_self_attention_sdpa() - time_self_attention_primitives() - time_self_attention_sdpa_with_mask() + mx.random.seed(3) + q = mx.random.uniform(shape=(1, H, 1, D), dtype=dtype) + k = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype) + v = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype) + mx.eval(q, k, v) + + k_quant = mx.quantize(k, bits=bits, mode=mode) + v_quant = mx.quantize(v, bits=bits, mode=mode) + mx.eval(k_quant, v_quant) + + k = mx.dequantize(*k_quant, bits=bits, mode=mode) + v = mx.dequantize(*v_quant, bits=bits, mode=mode) + + time_self_attention_sdpa(q, k, v) + time_self_attention_quant_sdpa(q, k_quant, v_quant, bits, mode) + time_self_attention_primitives(q, k, v) + time_self_attention_quant_primitives(q, k_quant, v_quant, bits, mode) diff --git a/mlx/backend/metal/kernels/fp_quantized_nax.h b/mlx/backend/metal/kernels/fp_quantized_nax.h index 381bc6c7d3..6a660eb760 100644 --- a/mlx/backend/metal/kernels/fp_quantized_nax.h +++ b/mlx/backend/metal/kernels/fp_quantized_nax.h @@ -5,6 +5,8 @@ #include "mlx/backend/metal/kernels/fp4.h" #include "mlx/backend/metal/kernels/fp8.h" +#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" +#include "mlx/backend/metal/kernels/steel/utils.h" constant bool align_M [[function_constant(200)]]; constant bool align_N [[function_constant(201)]]; diff --git a/mlx/backend/metal/kernels/fp_quantized_nax.metal b/mlx/backend/metal/kernels/fp_quantized_nax.metal index 4d65a384d3..e96c508f7b 100644 --- a/mlx/backend/metal/kernels/fp_quantized_nax.metal +++ b/mlx/backend/metal/kernels/fp_quantized_nax.metal @@ -2,8 +2,6 @@ // clang-format off #include "mlx/backend/metal/kernels/utils.h" -#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" -#include "mlx/backend/metal/kernels/quantized_utils.h" #include "mlx/backend/metal/kernels/steel/gemm/nax.h" #include "mlx/backend/metal/kernels/fp_quantized_nax.h" diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index c668d9d8c5..ab18d0286f 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -41,4 +41,26 @@ using namespace metal; instantiate_sdpa_vector_heads(float) instantiate_sdpa_vector_heads(bfloat16_t) instantiate_sdpa_vector_heads(float16_t) + +// Quantized SDPA vector instantiations +#define instantiate_quant_sdpa_vector(type, head_dim, group_size, bits) \ + instantiate_kernel( \ + "quant_sdpa_vector_2pass_1_" #type "_" #head_dim "_" #group_size "_" #bits, \ + quant_sdpa_vector_2pass_1, \ + type, \ + head_dim, \ + group_size, \ + bits) + +#define instantiate_quant_sdpa_vector_group_size(type, heads) \ + instantiate_quant_sdpa_vector(type, heads, 32, 4) \ + instantiate_quant_sdpa_vector(type, heads, 32, 8) + +#define instantiate_quant_sdpa_vector_heads(type) \ + instantiate_quant_sdpa_vector_group_size(type, 64) \ + instantiate_quant_sdpa_vector_group_size(type, 128) + +instantiate_quant_sdpa_vector_heads(float) +instantiate_quant_sdpa_vector_heads(bfloat16_t) +instantiate_quant_sdpa_vector_heads(float16_t) // clang-format on diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 1eec72be31..3dfdbe87c8 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -2,6 +2,8 @@ #include +#include "mlx/backend/metal/kernels/fp_quantized.h" + using namespace metal; constant bool has_mask [[function_constant(20)]]; @@ -176,6 +178,243 @@ template } } +template +METAL_FUNC void load_queries(const device T* queries, thread U* q, U scale) { + for (int i = 0; i < elem_per_thread; i++) { + q[i] = scale * queries[i]; + } +} + +template +[[gnu::always_inline]] METAL_FUNC U +dot_key(const thread U* q, const device uint32_t* keys) { + U score = 0; + if (bits == 4) { + auto ks = (const device uint16_t*)keys; +#pragma clang loop unroll(full) + for (int j = 0; j < elem_per_thread / 4; j++) { + uint16_t p = ks[j]; + score += q[4 * j + 0] * Dequantize<4, U>{}(p & 0xF); + score += q[4 * j + 1] * Dequantize<4, U>{}((p >> 4) & 0xF); + score += q[4 * j + 2] * Dequantize<4, U>{}((p >> 8) & 0xF); + score += q[4 * j + 3] * Dequantize<4, U>{}(p >> 12); + } + } else { // 8-bit + constexpr int pack_factor = 32 / bits; +#pragma clang loop unroll(full) + for (int j = 0; j < elem_per_thread / pack_factor; j++) { +#pragma clang loop unroll(full) + for (int k = 0; k < pack_factor; k++) { + score += q[pack_factor * j + k] * + Dequantize{}((keys[j] >> (k * bits)) & 0x0f); + } + } + } + return score; +} + +template +[[gnu::always_inline]] METAL_FUNC void accumulate_values( + thread U* o, + const device uint32_t* values, + U factor, + U w_scale) { + if (bits == 4) { + auto vs = (const device uint16_t*)values; +#pragma clang loop unroll(full) + for (int j = 0; j < elem_per_thread / 4; j++) { + uint16_t p = vs[j]; + U v[] = { + Dequantize<4, U>{}(p & 0xF), + Dequantize<4, U>{}((p >> 4) & 0xF), + Dequantize<4, U>{}((p >> 8) & 0xF), + Dequantize<4, U>{}(p >> 12)}; +#pragma clang loop unroll(full) + for (int k = 0; k < 4; k++) + o[4 * j + k] = fma(o[4 * j + k], factor, v[k] * w_scale); + } + } else { // 8-bit + constexpr int pack_factor = 32 / bits; +#pragma clang loop unroll(full) + for (int j = 0; j < elem_per_thread / pack_factor; j++) { +#pragma clang loop unroll(full) + for (int k = 0; k < pack_factor; k++) { + o[j] = + fma(o[j], + factor, + Dequantize<8, U>{}((values[j] >> (k * bits)) & 0x0f) * w_scale); + } + } + } +} + +template +[[kernel]] void quant_sdpa_vector_2pass_1( + const device T* queries [[buffer(0)]], + const device uint32_t* keys [[buffer(1)]], + const device uint8_t* key_scales [[buffer(2)]], + const device uint32_t* values [[buffer(3)]], + const device uint8_t* value_scales [[buffer(4)]], + device float* out [[buffer(5)]], + device float* sums [[buffer(6)]], + device float* maxs [[buffer(7)]], + const constant int& gqa_factor [[buffer(8)]], + const constant int& N [[buffer(9)]], + const constant size_t& k_stride [[buffer(10)]], + const constant size_t& v_stride [[buffer(11)]], + const constant size_t& k_group_stride [[buffer(12)]], + const constant size_t& v_group_stride [[buffer(13)]], + const constant float& scale [[buffer(14)]], + const device bool* bmask [[buffer(15), function_constant(bool_mask)]], + const device T* fmask [[buffer(16), function_constant(float_mask)]], + const constant int& mask_kv_seq_stride + [[buffer(17), function_constant(has_mask)]], + const constant int& mask_q_seq_stride + [[buffer(18), function_constant(has_mask)]], + const constant int& mask_head_stride + [[buffer(19), function_constant(has_mask)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 tpg [[threadgroups_per_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]]) { + constexpr int BN = 16; + constexpr int BD = 4; + constexpr int elem_per_thread = D / BD; + constexpr int blocks = 32; + constexpr int pack_factor = 32 / bits; + + const int stride = BN * D; + + typedef float U; + + thread U q[elem_per_thread]; + thread U o[elem_per_thread] = {0}; + + threadgroup U outputs[BN * BD]; + threadgroup U max_scores[BN]; + threadgroup U sum_exp_scores[BN]; + + const int block_idx = tid.z; + const int q_batch_head_idx = tid.x; + const int q_seq_idx = tid.y; + const int kv_head_idx = q_batch_head_idx / gqa_factor; + const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx; + const int q_offset = + query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset; + + queries += q_offset * D + quad_lid * elem_per_thread; + + const int kv_idx = + (block_idx * BN + quad_gid) * D + quad_lid * elem_per_thread; + const int packed_idx = kv_idx / pack_factor; + const int k_group_idx = kv_head_idx * k_group_stride + kv_idx / group_size; + const int v_group_idx = kv_head_idx * v_group_stride + kv_idx / group_size; + + keys += kv_head_idx * k_stride + packed_idx; + key_scales += k_group_idx; + values += kv_head_idx * v_stride + packed_idx; + value_scales += v_group_idx; + + out += o_offset * blocks * D + block_idx * D + quad_lid * elem_per_thread; + sums += o_offset * blocks + block_idx; + maxs += o_offset * blocks + block_idx; + + if (bool_mask) { + bmask += q_batch_head_idx * mask_head_stride + + (block_idx * BN + quad_gid) * mask_kv_seq_stride + + q_seq_idx * mask_q_seq_stride; + } + if (float_mask) { + fmask += q_batch_head_idx * mask_head_stride + + (block_idx * BN + quad_gid) * mask_kv_seq_stride + + q_seq_idx * mask_q_seq_stride; + } + + load_queries(queries, q, static_cast(scale)); + + U max_score = Limits::finite_min; + U sum_exp_score = 0; + + for (int i = block_idx * BN + quad_gid; i < N; i += blocks * BN) { + bool use_key = true; + if (do_causal) { + use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); + } else if (bool_mask) { + use_key = bmask[0]; + } else if (float_mask) { + use_key = (fmask[0] >= Limits::finite_min); + } + + if (use_key) { + // Compute attention score: dot(q, dequantize(k)) * scale + U key_scale = dequantize_scale(key_scales[0]); + U score = dot_key(q, keys) * key_scale; + score = quad_sum(score); + + if (float_mask) { + score += static_cast(fmask[0]); + } + + // Online softmax update + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + + U value_scale = dequantize_scale(value_scales[0]); + U weighted_val_scale = exp_score * value_scale; + accumulate_values( + o, values, factor, weighted_val_scale); + } + + keys += blocks * stride / pack_factor; + key_scales += blocks * stride / group_size; + values += blocks * stride / pack_factor; + value_scales += blocks * stride / group_size; + if (bool_mask) { + bmask += BN * blocks * mask_kv_seq_stride; + } + if (float_mask) { + fmask += BN * blocks * mask_kv_seq_stride; + } + } + + if (quad_lid == 0) { + max_scores[quad_gid] = max_score; + sum_exp_scores[quad_gid] = sum_exp_score; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + max_score = (simd_lid < BN) ? max_scores[simd_lid] : Limits::finite_min; + U new_max = simd_max(max_score); + U factor = fast::exp(max_score - new_max); + sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0; + sum_exp_score = simd_sum(sum_exp_score * factor); + + if (simd_gid == 0) { + sums[0] = sum_exp_score; + maxs[0] = new_max; + } + + for (int i = 0; i < elem_per_thread; i++) { + outputs[quad_lid * BN + quad_gid] = + o[i] * fast::exp(max_scores[quad_gid] - new_max); + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (quad_gid == 0) { + U output = outputs[quad_lid * BN]; + for (int j = 1; j < BN; j++) { + output += outputs[quad_lid * BN + j]; + } + out[i] = output; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } +} + template [[kernel]] void sdpa_vector_2pass_1( const device T* queries [[buffer(0)]], diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index c79cd51ff0..c8d3d9ceb9 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -583,6 +583,134 @@ void sdpa_vector_2pass( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } +void quant_sdpa_vector_2pass( + const Stream& s, + metal::Device& d, + const array& q, + const array& k, + const array& k_scales, + const array& v, + const array& v_scales, + array& out, + float scale, + int group_size, + int bits, + bool do_causal, + const std::optional& mask) { + std::string kname; + kname.reserve(96); + kname += "quant_sdpa_vector_2pass_1_"; + kname += get_type_string(q.dtype()); + kname += "_"; + kname += std::to_string(q.shape(-1)); + kname += "_"; + kname += std::to_string(group_size); + kname += "_"; + kname += std::to_string(bits); + + int gqa_factor = q.shape(1) / k.shape(1); + int N = k.shape(2); + int blocks = 32; + int B = q.shape(0) * q.shape(1); + + size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1); + size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1); + size_t k_group_stride = + k_scales.shape(1) == 1 ? k_scales.strides(0) : k_scales.strides(1); + size_t v_group_stride = + v_scales.shape(1) == 1 ? v_scales.strides(0) : v_scales.strides(1); + + MTL::Size group_dims(16 * 4, 1, 1); + MTL::Size grid_dims(B, q.shape(2), blocks); + + Shape intermediate_shape; + intermediate_shape.reserve(out.ndim() + 1); + intermediate_shape.insert( + intermediate_shape.end(), out.shape().begin(), out.shape().end() - 1); + intermediate_shape.push_back(blocks); + intermediate_shape.push_back(out.shape().back()); + array intermediate(intermediate_shape, float32, nullptr, {}); + intermediate_shape.pop_back(); + array sums(intermediate_shape, float32, nullptr, {}); + array maxs(std::move(intermediate_shape), float32, nullptr, {}); + intermediate.set_data(allocator::malloc(intermediate.nbytes())); + sums.set_data(allocator::malloc(sums.nbytes())); + maxs.set_data(allocator::malloc(maxs.nbytes())); + d.add_temporary(intermediate, s.index); + d.add_temporary(sums, s.index); + d.add_temporary(maxs, s.index); + + bool has_mask = mask.has_value(); + bool bool_mask = has_mask && (*mask).dtype() == bool_; + bool float_mask = has_mask && !bool_mask; + bool query_transposed = !q.flags().row_contiguous; + bool has_sinks = false; + metal::MTLFCList func_consts = { + {&has_mask, MTL::DataType::DataTypeBool, 20}, + {&query_transposed, MTL::DataType::DataTypeBool, 21}, + {&do_causal, MTL::DataType::DataTypeBool, 22}, + {&bool_mask, MTL::DataType::DataTypeBool, 23}, + {&float_mask, MTL::DataType::DataTypeBool, 24}, + {&has_sinks, MTL::DataType::DataTypeBool, 25}, + }; + std::string hash_name = kname; + hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask"; + hash_name += query_transposed ? "_qt" : "_qnt"; + hash_name += do_causal ? "_c" : "_nc"; + + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname, hash_name, func_consts); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(q, 0); + compute_encoder.set_input_array(k, 1); + compute_encoder.set_input_array(k_scales, 2); + compute_encoder.set_input_array(v, 3); + compute_encoder.set_input_array(v_scales, 4); + compute_encoder.set_output_array(intermediate, 5); + compute_encoder.set_output_array(sums, 6); + compute_encoder.set_output_array(maxs, 7); + compute_encoder.set_bytes(gqa_factor, 8); + compute_encoder.set_bytes(N, 9); + compute_encoder.set_bytes(k_head_stride, 10); + compute_encoder.set_bytes(v_head_stride, 11); + compute_encoder.set_bytes(k_group_stride, 12); + compute_encoder.set_bytes(v_group_stride, 13); + compute_encoder.set_bytes(scale, 14); + + if (has_mask) { + auto& m = *mask; + compute_encoder.set_input_array(m, 15 + float_mask); + int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0; + int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0; + int32_t head_stride = + m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0); + compute_encoder.set_bytes(kv_seq_stride, 17); + compute_encoder.set_bytes(q_seq_stride, 18); + compute_encoder.set_bytes(head_stride, 19); + } + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + + kname.clear(); + kname += "sdpa_vector_2pass_2_"; + kname += get_type_string(q.dtype()); + kname += "_"; + kname += std::to_string(q.shape(-1)); + + kernel = d.get_kernel(kname); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(intermediate, 0); + compute_encoder.set_input_array(sums, 1); + compute_encoder.set_input_array(maxs, 2); + compute_encoder.set_output_array(out, 3); + + group_dims = MTL::Size(1024, 1, 1); + grid_dims = MTL::Size(B, q.shape(2), 1); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + } // namespace bool ScaledDotProductAttention::use_fallback( @@ -785,6 +913,100 @@ void ScaledDotProductAttention::eval_gpu( metal::get_command_encoder(s).add_temporaries(std::move(copies)); } +void QuantizedScaledDotProductAttention::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + auto& s = stream(); + auto& d = metal::device(s.device); + + auto& q_pre = inputs[0]; + auto& k_pre = inputs[1]; + auto& k_scales_pre = inputs[2]; + auto& v_pre = inputs[3]; + auto& v_scales_pre = inputs[4]; + auto& o = outputs[0]; + + std::vector copies; + copies.reserve(inputs.size()); + + auto copy_unless = [&copies, &s]( + auto predicate, const array& arr) -> const array& { + if (!predicate(arr)) { + array arr_copy = contiguous_copy_gpu(arr, s); + copies.push_back(std::move(arr_copy)); + return copies.back(); + } else { + return arr; + } + }; + + auto q_copy_unless = [](const array& arr) { + if (arr.flags().row_contiguous) { + return true; + } + auto& strides = arr.strides(); + auto& shape = arr.shape(); + if (shape[0] == 1 || shape[1] == 1) { + auto bidx = shape[0] == 1 ? 1 : 0; + return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) && + (strides[bidx] == shape[3]); + } + return false; + }; + + auto kv_copy_unless = [](const array& arr) { + auto& strides = arr.strides(); + auto& shape = arr.shape(); + if (strides.back() != 1) { + return false; + } + if (shape[0] == 1 || shape[1] == 1) { + return true; + } + return (strides[0] == strides[1] * shape[1]); + }; + + const auto& q = copy_unless(q_copy_unless, q_pre); + const auto& k = copy_unless(kv_copy_unless, k_pre); + const auto& k_scales = copy_unless(kv_copy_unless, k_scales_pre); + const auto& v = copy_unless(kv_copy_unless, v_pre); + const auto& v_scales = copy_unless(kv_copy_unless, v_scales_pre); + + std::optional mask = std::nullopt; + if (needs_mask_) { + auto mask_copy_unless = [&q](const array& arr) { + auto& strides = arr.strides(); + auto& shape = arr.shape(); + return arr.flags().row_contiguous || q.shape(0) == 1 || q.shape(1) == 1 || + (strides[0] == strides[1] * shape[1]); + }; + mask = copy_unless(mask_copy_unless, inputs.back()); + } + + if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) { + o.copy_shared_buffer(q); + } else { + o.set_data(allocator::malloc(o.nbytes())); + } + + quant_sdpa_vector_2pass( + s, + d, + q, + k, + k_scales, + v, + v_scales, + o, + scale_, + group_size_, + bits_, + /* do_causal = */ false, + mask); + + d.add_temporaries(std::move(copies), s.index); +} + bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) { return true; } diff --git a/mlx/fast.cpp b/mlx/fast.cpp index a668fe9abd..4cedac55cb 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -5,6 +5,7 @@ #include "mlx/fast.h" #include "mlx/fast_primitives.h" #include "mlx/ops.h" +#include "mlx/primitives.h" #include "mlx/transforms.h" #include "mlx/transforms_impl.h" @@ -861,6 +862,273 @@ array scaled_dot_product_attention( return fallback(std::move(inputs))[0]; } +array quantized_scaled_dot_product_attention( + const array& queries, + const array& keys, + const array& key_scales, + const array& values, + const array& value_scales, + const float scale, + const std::optional& mask /* = std::nullopt */, + const std::optional group_size_ /* = std::nullopt */, + const std::optional bits_ /* = std::nullopt */, + const std::string& mode /* = "mxfp4" */, + StreamOrDevice s /* = {} */) { + for (const auto& tensor : {queries, keys, key_scales, values, value_scales}) { + if (tensor.ndim() != 4) { + std::ostringstream msg; + msg << "[quantized_scaled_dot_product_attention] input with shape " + << tensor.shape() << " expected to be rank 4"; + throw std::invalid_argument(msg.str()); + } + } + + auto qmode = string_to_quantization_mode( + mode, "quantized_scaled_dot_product_attention"); + if (qmode == QuantizationMode::Affine) { + throw std::invalid_argument( + "[quantized_scaled_dot_product_attention] Only fp quantization modes are supported."); + } + if (qmode == QuantizationMode::Nvfp4) { + throw std::invalid_argument( + "[quantized_scaled_dot_product_attention] Mode 'nvfp4' is not supported for fast attention."); + } + + auto expected_params = [](QuantizationMode mode) -> std::pair { + switch (mode) { + case QuantizationMode::Mxfp4: + return {32, 4}; + case QuantizationMode::Mxfp8: + return {32, 8}; + case QuantizationMode::Nvfp4: + return {16, 4}; + default: + return {0, 0}; + } + }; + + auto [expected_group_size, expected_bits] = expected_params(qmode); + int group_size = group_size_.value_or(expected_group_size); + int bits = bits_.value_or(expected_bits); + if (group_size != expected_group_size || bits != expected_bits) { + std::ostringstream msg; + msg << "[quantized_scaled_dot_product_attention] Quantization mode '" + << mode << "' requires group_size " << expected_group_size + << " and bits " << expected_bits << " but received group_size " + << group_size << " and bits " << bits << "."; + throw std::invalid_argument(msg.str()); + } + + if (bits != 4 && bits != 8) { + std::ostringstream msg; + msg << "[quantized_scaled_dot_product_attention] Unsupported bits " << bits + << ". Supported bits are 4 and 8."; + throw std::invalid_argument(msg.str()); + } + + if (key_scales.dtype() != uint8 || value_scales.dtype() != uint8) { + throw std::invalid_argument( + "[quantized_scaled_dot_product_attention] Scales must be uint8 for fp quantization."); + } + if (keys.dtype() != uint32 || values.dtype() != uint32) { + throw std::invalid_argument( + "[quantized_scaled_dot_product_attention] Keys and values must be packed quantized arrays of type uint32."); + } + + auto el_per_int = 32 / bits; + + const size_t batch_dim = queries.shape(0); + for (const auto& tensor : {keys, values, key_scales, value_scales}) { + if (tensor.shape(0) != batch_dim) { + std::ostringstream msg; + msg << "[quantized_scaled_dot_product_attention] mismatching batch dimension for input with shape " + << tensor.shape() << "."; + throw std::invalid_argument(msg.str()); + } + } + + auto n_q_heads = queries.shape(-3); + auto n_kv_heads = keys.shape(-3); + if (n_q_heads % n_kv_heads != 0) { + std::ostringstream msg; + msg << "[quantized_scaled_dot_product_attention] n_heads must be a multiple of n_kv_heads, found n_heads " + << n_q_heads << " for n_kv_heads " << n_kv_heads << "."; + throw std::invalid_argument(msg.str()); + } + + if (keys.shape(-3) != values.shape(-3)) { + std::ostringstream msg; + msg << "[quantized_scaled_dot_product_attention] keys, values expected to have matching n_kv_heads; found keys with n_heads " + << keys.shape(-3) << " for values with n_heads " << values.shape(-3) + << "."; + throw std::invalid_argument(msg.str()); + } + + if (queries.shape(-1) != keys.shape(-1) * el_per_int) { + std::ostringstream msg; + msg << "[quantized_scaled_dot_product_attention] query, keys expected to have matching last dimension; found query shape " + << queries.shape() << " for keys shape " << keys.shape() << "."; + throw std::invalid_argument(msg.str()); + } + + if (queries.shape(-1) != values.shape(-1) * el_per_int) { + std::ostringstream msg; + msg << "[quantized_scaled_dot_product_attention] query, values expected to have matching last dimension; found query shape " + << queries.shape() << " for values shape " << values.shape() << "."; + throw std::invalid_argument(msg.str()); + } + + if (queries.shape(-1) % group_size != 0) { + std::ostringstream msg; + msg << "[quantized_scaled_dot_product_attention] head dimension " + << queries.shape(-1) << " must be divisible by group_size " + << group_size << "."; + throw std::invalid_argument(msg.str()); + } + + auto expected_scale_dim = queries.shape(-1) / group_size; + for (const auto& tensor : {key_scales, value_scales}) { + if (tensor.shape(-1) != expected_scale_dim) { + std::ostringstream msg; + msg << "[quantized_scaled_dot_product_attention] Scales expected to have " + << expected_scale_dim << " elements in the last dimension but found " + << tensor.shape(); + throw std::invalid_argument(msg.str()); + } + } + if (key_scales.shape(-3) != keys.shape(-3) || + key_scales.shape(-2) != keys.shape(-2) || + value_scales.shape(-3) != values.shape(-3) || + value_scales.shape(-2) != values.shape(-2)) { + throw std::invalid_argument( + "[quantized_scaled_dot_product_attention] Scale shapes must match key/value batch, head, and sequence dimensions."); + } + + auto final_type = queries.dtype(); + if (!issubdtype(final_type, floating)) { + std::ostringstream msg; + msg << "[quantized_scaled_dot_product_attention] Received unsupported type " + << final_type << "."; + throw std::invalid_argument(msg.str()); + } + + bool needs_mask = mask.has_value(); + bool has_bool_mask = needs_mask && mask->dtype() == bool_; + if (needs_mask && mask->ndim() > 4) { + std::ostringstream msg; + msg << "[quantized_scaled_dot_product_attention] the mask with shape " + << mask->shape() << " expected to have at most rank 4."; + throw std::invalid_argument(msg.str()); + } + + auto q = astype(queries, final_type, s); + auto fallback = + [scale, n_q_heads, n_kv_heads, needs_mask, group_size, bits, mode, s]( + const std::vector& inputs) { + auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s); + int n_repeats = n_q_heads / n_kv_heads; + + auto k = inputs[1]; + auto k_scales = inputs[2]; + auto v = inputs[3]; + auto v_scales = inputs[4]; + + std::optional mask = + needs_mask ? std::optional{inputs[5]} : std::nullopt; + + if (n_repeats > 1) { + q = reshape( + q, {q.shape(0), n_kv_heads, n_repeats, q.shape(2), -1}, s); + k = expand_dims(k, 2, s); + k_scales = expand_dims(k_scales, 2, s); + v = expand_dims(v, 2, s); + v_scales = expand_dims(v_scales, 2, s); + } + + auto scores = quantized_matmul( + q, + k, + k_scales, + std::nullopt, + /*transpose=*/true, + group_size, + bits, + mode, + s); + if (mask) { + auto m = *mask; + if (n_repeats > 1 && m.ndim() >= 3) { + if (m.shape(-3) == 1) { + m = expand_dims(m, -3, s); + } else { + m = unflatten(m, -3, {n_kv_heads, n_repeats}, s); + } + } + if (m.dtype() == bool_) { + scores = where( + m, scores, array(finfo(scores.dtype()).min, scores.dtype()), s); + } else { + scores = add(scores, m, s); + } + } + + scores = softmax(scores, std::vector{-1}, true, s); + auto out = quantized_matmul( + scores, + v, + v_scales, + std::nullopt, + /*transpose=*/false, + group_size, + bits, + mode, + s); + if (n_repeats > 1) { + out = reshape(out, {out.shape(0), n_q_heads, out.shape(2), -1}, s); + } + return std::vector{out}; + }; + + auto stream = to_stream(s); + std::vector inputs = {q, keys, key_scales, values, value_scales}; + if (needs_mask) { + if (promote_types(mask->dtype(), final_type) != final_type && + mask->dtype() != bool_) { + std::ostringstream msg; + msg << "[quantized_scaled_dot_product_attention] Mask type must promote to output type " + << final_type << "."; + throw std::invalid_argument(msg.str()); + } + if (!has_bool_mask && mask->dtype() != final_type) { + inputs.push_back(astype(*mask, final_type, stream)); + } else { + inputs.push_back(*mask); + } + auto mask_shape = queries.shape(); + mask_shape.back() = keys.shape(-2); + inputs.back() = broadcast_to(inputs.back(), mask_shape, stream); + } + + int out_dim = values.shape(-1) * el_per_int; + Shape out_shape{ + queries.shape(0), queries.shape(1), queries.shape(2), out_dim}; + + bool supported_type = (queries.dtype() == float32) || + (queries.dtype() == float16) || (queries.dtype() == bfloat16); + bool unsupported = detail::in_grad_tracing() || + stream.device == Device::cpu || queries.shape(2) > 8 || + (queries.shape(2) > keys.shape(2)) || + !(queries.shape(-1) == 64 || queries.shape(-1) == 128) || !supported_type; + + if (unsupported) { + return fallback(std::move(inputs))[0]; + } + + auto primitive = std::make_shared( + stream, fallback, scale, needs_mask, group_size, bits, qmode); + return array(std::move(out_shape), final_type, primitive, std::move(inputs)); +} + std::vector ScaledDotProductAttention::vjp( const std::vector& primals, const std::vector& cotangents, @@ -915,6 +1183,15 @@ bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const { output_logsumexp_ == a_other.output_logsumexp_; } +bool QuantizedScaledDotProductAttention::is_equivalent( + const Primitive& other) const { + const QuantizedScaledDotProductAttention& a_other = + static_cast(other); + return scale_ == a_other.scale_ && needs_mask_ == a_other.needs_mask_ && + group_size_ == a_other.group_size_ && bits_ == a_other.bits_ && + mode_ == a_other.mode_; +} + bool ScaledDotProductAttentionVJP::is_equivalent(const Primitive& other) const { const ScaledDotProductAttentionVJP& a_other = static_cast(other); diff --git a/mlx/fast.h b/mlx/fast.h index 1183aba8fe..ffa1a034f6 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -54,6 +54,20 @@ MLX_API array scaled_dot_product_attention( const std::optional& sinks = {}, StreamOrDevice s = {}); +/** Computes: `O = softmax(Q @ K.T) @ V` where K and V are quantized. **/ +array quantized_scaled_dot_product_attention( + const array& queries, + const array& keys, + const array& key_scales, + const array& values, + const array& value_scales, + const float scale, + const std::optional& mask = std::nullopt, + std::optional group_size = std::nullopt, + std::optional bits = std::nullopt, + const std::string& mode = "mxfp4", + StreamOrDevice s = {}); + using TemplateArg = std::variant; using ScalarArg = std::variant; diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 4434830875..4ef110c855 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -260,6 +260,47 @@ class ScaledDotProductAttention : public Custom { bool output_logsumexp_; }; +class QuantizedScaledDotProductAttention : public Custom { + public: + QuantizedScaledDotProductAttention( + Stream stream, + std::function(std::vector)> fallback, + float scale, + bool needs_mask, + int group_size, + int bits, + QuantizationMode mode) + : Custom(stream, std::move(fallback)), + scale_(scale), + needs_mask_(needs_mask), + group_size_(group_size), + bits_(bits), + mode_(mode) {} + + void eval_cpu(const std::vector&, std::vector&) override { + throw std::runtime_error("NYI"); + } + + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + bool is_equivalent(const Primitive& other) const override; + + DEFINE_NAME(QuantizedScaledDotProductAttention); + DEFINE_INPUT_OUTPUT_SHAPE() + auto state() const { + return std::make_tuple( + nullptr, scale_, needs_mask_, group_size_, bits_, mode_); + } + + private: + float scale_; + bool needs_mask_; + int group_size_; + int bits_; + QuantizationMode mode_; +}; + class ScaledDotProductAttentionVJP : public Custom { public: ScaledDotProductAttentionVJP( diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 1a43d89d9b..7a42e6fe43 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -296,6 +296,43 @@ void init_fast(nb::module_& parent_module) { out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask="causal") )pbdoc"); + m.def( + "quantized_scaled_dot_product_attention", + &mx::fast::quantized_scaled_dot_product_attention, + "q"_a, + "k"_a, + "k_scales"_a, + "v"_a, + "v_scales"_a, + nb::kw_only(), + "scale"_a, + "mask"_a = nb::none(), + "group_size"_a = nb::none(), + "bits"_a = nb::none(), + "mode"_a = "mxfp4", + "stream"_a = nb::none(), + nb::sig( + "def quantized_scaled_dot_product_attention(q: array, k: array, k_scales: array, v: array, v_scales: array, *, scale: float, mask: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = \"mxfp4\", stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + A fast implementation of multi-head attention where the keys and values are quantized. + + see :func:`scaled_dot_product_attention` for more details. + + Args: + q (array): Input query array. + k (array): Input keys array. + k_scales (array): ``uint8`` scales for the fp-quantized keys array. + v (array): Input values array. + v_scales (array): ``uint8`` scales for the fp-quantized values array. + scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``) + mask (array, optional): An additive or boolean mask to apply to the query-key scores. + group_size (int, optional): The group size used in the KV quantization. Defaults follow the quantization ``mode``. + bits (int, optional): The bits used in the KV quantization. Defaults follow the quantization ``mode``. + mode (str, optional): The fp quantization mode, ``"mxfp4"`` or ``"mxfp8"``. + Returns: + array: The output array. + )pbdoc"); + m.def( "metal_kernel", [](const std::string& name, diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index a7472e9920..87408a28a3 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -694,6 +694,7 @@ def test_non_multiples(self): self.assertEqual(y_q.shape, y_hat.shape) self.assertLess((y_q - y_hat).abs().max(), 1e-3) +<<<<<<< HEAD def test_qmv_small_non_multiples(self): # Test very small K and N dimensions (e.g., [MxK] x [NxK].T = [MxN]) # Each tuple is (M, K, N) representing input rows, weight cols, weight rows @@ -753,6 +754,39 @@ def test_qmv_small_non_multiples(self): self.assertEqual(y_q.shape, y_hat.shape) self.assertLess((y_q - y_hat).abs().max(), 1e-3) + def test_quantized_sdpa(self): + if mx.default_device() == mx.cpu: + self.skipTest("Quantized fast attention is only available on GPU.") + + mx.random.seed(0) + B, Hq, Hkv = 1, 2, 1 + Lq, Lk, D = 4, 640, 128 + + for mode in ["mxfp4", "mxfp8"]: + bits = 8 if mode == "mxfp8" else 4 + q = 0.1 * mx.random.normal(shape=(B, Hq, Lq, D)) + k = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + v = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + + k_q, k_scales = mx.quantize(k, mode=mode) + v_q, v_scales = mx.quantize(v, mode=mode) + + ref = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0) + out = mx.fast.quantized_scaled_dot_product_attention( + q, + k_q, + k_scales, + v_q, + v_scales, + scale=1.0, + mode=mode, + bits=bits, + ) + + self.assertEqual(out.shape, ref.shape) + tol = 5e-2 if bits == 4 else 2e-2 + self.assertLess((out - ref).abs().max(), tol) + def test_gather_qmm(self): def quantize(w, transpose=True, group_size=None, bits=None, mode="affine"): if mode == "affine": From e585934f92f41bf7535a1f3eedda0b1ac2327d20 Mon Sep 17 00:00:00 2001 From: Dan Yeh Date: Wed, 21 Jan 2026 02:59:48 +0100 Subject: [PATCH 02/21] fix --- benchmarks/python/sdpa_vector_bench.py | 3 --- mlx/backend/metal/kernels/fp_quantized.h | 2 ++ mlx/backend/metal/kernels/sdpa_vector.h | 10 +++++----- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/benchmarks/python/sdpa_vector_bench.py b/benchmarks/python/sdpa_vector_bench.py index c98b2e35d9..836700fd03 100644 --- a/benchmarks/python/sdpa_vector_bench.py +++ b/benchmarks/python/sdpa_vector_bench.py @@ -85,9 +85,6 @@ def time_self_attention_quant_primitives(q, k, v, bits, mode): v_quant = mx.quantize(v, bits=bits, mode=mode) mx.eval(k_quant, v_quant) - k = mx.dequantize(*k_quant, bits=bits, mode=mode) - v = mx.dequantize(*v_quant, bits=bits, mode=mode) - time_self_attention_sdpa(q, k, v) time_self_attention_quant_sdpa(q, k_quant, v_quant, bits, mode) time_self_attention_primitives(q, k, v) diff --git a/mlx/backend/metal/kernels/fp_quantized.h b/mlx/backend/metal/kernels/fp_quantized.h index cc9b68ade8..e6684b5880 100644 --- a/mlx/backend/metal/kernels/fp_quantized.h +++ b/mlx/backend/metal/kernels/fp_quantized.h @@ -5,6 +5,8 @@ #include "mlx/backend/metal/kernels/fp4.h" #include "mlx/backend/metal/kernels/fp8.h" +#include "mlx/backend/metal/kernels/steel/gemm/loader.h" +#include "mlx/backend/metal/kernels/steel/gemm/mma.h" constant bool align_M [[function_constant(200)]]; constant bool align_N [[function_constant(201)]]; diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 3dfdbe87c8..4b2f5f6efa 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -206,7 +206,7 @@ dot_key(const thread U* q, const device uint32_t* keys) { #pragma clang loop unroll(full) for (int k = 0; k < pack_factor; k++) { score += q[pack_factor * j + k] * - Dequantize{}((keys[j] >> (k * bits)) & 0x0f); + Dequantize{}((keys[j] >> (k * bits)) & 0xff); } } } @@ -239,10 +239,10 @@ template for (int j = 0; j < elem_per_thread / pack_factor; j++) { #pragma clang loop unroll(full) for (int k = 0; k < pack_factor; k++) { - o[j] = - fma(o[j], - factor, - Dequantize<8, U>{}((values[j] >> (k * bits)) & 0x0f) * w_scale); + o[pack_factor * j + k] = fma( + o[pack_factor * j + k], + factor, + Dequantize{}((values[j] >> (k * bits)) & 0xff) * w_scale); } } } From c1c5a4b0cf885bce940c6c873e5d754cad1c6f86 Mon Sep 17 00:00:00 2001 From: Dan Yeh Date: Wed, 21 Jan 2026 23:30:06 +0100 Subject: [PATCH 03/21] Unify mxfp4/8 paths and optimize mxfp8 fused calculation --- mlx/backend/metal/kernels/sdpa_vector.h | 72 ++++++++++--------------- 1 file changed, 28 insertions(+), 44 deletions(-) diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 4b2f5f6efa..36c65528a8 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -188,27 +188,20 @@ METAL_FUNC void load_queries(const device T* queries, thread U* q, U scale) { template [[gnu::always_inline]] METAL_FUNC U dot_key(const thread U* q, const device uint32_t* keys) { + using LoadT = typename conditional::type; + + constexpr uint32_t mask = (1 << bits) - 1; + auto ks = (const device LoadT*)keys; U score = 0; - if (bits == 4) { - auto ks = (const device uint16_t*)keys; -#pragma clang loop unroll(full) - for (int j = 0; j < elem_per_thread / 4; j++) { - uint16_t p = ks[j]; - score += q[4 * j + 0] * Dequantize<4, U>{}(p & 0xF); - score += q[4 * j + 1] * Dequantize<4, U>{}((p >> 4) & 0xF); - score += q[4 * j + 2] * Dequantize<4, U>{}((p >> 8) & 0xF); - score += q[4 * j + 3] * Dequantize<4, U>{}(p >> 12); - } - } else { // 8-bit - constexpr int pack_factor = 32 / bits; -#pragma clang loop unroll(full) - for (int j = 0; j < elem_per_thread / pack_factor; j++) { + #pragma clang loop unroll(full) - for (int k = 0; k < pack_factor; k++) { - score += q[pack_factor * j + k] * - Dequantize{}((keys[j] >> (k * bits)) & 0xff); - } - } + for (int j = 0; j < elem_per_thread / 4; j++) { + LoadT p = ks[j]; + + score += q[4 * j + 0] * Dequantize{}(p & mask); + score += q[4 * j + 1] * Dequantize{}((p >> bits) & mask); + score += q[4 * j + 2] * Dequantize{}((p >> (2 * bits)) & mask); + score += q[4 * j + 3] * Dequantize{}((p >> (3 * bits)) & mask); } return score; } @@ -219,32 +212,23 @@ template const device uint32_t* values, U factor, U w_scale) { - if (bits == 4) { - auto vs = (const device uint16_t*)values; -#pragma clang loop unroll(full) - for (int j = 0; j < elem_per_thread / 4; j++) { - uint16_t p = vs[j]; - U v[] = { - Dequantize<4, U>{}(p & 0xF), - Dequantize<4, U>{}((p >> 4) & 0xF), - Dequantize<4, U>{}((p >> 8) & 0xF), - Dequantize<4, U>{}(p >> 12)}; -#pragma clang loop unroll(full) - for (int k = 0; k < 4; k++) - o[4 * j + k] = fma(o[4 * j + k], factor, v[k] * w_scale); - } - } else { // 8-bit - constexpr int pack_factor = 32 / bits; -#pragma clang loop unroll(full) - for (int j = 0; j < elem_per_thread / pack_factor; j++) { + using LoadT = typename conditional::type; + constexpr uint32_t mask = (1 << bits) - 1; + auto vs = (const device LoadT*)values; + #pragma clang loop unroll(full) - for (int k = 0; k < pack_factor; k++) { - o[pack_factor * j + k] = fma( - o[pack_factor * j + k], - factor, - Dequantize{}((values[j] >> (k * bits)) & 0xff) * w_scale); - } - } + for (int j = 0; j < elem_per_thread / 4; j++) { + LoadT p = vs[j]; + + U v0 = Dequantize{}(p & mask); + U v1 = Dequantize{}((p >> bits) & mask); + U v2 = Dequantize{}((p >> (2 * bits)) & mask); + U v3 = Dequantize{}((p >> (3 * bits)) & mask); + + o[4 * j + 0] = fma(o[4 * j + 0], factor, v0 * w_scale); + o[4 * j + 1] = fma(o[4 * j + 1], factor, v1 * w_scale); + o[4 * j + 2] = fma(o[4 * j + 2], factor, v2 * w_scale); + o[4 * j + 3] = fma(o[4 * j + 3], factor, v3 * w_scale); } } From 5c81b192a4f17e069487edd183060689704f66f0 Mon Sep 17 00:00:00 2001 From: Dan Yeh Date: Thu, 22 Jan 2026 16:06:03 +0100 Subject: [PATCH 04/21] supports nvpf4 --- mlx/backend/metal/kernels/quantized_utils.h | 134 ++++++++++++++++++ .../scaled_dot_product_attention.metal | 27 ++-- mlx/backend/metal/kernels/sdpa_vector.h | 129 +++++++++++++---- .../metal/scaled_dot_product_attention.cpp | 24 +++- mlx/fast.cpp | 4 - python/tests/test_quantized.py | 47 +++--- 6 files changed, 292 insertions(+), 73 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized_utils.h b/mlx/backend/metal/kernels/quantized_utils.h index 38253f8fe9..fcb3246467 100644 --- a/mlx/backend/metal/kernels/quantized_utils.h +++ b/mlx/backend/metal/kernels/quantized_utils.h @@ -1,8 +1,142 @@ // Copyright © 2023-2024 Apple Inc. +#pragma once + #include #include +#include "mlx/backend/metal/kernels/fp4.h" +#include "mlx/backend/metal/kernels/fp8.h" + +enum class QuantMode { Affine, Mxfp4, Mxfp8, Nvfp4 }; + +template +struct QuantTraits; + +// Affine quantization: scale * val + bias +template <> +struct QuantTraits { + static constant constexpr int default_group_size = 64; + static constant constexpr int default_bits = 4; + static constant constexpr int group_size = default_group_size; + static constant constexpr int bits = default_bits; + static constant constexpr bool has_bias = true; + + template + static inline T dequantize_scale(T s) { + return s; + } + + template + static inline T dequantize_value(uint8_t v, T scale, T bias) { + return fma(scale, T(v), bias); + } + + template + static inline T dequantize(uint8_t v, T scale, T bias) { + return fma(scale, T(v), bias); + } +}; + +// MXFP4: fp4_e2m1 data, fp8_e8m0 scale (power-of-2), group_size=32 +template <> +struct QuantTraits { + static constant constexpr int group_size = 32; + static constant constexpr int bits = 4; + static constant constexpr bool has_bias = false; + + template + static inline T dequantize_scale(uint8_t s) { + return T(*(thread fp8_e8m0*)(&s)); + } + + template + static inline T dequantize_value(uint8_t v) { + return T(*(thread fp4_e2m1*)(&v)); + } + + template + static inline T dequantize(uint8_t v, T scale, T /*bias*/) { + return scale * dequantize_value(v); + } +}; + +// NVFP4: fp4_e2m1 data, fp8_e4m3 scale (with mantissa), group_size=16 +template <> +struct QuantTraits { + static constant constexpr int group_size = 16; + static constant constexpr int bits = 4; + static constant constexpr bool has_bias = false; + + template + static inline T dequantize_scale(uint8_t s) { + return T(*(thread fp8_e4m3*)(&s)); + } + + template + static inline T dequantize_value(uint8_t v) { + return T(*(thread fp4_e2m1*)(&v)); + } + + template + static inline T dequantize(uint8_t v, T scale, T /*bias*/) { + return scale * dequantize_value(v); + } +}; + +// MXFP8: fp8_e4m3 data, fp8_e8m0 scale, group_size=32 +template <> +struct QuantTraits { + static constant constexpr int group_size = 32; + static constant constexpr int bits = 8; + static constant constexpr bool has_bias = false; + + template + static inline T dequantize_scale(uint8_t s) { + return T(*(thread fp8_e8m0*)(&s)); + } + + template + static inline T dequantize_value(uint8_t v) { + return T(*(thread fp8_e4m3*)(&v)); + } + + template + static inline T dequantize(uint8_t v, T scale, T /*bias*/) { + return scale * dequantize_value(v); + } +}; + +// Compile-time LoadType selector by bit-width +template +struct LoadType { + using type = uint32_t; +}; + +template <> +struct LoadType<4> { + using type = uint16_t; +}; + +// Helpers to fetch mode-specific defaults (affine uses default_* values) +template +constexpr int get_group_size() { + if constexpr (mode == QuantMode::Affine) { + return QuantTraits::default_group_size; + } else { + return QuantTraits::group_size; + } +} + +template +constexpr int get_bits() { + if constexpr (mode == QuantMode::Affine) { + return QuantTraits::default_bits; + } else { + return QuantTraits::bits; + } +} + template METAL_FUNC void gemm_loop_aligned( threadgroup T* As, diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index ab18d0286f..682febe132 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -43,22 +43,25 @@ instantiate_sdpa_vector_heads(bfloat16_t) instantiate_sdpa_vector_heads(float16_t) // Quantized SDPA vector instantiations -#define instantiate_quant_sdpa_vector(type, head_dim, group_size, bits) \ - instantiate_kernel( \ - "quant_sdpa_vector_2pass_1_" #type "_" #head_dim "_" #group_size "_" #bits, \ - quant_sdpa_vector_2pass_1, \ - type, \ - head_dim, \ - group_size, \ +// Uses QuantMode enum for explicit mode selection +#define instantiate_quant_sdpa_vector(type, head_dim, mode, group_size, bits) \ + instantiate_kernel( \ + "quant_sdpa_vector_2pass_1_" #type "_" #head_dim "_" #mode, \ + quant_sdpa_vector_2pass_1, \ + type, \ + head_dim, \ + QuantMode::mode, \ + group_size, \ bits) -#define instantiate_quant_sdpa_vector_group_size(type, heads) \ - instantiate_quant_sdpa_vector(type, heads, 32, 4) \ - instantiate_quant_sdpa_vector(type, heads, 32, 8) +#define instantiate_quant_sdpa_vector_all_modes(type, head_dim) \ + instantiate_quant_sdpa_vector(type, head_dim, Mxfp4, 32, 4) \ + instantiate_quant_sdpa_vector(type, head_dim, Nvfp4, 16, 4) \ + instantiate_quant_sdpa_vector(type, head_dim, Mxfp8, 32, 8) #define instantiate_quant_sdpa_vector_heads(type) \ - instantiate_quant_sdpa_vector_group_size(type, 64) \ - instantiate_quant_sdpa_vector_group_size(type, 128) + instantiate_quant_sdpa_vector_all_modes(type, 64) \ + instantiate_quant_sdpa_vector_all_modes(type, 128) instantiate_quant_sdpa_vector_heads(float) instantiate_quant_sdpa_vector_heads(bfloat16_t) diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 36c65528a8..5a6539f927 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -3,6 +3,7 @@ #include #include "mlx/backend/metal/kernels/fp_quantized.h" +#include "mlx/backend/metal/kernels/quantized_utils.h" using namespace metal; @@ -185,12 +186,24 @@ METAL_FUNC void load_queries(const device T* queries, thread U* q, U scale) { } } -template -[[gnu::always_inline]] METAL_FUNC U -dot_key(const thread U* q, const device uint32_t* keys) { - using LoadT = typename conditional::type; - +// Function constant for affine bias support +constant bool has_affine_bias [[function_constant(26)]]; + +// Unified dot product with keys across all QuantModes +template < + typename U, + int elem_per_thread, + QuantMode mode, + int bits = QuantTraits::bits> +[[gnu::always_inline]] METAL_FUNC U dot_key( + const thread U* q, + const device uint32_t* keys, + U scale, + U bias = U{0}) { + using Traits = QuantTraits; + using LoadT = typename LoadType::type; constexpr uint32_t mask = (1 << bits) - 1; + auto ks = (const device LoadT*)keys; U score = 0; @@ -198,41 +211,67 @@ dot_key(const thread U* q, const device uint32_t* keys) { for (int j = 0; j < elem_per_thread / 4; j++) { LoadT p = ks[j]; - score += q[4 * j + 0] * Dequantize{}(p & mask); - score += q[4 * j + 1] * Dequantize{}((p >> bits) & mask); - score += q[4 * j + 2] * Dequantize{}((p >> (2 * bits)) & mask); - score += q[4 * j + 3] * Dequantize{}((p >> (3 * bits)) & mask); + uint8_t v0 = uint8_t(p & mask); + uint8_t v1 = uint8_t((p >> bits) & mask); + uint8_t v2 = uint8_t((p >> (2 * bits)) & mask); + uint8_t v3 = uint8_t((p >> (3 * bits)) & mask); + + score += q[4 * j + 0] * Traits::template dequantize_value(v0); + score += q[4 * j + 1] * Traits::template dequantize_value(v1); + score += q[4 * j + 2] * Traits::template dequantize_value(v2); + score += q[4 * j + 3] * Traits::template dequantize_value(v3); } - return score; + + return score * scale; } -template +template < + typename U, + int elem_per_thread, + QuantMode mode, + int bits = QuantTraits::bits> [[gnu::always_inline]] METAL_FUNC void accumulate_values( thread U* o, const device uint32_t* values, U factor, - U w_scale) { - using LoadT = typename conditional::type; + U w_scale, + U bias = U{0}) { + using Traits = QuantTraits; + using LoadT = typename LoadType::type; constexpr uint32_t mask = (1 << bits) - 1; + auto vs = (const device LoadT*)values; #pragma clang loop unroll(full) for (int j = 0; j < elem_per_thread / 4; j++) { LoadT p = vs[j]; - U v0 = Dequantize{}(p & mask); - U v1 = Dequantize{}((p >> bits) & mask); - U v2 = Dequantize{}((p >> (2 * bits)) & mask); - U v3 = Dequantize{}((p >> (3 * bits)) & mask); + uint8_t v0 = uint8_t(p & mask); + uint8_t v1 = uint8_t((p >> bits) & mask); + uint8_t v2 = uint8_t((p >> (2 * bits)) & mask); + uint8_t v3 = uint8_t((p >> (3 * bits)) & mask); + + U dq0 = Traits::template dequantize(v0, w_scale, bias); + U dq1 = Traits::template dequantize(v1, w_scale, bias); + U dq2 = Traits::template dequantize(v2, w_scale, bias); + U dq3 = Traits::template dequantize(v3, w_scale, bias); - o[4 * j + 0] = fma(o[4 * j + 0], factor, v0 * w_scale); - o[4 * j + 1] = fma(o[4 * j + 1], factor, v1 * w_scale); - o[4 * j + 2] = fma(o[4 * j + 2], factor, v2 * w_scale); - o[4 * j + 3] = fma(o[4 * j + 3], factor, v3 * w_scale); + o[4 * j + 0] = fma(o[4 * j + 0], factor, dq0); + o[4 * j + 1] = fma(o[4 * j + 1], factor, dq1); + o[4 * j + 2] = fma(o[4 * j + 2], factor, dq2); + o[4 * j + 3] = fma(o[4 * j + 3], factor, dq3); } } -template +/////////////////////////////////////////////////////////////////////////////// +// Quantized SDPA kernel using QuantTraits +// +// This kernel supports all quantization modes (Mxfp4, Nvfp4, Mxfp8, Affine) +// through the QuantMode template parameter. For Affine mode, bias buffers +// are enabled via the has_affine_bias function constant. +/////////////////////////////////////////////////////////////////////////////// + +template [[kernel]] void quant_sdpa_vector_2pass_1( const device T* queries [[buffer(0)]], const device uint32_t* keys [[buffer(1)]], @@ -257,12 +296,18 @@ template [[buffer(18), function_constant(has_mask)]], const constant int& mask_head_stride [[buffer(19), function_constant(has_mask)]], + const device T* key_biases + [[buffer(20), function_constant(has_affine_bias)]], + const device T* value_biases + [[buffer(21), function_constant(has_affine_bias)]], uint3 tid [[threadgroup_position_in_grid]], uint3 tpg [[threadgroups_per_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]], uint quad_gid [[quadgroup_index_in_threadgroup]], uint quad_lid [[thread_index_in_quadgroup]]) { + using Traits = QuantTraits; + constexpr int BN = 16; constexpr int BD = 4; constexpr int elem_per_thread = D / BD; @@ -301,6 +346,11 @@ template values += kv_head_idx * v_stride + packed_idx; value_scales += v_group_idx; + if constexpr (Traits::has_bias) { + key_biases += k_group_idx; + value_biases += v_group_idx; + } + out += o_offset * blocks * D + block_idx * D + quad_lid * elem_per_thread; sums += o_offset * blocks + block_idx; maxs += o_offset * blocks + block_idx; @@ -332,9 +382,18 @@ template } if (use_key) { - // Compute attention score: dot(q, dequantize(k)) * scale - U key_scale = dequantize_scale(key_scales[0]); - U score = dot_key(q, keys) * key_scale; + U key_scale; + U key_bias = 0; + + if constexpr (Traits::has_bias) { + key_scale = U(((const device T*)key_scales)[0]); + key_bias = U(key_biases[0]); + } else { + key_scale = Traits::template dequantize_scale(key_scales[0]); + } + + U score = + dot_key(q, keys, key_scale, key_bias); score = quad_sum(score); if (float_mask) { @@ -349,16 +408,28 @@ template max_score = new_max; sum_exp_score = sum_exp_score * factor + exp_score; - U value_scale = dequantize_scale(value_scales[0]); - U weighted_val_scale = exp_score * value_scale; - accumulate_values( - o, values, factor, weighted_val_scale); + U value_scale; + U value_bias = 0; + + if constexpr (Traits::has_bias) { + value_scale = U(((const device T*)value_scales)[0]); + value_bias = U(value_biases[0]); + } else { + value_scale = Traits::template dequantize_scale(value_scales[0]); + } + + accumulate_values( + o, values, factor, exp_score * value_scale, exp_score * value_bias); } keys += blocks * stride / pack_factor; key_scales += blocks * stride / group_size; values += blocks * stride / pack_factor; value_scales += blocks * stride / group_size; + if constexpr (Traits::has_bias) { + key_biases += blocks * stride / group_size; + value_biases += blocks * stride / group_size; + } if (bool_mask) { bmask += BN * blocks * mask_kv_seq_stride; } diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index c8d3d9ceb9..1671d30472 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -583,6 +583,20 @@ void sdpa_vector_2pass( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } +std::string quant_mode_to_kernel_suffix(QuantizationMode mode) { + switch (mode) { + case QuantizationMode::Mxfp4: + return "Mxfp4"; + case QuantizationMode::Nvfp4: + return "Nvfp4"; + case QuantizationMode::Mxfp8: + return "Mxfp8"; + default: + throw std::invalid_argument( + "[quant_sdpa_vector_2pass] Unsupported quantization mode."); + } +} + void quant_sdpa_vector_2pass( const Stream& s, metal::Device& d, @@ -596,7 +610,8 @@ void quant_sdpa_vector_2pass( int group_size, int bits, bool do_causal, - const std::optional& mask) { + const std::optional& mask, + QuantizationMode mode) { std::string kname; kname.reserve(96); kname += "quant_sdpa_vector_2pass_1_"; @@ -604,9 +619,7 @@ void quant_sdpa_vector_2pass( kname += "_"; kname += std::to_string(q.shape(-1)); kname += "_"; - kname += std::to_string(group_size); - kname += "_"; - kname += std::to_string(bits); + kname += quant_mode_to_kernel_suffix(mode); int gqa_factor = q.shape(1) / k.shape(1); int N = k.shape(2); @@ -1002,7 +1015,8 @@ void QuantizedScaledDotProductAttention::eval_gpu( group_size_, bits_, /* do_causal = */ false, - mask); + mask, + mode_); d.add_temporaries(std::move(copies), s.index); } diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 4cedac55cb..ea5c211bb4 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -889,10 +889,6 @@ array quantized_scaled_dot_product_attention( throw std::invalid_argument( "[quantized_scaled_dot_product_attention] Only fp quantization modes are supported."); } - if (qmode == QuantizationMode::Nvfp4) { - throw std::invalid_argument( - "[quantized_scaled_dot_product_attention] Mode 'nvfp4' is not supported for fast attention."); - } auto expected_params = [](QuantizationMode mode) -> std::pair { switch (mode) { diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 87408a28a3..7536b72ab0 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -762,30 +762,31 @@ def test_quantized_sdpa(self): B, Hq, Hkv = 1, 2, 1 Lq, Lk, D = 4, 640, 128 - for mode in ["mxfp4", "mxfp8"]: - bits = 8 if mode == "mxfp8" else 4 - q = 0.1 * mx.random.normal(shape=(B, Hq, Lq, D)) - k = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) - v = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) - - k_q, k_scales = mx.quantize(k, mode=mode) - v_q, v_scales = mx.quantize(v, mode=mode) - - ref = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0) - out = mx.fast.quantized_scaled_dot_product_attention( - q, - k_q, - k_scales, - v_q, - v_scales, - scale=1.0, - mode=mode, - bits=bits, - ) + for mode in ["mxfp4", "mxfp8", "nvfp4"]: + with self.subTest(mode=mode): + bits = 8 if mode == "mxfp8" else 4 + q = 0.1 * mx.random.normal(shape=(B, Hq, Lq, D)) + k = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + v = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + + k_q, k_scales = mx.quantize(k, mode=mode) + v_q, v_scales = mx.quantize(v, mode=mode) + + ref = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0) + out = mx.fast.quantized_scaled_dot_product_attention( + q, + k_q, + k_scales, + v_q, + v_scales, + scale=1.0, + mode=mode, + bits=bits, + ) - self.assertEqual(out.shape, ref.shape) - tol = 5e-2 if bits == 4 else 2e-2 - self.assertLess((out - ref).abs().max(), tol) + self.assertEqual(out.shape, ref.shape) + tol = 5e-2 if bits == 4 else 2e-2 + self.assertLess((out - ref).abs().max(), tol) def test_gather_qmm(self): def quantize(w, transpose=True, group_size=None, bits=None, mode="affine"): From 22358d29caa4173769bc50680eee6cada0b0cd6d Mon Sep 17 00:00:00 2001 From: Dan Yeh Date: Sun, 25 Jan 2026 16:13:35 +0100 Subject: [PATCH 05/21] supports affine 4/8 bits --- mlx/backend/metal/kernels/quantized_utils.h | 15 ++ .../scaled_dot_product_attention.metal | 24 +- mlx/backend/metal/kernels/sdpa_vector.h | 43 +++- .../metal/scaled_dot_product_attention.cpp | 44 +++- mlx/fast.cpp | 228 ++++++++++++------ mlx/fast.h | 2 + python/src/fast.cpp | 59 ++++- python/tests/test_quantized.py | 41 ++++ 8 files changed, 360 insertions(+), 96 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized_utils.h b/mlx/backend/metal/kernels/quantized_utils.h index fcb3246467..7741161248 100644 --- a/mlx/backend/metal/kernels/quantized_utils.h +++ b/mlx/backend/metal/kernels/quantized_utils.h @@ -21,12 +21,21 @@ struct QuantTraits { static constant constexpr int group_size = default_group_size; static constant constexpr int bits = default_bits; static constant constexpr bool has_bias = true; + template + using scale_type = T; template static inline T dequantize_scale(T s) { return s; } + // Single-arg version returns raw value (for use in dot_key where + // dequantization is applied separately) + template + static inline T dequantize_value(uint8_t v) { + return T(v); + } + template static inline T dequantize_value(uint8_t v, T scale, T bias) { return fma(scale, T(v), bias); @@ -44,6 +53,8 @@ struct QuantTraits { static constant constexpr int group_size = 32; static constant constexpr int bits = 4; static constant constexpr bool has_bias = false; + template + using scale_type = uint8_t; template static inline T dequantize_scale(uint8_t s) { @@ -67,6 +78,8 @@ struct QuantTraits { static constant constexpr int group_size = 16; static constant constexpr int bits = 4; static constant constexpr bool has_bias = false; + template + using scale_type = uint8_t; template static inline T dequantize_scale(uint8_t s) { @@ -90,6 +103,8 @@ struct QuantTraits { static constant constexpr int group_size = 32; static constant constexpr int bits = 8; static constant constexpr bool has_bias = false; + template + using scale_type = uint8_t; template static inline T dequantize_scale(uint8_t s) { diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index 682febe132..5674d179c4 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -44,20 +44,26 @@ instantiate_sdpa_vector_heads(float16_t) // Quantized SDPA vector instantiations // Uses QuantMode enum for explicit mode selection -#define instantiate_quant_sdpa_vector(type, head_dim, mode, group_size, bits) \ - instantiate_kernel( \ - "quant_sdpa_vector_2pass_1_" #type "_" #head_dim "_" #mode, \ - quant_sdpa_vector_2pass_1, \ - type, \ - head_dim, \ - QuantMode::mode, \ - group_size, \ +#define instantiate_quant_sdpa_vector(type, head_dim, mode, group_size, bits) \ + instantiate_kernel( \ + "quant_sdpa_vector_2pass_1_" #type "_" #head_dim "_" #mode "_" #group_size "_" #bits, \ + quant_sdpa_vector_2pass_1, \ + type, \ + head_dim, \ + QuantMode::mode, \ + group_size, \ bits) #define instantiate_quant_sdpa_vector_all_modes(type, head_dim) \ instantiate_quant_sdpa_vector(type, head_dim, Mxfp4, 32, 4) \ instantiate_quant_sdpa_vector(type, head_dim, Nvfp4, 16, 4) \ - instantiate_quant_sdpa_vector(type, head_dim, Mxfp8, 32, 8) + instantiate_quant_sdpa_vector(type, head_dim, Mxfp8, 32, 8) \ + instantiate_quant_sdpa_vector(type, head_dim, Affine, 32, 4) \ + instantiate_quant_sdpa_vector(type, head_dim, Affine, 64, 4) \ + instantiate_quant_sdpa_vector(type, head_dim, Affine, 128, 4) \ + instantiate_quant_sdpa_vector(type, head_dim, Affine, 32, 8) \ + instantiate_quant_sdpa_vector(type, head_dim, Affine, 64, 8) \ + instantiate_quant_sdpa_vector(type, head_dim, Affine, 128, 8) #define instantiate_quant_sdpa_vector_heads(type) \ instantiate_quant_sdpa_vector_all_modes(type, 64) \ diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 5a6539f927..e2e552f17f 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -15,6 +15,9 @@ constant bool float_mask [[function_constant(24)]]; constant bool has_sinks [[function_constant(25)]]; constant int blocks [[function_constant(26)]]; +template +using ScaleTypeT = typename QuantTraits::template scale_type; + template [[kernel]] void sdpa_vector( const device T* queries [[buffer(0)]], @@ -216,13 +219,29 @@ template < uint8_t v2 = uint8_t((p >> (2 * bits)) & mask); uint8_t v3 = uint8_t((p >> (3 * bits)) & mask); - score += q[4 * j + 0] * Traits::template dequantize_value(v0); - score += q[4 * j + 1] * Traits::template dequantize_value(v1); - score += q[4 * j + 2] * Traits::template dequantize_value(v2); - score += q[4 * j + 3] * Traits::template dequantize_value(v3); + // For affine mode, apply full dequantization: scale * v + bias + // For other modes, dequantize_value returns the decoded value and we scale + // at the end + if constexpr (Traits::has_bias) { + score += q[4 * j + 0] * Traits::template dequantize(v0, scale, bias); + score += q[4 * j + 1] * Traits::template dequantize(v1, scale, bias); + score += q[4 * j + 2] * Traits::template dequantize(v2, scale, bias); + score += q[4 * j + 3] * Traits::template dequantize(v3, scale, bias); + } else { + score += q[4 * j + 0] * Traits::template dequantize_value(v0); + score += q[4 * j + 1] * Traits::template dequantize_value(v1); + score += q[4 * j + 2] * Traits::template dequantize_value(v2); + score += q[4 * j + 3] * Traits::template dequantize_value(v3); + } } - return score * scale; + // For non-affine modes, apply scale at the end + // For affine mode, dequantization is already applied inline + if constexpr (Traits::has_bias) { + return score; + } else { + return score * scale; + } } template < @@ -275,9 +294,9 @@ template [[kernel]] void quant_sdpa_vector_2pass_1( const device T* queries [[buffer(0)]], const device uint32_t* keys [[buffer(1)]], - const device uint8_t* key_scales [[buffer(2)]], + const device ScaleTypeT* key_scales [[buffer(2)]], const device uint32_t* values [[buffer(3)]], - const device uint8_t* value_scales [[buffer(4)]], + const device ScaleTypeT* value_scales [[buffer(4)]], device float* out [[buffer(5)]], device float* sums [[buffer(6)]], device float* maxs [[buffer(7)]], @@ -342,10 +361,10 @@ template const int v_group_idx = kv_head_idx * v_group_stride + kv_idx / group_size; keys += kv_head_idx * k_stride + packed_idx; - key_scales += k_group_idx; values += kv_head_idx * v_stride + packed_idx; - value_scales += v_group_idx; + key_scales += k_group_idx; + value_scales += v_group_idx; if constexpr (Traits::has_bias) { key_biases += k_group_idx; value_biases += v_group_idx; @@ -386,7 +405,7 @@ template U key_bias = 0; if constexpr (Traits::has_bias) { - key_scale = U(((const device T*)key_scales)[0]); + key_scale = U(key_scales[0]); key_bias = U(key_biases[0]); } else { key_scale = Traits::template dequantize_scale(key_scales[0]); @@ -412,7 +431,7 @@ template U value_bias = 0; if constexpr (Traits::has_bias) { - value_scale = U(((const device T*)value_scales)[0]); + value_scale = U(value_scales[0]); value_bias = U(value_biases[0]); } else { value_scale = Traits::template dequantize_scale(value_scales[0]); @@ -423,8 +442,8 @@ template } keys += blocks * stride / pack_factor; - key_scales += blocks * stride / group_size; values += blocks * stride / pack_factor; + key_scales += blocks * stride / group_size; value_scales += blocks * stride / group_size; if constexpr (Traits::has_bias) { key_biases += blocks * stride / group_size; diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 1671d30472..32ee6be1a9 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -585,6 +585,8 @@ void sdpa_vector_2pass( std::string quant_mode_to_kernel_suffix(QuantizationMode mode) { switch (mode) { + case QuantizationMode::Affine: + return "Affine"; case QuantizationMode::Mxfp4: return "Mxfp4"; case QuantizationMode::Nvfp4: @@ -603,8 +605,10 @@ void quant_sdpa_vector_2pass( const array& q, const array& k, const array& k_scales, + const std::optional& k_biases, const array& v, const array& v_scales, + const std::optional& v_biases, array& out, float scale, int group_size, @@ -620,6 +624,10 @@ void quant_sdpa_vector_2pass( kname += std::to_string(q.shape(-1)); kname += "_"; kname += quant_mode_to_kernel_suffix(mode); + kname += "_"; + kname += std::to_string(group_size); + kname += "_"; + kname += std::to_string(bits); int gqa_factor = q.shape(1) / k.shape(1); int N = k.shape(2); @@ -658,6 +666,7 @@ void quant_sdpa_vector_2pass( bool float_mask = has_mask && !bool_mask; bool query_transposed = !q.flags().row_contiguous; bool has_sinks = false; + bool has_affine_bias = mode == QuantizationMode::Affine; metal::MTLFCList func_consts = { {&has_mask, MTL::DataType::DataTypeBool, 20}, {&query_transposed, MTL::DataType::DataTypeBool, 21}, @@ -665,11 +674,13 @@ void quant_sdpa_vector_2pass( {&bool_mask, MTL::DataType::DataTypeBool, 23}, {&float_mask, MTL::DataType::DataTypeBool, 24}, {&has_sinks, MTL::DataType::DataTypeBool, 25}, + {&has_affine_bias, MTL::DataType::DataTypeBool, 26}, }; std::string hash_name = kname; hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask"; hash_name += query_transposed ? "_qt" : "_qnt"; hash_name += do_causal ? "_c" : "_nc"; + hash_name += has_affine_bias ? "_affine" : "_noaffine"; auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname, hash_name, func_consts); @@ -703,6 +714,11 @@ void quant_sdpa_vector_2pass( compute_encoder.set_bytes(head_stride, 19); } + if (has_affine_bias) { + compute_encoder.set_input_array(*k_biases, 20); + compute_encoder.set_input_array(*v_biases, 21); + } + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); kname.clear(); @@ -932,11 +948,25 @@ void QuantizedScaledDotProductAttention::eval_gpu( auto& s = stream(); auto& d = metal::device(s.device); + bool is_affine = mode_ == QuantizationMode::Affine; + + // Inputs layout: + // [q, k, k_scales, k_biases (if affine), v, v_scales, v_biases (if affine), + // mask (if present)] auto& q_pre = inputs[0]; auto& k_pre = inputs[1]; auto& k_scales_pre = inputs[2]; - auto& v_pre = inputs[3]; - auto& v_scales_pre = inputs[4]; + int idx = 3; + const array* k_biases_pre = nullptr; + if (is_affine) { + k_biases_pre = &inputs[idx++]; + } + auto& v_pre = inputs[idx++]; + auto& v_scales_pre = inputs[idx++]; + const array* v_biases_pre = nullptr; + if (is_affine) { + v_biases_pre = &inputs[idx++]; + } auto& o = outputs[0]; std::vector copies; @@ -982,8 +1012,16 @@ void QuantizedScaledDotProductAttention::eval_gpu( const auto& q = copy_unless(q_copy_unless, q_pre); const auto& k = copy_unless(kv_copy_unless, k_pre); const auto& k_scales = copy_unless(kv_copy_unless, k_scales_pre); + std::optional k_biases = std::nullopt; + if (is_affine) { + k_biases = copy_unless(kv_copy_unless, *k_biases_pre); + } const auto& v = copy_unless(kv_copy_unless, v_pre); const auto& v_scales = copy_unless(kv_copy_unless, v_scales_pre); + std::optional v_biases = std::nullopt; + if (is_affine) { + v_biases = copy_unless(kv_copy_unless, *v_biases_pre); + } std::optional mask = std::nullopt; if (needs_mask_) { @@ -1008,8 +1046,10 @@ void QuantizedScaledDotProductAttention::eval_gpu( q, k, k_scales, + k_biases, v, v_scales, + v_biases, o, scale_, group_size_, diff --git a/mlx/fast.cpp b/mlx/fast.cpp index ea5c211bb4..97c95b9466 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -866,8 +866,10 @@ array quantized_scaled_dot_product_attention( const array& queries, const array& keys, const array& key_scales, + const std::optional& key_biases, const array& values, const array& value_scales, + const std::optional& value_biases, const float scale, const std::optional& mask /* = std::nullopt */, const std::optional group_size_ /* = std::nullopt */, @@ -885,13 +887,30 @@ array quantized_scaled_dot_product_attention( auto qmode = string_to_quantization_mode( mode, "quantized_scaled_dot_product_attention"); - if (qmode == QuantizationMode::Affine) { - throw std::invalid_argument( - "[quantized_scaled_dot_product_attention] Only fp quantization modes are supported."); + + bool is_affine = qmode == QuantizationMode::Affine; + + // Validate biases for affine mode + if (is_affine) { + if (!key_biases.has_value() || !value_biases.has_value()) { + throw std::invalid_argument( + "[quantized_scaled_dot_product_attention] Affine mode requires key_biases and value_biases."); + } + if (key_biases->ndim() != 4 || value_biases->ndim() != 4) { + throw std::invalid_argument( + "[quantized_scaled_dot_product_attention] Biases must be rank 4."); + } + } else { + if (key_biases.has_value() || value_biases.has_value()) { + throw std::invalid_argument( + "[quantized_scaled_dot_product_attention] Biases should only be provided for affine mode."); + } } auto expected_params = [](QuantizationMode mode) -> std::pair { switch (mode) { + case QuantizationMode::Affine: + return {64, 4}; // default for affine case QuantizationMode::Mxfp4: return {32, 4}; case QuantizationMode::Mxfp8: @@ -906,7 +925,10 @@ array quantized_scaled_dot_product_attention( auto [expected_group_size, expected_bits] = expected_params(qmode); int group_size = group_size_.value_or(expected_group_size); int bits = bits_.value_or(expected_bits); - if (group_size != expected_group_size || bits != expected_bits) { + + // For affine mode, allow flexible group_size and bits + if (!is_affine && + (group_size != expected_group_size || bits != expected_bits)) { std::ostringstream msg; msg << "[quantized_scaled_dot_product_attention] Quantization mode '" << mode << "' requires group_size " << expected_group_size @@ -915,6 +937,22 @@ array quantized_scaled_dot_product_attention( throw std::invalid_argument(msg.str()); } + // Validate affine group_size and bits + if (is_affine) { + if (group_size != 32 && group_size != 64 && group_size != 128) { + std::ostringstream msg; + msg << "[quantized_scaled_dot_product_attention] Affine mode supports group_size 32, 64, or 128 but received " + << group_size << "."; + throw std::invalid_argument(msg.str()); + } + if (bits < 2 || bits > 8 || bits == 7) { + std::ostringstream msg; + msg << "[quantized_scaled_dot_product_attention] Affine mode supports bits 2-6 or 8 but received " + << bits << "."; + throw std::invalid_argument(msg.str()); + } + } + if (bits != 4 && bits != 8) { std::ostringstream msg; msg << "[quantized_scaled_dot_product_attention] Unsupported bits " << bits @@ -922,7 +960,10 @@ array quantized_scaled_dot_product_attention( throw std::invalid_argument(msg.str()); } - if (key_scales.dtype() != uint8 || value_scales.dtype() != uint8) { + // For affine mode, scales are stored as the query type + // (float16/bfloat16/float32) For fp modes, scales are uint8 + if (!is_affine && + (key_scales.dtype() != uint8 || value_scales.dtype() != uint8)) { throw std::invalid_argument( "[quantized_scaled_dot_product_attention] Scales must be uint8 for fp quantization."); } @@ -1000,6 +1041,24 @@ array quantized_scaled_dot_product_attention( "[quantized_scaled_dot_product_attention] Scale shapes must match key/value batch, head, and sequence dimensions."); } + // Validate biases have same shape as scales + if (is_affine) { + if (key_biases->shape() != key_scales.shape()) { + std::ostringstream msg; + msg << "[quantized_scaled_dot_product_attention] key_biases shape " + << key_biases->shape() << " must match key_scales shape " + << key_scales.shape() << "."; + throw std::invalid_argument(msg.str()); + } + if (value_biases->shape() != value_scales.shape()) { + std::ostringstream msg; + msg << "[quantized_scaled_dot_product_attention] value_biases shape " + << value_biases->shape() << " must match value_scales shape " + << value_scales.shape() << "."; + throw std::invalid_argument(msg.str()); + } + } + auto final_type = queries.dtype(); if (!issubdtype(final_type, floating)) { std::ostringstream msg; @@ -1018,75 +1077,106 @@ array quantized_scaled_dot_product_attention( } auto q = astype(queries, final_type, s); - auto fallback = - [scale, n_q_heads, n_kv_heads, needs_mask, group_size, bits, mode, s]( - const std::vector& inputs) { - auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s); - int n_repeats = n_q_heads / n_kv_heads; - - auto k = inputs[1]; - auto k_scales = inputs[2]; - auto v = inputs[3]; - auto v_scales = inputs[4]; - - std::optional mask = - needs_mask ? std::optional{inputs[5]} : std::nullopt; - - if (n_repeats > 1) { - q = reshape( - q, {q.shape(0), n_kv_heads, n_repeats, q.shape(2), -1}, s); - k = expand_dims(k, 2, s); - k_scales = expand_dims(k_scales, 2, s); - v = expand_dims(v, 2, s); - v_scales = expand_dims(v_scales, 2, s); - } + // Inputs layout: + // [q, k, k_scales, k_biases (if affine), v, v_scales, v_biases (if affine), + // mask (if present)] + auto fallback = [scale, + n_q_heads, + n_kv_heads, + needs_mask, + is_affine, + group_size, + bits, + mode, + s](const std::vector& inputs) { + auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s); + int n_repeats = n_q_heads / n_kv_heads; - auto scores = quantized_matmul( - q, - k, - k_scales, - std::nullopt, - /*transpose=*/true, - group_size, - bits, - mode, - s); - if (mask) { - auto m = *mask; - if (n_repeats > 1 && m.ndim() >= 3) { - if (m.shape(-3) == 1) { - m = expand_dims(m, -3, s); - } else { - m = unflatten(m, -3, {n_kv_heads, n_repeats}, s); - } - } - if (m.dtype() == bool_) { - scores = where( - m, scores, array(finfo(scores.dtype()).min, scores.dtype()), s); - } else { - scores = add(scores, m, s); - } - } + auto k = inputs[1]; + auto k_scales = inputs[2]; + std::optional k_biases = std::nullopt; + int idx = 3; + if (is_affine) { + k_biases = inputs[idx++]; + } + auto v = inputs[idx++]; + auto v_scales = inputs[idx++]; + std::optional v_biases = std::nullopt; + if (is_affine) { + v_biases = inputs[idx++]; + } - scores = softmax(scores, std::vector{-1}, true, s); - auto out = quantized_matmul( - scores, - v, - v_scales, - std::nullopt, - /*transpose=*/false, - group_size, - bits, - mode, - s); - if (n_repeats > 1) { - out = reshape(out, {out.shape(0), n_q_heads, out.shape(2), -1}, s); + std::optional mask = + needs_mask ? std::optional{inputs[idx]} : std::nullopt; + + if (n_repeats > 1) { + q = reshape(q, {q.shape(0), n_kv_heads, n_repeats, q.shape(2), -1}, s); + k = expand_dims(k, 2, s); + k_scales = expand_dims(k_scales, 2, s); + if (k_biases) { + k_biases = expand_dims(*k_biases, 2, s); + } + v = expand_dims(v, 2, s); + v_scales = expand_dims(v_scales, 2, s); + if (v_biases) { + v_biases = expand_dims(*v_biases, 2, s); + } + } + + auto scores = quantized_matmul( + q, + k, + k_scales, + k_biases, + /*transpose=*/true, + group_size, + bits, + mode, + s); + if (mask) { + auto m = *mask; + if (n_repeats > 1 && m.ndim() >= 3) { + if (m.shape(-3) == 1) { + m = expand_dims(m, -3, s); + } else { + m = unflatten(m, -3, {n_kv_heads, n_repeats}, s); } - return std::vector{out}; - }; + } + if (m.dtype() == bool_) { + scores = where( + m, scores, array(finfo(scores.dtype()).min, scores.dtype()), s); + } else { + scores = add(scores, m, s); + } + } + + scores = softmax(scores, std::vector{-1}, true, s); + auto out = quantized_matmul( + scores, + v, + v_scales, + v_biases, + /*transpose=*/false, + group_size, + bits, + mode, + s); + if (n_repeats > 1) { + out = reshape(out, {out.shape(0), n_q_heads, out.shape(2), -1}, s); + } + return std::vector{out}; + }; auto stream = to_stream(s); - std::vector inputs = {q, keys, key_scales, values, value_scales}; + std::vector inputs = {q, keys, key_scales}; + if (is_affine) { + inputs.push_back(*key_biases); + } + inputs.push_back(values); + inputs.push_back(value_scales); + if (is_affine) { + inputs.push_back(*value_biases); + } if (needs_mask) { if (promote_types(mask->dtype(), final_type) != final_type && mask->dtype() != bool_) { diff --git a/mlx/fast.h b/mlx/fast.h index ffa1a034f6..b1bc1e4ea8 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -59,8 +59,10 @@ array quantized_scaled_dot_product_attention( const array& queries, const array& keys, const array& key_scales, + const std::optional& key_biases, const array& values, const array& value_scales, + const std::optional& value_biases, const float scale, const std::optional& mask = std::nullopt, std::optional group_size = std::nullopt, diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 7a42e6fe43..d30ee2cfd8 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -296,14 +296,63 @@ void init_fast(nb::module_& parent_module) { out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask="causal") )pbdoc"); + m.def( + "quantized_scaled_dot_product_attention", + [](const mx::array& q, + const mx::array& k, + const mx::array& k_scales, + const mx::array& v, + const mx::array& v_scales, + const std::optional& k_biases, + const std::optional& v_biases, + const float scale, + const std::optional& mask, + std::optional group_size, + std::optional bits, + const std::string& mode, + mx::StreamOrDevice s) { + return mx::fast::quantized_scaled_dot_product_attention( + q, + k, + k_scales, + k_biases, + v, + v_scales, + v_biases, + scale, + mask, + group_size, + bits, + mode, + s); + }, + "q"_a, + "k"_a, + "k_scales"_a, + "v"_a, + "v_scales"_a, + nb::kw_only(), + "k_biases"_a = nb::none(), + "v_biases"_a = nb::none(), + "scale"_a, + "mask"_a = nb::none(), + "group_size"_a = nb::none(), + "bits"_a = nb::none(), + "mode"_a = "mxfp4", + "stream"_a = nb::none(), + nb::sig( + "def quantized_scaled_dot_product_attention(q: array, k: array, k_scales: array, v: array, v_scales: array, *, k_biases: Optional[array] = None, v_biases: Optional[array] = None, scale: float, mask: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = \"mxfp4\", stream: Union[None, Stream, Device] = None) -> array")); + m.def( "quantized_scaled_dot_product_attention", &mx::fast::quantized_scaled_dot_product_attention, "q"_a, "k"_a, "k_scales"_a, + "k_biases"_a = nb::none(), "v"_a, "v_scales"_a, + "v_biases"_a = nb::none(), nb::kw_only(), "scale"_a, "mask"_a = nb::none(), @@ -312,7 +361,7 @@ void init_fast(nb::module_& parent_module) { "mode"_a = "mxfp4", "stream"_a = nb::none(), nb::sig( - "def quantized_scaled_dot_product_attention(q: array, k: array, k_scales: array, v: array, v_scales: array, *, scale: float, mask: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = \"mxfp4\", stream: Union[None, Stream, Device] = None) -> array"), + "def quantized_scaled_dot_product_attention(q: array, k: array, k_scales: array, k_biases: Optional[array] = None, v: array, v_scales: array, v_biases: Optional[array] = None, *, scale: float, mask: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = \"mxfp4\", stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( A fast implementation of multi-head attention where the keys and values are quantized. @@ -321,14 +370,16 @@ void init_fast(nb::module_& parent_module) { Args: q (array): Input query array. k (array): Input keys array. - k_scales (array): ``uint8`` scales for the fp-quantized keys array. + k_scales (array): Scales for the quantized keys array. + k_biases (array or None): Biases for the affine-quantized keys array. Required for affine mode, None for fp modes. v (array): Input values array. - v_scales (array): ``uint8`` scales for the fp-quantized values array. + v_scales (array): Scales for the quantized values array. + v_biases (array or None): Biases for the affine-quantized values array. Required for affine mode, None for fp modes. scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``) mask (array, optional): An additive or boolean mask to apply to the query-key scores. group_size (int, optional): The group size used in the KV quantization. Defaults follow the quantization ``mode``. bits (int, optional): The bits used in the KV quantization. Defaults follow the quantization ``mode``. - mode (str, optional): The fp quantization mode, ``"mxfp4"`` or ``"mxfp8"``. + mode (str, optional): The quantization mode: ``"mxfp4"``, ``"mxfp8"``, ``"nvfp4"``, or ``"affine"``. Returns: array: The output array. )pbdoc"); diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 7536b72ab0..691ebdde8f 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -788,6 +788,47 @@ def test_quantized_sdpa(self): tol = 5e-2 if bits == 4 else 2e-2 self.assertLess((out - ref).abs().max(), tol) + def test_quantized_sdpa_affine(self): + if mx.default_device() == mx.cpu: + self.skipTest("Quantized fast attention is only available on GPU.") + + mx.random.seed(0) + B, Hq, Hkv = 1, 2, 1 + Lq, Lk, D = 4, 640, 128 + + for group_size in [32, 64, 128]: + for bits in [4, 8]: + with self.subTest(group_size=group_size, bits=bits): + q = 0.1 * mx.random.normal(shape=(B, Hq, Lq, D)) + k = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + v = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + + k_q, k_scales, k_biases = mx.quantize( + k, group_size=group_size, bits=bits, mode="affine" + ) + v_q, v_scales, v_biases = mx.quantize( + v, group_size=group_size, bits=bits, mode="affine" + ) + + ref = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0) + out = mx.fast.quantized_scaled_dot_product_attention( + q, + k_q, + k_scales, + k_biases, + v_q, + v_scales, + v_biases, + scale=1.0, + mode="affine", + group_size=group_size, + bits=bits, + ) + + self.assertEqual(out.shape, ref.shape) + tol = 5e-2 if bits == 4 else 2e-2 + self.assertLess((out - ref).abs().max(), tol) + def test_gather_qmm(self): def quantize(w, transpose=True, group_size=None, bits=None, mode="affine"): if mode == "affine": From cda85af654f73b0e36a0915527e1f845f0d03c67 Mon Sep 17 00:00:00 2001 From: Dan Yeh Date: Sun, 25 Jan 2026 18:06:35 +0100 Subject: [PATCH 06/21] supports affine 2/3/5/6 bits --- mlx/backend/metal/kernels/quantized_utils.h | 123 ++++++++ .../scaled_dot_product_attention.metal | 17 +- mlx/backend/metal/kernels/sdpa_vector.h | 263 ++++++++++++------ mlx/fast.cpp | 11 +- python/tests/test_quantized.py | 13 +- 5 files changed, 334 insertions(+), 93 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized_utils.h b/mlx/backend/metal/kernels/quantized_utils.h index 7741161248..65649a4141 100644 --- a/mlx/backend/metal/kernels/quantized_utils.h +++ b/mlx/backend/metal/kernels/quantized_utils.h @@ -133,6 +133,129 @@ struct LoadType<4> { using type = uint16_t; }; +// Pack metadata and unpackers for arbitrary bit-widths (wsize fixed at 32 bits) +template +struct PackInfo { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "PackInfo only supports bits in {2,3,4,5,6,8}"); + + static constant constexpr int pack_factor = + (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : 32 / bits); + static constant constexpr int bytes_per_pack = + ((bits & (bits - 1)) == 0) ? 4 : (bits == 5 ? 5 : 3); +}; + +template +struct PackReader; + +template <> +struct PackReader<2> { + static constant constexpr int pack_factor = PackInfo<2>::pack_factor; + static constant constexpr int bytes_per_pack = PackInfo<2>::bytes_per_pack; + + [[gnu::always_inline]] static void load( + const device uint8_t* p, thread uint8_t (&out)[pack_factor]) { + uint32_t v = *(reinterpret_cast(p)); +#pragma clang loop unroll(full) + for (int i = 0; i < pack_factor; ++i) { + out[i] = (v >> (2 * i)) & 0x03; + } + } +}; + +template <> +struct PackReader<3> { + static constant constexpr int pack_factor = PackInfo<3>::pack_factor; + static constant constexpr int bytes_per_pack = PackInfo<3>::bytes_per_pack; + + [[gnu::always_inline]] static void load( + const device uint8_t* p, thread uint8_t (&out)[pack_factor]) { + uint8_t w0 = p[0]; + uint8_t w1 = p[1]; + uint8_t w2 = p[2]; + out[0] = w0 & 0x07; + out[1] = (w0 >> 3) & 0x07; + out[2] = ((w0 >> 6) | ((w1 & 0x01) << 2)) & 0x07; + out[3] = (w1 >> 1) & 0x07; + out[4] = (w1 >> 4) & 0x07; + out[5] = ((w1 >> 7) | ((w2 & 0x03) << 1)) & 0x07; + out[6] = (w2 >> 2) & 0x07; + out[7] = (w2 >> 5) & 0x07; + } +}; + +template <> +struct PackReader<4> { + static constant constexpr int pack_factor = PackInfo<4>::pack_factor; + static constant constexpr int bytes_per_pack = PackInfo<4>::bytes_per_pack; + + [[gnu::always_inline]] static void load( + const device uint8_t* p, thread uint8_t (&out)[pack_factor]) { + uint32_t v = *(reinterpret_cast(p)); +#pragma clang loop unroll(full) + for (int i = 0; i < pack_factor; ++i) { + out[i] = (v >> (4 * i)) & 0x0f; + } + } +}; + +template <> +struct PackReader<5> { + static constant constexpr int pack_factor = PackInfo<5>::pack_factor; + static constant constexpr int bytes_per_pack = PackInfo<5>::bytes_per_pack; + + [[gnu::always_inline]] static void load( + const device uint8_t* p, thread uint8_t (&out)[pack_factor]) { + uint8_t w0 = p[0]; + uint8_t w1 = p[1]; + uint8_t w2 = p[2]; + uint8_t w3 = p[3]; + uint8_t w4 = p[4]; + out[0] = w0 & 0x1f; + out[1] = ((w0 >> 5) | ((w1 & 0x03) << 3)) & 0x1f; + out[2] = (w1 >> 2) & 0x1f; + out[3] = ((w1 >> 7) | ((w2 & 0x0f) << 1)) & 0x1f; + out[4] = ((w2 >> 4) | ((w3 & 0x01) << 4)) & 0x1f; + out[5] = (w3 >> 1) & 0x1f; + out[6] = ((w3 >> 6) | ((w4 & 0x07) << 2)) & 0x1f; + out[7] = (w4 >> 3) & 0x1f; + } +}; + +template <> +struct PackReader<6> { + static constant constexpr int pack_factor = PackInfo<6>::pack_factor; + static constant constexpr int bytes_per_pack = PackInfo<6>::bytes_per_pack; + + [[gnu::always_inline]] static void load( + const device uint8_t* p, thread uint8_t (&out)[pack_factor]) { + uint8_t w0 = p[0]; + uint8_t w1 = p[1]; + uint8_t w2 = p[2]; + out[0] = w0 & 0x3f; + out[1] = ((w0 >> 6) | ((w1 & 0x0f) << 2)) & 0x3f; + out[2] = ((w1 >> 4) | ((w2 & 0x03) << 4)) & 0x3f; + out[3] = (w2 >> 2) & 0x3f; + } +}; + +template <> +struct PackReader<8> { + static constant constexpr int pack_factor = PackInfo<8>::pack_factor; + static constant constexpr int bytes_per_pack = PackInfo<8>::bytes_per_pack; + + [[gnu::always_inline]] static void load( + const device uint8_t* p, thread uint8_t (&out)[pack_factor]) { + uint32_t v = *(reinterpret_cast(p)); + out[0] = v & 0xff; + out[1] = (v >> 8) & 0xff; + out[2] = (v >> 16) & 0xff; + out[3] = (v >> 24) & 0xff; + } +}; + // Helpers to fetch mode-specific defaults (affine uses default_* values) template constexpr int get_group_size() { diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index 5674d179c4..23161fd8c9 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -54,16 +54,21 @@ instantiate_sdpa_vector_heads(float16_t) group_size, \ bits) +#define instantiate_quant_sdpa_vector_affine(type, head_dim, group_size) \ + instantiate_quant_sdpa_vector(type, head_dim, Affine, group_size, 2) \ + instantiate_quant_sdpa_vector(type, head_dim, Affine, group_size, 3) \ + instantiate_quant_sdpa_vector(type, head_dim, Affine, group_size, 4) \ + instantiate_quant_sdpa_vector(type, head_dim, Affine, group_size, 5) \ + instantiate_quant_sdpa_vector(type, head_dim, Affine, group_size, 6) \ + instantiate_quant_sdpa_vector(type, head_dim, Affine, group_size, 8) + #define instantiate_quant_sdpa_vector_all_modes(type, head_dim) \ instantiate_quant_sdpa_vector(type, head_dim, Mxfp4, 32, 4) \ instantiate_quant_sdpa_vector(type, head_dim, Nvfp4, 16, 4) \ instantiate_quant_sdpa_vector(type, head_dim, Mxfp8, 32, 8) \ - instantiate_quant_sdpa_vector(type, head_dim, Affine, 32, 4) \ - instantiate_quant_sdpa_vector(type, head_dim, Affine, 64, 4) \ - instantiate_quant_sdpa_vector(type, head_dim, Affine, 128, 4) \ - instantiate_quant_sdpa_vector(type, head_dim, Affine, 32, 8) \ - instantiate_quant_sdpa_vector(type, head_dim, Affine, 64, 8) \ - instantiate_quant_sdpa_vector(type, head_dim, Affine, 128, 8) + instantiate_quant_sdpa_vector_affine(type, head_dim, 32) \ + instantiate_quant_sdpa_vector_affine(type, head_dim, 64) \ + instantiate_quant_sdpa_vector_affine(type, head_dim, 128) #define instantiate_quant_sdpa_vector_heads(type) \ instantiate_quant_sdpa_vector_all_modes(type, 64) \ diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index e2e552f17f..7711e14082 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -192,95 +192,186 @@ METAL_FUNC void load_queries(const device T* queries, thread U* q, U scale) { // Function constant for affine bias support constant bool has_affine_bias [[function_constant(26)]]; -// Unified dot product with keys across all QuantModes -template < - typename U, - int elem_per_thread, - QuantMode mode, - int bits = QuantTraits::bits> -[[gnu::always_inline]] METAL_FUNC U dot_key( - const thread U* q, - const device uint32_t* keys, - U scale, - U bias = U{0}) { - using Traits = QuantTraits; - using LoadT = typename LoadType::type; - constexpr uint32_t mask = (1 << bits) - 1; +/////////////////////////////////////////////////////////////////////////////// +// Fused Quantized SDPA Operations +// +// SdpaQuantOps provides compile-time specialized fused dequant+compute +// operations for SDPA kernels. Uses optimal memory access patterns +// (uint32_t/uint16_t loads) and avoids intermediate arrays. +/////////////////////////////////////////////////////////////////////////////// + +template +struct SdpaQuantOps; - auto ks = (const device LoadT*)keys; - U score = 0; +// Optimized path for MXFP4: uint16_t loads, 4 elements at a time +// Uses Dequantize functor for consistent codegen with other kernels +template <> +struct SdpaQuantOps { + // Fused dot product: sum(q[i] * dequant(packed_k[i])) + template + [[gnu::always_inline]] static U + dot(const thread U* q, const device uint32_t* keys, U scale, U /*bias*/) { + auto ks = reinterpret_cast(keys); + U score = 0; #pragma clang loop unroll(full) - for (int j = 0; j < elem_per_thread / 4; j++) { - LoadT p = ks[j]; + for (int j = 0; j < elem_per_thread / 4; j++) { + uint16_t p = ks[j]; - uint8_t v0 = uint8_t(p & mask); - uint8_t v1 = uint8_t((p >> bits) & mask); - uint8_t v2 = uint8_t((p >> (2 * bits)) & mask); - uint8_t v3 = uint8_t((p >> (3 * bits)) & mask); + score += q[4 * j + 0] * Dequantize<4, U>{}(p & 0xf); + score += q[4 * j + 1] * Dequantize<4, U>{}((p >> 4) & 0xf); + score += q[4 * j + 2] * Dequantize<4, U>{}((p >> 8) & 0xf); + score += q[4 * j + 3] * Dequantize<4, U>{}((p >> 12) & 0xf); + } + return score * scale; + } - // For affine mode, apply full dequantization: scale * v + bias - // For other modes, dequantize_value returns the decoded value and we scale - // at the end - if constexpr (Traits::has_bias) { - score += q[4 * j + 0] * Traits::template dequantize(v0, scale, bias); - score += q[4 * j + 1] * Traits::template dequantize(v1, scale, bias); - score += q[4 * j + 2] * Traits::template dequantize(v2, scale, bias); - score += q[4 * j + 3] * Traits::template dequantize(v3, scale, bias); - } else { - score += q[4 * j + 0] * Traits::template dequantize_value(v0); - score += q[4 * j + 1] * Traits::template dequantize_value(v1); - score += q[4 * j + 2] * Traits::template dequantize_value(v2); - score += q[4 * j + 3] * Traits::template dequantize_value(v3); + // Fused accumulate: o[i] = o[i] * factor + dequant(packed_v[i]) * w_scale + template + [[gnu::always_inline]] static void accumulate( + thread U* o, + const device uint32_t* values, + U factor, + U w_scale, + U /*bias*/) { + auto vs = reinterpret_cast(values); + +#pragma clang loop unroll(full) + for (int j = 0; j < elem_per_thread / 4; j++) { + uint16_t p = vs[j]; + + U v0 = Dequantize<4, U>{}(p & 0xf); + U v1 = Dequantize<4, U>{}((p >> 4) & 0xf); + U v2 = Dequantize<4, U>{}((p >> 8) & 0xf); + U v3 = Dequantize<4, U>{}((p >> 12) & 0xf); + + o[4 * j + 0] = fma(o[4 * j + 0], factor, v0 * w_scale); + o[4 * j + 1] = fma(o[4 * j + 1], factor, v1 * w_scale); + o[4 * j + 2] = fma(o[4 * j + 2], factor, v2 * w_scale); + o[4 * j + 3] = fma(o[4 * j + 3], factor, v3 * w_scale); } } +}; + +// Optimized path for NVFP4: same as MXFP4 (both use fp4_e2m1) +template <> +struct SdpaQuantOps { + template + [[gnu::always_inline]] static U + dot(const thread U* q, const device uint32_t* keys, U scale, U bias) { + return SdpaQuantOps::template dot( + q, keys, scale, bias); + } + + template + [[gnu::always_inline]] static void accumulate( + thread U* o, + const device uint32_t* values, + U factor, + U w_scale, + U bias) { + SdpaQuantOps::template accumulate( + o, values, factor, w_scale, bias); + } +}; + +// Optimized path for MXFP8: uint32_t loads, 4 elements at a time +// Uses Dequantize functor for consistent codegen with other kernels +template <> +struct SdpaQuantOps { + template + [[gnu::always_inline]] static U + dot(const thread U* q, const device uint32_t* keys, U scale, U /*bias*/) { + U score = 0; - // For non-affine modes, apply scale at the end - // For affine mode, dequantization is already applied inline - if constexpr (Traits::has_bias) { - return score; - } else { +#pragma clang loop unroll(full) + for (int j = 0; j < elem_per_thread / 4; j++) { + uint32_t p = keys[j]; + + score += q[4 * j + 0] * Dequantize<8, U>{}(p & 0xff); + score += q[4 * j + 1] * Dequantize<8, U>{}((p >> 8) & 0xff); + score += q[4 * j + 2] * Dequantize<8, U>{}((p >> 16) & 0xff); + score += q[4 * j + 3] * Dequantize<8, U>{}((p >> 24) & 0xff); + } return score * scale; } -} -template < - typename U, - int elem_per_thread, - QuantMode mode, - int bits = QuantTraits::bits> -[[gnu::always_inline]] METAL_FUNC void accumulate_values( - thread U* o, - const device uint32_t* values, - U factor, - U w_scale, - U bias = U{0}) { - using Traits = QuantTraits; - using LoadT = typename LoadType::type; - constexpr uint32_t mask = (1 << bits) - 1; + template + [[gnu::always_inline]] static void accumulate( + thread U* o, + const device uint32_t* values, + U factor, + U w_scale, + U /*bias*/) { +#pragma clang loop unroll(full) + for (int j = 0; j < elem_per_thread / 4; j++) { + uint32_t p = values[j]; + + U v0 = Dequantize<8, U>{}(p & 0xff); + U v1 = Dequantize<8, U>{}((p >> 8) & 0xff); + U v2 = Dequantize<8, U>{}((p >> 16) & 0xff); + U v3 = Dequantize<8, U>{}((p >> 24) & 0xff); + + o[4 * j + 0] = fma(o[4 * j + 0], factor, v0 * w_scale); + o[4 * j + 1] = fma(o[4 * j + 1], factor, v1 * w_scale); + o[4 * j + 2] = fma(o[4 * j + 2], factor, v2 * w_scale); + o[4 * j + 3] = fma(o[4 * j + 3], factor, v3 * w_scale); + } + } +}; - auto vs = (const device LoadT*)values; +// Generic path for Affine quantization (supports all bit widths) +// Uses PackReader for non-power-of-2 bit widths (3, 5, 6) +template +struct SdpaQuantOps { + static constant constexpr int pack_factor = PackInfo::pack_factor; + static constant constexpr int bytes_per_pack = PackInfo::bytes_per_pack; + + template + [[gnu::always_inline]] static U + dot(const thread U* q, const device uint32_t* keys, U scale, U bias) { + auto ks = reinterpret_cast(keys); + thread uint8_t raw[pack_factor]; + U score = 0; #pragma clang loop unroll(full) - for (int j = 0; j < elem_per_thread / 4; j++) { - LoadT p = vs[j]; + for (int j = 0; j < elem_per_thread; j += pack_factor) { + PackReader::load(ks, raw); - uint8_t v0 = uint8_t(p & mask); - uint8_t v1 = uint8_t((p >> bits) & mask); - uint8_t v2 = uint8_t((p >> (2 * bits)) & mask); - uint8_t v3 = uint8_t((p >> (3 * bits)) & mask); +#pragma clang loop unroll(full) + for (int t = 0; t < pack_factor; ++t) { + score += q[j + t] * fma(scale, U(raw[t]), bias); + } - U dq0 = Traits::template dequantize(v0, w_scale, bias); - U dq1 = Traits::template dequantize(v1, w_scale, bias); - U dq2 = Traits::template dequantize(v2, w_scale, bias); - U dq3 = Traits::template dequantize(v3, w_scale, bias); + ks += bytes_per_pack; + } + return score; + } - o[4 * j + 0] = fma(o[4 * j + 0], factor, dq0); - o[4 * j + 1] = fma(o[4 * j + 1], factor, dq1); - o[4 * j + 2] = fma(o[4 * j + 2], factor, dq2); - o[4 * j + 3] = fma(o[4 * j + 3], factor, dq3); + template + [[gnu::always_inline]] static void accumulate( + thread U* o, + const device uint32_t* values, + U factor, + U w_scale, + U bias) { + auto vs = reinterpret_cast(values); + thread uint8_t raw[pack_factor]; + +#pragma clang loop unroll(full) + for (int j = 0; j < elem_per_thread; j += pack_factor) { + PackReader::load(vs, raw); + +#pragma clang loop unroll(full) + for (int t = 0; t < pack_factor; ++t) { + U dq = fma(w_scale, U(raw[t]), bias); + o[j + t] = fma(o[j + t], factor, dq); + } + + vs += bytes_per_pack; + } } -} +}; /////////////////////////////////////////////////////////////////////////////// // Quantized SDPA kernel using QuantTraits @@ -331,7 +422,8 @@ template constexpr int BD = 4; constexpr int elem_per_thread = D / BD; constexpr int blocks = 32; - constexpr int pack_factor = 32 / bits; + constexpr int pack_factor = PackInfo::pack_factor; + constexpr int bytes_per_pack = PackInfo::bytes_per_pack; const int stride = BN * D; @@ -360,8 +452,13 @@ template const int k_group_idx = kv_head_idx * k_group_stride + kv_idx / group_size; const int v_group_idx = kv_head_idx * v_group_stride + kv_idx / group_size; - keys += kv_head_idx * k_stride + packed_idx; - values += kv_head_idx * v_stride + packed_idx; + // Use uint32_t pointers for optimal memory bandwidth + // For 4-bit: pack_factor=8, so we advance by elem_per_thread/8 uint32_t words + // For 8-bit: pack_factor=4, so we advance by elem_per_thread/4 uint32_t words + auto key_ptr = + keys + kv_head_idx * k_stride + packed_idx * bytes_per_pack / 4; + auto value_ptr = + values + kv_head_idx * v_stride + packed_idx * bytes_per_pack / 4; key_scales += k_group_idx; value_scales += v_group_idx; @@ -411,8 +508,8 @@ template key_scale = Traits::template dequantize_scale(key_scales[0]); } - U score = - dot_key(q, keys, key_scale, key_bias); + U score = SdpaQuantOps::template dot( + q, key_ptr, key_scale, key_bias); score = quad_sum(score); if (float_mask) { @@ -437,12 +534,18 @@ template value_scale = Traits::template dequantize_scale(value_scales[0]); } - accumulate_values( - o, values, factor, exp_score * value_scale, exp_score * value_bias); + SdpaQuantOps::template accumulate( + o, + value_ptr, + factor, + exp_score * value_scale, + exp_score * value_bias); } - keys += blocks * stride / pack_factor; - values += blocks * stride / pack_factor; + // Advance pointers by blocks * BN * D elements + // For uint32_t*, advance by (blocks * stride * bits) / 32 + key_ptr += blocks * stride * bits / 32; + value_ptr += blocks * stride * bits / 32; key_scales += blocks * stride / group_size; value_scales += blocks * stride / group_size; if constexpr (Traits::has_bias) { diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 97c95b9466..06e5ad598c 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -953,7 +953,7 @@ array quantized_scaled_dot_product_attention( } } - if (bits != 4 && bits != 8) { + if (!is_affine && bits != 4 && bits != 8) { std::ostringstream msg; msg << "[quantized_scaled_dot_product_attention] Unsupported bits " << bits << ". Supported bits are 4 and 8."; @@ -972,7 +972,8 @@ array quantized_scaled_dot_product_attention( "[quantized_scaled_dot_product_attention] Keys and values must be packed quantized arrays of type uint32."); } - auto el_per_int = 32 / bits; + auto key_head_dim = (keys.shape(-1) * 32) / bits; + auto value_head_dim = (values.shape(-1) * 32) / bits; const size_t batch_dim = queries.shape(0); for (const auto& tensor : {keys, values, key_scales, value_scales}) { @@ -1001,14 +1002,14 @@ array quantized_scaled_dot_product_attention( throw std::invalid_argument(msg.str()); } - if (queries.shape(-1) != keys.shape(-1) * el_per_int) { + if (queries.shape(-1) != key_head_dim) { std::ostringstream msg; msg << "[quantized_scaled_dot_product_attention] query, keys expected to have matching last dimension; found query shape " << queries.shape() << " for keys shape " << keys.shape() << "."; throw std::invalid_argument(msg.str()); } - if (queries.shape(-1) != values.shape(-1) * el_per_int) { + if (queries.shape(-1) != value_head_dim) { std::ostringstream msg; msg << "[quantized_scaled_dot_product_attention] query, values expected to have matching last dimension; found query shape " << queries.shape() << " for values shape " << values.shape() << "."; @@ -1195,7 +1196,7 @@ array quantized_scaled_dot_product_attention( inputs.back() = broadcast_to(inputs.back(), mask_shape, stream); } - int out_dim = values.shape(-1) * el_per_int; + int out_dim = value_head_dim; Shape out_shape{ queries.shape(0), queries.shape(1), queries.shape(2), out_dim}; diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 691ebdde8f..c61707a01b 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -797,7 +797,7 @@ def test_quantized_sdpa_affine(self): Lq, Lk, D = 4, 640, 128 for group_size in [32, 64, 128]: - for bits in [4, 8]: + for bits in [2, 3, 4, 5, 6, 8]: with self.subTest(group_size=group_size, bits=bits): q = 0.1 * mx.random.normal(shape=(B, Hq, Lq, D)) k = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) @@ -826,7 +826,16 @@ def test_quantized_sdpa_affine(self): ) self.assertEqual(out.shape, ref.shape) - tol = 5e-2 if bits == 4 else 2e-2 + if bits <= 3: + tol = 3e-1 + elif bits == 5: + tol = 1.5e-1 + elif bits == 6: + tol = 1e-1 + elif bits == 4: + tol = 5e-2 + else: + tol = 2e-2 self.assertLess((out - ref).abs().max(), tol) def test_gather_qmm(self): From 4c796a29a825ac791b4fc28f4f0cd01acd4d5683 Mon Sep 17 00:00:00 2001 From: Dan Yeh Date: Thu, 29 Jan 2026 18:01:09 +0100 Subject: [PATCH 07/21] clean up --- mlx/backend/metal/kernels/quantized_utils.h | 299 +++++++------------- mlx/backend/metal/kernels/sdpa_vector.h | 281 +++++++++--------- mlx/fast.cpp | 226 ++++++--------- mlx/ops.cpp | 29 -- mlx/primitives.cpp | 28 ++ mlx/primitives.h | 7 + 6 files changed, 349 insertions(+), 521 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized_utils.h b/mlx/backend/metal/kernels/quantized_utils.h index 65649a4141..f517a68c3f 100644 --- a/mlx/backend/metal/kernels/quantized_utils.h +++ b/mlx/backend/metal/kernels/quantized_utils.h @@ -10,127 +10,92 @@ enum class QuantMode { Affine, Mxfp4, Mxfp8, Nvfp4 }; +template +struct DecodeValue { + [[clang::always_inline]] OutT operator()(uint8_t v) const { + return OutT(*(thread EncodedT*)(&v)); + } +}; + +// Specialization for Affine (plain integer cast) +template +struct DecodeValue { + [[clang::always_inline]] OutT operator()(uint8_t v) const { + return OutT(v); + } +}; + template -struct QuantTraits; +struct QuantConfig; -// Affine quantization: scale * val + bias template <> -struct QuantTraits { - static constant constexpr int default_group_size = 64; - static constant constexpr int default_bits = 4; - static constant constexpr int group_size = default_group_size; - static constant constexpr int bits = default_bits; +struct QuantConfig { static constant constexpr bool has_bias = true; - template - using scale_type = T; - template - static inline T dequantize_scale(T s) { - return s; - } + using value_type = void; + using scale_type = void; - // Single-arg version returns raw value (for use in dot_key where - // dequantization is applied separately) template - static inline T dequantize_value(uint8_t v) { - return T(v); - } - - template - static inline T dequantize_value(uint8_t v, T scale, T bias) { - return fma(scale, T(v), bias); - } - - template - static inline T dequantize(uint8_t v, T scale, T bias) { - return fma(scale, T(v), bias); - } + using scale_storage_t = T; }; -// MXFP4: fp4_e2m1 data, fp8_e8m0 scale (power-of-2), group_size=32 template <> -struct QuantTraits { - static constant constexpr int group_size = 32; - static constant constexpr int bits = 4; +struct QuantConfig { static constant constexpr bool has_bias = false; - template - using scale_type = uint8_t; - template - static inline T dequantize_scale(uint8_t s) { - return T(*(thread fp8_e8m0*)(&s)); - } + using value_type = fp4_e2m1; + using scale_type = fp8_e8m0; template - static inline T dequantize_value(uint8_t v) { - return T(*(thread fp4_e2m1*)(&v)); - } - - template - static inline T dequantize(uint8_t v, T scale, T /*bias*/) { - return scale * dequantize_value(v); - } + using scale_storage_t = uint8_t; }; -// NVFP4: fp4_e2m1 data, fp8_e4m3 scale (with mantissa), group_size=16 template <> -struct QuantTraits { - static constant constexpr int group_size = 16; - static constant constexpr int bits = 4; +struct QuantConfig { static constant constexpr bool has_bias = false; - template - using scale_type = uint8_t; - - template - static inline T dequantize_scale(uint8_t s) { - return T(*(thread fp8_e4m3*)(&s)); - } - template - static inline T dequantize_value(uint8_t v) { - return T(*(thread fp4_e2m1*)(&v)); - } + using value_type = fp4_e2m1; + using scale_type = fp8_e4m3; template - static inline T dequantize(uint8_t v, T scale, T /*bias*/) { - return scale * dequantize_value(v); - } + using scale_storage_t = uint8_t; }; -// MXFP8: fp8_e4m3 data, fp8_e8m0 scale, group_size=32 template <> -struct QuantTraits { - static constant constexpr int group_size = 32; - static constant constexpr int bits = 8; +struct QuantConfig { static constant constexpr bool has_bias = false; - template - using scale_type = uint8_t; - template - static inline T dequantize_scale(uint8_t s) { - return T(*(thread fp8_e8m0*)(&s)); - } + using value_type = fp8_e4m3; + using scale_type = fp8_e8m0; template - static inline T dequantize_value(uint8_t v) { - return T(*(thread fp8_e4m3*)(&v)); - } + using scale_storage_t = uint8_t; +}; - template - static inline T dequantize(uint8_t v, T scale, T /*bias*/) { - return scale * dequantize_value(v); +template +struct Dequant { + using Cfg = QuantConfig; + + [[clang::always_inline]] T raw(uint8_t v) const { + return DecodeValue{}(v); } -}; -// Compile-time LoadType selector by bit-width -template -struct LoadType { - using type = uint32_t; -}; + [[clang::always_inline]] T scale( + typename Cfg::template scale_storage_t s) const { + if constexpr (metal::is_same_v) { + return s; + } else { + return DecodeValue{}(s); + } + } -template <> -struct LoadType<4> { - using type = uint16_t; + [[clang::always_inline]] T operator()(uint8_t v, T s, T bias) const { + if constexpr (Cfg::has_bias) { + return fma(s, raw(v), bias); + } else { + return s * raw(v); + } + } }; // Pack metadata and unpackers for arbitrary bit-widths (wsize fixed at 32 bits) @@ -148,133 +113,73 @@ struct PackInfo { }; template -struct PackReader; - -template <> -struct PackReader<2> { - static constant constexpr int pack_factor = PackInfo<2>::pack_factor; - static constant constexpr int bytes_per_pack = PackInfo<2>::bytes_per_pack; +struct PackReader { + static constant constexpr int pack_factor = PackInfo::pack_factor; + static constant constexpr int bytes_per_pack = PackInfo::bytes_per_pack; + static constant constexpr uint64_t mask = (uint64_t(1) << bits) - 1; [[gnu::always_inline]] static void load( - const device uint8_t* p, thread uint8_t (&out)[pack_factor]) { - uint32_t v = *(reinterpret_cast(p)); + const device uint8_t* p, + thread uint8_t (&out)[pack_factor]) { + uint64_t packed = load_packed(p); #pragma clang loop unroll(full) for (int i = 0; i < pack_factor; ++i) { - out[i] = (v >> (2 * i)) & 0x03; + out[i] = static_cast((packed >> (bits * i)) & mask); } } -}; - -template <> -struct PackReader<3> { - static constant constexpr int pack_factor = PackInfo<3>::pack_factor; - static constant constexpr int bytes_per_pack = PackInfo<3>::bytes_per_pack; - - [[gnu::always_inline]] static void load( - const device uint8_t* p, thread uint8_t (&out)[pack_factor]) { - uint8_t w0 = p[0]; - uint8_t w1 = p[1]; - uint8_t w2 = p[2]; - out[0] = w0 & 0x07; - out[1] = (w0 >> 3) & 0x07; - out[2] = ((w0 >> 6) | ((w1 & 0x01) << 2)) & 0x07; - out[3] = (w1 >> 1) & 0x07; - out[4] = (w1 >> 4) & 0x07; - out[5] = ((w1 >> 7) | ((w2 & 0x03) << 1)) & 0x07; - out[6] = (w2 >> 2) & 0x07; - out[7] = (w2 >> 5) & 0x07; - } -}; -template <> -struct PackReader<4> { - static constant constexpr int pack_factor = PackInfo<4>::pack_factor; - static constant constexpr int bytes_per_pack = PackInfo<4>::bytes_per_pack; - - [[gnu::always_inline]] static void load( - const device uint8_t* p, thread uint8_t (&out)[pack_factor]) { - uint32_t v = *(reinterpret_cast(p)); + private: + [[gnu::always_inline]] static uint64_t load_packed(const device uint8_t* p) { + if constexpr (bytes_per_pack == 4) { + return static_cast( + *(reinterpret_cast(p))); + } else { + uint64_t packed = 0; #pragma clang loop unroll(full) - for (int i = 0; i < pack_factor; ++i) { - out[i] = (v >> (4 * i)) & 0x0f; + for (int i = 0; i < bytes_per_pack; ++i) { + packed |= static_cast(p[i]) << (8 * i); + } + return packed; } } }; -template <> -struct PackReader<5> { - static constant constexpr int pack_factor = PackInfo<5>::pack_factor; - static constant constexpr int bytes_per_pack = PackInfo<5>::bytes_per_pack; - - [[gnu::always_inline]] static void load( - const device uint8_t* p, thread uint8_t (&out)[pack_factor]) { - uint8_t w0 = p[0]; - uint8_t w1 = p[1]; - uint8_t w2 = p[2]; - uint8_t w3 = p[3]; - uint8_t w4 = p[4]; - out[0] = w0 & 0x1f; - out[1] = ((w0 >> 5) | ((w1 & 0x03) << 3)) & 0x1f; - out[2] = (w1 >> 2) & 0x1f; - out[3] = ((w1 >> 7) | ((w2 & 0x0f) << 1)) & 0x1f; - out[4] = ((w2 >> 4) | ((w3 & 0x01) << 4)) & 0x1f; - out[5] = (w3 >> 1) & 0x1f; - out[6] = ((w3 >> 6) | ((w4 & 0x07) << 2)) & 0x1f; - out[7] = (w4 >> 3) & 0x1f; +// Pointer wrapper for quantized data that handles byte-level addressing +// correctly for all bit widths. For non-4-byte-aligned packs (3, 5, 6-bit), +// simple uint32_t pointer arithmetic truncates and causes misalignment. +// This class uses byte-level arithmetic internally to ensure correctness. +template +class QuantDataPtr { + const device uint8_t* byte_ptr_; + + public: + static constant constexpr int pack_factor = PackInfo::pack_factor; + static constant constexpr int bytes_per_pack = PackInfo::bytes_per_pack; + + // Initialize from base pointer, head stride (in uint32 units), head index, + // and element index + [[clang::always_inline]] QuantDataPtr( + const device uint32_t* base, + size_t head_stride, + int head_idx, + int elem_idx) { + int packed_idx = elem_idx / pack_factor; + byte_ptr_ = reinterpret_cast(base) + + head_idx * head_stride * 4 + // head_stride is in uint32 units + packed_idx * bytes_per_pack; } -}; -template <> -struct PackReader<6> { - static constant constexpr int pack_factor = PackInfo<6>::pack_factor; - static constant constexpr int bytes_per_pack = PackInfo<6>::bytes_per_pack; - - [[gnu::always_inline]] static void load( - const device uint8_t* p, thread uint8_t (&out)[pack_factor]) { - uint8_t w0 = p[0]; - uint8_t w1 = p[1]; - uint8_t w2 = p[2]; - out[0] = w0 & 0x3f; - out[1] = ((w0 >> 6) | ((w1 & 0x0f) << 2)) & 0x3f; - out[2] = ((w1 >> 4) | ((w2 & 0x03) << 4)) & 0x3f; - out[3] = (w2 >> 2) & 0x3f; + // Advance by number of elements + [[clang::always_inline]] void advance(int num_elements) { + byte_ptr_ += num_elements * bits / 8; } -}; - -template <> -struct PackReader<8> { - static constant constexpr int pack_factor = PackInfo<8>::pack_factor; - static constant constexpr int bytes_per_pack = PackInfo<8>::bytes_per_pack; - [[gnu::always_inline]] static void load( - const device uint8_t* p, thread uint8_t (&out)[pack_factor]) { - uint32_t v = *(reinterpret_cast(p)); - out[0] = v & 0xff; - out[1] = (v >> 8) & 0xff; - out[2] = (v >> 16) & 0xff; - out[3] = (v >> 24) & 0xff; + // Get raw pointer for passing to dot/accumulate functions + [[clang::always_inline]] const device uint32_t* ptr() const { + return reinterpret_cast(byte_ptr_); } }; -// Helpers to fetch mode-specific defaults (affine uses default_* values) -template -constexpr int get_group_size() { - if constexpr (mode == QuantMode::Affine) { - return QuantTraits::default_group_size; - } else { - return QuantTraits::group_size; - } -} - -template -constexpr int get_bits() { - if constexpr (mode == QuantMode::Affine) { - return QuantTraits::default_bits; - } else { - return QuantTraits::bits; - } -} - template METAL_FUNC void gemm_loop_aligned( threadgroup T* As, diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 7711e14082..607e7dc2c5 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -15,9 +15,6 @@ constant bool float_mask [[function_constant(24)]]; constant bool has_sinks [[function_constant(25)]]; constant int blocks [[function_constant(26)]]; -template -using ScaleTypeT = typename QuantTraits::template scale_type; - template [[kernel]] void sdpa_vector( const device T* queries [[buffer(0)]], @@ -189,197 +186,181 @@ METAL_FUNC void load_queries(const device T* queries, thread U* q, U scale) { } } -// Function constant for affine bias support constant bool has_affine_bias [[function_constant(26)]]; -/////////////////////////////////////////////////////////////////////////////// -// Fused Quantized SDPA Operations -// -// SdpaQuantOps provides compile-time specialized fused dequant+compute -// operations for SDPA kernels. Uses optimal memory access patterns -// (uint32_t/uint16_t loads) and avoids intermediate arrays. -/////////////////////////////////////////////////////////////////////////////// - template -struct SdpaQuantOps; +struct QuantFastOps { + using Cfg = QuantConfig; + using load_t = metal::conditional_t; + static_assert( + bits == 4 || bits == 8, + "QuantFastOps only supports 4/8-bit packing"); + static constant constexpr uint32_t mask = (1u << bits) - 1; -// Optimized path for MXFP4: uint16_t loads, 4 elements at a time -// Uses Dequantize functor for consistent codegen with other kernels -template <> -struct SdpaQuantOps { - // Fused dot product: sum(q[i] * dequant(packed_k[i])) template - [[gnu::always_inline]] static U - dot(const thread U* q, const device uint32_t* keys, U scale, U /*bias*/) { - auto ks = reinterpret_cast(keys); + [[clang::always_inline]] static U + dot(const thread U* q, const device uint32_t* keys, U scale, U bias) { + Dequant dequant; + auto ks = reinterpret_cast(keys); U score = 0; + [[maybe_unused]] U bias_acc = 0; #pragma clang loop unroll(full) for (int j = 0; j < elem_per_thread / 4; j++) { - uint16_t p = ks[j]; - - score += q[4 * j + 0] * Dequantize<4, U>{}(p & 0xf); - score += q[4 * j + 1] * Dequantize<4, U>{}((p >> 4) & 0xf); - score += q[4 * j + 2] * Dequantize<4, U>{}((p >> 8) & 0xf); - score += q[4 * j + 3] * Dequantize<4, U>{}((p >> 12) & 0xf); + load_t p = ks[j]; + U v0 = dequant.raw(p & mask); + U v1 = dequant.raw((p >> (bits * 1)) & mask); + U v2 = dequant.raw((p >> (bits * 2)) & mask); + U v3 = dequant.raw((p >> (bits * 3)) & mask); + + score += q[4 * j + 0] * v0; + score += q[4 * j + 1] * v1; + score += q[4 * j + 2] * v2; + score += q[4 * j + 3] * v3; + + if constexpr (Cfg::has_bias) { + bias_acc += (q[4 * j + 0] + q[4 * j + 1] + q[4 * j + 2] + q[4 * j + 3]); + } } - return score * scale; - } - - // Fused accumulate: o[i] = o[i] * factor + dequant(packed_v[i]) * w_scale - template - [[gnu::always_inline]] static void accumulate( - thread U* o, - const device uint32_t* values, - U factor, - U w_scale, - U /*bias*/) { - auto vs = reinterpret_cast(values); - -#pragma clang loop unroll(full) - for (int j = 0; j < elem_per_thread / 4; j++) { - uint16_t p = vs[j]; - - U v0 = Dequantize<4, U>{}(p & 0xf); - U v1 = Dequantize<4, U>{}((p >> 4) & 0xf); - U v2 = Dequantize<4, U>{}((p >> 8) & 0xf); - U v3 = Dequantize<4, U>{}((p >> 12) & 0xf); - o[4 * j + 0] = fma(o[4 * j + 0], factor, v0 * w_scale); - o[4 * j + 1] = fma(o[4 * j + 1], factor, v1 * w_scale); - o[4 * j + 2] = fma(o[4 * j + 2], factor, v2 * w_scale); - o[4 * j + 3] = fma(o[4 * j + 3], factor, v3 * w_scale); + if constexpr (Cfg::has_bias) { + return fma(scale, score, bias * bias_acc); + } else { + return scale * score; } } -}; - -// Optimized path for NVFP4: same as MXFP4 (both use fp4_e2m1) -template <> -struct SdpaQuantOps { - template - [[gnu::always_inline]] static U - dot(const thread U* q, const device uint32_t* keys, U scale, U bias) { - return SdpaQuantOps::template dot( - q, keys, scale, bias); - } template - [[gnu::always_inline]] static void accumulate( + [[clang::always_inline]] static void accumulate( thread U* o, const device uint32_t* values, U factor, U w_scale, U bias) { - SdpaQuantOps::template accumulate( - o, values, factor, w_scale, bias); - } -}; - -// Optimized path for MXFP8: uint32_t loads, 4 elements at a time -// Uses Dequantize functor for consistent codegen with other kernels -template <> -struct SdpaQuantOps { - template - [[gnu::always_inline]] static U - dot(const thread U* q, const device uint32_t* keys, U scale, U /*bias*/) { - U score = 0; - + Dequant dequant; + auto vs = reinterpret_cast(values); #pragma clang loop unroll(full) for (int j = 0; j < elem_per_thread / 4; j++) { - uint32_t p = keys[j]; - - score += q[4 * j + 0] * Dequantize<8, U>{}(p & 0xff); - score += q[4 * j + 1] * Dequantize<8, U>{}((p >> 8) & 0xff); - score += q[4 * j + 2] * Dequantize<8, U>{}((p >> 16) & 0xff); - score += q[4 * j + 3] * Dequantize<8, U>{}((p >> 24) & 0xff); - } - return score * scale; - } - - template - [[gnu::always_inline]] static void accumulate( - thread U* o, - const device uint32_t* values, - U factor, - U w_scale, - U /*bias*/) { -#pragma clang loop unroll(full) - for (int j = 0; j < elem_per_thread / 4; j++) { - uint32_t p = values[j]; - - U v0 = Dequantize<8, U>{}(p & 0xff); - U v1 = Dequantize<8, U>{}((p >> 8) & 0xff); - U v2 = Dequantize<8, U>{}((p >> 16) & 0xff); - U v3 = Dequantize<8, U>{}((p >> 24) & 0xff); - - o[4 * j + 0] = fma(o[4 * j + 0], factor, v0 * w_scale); - o[4 * j + 1] = fma(o[4 * j + 1], factor, v1 * w_scale); - o[4 * j + 2] = fma(o[4 * j + 2], factor, v2 * w_scale); - o[4 * j + 3] = fma(o[4 * j + 3], factor, v3 * w_scale); + load_t p = vs[j]; + U v0 = dequant.raw(p & mask); + U v1 = dequant.raw((p >> (bits * 1)) & mask); + U v2 = dequant.raw((p >> (bits * 2)) & mask); + U v3 = dequant.raw((p >> (bits * 3)) & mask); + + if constexpr (Cfg::has_bias) { + o[4 * j + 0] = fma(o[4 * j + 0], factor, fma(w_scale, v0, bias)); + o[4 * j + 1] = fma(o[4 * j + 1], factor, fma(w_scale, v1, bias)); + o[4 * j + 2] = fma(o[4 * j + 2], factor, fma(w_scale, v2, bias)); + o[4 * j + 3] = fma(o[4 * j + 3], factor, fma(w_scale, v3, bias)); + } else { + o[4 * j + 0] = fma(o[4 * j + 0], factor, v0 * w_scale); + o[4 * j + 1] = fma(o[4 * j + 1], factor, v1 * w_scale); + o[4 * j + 2] = fma(o[4 * j + 2], factor, v2 * w_scale); + o[4 * j + 3] = fma(o[4 * j + 3], factor, v3 * w_scale); + } } } }; -// Generic path for Affine quantization (supports all bit widths) -// Uses PackReader for non-power-of-2 bit widths (3, 5, 6) -template -struct SdpaQuantOps { +// (Generic Path for 2, 3, 5, 6 bits) +template +struct QuantOps { + using Cfg = QuantConfig; static constant constexpr int pack_factor = PackInfo::pack_factor; static constant constexpr int bytes_per_pack = PackInfo::bytes_per_pack; template - [[gnu::always_inline]] static U + [[clang::always_inline]] static U dot(const thread U* q, const device uint32_t* keys, U scale, U bias) { + Dequant dequant; auto ks = reinterpret_cast(keys); thread uint8_t raw[pack_factor]; U score = 0; + [[maybe_unused]] U bias_acc = 0; #pragma clang loop unroll(full) for (int j = 0; j < elem_per_thread; j += pack_factor) { PackReader::load(ks, raw); - #pragma clang loop unroll(full) for (int t = 0; t < pack_factor; ++t) { - score += q[j + t] * fma(scale, U(raw[t]), bias); + U decoded = dequant.raw(raw[t]); + score += q[j + t] * decoded; + if constexpr (Cfg::has_bias) + bias_acc += q[j + t]; } - ks += bytes_per_pack; } - return score; + return Cfg::has_bias ? fma(scale, score, bias * bias_acc) : scale * score; } template - [[gnu::always_inline]] static void accumulate( + [[clang::always_inline]] static void accumulate( thread U* o, const device uint32_t* values, U factor, U w_scale, U bias) { + Dequant dequant; auto vs = reinterpret_cast(values); thread uint8_t raw[pack_factor]; - #pragma clang loop unroll(full) for (int j = 0; j < elem_per_thread; j += pack_factor) { PackReader::load(vs, raw); - #pragma clang loop unroll(full) for (int t = 0; t < pack_factor; ++t) { - U dq = fma(w_scale, U(raw[t]), bias); - o[j + t] = fma(o[j + t], factor, dq); + U decoded = dequant.raw(raw[t]); + if constexpr (Cfg::has_bias) { + o[j + t] = fma(o[j + t], factor, fma(w_scale, decoded, bias)); + } else { + o[j + t] = fma(o[j + t], factor, decoded * w_scale); + } } - vs += bytes_per_pack; } } }; -/////////////////////////////////////////////////////////////////////////////// -// Quantized SDPA kernel using QuantTraits -// -// This kernel supports all quantization modes (Mxfp4, Nvfp4, Mxfp8, Affine) -// through the QuantMode template parameter. For Affine mode, bias buffers -// are enabled via the has_affine_bias function constant. -/////////////////////////////////////////////////////////////////////////////// +template +struct QuantOps { + template + [[clang::always_inline]] static U + dot(const thread U* q, const device uint32_t* keys, U scale, U bias) { + return QuantFastOps::template dot( + q, keys, scale, bias); + } + template + [[clang::always_inline]] static void accumulate( + thread U* o, + const device uint32_t* values, + U factor, + U w_scale, + U bias) { + QuantFastOps::template accumulate( + o, values, factor, w_scale, bias); + } +}; + +template +struct QuantOps { + template + [[clang::always_inline]] static U + dot(const thread U* q, const device uint32_t* keys, U scale, U bias) { + return QuantFastOps::template dot( + q, keys, scale, bias); + } + template + [[clang::always_inline]] static void accumulate( + thread U* o, + const device uint32_t* values, + U factor, + U w_scale, + U bias) { + QuantFastOps::template accumulate( + o, values, factor, w_scale, bias); + } +}; +template +using ScaleTypeT = typename QuantConfig::template scale_storage_t; template [[kernel]] void quant_sdpa_vector_2pass_1( @@ -416,18 +397,17 @@ template uint simd_lid [[thread_index_in_simdgroup]], uint quad_gid [[quadgroup_index_in_threadgroup]], uint quad_lid [[thread_index_in_quadgroup]]) { - using Traits = QuantTraits; + using Cfg = QuantConfig; constexpr int BN = 16; constexpr int BD = 4; constexpr int elem_per_thread = D / BD; constexpr int blocks = 32; - constexpr int pack_factor = PackInfo::pack_factor; - constexpr int bytes_per_pack = PackInfo::bytes_per_pack; const int stride = BN * D; typedef float U; + [[maybe_unused]] Dequant dequant; thread U q[elem_per_thread]; thread U o[elem_per_thread] = {0}; @@ -448,21 +428,15 @@ template const int kv_idx = (block_idx * BN + quad_gid) * D + quad_lid * elem_per_thread; - const int packed_idx = kv_idx / pack_factor; const int k_group_idx = kv_head_idx * k_group_stride + kv_idx / group_size; const int v_group_idx = kv_head_idx * v_group_stride + kv_idx / group_size; - // Use uint32_t pointers for optimal memory bandwidth - // For 4-bit: pack_factor=8, so we advance by elem_per_thread/8 uint32_t words - // For 8-bit: pack_factor=4, so we advance by elem_per_thread/4 uint32_t words - auto key_ptr = - keys + kv_head_idx * k_stride + packed_idx * bytes_per_pack / 4; - auto value_ptr = - values + kv_head_idx * v_stride + packed_idx * bytes_per_pack / 4; + QuantDataPtr key_ptr(keys, k_stride, kv_head_idx, kv_idx); + QuantDataPtr value_ptr(values, v_stride, kv_head_idx, kv_idx); key_scales += k_group_idx; value_scales += v_group_idx; - if constexpr (Traits::has_bias) { + if constexpr (Cfg::has_bias) { key_biases += k_group_idx; value_biases += v_group_idx; } @@ -501,15 +475,15 @@ template U key_scale; U key_bias = 0; - if constexpr (Traits::has_bias) { + if constexpr (Cfg::has_bias) { key_scale = U(key_scales[0]); key_bias = U(key_biases[0]); } else { - key_scale = Traits::template dequantize_scale(key_scales[0]); + key_scale = dequant.scale(key_scales[0]); } - U score = SdpaQuantOps::template dot( - q, key_ptr, key_scale, key_bias); + U score = QuantOps::template dot( + q, key_ptr.ptr(), key_scale, key_bias); score = quad_sum(score); if (float_mask) { @@ -527,28 +501,27 @@ template U value_scale; U value_bias = 0; - if constexpr (Traits::has_bias) { + if constexpr (Cfg::has_bias) { value_scale = U(value_scales[0]); value_bias = U(value_biases[0]); } else { - value_scale = Traits::template dequantize_scale(value_scales[0]); + value_scale = dequant.scale(value_scales[0]); } - SdpaQuantOps::template accumulate( + QuantOps::template accumulate( o, - value_ptr, + value_ptr.ptr(), factor, exp_score * value_scale, exp_score * value_bias); } // Advance pointers by blocks * BN * D elements - // For uint32_t*, advance by (blocks * stride * bits) / 32 - key_ptr += blocks * stride * bits / 32; - value_ptr += blocks * stride * bits / 32; + key_ptr.advance(blocks * stride); + value_ptr.advance(blocks * stride); key_scales += blocks * stride / group_size; value_scales += blocks * stride / group_size; - if constexpr (Traits::has_bias) { + if constexpr (Cfg::has_bias) { key_biases += blocks * stride / group_size; value_biases += blocks * stride / group_size; } diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 06e5ad598c..2c623f6583 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -876,204 +876,148 @@ array quantized_scaled_dot_product_attention( const std::optional bits_ /* = std::nullopt */, const std::string& mode /* = "mxfp4" */, StreamOrDevice s /* = {} */) { - for (const auto& tensor : {queries, keys, key_scales, values, value_scales}) { - if (tensor.ndim() != 4) { - std::ostringstream msg; - msg << "[quantized_scaled_dot_product_attention] input with shape " - << tensor.shape() << " expected to be rank 4"; - throw std::invalid_argument(msg.str()); - } - } - - auto qmode = string_to_quantization_mode( - mode, "quantized_scaled_dot_product_attention"); + constexpr const char* tag = "quantized_scaled_dot_product_attention"; + // Parse mode and get parameters + auto qmode = string_to_quantization_mode(mode, tag); bool is_affine = qmode == QuantizationMode::Affine; + auto [group_size, bits] = + quantization_params_from_mode(qmode, group_size_, bits_); - // Validate biases for affine mode + // Validate mode-specific group_size and bits if (is_affine) { - if (!key_biases.has_value() || !value_biases.has_value()) { - throw std::invalid_argument( - "[quantized_scaled_dot_product_attention] Affine mode requires key_biases and value_biases."); + if (group_size != 32 && group_size != 64 && group_size != 128) { + std::ostringstream msg; + msg << "[" << tag << "] Affine mode supports group_size 32, 64, or 128 " + << "but received " << group_size << "."; + throw std::invalid_argument(msg.str()); } - if (key_biases->ndim() != 4 || value_biases->ndim() != 4) { + if (bits < 2 || bits > 8 || bits == 7) { + std::ostringstream msg; + msg << "[" << tag << "] Affine mode supports bits 2-6 or 8 but received " + << bits << "."; + throw std::invalid_argument(msg.str()); + } + if (!key_biases.has_value() || !value_biases.has_value()) { throw std::invalid_argument( - "[quantized_scaled_dot_product_attention] Biases must be rank 4."); + "[quantized_scaled_dot_product_attention] Affine mode requires " + "key_biases and value_biases."); } } else { + // FP modes have fixed params - verify if user overrode them incorrectly + auto [expected_gs, expected_bits] = + quantization_params_from_mode(qmode, std::nullopt, std::nullopt); + if (group_size != expected_gs || bits != expected_bits) { + std::ostringstream msg; + msg << "[" << tag << "] Mode '" << mode << "' requires group_size " + << expected_gs << " and bits " << expected_bits << "."; + throw std::invalid_argument(msg.str()); + } if (key_biases.has_value() || value_biases.has_value()) { throw std::invalid_argument( - "[quantized_scaled_dot_product_attention] Biases should only be provided for affine mode."); + "[quantized_scaled_dot_product_attention] Biases should only be " + "provided for affine mode."); } } - auto expected_params = [](QuantizationMode mode) -> std::pair { - switch (mode) { - case QuantizationMode::Affine: - return {64, 4}; // default for affine - case QuantizationMode::Mxfp4: - return {32, 4}; - case QuantizationMode::Mxfp8: - return {32, 8}; - case QuantizationMode::Nvfp4: - return {16, 4}; - default: - return {0, 0}; - } - }; - - auto [expected_group_size, expected_bits] = expected_params(qmode); - int group_size = group_size_.value_or(expected_group_size); - int bits = bits_.value_or(expected_bits); - - // For affine mode, allow flexible group_size and bits - if (!is_affine && - (group_size != expected_group_size || bits != expected_bits)) { - std::ostringstream msg; - msg << "[quantized_scaled_dot_product_attention] Quantization mode '" - << mode << "' requires group_size " << expected_group_size - << " and bits " << expected_bits << " but received group_size " - << group_size << " and bits " << bits << "."; - throw std::invalid_argument(msg.str()); - } - - // Validate affine group_size and bits - if (is_affine) { - if (group_size != 32 && group_size != 64 && group_size != 128) { - std::ostringstream msg; - msg << "[quantized_scaled_dot_product_attention] Affine mode supports group_size 32, 64, or 128 but received " - << group_size << "."; - throw std::invalid_argument(msg.str()); - } - if (bits < 2 || bits > 8 || bits == 7) { + // Validate rank 4 for all inputs + for (const auto& t : {queries, keys, key_scales, values, value_scales}) { + if (t.ndim() != 4) { std::ostringstream msg; - msg << "[quantized_scaled_dot_product_attention] Affine mode supports bits 2-6 or 8 but received " - << bits << "."; + msg << "[" << tag << "] input with shape " << t.shape() + << " expected to be rank 4."; throw std::invalid_argument(msg.str()); } } + if (is_affine && (key_biases->ndim() != 4 || value_biases->ndim() != 4)) { + throw std::invalid_argument( + "[quantized_scaled_dot_product_attention] Biases must be rank 4."); + } - if (!is_affine && bits != 4 && bits != 8) { + // Validate dtypes + auto final_type = queries.dtype(); + if (!issubdtype(final_type, floating)) { std::ostringstream msg; - msg << "[quantized_scaled_dot_product_attention] Unsupported bits " << bits - << ". Supported bits are 4 and 8."; + msg << "[" << tag << "] queries must be floating type but got " + << final_type << "."; throw std::invalid_argument(msg.str()); } - - // For affine mode, scales are stored as the query type - // (float16/bfloat16/float32) For fp modes, scales are uint8 - if (!is_affine && - (key_scales.dtype() != uint8 || value_scales.dtype() != uint8)) { + if (keys.dtype() != uint32 || values.dtype() != uint32) { throw std::invalid_argument( - "[quantized_scaled_dot_product_attention] Scales must be uint8 for fp quantization."); + "[quantized_scaled_dot_product_attention] Keys and values must be " + "uint32."); } - if (keys.dtype() != uint32 || values.dtype() != uint32) { + if (!is_affine && + (key_scales.dtype() != uint8 || value_scales.dtype() != uint8)) { throw std::invalid_argument( - "[quantized_scaled_dot_product_attention] Keys and values must be packed quantized arrays of type uint32."); + "[quantized_scaled_dot_product_attention] Scales must be uint8 for fp " + "quantization."); } + // Compute and validate dimensions auto key_head_dim = (keys.shape(-1) * 32) / bits; auto value_head_dim = (values.shape(-1) * 32) / bits; - - const size_t batch_dim = queries.shape(0); - for (const auto& tensor : {keys, values, key_scales, value_scales}) { - if (tensor.shape(0) != batch_dim) { - std::ostringstream msg; - msg << "[quantized_scaled_dot_product_attention] mismatching batch dimension for input with shape " - << tensor.shape() << "."; - throw std::invalid_argument(msg.str()); - } - } - auto n_q_heads = queries.shape(-3); auto n_kv_heads = keys.shape(-3); + + if (queries.shape(0) != keys.shape(0) || + queries.shape(0) != values.shape(0)) { + throw std::invalid_argument( + "[quantized_scaled_dot_product_attention] Batch dimensions must match."); + } if (n_q_heads % n_kv_heads != 0) { std::ostringstream msg; - msg << "[quantized_scaled_dot_product_attention] n_heads must be a multiple of n_kv_heads, found n_heads " - << n_q_heads << " for n_kv_heads " << n_kv_heads << "."; + msg << "[" << tag << "] n_heads must be a multiple of n_kv_heads, found " + << n_q_heads << " vs " << n_kv_heads << "."; throw std::invalid_argument(msg.str()); } - if (keys.shape(-3) != values.shape(-3)) { - std::ostringstream msg; - msg << "[quantized_scaled_dot_product_attention] keys, values expected to have matching n_kv_heads; found keys with n_heads " - << keys.shape(-3) << " for values with n_heads " << values.shape(-3) - << "."; - throw std::invalid_argument(msg.str()); - } - - if (queries.shape(-1) != key_head_dim) { - std::ostringstream msg; - msg << "[quantized_scaled_dot_product_attention] query, keys expected to have matching last dimension; found query shape " - << queries.shape() << " for keys shape " << keys.shape() << "."; - throw std::invalid_argument(msg.str()); + throw std::invalid_argument( + "[quantized_scaled_dot_product_attention] Keys and values must have " + "matching n_kv_heads."); } - - if (queries.shape(-1) != value_head_dim) { + if (queries.shape(-1) != key_head_dim || + queries.shape(-1) != value_head_dim) { std::ostringstream msg; - msg << "[quantized_scaled_dot_product_attention] query, values expected to have matching last dimension; found query shape " - << queries.shape() << " for values shape " << values.shape() << "."; + msg << "[" << tag << "] Query head dim " << queries.shape(-1) + << " must match key (" << key_head_dim << ") and value (" + << value_head_dim << ")."; throw std::invalid_argument(msg.str()); } - if (queries.shape(-1) % group_size != 0) { std::ostringstream msg; - msg << "[quantized_scaled_dot_product_attention] head dimension " - << queries.shape(-1) << " must be divisible by group_size " - << group_size << "."; + msg << "[" << tag << "] Head dim " << queries.shape(-1) + << " must be divisible by group_size " << group_size << "."; throw std::invalid_argument(msg.str()); } + // Validate scale/bias shapes auto expected_scale_dim = queries.shape(-1) / group_size; - for (const auto& tensor : {key_scales, value_scales}) { - if (tensor.shape(-1) != expected_scale_dim) { - std::ostringstream msg; - msg << "[quantized_scaled_dot_product_attention] Scales expected to have " - << expected_scale_dim << " elements in the last dimension but found " - << tensor.shape(); - throw std::invalid_argument(msg.str()); - } - } - if (key_scales.shape(-3) != keys.shape(-3) || - key_scales.shape(-2) != keys.shape(-2) || - value_scales.shape(-3) != values.shape(-3) || - value_scales.shape(-2) != values.shape(-2)) { - throw std::invalid_argument( - "[quantized_scaled_dot_product_attention] Scale shapes must match key/value batch, head, and sequence dimensions."); - } - - // Validate biases have same shape as scales - if (is_affine) { - if (key_biases->shape() != key_scales.shape()) { + for (const auto& [qdata, sc, bias, name] : + {std::tuple{&keys, &key_scales, &key_biases, "key"}, + std::tuple{&values, &value_scales, &value_biases, "value"}}) { + if (sc->shape(-1) != expected_scale_dim || + sc->shape(-3) != qdata->shape(-3) || + sc->shape(-2) != qdata->shape(-2)) { std::ostringstream msg; - msg << "[quantized_scaled_dot_product_attention] key_biases shape " - << key_biases->shape() << " must match key_scales shape " - << key_scales.shape() << "."; + msg << "[" << tag << "] " << name << " scale shape mismatch."; throw std::invalid_argument(msg.str()); } - if (value_biases->shape() != value_scales.shape()) { + if (is_affine && bias->has_value() && (*bias)->shape() != sc->shape()) { std::ostringstream msg; - msg << "[quantized_scaled_dot_product_attention] value_biases shape " - << value_biases->shape() << " must match value_scales shape " - << value_scales.shape() << "."; + msg << "[" << tag << "] " << name + << " bias shape must match scale shape."; throw std::invalid_argument(msg.str()); } } - auto final_type = queries.dtype(); - if (!issubdtype(final_type, floating)) { - std::ostringstream msg; - msg << "[quantized_scaled_dot_product_attention] Received unsupported type " - << final_type << "."; - throw std::invalid_argument(msg.str()); - } - + // Validate mask bool needs_mask = mask.has_value(); bool has_bool_mask = needs_mask && mask->dtype() == bool_; if (needs_mask && mask->ndim() > 4) { std::ostringstream msg; - msg << "[quantized_scaled_dot_product_attention] the mask with shape " - << mask->shape() << " expected to have at most rank 4."; + msg << "[" << tag << "] Mask with shape " << mask->shape() + << " expected to have at most rank 4."; throw std::invalid_argument(msg.str()); } diff --git a/mlx/ops.cpp b/mlx/ops.cpp index defcc2f6e0..e4639923c5 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4365,35 +4365,6 @@ array conv_general( {in, wt}); } -std::pair quantization_params_from_mode( - QuantizationMode mode, - std::optional group_size_, - std::optional bits_) { - int default_group_size; - int default_bits; - switch (mode) { - case QuantizationMode::Affine: - default_group_size = 64; - default_bits = 4; - break; - case QuantizationMode::Nvfp4: - default_group_size = 16; - default_bits = 4; - break; - case QuantizationMode::Mxfp4: - default_group_size = 32; - default_bits = 4; - break; - case QuantizationMode::Mxfp8: - default_group_size = 32; - default_bits = 8; - break; - } - return { - group_size_.has_value() ? *group_size_ : default_group_size, - bits_.has_value() ? *bits_ : default_bits}; -} - std::pair validate_mode_with_type( std::string_view tag, const array& scales, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 8e209eeb26..46e5295ca0 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3450,6 +3450,34 @@ QuantizationMode string_to_quantization_mode( throw std::invalid_argument(msg); } +std::pair quantization_params_from_mode( + QuantizationMode mode, + std::optional group_size_, + std::optional bits_) { + int default_group_size; + int default_bits; + switch (mode) { + case QuantizationMode::Affine: + default_group_size = 64; + default_bits = 4; + break; + case QuantizationMode::Nvfp4: + default_group_size = 16; + default_bits = 4; + break; + case QuantizationMode::Mxfp4: + default_group_size = 32; + default_bits = 4; + break; + case QuantizationMode::Mxfp8: + default_group_size = 32; + default_bits = 8; + break; + } + return { + group_size_.value_or(default_group_size), bits_.value_or(default_bits)}; +} + std::pair, std::vector> QuantizedMatmul::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/mlx/primitives.h b/mlx/primitives.h index 75fb978dce..2b3a6e4719 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -159,6 +159,13 @@ QuantizationMode string_to_quantization_mode( const std::string& mode, std::string_view error_tag = ""); +// Returns (group_size, bits) for a given quantization mode. +// Uses provided values if given, otherwise returns mode defaults. +std::pair quantization_params_from_mode( + QuantizationMode mode, + std::optional group_size = std::nullopt, + std::optional bits = std::nullopt); + class Abs : public UnaryPrimitive { public: explicit Abs(Stream stream) : UnaryPrimitive(stream) {} From ea338b6e9043fac8b7657cf36bca182fc2b7f9f6 Mon Sep 17 00:00:00 2001 From: Dan Yeh Date: Thu, 29 Jan 2026 22:40:56 +0100 Subject: [PATCH 08/21] adapt #3023 --- benchmarks/python/sdpa_vector_bench.py | 223 ++++++++++++------ mlx/backend/metal/kernels/sdpa_vector.h | 89 ++++--- .../metal/scaled_dot_product_attention.cpp | 64 ++++- mlx/fast.h | 2 +- 4 files changed, 248 insertions(+), 130 deletions(-) diff --git a/benchmarks/python/sdpa_vector_bench.py b/benchmarks/python/sdpa_vector_bench.py index 836700fd03..693816e5d0 100644 --- a/benchmarks/python/sdpa_vector_bench.py +++ b/benchmarks/python/sdpa_vector_bench.py @@ -1,39 +1,23 @@ +import argparse +import time + import mlx.core as mx -from mlx.utils import tree_map -from time_utils import time_fn -L = 32768 -H = 32 -H_k = H // 4 -D = 128 -dtype = mx.float16 -bits = 4 -mode = "mxfp8" if bits == 8 else "mxfp4" -loops = 20 +def time_fn(fn, *args, warmup=5, iters=100, **kwargs): + """Time a function, return milliseconds per call.""" + for _ in range(warmup): + mx.eval(fn(*args, **kwargs)) + tic = time.perf_counter() + for _ in range(iters): + mx.eval(fn(*args, **kwargs)) + toc = time.perf_counter() -def attention(q, k, v): - for _ in range(loops): - B, Hq, Lq, Dq = q.shape - _, Hk, S, _ = k.shape - q = q.reshape(B, Hk, Hq // Hk, Lq, Dq) - ke = k[:, :, None, :, :] - ve = v[:, :, None, :, :] - scores = q @ ke.transpose(0, 1, 2, 4, 3) - probs = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) - q = probs @ ve - q = q.reshape(B, Hq, Lq, Dq) - return q + return 1e3 * (toc - tic) / iters -def sdpa(q, k, v): - for _ in range(loops): - q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None) - return q - - -def quant_sdpa(q, k, v, bits=4, mode="mxfp4"): +def quant_sdpa(q, k, v, bits, mode, loops=1): for _ in range(loops): q = mx.fast.quantized_scaled_dot_product_attention( q, *k, *v, scale=1.0, mask=None, bits=bits, mode=mode @@ -41,51 +25,150 @@ def quant_sdpa(q, k, v, bits=4, mode="mxfp4"): return q -def quant_attention(q, k, v, bits=4, mode="mxfp4"): +def sdpa(q, k, v, loops=1): for _ in range(loops): - B, Hq, Lq, Dq = q.shape - Hk = k[0].shape[1] - - q = q.reshape((B, Hk, Hq // Hk, Lq, Dq)) - ke = tree_map(lambda x: mx.expand_dims(x, axis=2), k) - ve = tree_map(lambda x: mx.expand_dims(x, axis=2), v) - - scores = mx.quantized_matmul(q, *ke, transpose=True, bits=bits, mode=mode) - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) - - q = mx.quantized_matmul(scores, *ve, transpose=False, bits=bits, mode=mode) - q = q.reshape((B, Hq, Lq, Dq)) + q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None) return q -def time_self_attention_primitives(q, k, v): - time_fn(attention, q, k, v) - - -def time_self_attention_sdpa(q, k, v): - time_fn(sdpa, q, k, v) - - -def time_self_attention_quant_sdpa(q, k, v, bits, mode): - time_fn(quant_sdpa, q, k, v, bits, mode) - - -def time_self_attention_quant_primitives(q, k, v, bits, mode): - time_fn(quant_attention, q, k, v, bits, mode) +def run_benchmark( + seq_lengths, + modes, + H=32, + H_k=8, + D=128, + dtype=mx.float16, + loops=20, + warmup=5, + iters=100, +): + """Run benchmarks across sequence lengths and modes.""" + results = {} + + print(f"\n{'=' * 70}") + print(f"Quant SDPA Benchmark: H={H}, H_k={H_k}, D={D}, GQA={H // H_k}x") + print(f"{'=' * 70}") + + # Header + header = f"{'SeqLen':>8}" + for mode, bits in modes: + header += f" | {mode}({bits}b):ms" + header += " | fp16:ms" + print(header) + print("-" * len(header)) + + for L in seq_lengths: + mx.random.seed(42) + q = mx.random.uniform(shape=(1, H, 1, D), dtype=dtype) + k = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype) + v = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype) + mx.eval(q, k, v) + + row = f"{L:>8}" + results[L] = {} + + # Benchmark each quant mode + for mode, bits in modes: + k_quant = mx.quantize(k, bits=bits, mode=mode) + v_quant = mx.quantize(v, bits=bits, mode=mode) + mx.eval(k_quant, v_quant) + + ms = time_fn( + quant_sdpa, + q, + k_quant, + v_quant, + bits, + mode, + loops=loops, + warmup=warmup, + iters=iters, + ) + ms_per_call = ms / loops + results[L][(mode, bits)] = ms_per_call + row += f" | {ms_per_call:8.4f}" + + # Benchmark fp16 baseline + ms = time_fn(sdpa, q, k, v, loops=loops, warmup=warmup, iters=iters) + ms_per_call = ms / loops + results[L]["fp16"] = ms_per_call + row += f" | {ms_per_call:8.4f}" + + print(row) + + return results + + +def print_speedup_table(results, modes): + """Print speedup vs fp16 baseline.""" + print(f"\n{'=' * 60}") + print("Speedup vs fp16") + print(f"{'=' * 60}") + + header = f"{'SeqLen':>8}" + for mode, bits in modes: + header += f" | {mode}({bits}b)" + print(header) + print("-" * len(header)) + + for L, data in results.items(): + fp16_ms = data["fp16"] + row = f"{L:>8}" + for mode, bits in modes: + quant_ms = data[(mode, bits)] + speedup = fp16_ms / quant_ms + row += f" | {speedup:5.2f}x" + print(row) + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark Quant SDPA") + parser.add_argument("--heads", type=int, default=32, help="Number of query heads") + parser.add_argument("--kv-heads", type=int, default=8, help="Number of KV heads") + parser.add_argument("--dim", type=int, default=128, help="Head dimension") + parser.add_argument("--loops", type=int, default=20, help="Loops per timing call") + parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations") + parser.add_argument("--iters", type=int, default=100, help="Timing iterations") + parser.add_argument( + "--seq-lengths", + type=int, + nargs="+", + default=[2048, 4096, 8192, 16384, 32768, 65536, 131072], + help="Sequence lengths to test", + ) + parser.add_argument( + "--modes", + type=str, + nargs="+", + default=["mxfp4", "mxfp8"], + help="Quantization modes to test", + ) + args = parser.parse_args() + + # Map mode names to (mode, bits) + mode_map = { + "mxfp4": ("mxfp4", 4), + "mxfp8": ("mxfp8", 8), + "affine4": ("affine", 4), + "affine8": ("affine", 8), + "nvfp4": ("nvfp4", 4), + } + + modes = [mode_map[m] for m in args.modes if m in mode_map] + + results = run_benchmark( + seq_lengths=args.seq_lengths, + modes=modes, + H=args.heads, + H_k=args.kv_heads, + D=args.dim, + loops=args.loops, + warmup=args.warmup, + iters=args.iters, + ) + + print_speedup_table(results, modes) if __name__ == "__main__": - mx.random.seed(3) - q = mx.random.uniform(shape=(1, H, 1, D), dtype=dtype) - k = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype) - v = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype) - mx.eval(q, k, v) - - k_quant = mx.quantize(k, bits=bits, mode=mode) - v_quant = mx.quantize(v, bits=bits, mode=mode) - mx.eval(k_quant, v_quant) - - time_self_attention_sdpa(q, k, v) - time_self_attention_quant_sdpa(q, k_quant, v_quant, bits, mode) - time_self_attention_primitives(q, k, v) - time_self_attention_quant_primitives(q, k_quant, v_quant, bits, mode) + main() diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 607e7dc2c5..0cda30fba6 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -186,7 +186,7 @@ METAL_FUNC void load_queries(const device T* queries, thread U* q, U scale) { } } -constant bool has_affine_bias [[function_constant(26)]]; +constant bool has_affine_bias [[function_constant(27)]]; template struct QuantFastOps { @@ -369,7 +369,7 @@ template const device ScaleTypeT* key_scales [[buffer(2)]], const device uint32_t* values [[buffer(3)]], const device ScaleTypeT* value_scales [[buffer(4)]], - device float* out [[buffer(5)]], + device T* out [[buffer(5)]], device float* sums [[buffer(6)]], device float* maxs [[buffer(7)]], const constant int& gqa_factor [[buffer(8)]], @@ -393,18 +393,17 @@ template [[buffer(21), function_constant(has_affine_bias)]], uint3 tid [[threadgroup_position_in_grid]], uint3 tpg [[threadgroups_per_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]], uint quad_gid [[quadgroup_index_in_threadgroup]], uint quad_lid [[thread_index_in_quadgroup]]) { + // Quadgroup approach: BN=8 quads × BD=4 lanes = 32 threads = 1 simdgroup + // Each quad processes one key, lanes split D dimension. + // elem_per_thread=D/4 is large enough for all pack_factors (max 8). using Cfg = QuantConfig; - constexpr int BN = 16; + constexpr int BN = 8; constexpr int BD = 4; constexpr int elem_per_thread = D / BD; - constexpr int blocks = 32; - - const int stride = BN * D; typedef float U; [[maybe_unused]] Dequant dequant; @@ -412,10 +411,6 @@ template thread U q[elem_per_thread]; thread U o[elem_per_thread] = {0}; - threadgroup U outputs[BN * BD]; - threadgroup U max_scores[BN]; - threadgroup U sum_exp_scores[BN]; - const int block_idx = tid.z; const int q_batch_head_idx = tid.x; const int q_seq_idx = tid.y; @@ -458,9 +453,15 @@ template load_queries(queries, q, static_cast(scale)); + constexpr int stride = BN * D; + const int data_step = blocks * stride; + const int scale_step = data_step / group_size; + const int mask_step = BN * blocks * mask_kv_seq_stride; + U max_score = Limits::finite_min; U sum_exp_score = 0; + // Main loop: each quad processes one key at a time for (int i = block_idx * BN + quad_gid; i < N; i += blocks * BN) { bool use_key = true; if (do_causal) { @@ -472,9 +473,7 @@ template } if (use_key) { - U key_scale; - U key_bias = 0; - + U key_scale, key_bias = 0; if constexpr (Cfg::has_bias) { key_scale = U(key_scales[0]); key_bias = U(key_biases[0]); @@ -498,9 +497,7 @@ template max_score = new_max; sum_exp_score = sum_exp_score * factor + exp_score; - U value_scale; - U value_bias = 0; - + U value_scale, value_bias = 0; if constexpr (Cfg::has_bias) { value_scale = U(value_scales[0]); value_bias = U(value_biases[0]); @@ -516,52 +513,46 @@ template exp_score * value_bias); } - // Advance pointers by blocks * BN * D elements - key_ptr.advance(blocks * stride); - value_ptr.advance(blocks * stride); - key_scales += blocks * stride / group_size; - value_scales += blocks * stride / group_size; + // Advance pointers + key_ptr.advance(data_step); + value_ptr.advance(data_step); + key_scales += scale_step; + value_scales += scale_step; if constexpr (Cfg::has_bias) { - key_biases += blocks * stride / group_size; - value_biases += blocks * stride / group_size; + key_biases += scale_step; + value_biases += scale_step; } if (bool_mask) { - bmask += BN * blocks * mask_kv_seq_stride; + bmask += mask_step; } if (float_mask) { - fmask += BN * blocks * mask_kv_seq_stride; + fmask += mask_step; } } - if (quad_lid == 0) { - max_scores[quad_gid] = max_score; - sum_exp_scores[quad_gid] = sum_exp_score; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - max_score = (simd_lid < BN) ? max_scores[simd_lid] : Limits::finite_min; - U new_max = simd_max(max_score); - U factor = fast::exp(max_score - new_max); - sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0; - sum_exp_score = simd_sum(sum_exp_score * factor); + U sg_max = (quad_lid == 0) ? max_score : Limits::finite_min; + U global_max = simd_max(sg_max); - if (simd_gid == 0) { - sums[0] = sum_exp_score; - maxs[0] = new_max; + U sg_sum = + (quad_lid == 0) ? sum_exp_score * fast::exp(max_score - global_max) : 0; + U global_sum = simd_sum(sg_sum); + + if (simd_lid == 0) { + sums[0] = global_sum; + maxs[0] = global_max; } + // Output reduction: sum across quads (same quad_lid only) + U rescale = fast::exp(max_score - global_max); for (int i = 0; i < elem_per_thread; i++) { - outputs[quad_lid * BN + quad_gid] = - o[i] * fast::exp(max_scores[quad_gid] - new_max); - threadgroup_barrier(mem_flags::mem_threadgroup); - + U val = o[i] * rescale; + val += simd_shuffle_xor(val, 4); // sum quads 0+1, 2+3, 4+5, 6+7 + val += simd_shuffle_xor(val, 8); // sum quads 0-3, 4-7 + val += simd_shuffle_xor(val, 16); // sum quads 0-7 + // All lanes with same quad_lid now have the full sum; quad_gid=0 writes if (quad_gid == 0) { - U output = outputs[quad_lid * BN]; - for (int j = 1; j < BN; j++) { - output += outputs[quad_lid * BN + j]; - } - out[i] = output; + out[i] = static_cast(val); } - threadgroup_barrier(mem_flags::mem_threadgroup); } } diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 32ee6be1a9..23e93da546 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -631,17 +631,53 @@ void quant_sdpa_vector_2pass( int gqa_factor = q.shape(1) / k.shape(1); int N = k.shape(2); - int blocks = 32; + int n_simds = gqa_factor * q.shape(2); int B = q.shape(0) * q.shape(1); - size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1); - size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1); + // TODO: tune block sizes for different devices + char devc = d.get_architecture().back(); + int blocks; + if (devc == 's') { + blocks = 64; + if (N > 1024 && n_simds > 4) { + if (N <= 8192) { + blocks = 128; + } else if (N <= 32768) { + blocks = 256; + } else if (N <= 65536) { + blocks = 512; + } else { + blocks = 1024; + } + } + } else if (devc == 'd') { + blocks = 128; + if (n_simds <= 2 && N > 8192) { + blocks = 256; + } else if (n_simds >= 6) { + if (N >= 16384 && N < 65536) { + blocks = 512; + } else if (N >= 65536) { + blocks = 1024; + } + } + } else { + if (n_simds >= 4) { + blocks = 64; + } else { + blocks = 32; + } + } + + // Head strides for quantized data (in uint32 units) and scales + size_t k_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1); + size_t v_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1); size_t k_group_stride = k_scales.shape(1) == 1 ? k_scales.strides(0) : k_scales.strides(1); size_t v_group_stride = v_scales.shape(1) == 1 ? v_scales.strides(0) : v_scales.strides(1); - MTL::Size group_dims(16 * 4, 1, 1); + MTL::Size group_dims(32, 1, 1); // 1 simdgroup, like non-quant MTL::Size grid_dims(B, q.shape(2), blocks); Shape intermediate_shape; @@ -650,7 +686,7 @@ void quant_sdpa_vector_2pass( intermediate_shape.end(), out.shape().begin(), out.shape().end() - 1); intermediate_shape.push_back(blocks); intermediate_shape.push_back(out.shape().back()); - array intermediate(intermediate_shape, float32, nullptr, {}); + array intermediate(intermediate_shape, q.dtype(), nullptr, {}); intermediate_shape.pop_back(); array sums(intermediate_shape, float32, nullptr, {}); array maxs(std::move(intermediate_shape), float32, nullptr, {}); @@ -674,13 +710,15 @@ void quant_sdpa_vector_2pass( {&bool_mask, MTL::DataType::DataTypeBool, 23}, {&float_mask, MTL::DataType::DataTypeBool, 24}, {&has_sinks, MTL::DataType::DataTypeBool, 25}, - {&has_affine_bias, MTL::DataType::DataTypeBool, 26}, + {&blocks, MTL::DataType::DataTypeInt, 26}, + {&has_affine_bias, MTL::DataType::DataTypeBool, 27}, }; std::string hash_name = kname; hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask"; hash_name += query_transposed ? "_qt" : "_qnt"; hash_name += do_causal ? "_c" : "_nc"; - hash_name += has_affine_bias ? "_affine" : "_noaffine"; + hash_name += has_affine_bias ? "_affine_" : "_noaffine_"; + hash_name += std::to_string(blocks); auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname, hash_name, func_consts); @@ -696,8 +734,8 @@ void quant_sdpa_vector_2pass( compute_encoder.set_output_array(maxs, 7); compute_encoder.set_bytes(gqa_factor, 8); compute_encoder.set_bytes(N, 9); - compute_encoder.set_bytes(k_head_stride, 10); - compute_encoder.set_bytes(v_head_stride, 11); + compute_encoder.set_bytes(k_stride, 10); + compute_encoder.set_bytes(v_stride, 11); compute_encoder.set_bytes(k_group_stride, 12); compute_encoder.set_bytes(v_group_stride, 13); compute_encoder.set_bytes(scale, 14); @@ -721,13 +759,19 @@ void quant_sdpa_vector_2pass( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + // Second pass kernel kname.clear(); kname += "sdpa_vector_2pass_2_"; kname += get_type_string(q.dtype()); kname += "_"; kname += std::to_string(q.shape(-1)); - kernel = d.get_kernel(kname); + func_consts = { + {&blocks, MTL::DataType::DataTypeInt, 26}, + }; + hash_name = kname + "_" + std::to_string(blocks); + + kernel = d.get_kernel(kname, hash_name, func_consts); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(intermediate, 0); diff --git a/mlx/fast.h b/mlx/fast.h index b1bc1e4ea8..6d82a4208a 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -55,7 +55,7 @@ MLX_API array scaled_dot_product_attention( StreamOrDevice s = {}); /** Computes: `O = softmax(Q @ K.T) @ V` where K and V are quantized. **/ -array quantized_scaled_dot_product_attention( +MLX_API array quantized_scaled_dot_product_attention( const array& queries, const array& keys, const array& key_scales, From e9a61f8e71a2a99ead2c04e912811cb64eb072d2 Mon Sep 17 00:00:00 2001 From: Dan Yeh Date: Fri, 30 Jan 2026 00:02:35 +0100 Subject: [PATCH 09/21] Limit affine SDPA to group_size=32 and bits={4,6,8} --- .../scaled_dot_product_attention.metal | 15 ++-- mlx/fast.cpp | 11 +-- python/tests/test_quantized.py | 73 +++++++++---------- 3 files changed, 45 insertions(+), 54 deletions(-) diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index 23161fd8c9..dd566bd1a4 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -54,21 +54,16 @@ instantiate_sdpa_vector_heads(float16_t) group_size, \ bits) -#define instantiate_quant_sdpa_vector_affine(type, head_dim, group_size) \ - instantiate_quant_sdpa_vector(type, head_dim, Affine, group_size, 2) \ - instantiate_quant_sdpa_vector(type, head_dim, Affine, group_size, 3) \ - instantiate_quant_sdpa_vector(type, head_dim, Affine, group_size, 4) \ - instantiate_quant_sdpa_vector(type, head_dim, Affine, group_size, 5) \ - instantiate_quant_sdpa_vector(type, head_dim, Affine, group_size, 6) \ - instantiate_quant_sdpa_vector(type, head_dim, Affine, group_size, 8) +#define instantiate_quant_sdpa_vector_affine(type, head_dim) \ + instantiate_quant_sdpa_vector(type, head_dim, Affine, 32, 4) \ + instantiate_quant_sdpa_vector(type, head_dim, Affine, 32, 6) \ + instantiate_quant_sdpa_vector(type, head_dim, Affine, 32, 8) #define instantiate_quant_sdpa_vector_all_modes(type, head_dim) \ instantiate_quant_sdpa_vector(type, head_dim, Mxfp4, 32, 4) \ instantiate_quant_sdpa_vector(type, head_dim, Nvfp4, 16, 4) \ instantiate_quant_sdpa_vector(type, head_dim, Mxfp8, 32, 8) \ - instantiate_quant_sdpa_vector_affine(type, head_dim, 32) \ - instantiate_quant_sdpa_vector_affine(type, head_dim, 64) \ - instantiate_quant_sdpa_vector_affine(type, head_dim, 128) + instantiate_quant_sdpa_vector_affine(type, head_dim) #define instantiate_quant_sdpa_vector_heads(type) \ instantiate_quant_sdpa_vector_all_modes(type, 64) \ diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 2c623f6583..d86862e5f7 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -886,16 +886,17 @@ array quantized_scaled_dot_product_attention( // Validate mode-specific group_size and bits if (is_affine) { - if (group_size != 32 && group_size != 64 && group_size != 128) { + if (group_size != 32) { std::ostringstream msg; - msg << "[" << tag << "] Affine mode supports group_size 32, 64, or 128 " + msg << "[" << tag << "] Affine mode supports group_size 32 " << "but received " << group_size << "."; throw std::invalid_argument(msg.str()); } - if (bits < 2 || bits > 8 || bits == 7) { + if (bits != 4 && bits != 6 && bits != 8) { std::ostringstream msg; - msg << "[" << tag << "] Affine mode supports bits 2-6 or 8 but received " - << bits << "."; + msg << "[" << tag + << "] Affine mode supports bits 4, 6, or 8 but received " << bits + << "."; throw std::invalid_argument(msg.str()); } if (!key_biases.has_value() || !value_biases.has_value()) { diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index c61707a01b..9e3859772d 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -796,47 +796,42 @@ def test_quantized_sdpa_affine(self): B, Hq, Hkv = 1, 2, 1 Lq, Lk, D = 4, 640, 128 - for group_size in [32, 64, 128]: - for bits in [2, 3, 4, 5, 6, 8]: - with self.subTest(group_size=group_size, bits=bits): - q = 0.1 * mx.random.normal(shape=(B, Hq, Lq, D)) - k = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) - v = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) - - k_q, k_scales, k_biases = mx.quantize( - k, group_size=group_size, bits=bits, mode="affine" - ) - v_q, v_scales, v_biases = mx.quantize( - v, group_size=group_size, bits=bits, mode="affine" - ) + for bits in [4, 6, 8]: + with self.subTest(bits=bits): + q = 0.1 * mx.random.normal(shape=(B, Hq, Lq, D)) + k = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + v = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) - ref = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0) - out = mx.fast.quantized_scaled_dot_product_attention( - q, - k_q, - k_scales, - k_biases, - v_q, - v_scales, - v_biases, - scale=1.0, - mode="affine", - group_size=group_size, - bits=bits, - ) + k_q, k_scales, k_biases = mx.quantize( + k, group_size=32, bits=bits, mode="affine" + ) + v_q, v_scales, v_biases = mx.quantize( + v, group_size=32, bits=bits, mode="affine" + ) + + ref = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0) + out = mx.fast.quantized_scaled_dot_product_attention( + q, + k_q, + k_scales, + k_biases, + v_q, + v_scales, + v_biases, + scale=1.0, + mode="affine", + group_size=32, + bits=bits, + ) - self.assertEqual(out.shape, ref.shape) - if bits <= 3: - tol = 3e-1 - elif bits == 5: - tol = 1.5e-1 - elif bits == 6: - tol = 1e-1 - elif bits == 4: - tol = 5e-2 - else: - tol = 2e-2 - self.assertLess((out - ref).abs().max(), tol) + self.assertEqual(out.shape, ref.shape) + if bits == 6: + tol = 1e-1 + elif bits == 4: + tol = 5e-2 + else: + tol = 2e-2 + self.assertLess((out - ref).abs().max(), tol) def test_gather_qmm(self): def quantize(w, transpose=True, group_size=None, bits=None, mode="affine"): From 7896f6d7efec06ef508d38aaa69eb4dc245a6efa Mon Sep 17 00:00:00 2001 From: Dan Yeh Date: Thu, 5 Feb 2026 12:52:54 +0800 Subject: [PATCH 10/21] fix group_size for nvfp4 and simplify code --- mlx/backend/metal/kernels/sdpa_vector.h | 352 +++++++++--------- .../metal/scaled_dot_product_attention.cpp | 2 + 2 files changed, 181 insertions(+), 173 deletions(-) diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 0cda30fba6..91d3f90d9a 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -188,177 +188,185 @@ METAL_FUNC void load_queries(const device T* queries, thread U* q, U scale) { constant bool has_affine_bias [[function_constant(27)]]; -template -struct QuantFastOps { - using Cfg = QuantConfig; - using load_t = metal::conditional_t; +template +struct GroupSlice { + enum : int { + value = (elem_per_thread < group_size) ? elem_per_thread : group_size, + num_groups = elem_per_thread / value, + iters_per_group = value / granularity + }; + static_assert( + (value % granularity) == 0, + "group slice must be divisible by granularity"); static_assert( - bits == 4 || bits == 8, - "QuantFastOps only supports 4/8-bit packing"); - static constant constexpr uint32_t mask = (1u << bits) - 1; + (elem_per_thread % value) == 0, + "elem_per_thread must be divisible by group slice"); +}; - template - [[clang::always_inline]] static U - dot(const thread U* q, const device uint32_t* keys, U scale, U bias) { +template +struct QuantOps { + using Cfg = QuantConfig; + static constant constexpr bool is_fast_path = (bits == 4 || bits == 8); + static constant constexpr int pack_factor = PackInfo::pack_factor; + static constant constexpr int bytes_per_pack = PackInfo::bytes_per_pack; + static constant constexpr int granularity = is_fast_path ? 4 : pack_factor; + using fast_load_t = metal::conditional_t; + static constant constexpr uint32_t fast_mask = (1u << bits) - 1; + static_assert(bits == 4 || bits == 6 || bits == 8, "unsupported quant bits"); + static_assert( + !is_fast_path || (group_size % 4) == 0, + "group_size must be divisible by 4 for 4/8-bit fast path"); + + template + [[clang::always_inline]] static U dot( + const thread U* q, + const device uint32_t* keys, + const device ScaleT* scales, + [[maybe_unused]] const device ScaleT* biases) { + static_assert( + (elem_per_thread % granularity) == 0, + "elem_per_thread must be divisible by the granularity"); Dequant dequant; - auto ks = reinterpret_cast(keys); U score = 0; - [[maybe_unused]] U bias_acc = 0; -#pragma clang loop unroll(full) - for (int j = 0; j < elem_per_thread / 4; j++) { - load_t p = ks[j]; - U v0 = dequant.raw(p & mask); - U v1 = dequant.raw((p >> (bits * 1)) & mask); - U v2 = dequant.raw((p >> (bits * 2)) & mask); - U v3 = dequant.raw((p >> (bits * 3)) & mask); - - score += q[4 * j + 0] * v0; - score += q[4 * j + 1] * v1; - score += q[4 * j + 2] * v2; - score += q[4 * j + 3] * v3; + using Slice = GroupSlice; + constexpr int group_slice = Slice::value; + constexpr int num_groups = Slice::num_groups; + constexpr int iters_per_group = Slice::iters_per_group; - if constexpr (Cfg::has_bias) { - bias_acc += (q[4 * j + 0] + q[4 * j + 1] + q[4 * j + 2] + q[4 * j + 3]); - } - } +#pragma clang loop unroll(full) + for (int g = 0; g < num_groups; g++) { + U scale = dequant.scale(scales[g]); + U bias = 0; + if constexpr (Cfg::has_bias) + bias = static_cast(biases[g]); - if constexpr (Cfg::has_bias) { - return fma(scale, score, bias * bias_acc); - } else { - return scale * score; - } - } + U group_score = 0; + U bias_acc = 0; - template - [[clang::always_inline]] static void accumulate( - thread U* o, - const device uint32_t* values, - U factor, - U w_scale, - U bias) { - Dequant dequant; - auto vs = reinterpret_cast(values); + if constexpr (is_fast_path) { + auto ks = reinterpret_cast(keys); #pragma clang loop unroll(full) - for (int j = 0; j < elem_per_thread / 4; j++) { - load_t p = vs[j]; - U v0 = dequant.raw(p & mask); - U v1 = dequant.raw((p >> (bits * 1)) & mask); - U v2 = dequant.raw((p >> (bits * 2)) & mask); - U v3 = dequant.raw((p >> (bits * 3)) & mask); - - if constexpr (Cfg::has_bias) { - o[4 * j + 0] = fma(o[4 * j + 0], factor, fma(w_scale, v0, bias)); - o[4 * j + 1] = fma(o[4 * j + 1], factor, fma(w_scale, v1, bias)); - o[4 * j + 2] = fma(o[4 * j + 2], factor, fma(w_scale, v2, bias)); - o[4 * j + 3] = fma(o[4 * j + 3], factor, fma(w_scale, v3, bias)); + for (int j = 0; j < iters_per_group; j++) { + fast_load_t p = ks[g * iters_per_group + j]; + int base = g * group_slice + 4 * j; + + U v0 = dequant.raw(p & fast_mask); + U v1 = dequant.raw((p >> (bits * 1)) & fast_mask); + U v2 = dequant.raw((p >> (bits * 2)) & fast_mask); + U v3 = dequant.raw((p >> (bits * 3)) & fast_mask); + + group_score += q[base + 0] * v0; + group_score += q[base + 1] * v1; + group_score += q[base + 2] * v2; + group_score += q[base + 3] * v3; + if constexpr (Cfg::has_bias) { + bias_acc += q[base + 0] + q[base + 1] + q[base + 2] + q[base + 3]; + } + } } else { - o[4 * j + 0] = fma(o[4 * j + 0], factor, v0 * w_scale); - o[4 * j + 1] = fma(o[4 * j + 1], factor, v1 * w_scale); - o[4 * j + 2] = fma(o[4 * j + 2], factor, v2 * w_scale); - o[4 * j + 3] = fma(o[4 * j + 3], factor, v3 * w_scale); - } - } - } -}; - -// (Generic Path for 2, 3, 5, 6 bits) -template -struct QuantOps { - using Cfg = QuantConfig; - static constant constexpr int pack_factor = PackInfo::pack_factor; - static constant constexpr int bytes_per_pack = PackInfo::bytes_per_pack; - - template - [[clang::always_inline]] static U - dot(const thread U* q, const device uint32_t* keys, U scale, U bias) { - Dequant dequant; - auto ks = reinterpret_cast(keys); - thread uint8_t raw[pack_factor]; - U score = 0; - [[maybe_unused]] U bias_acc = 0; + auto ks = reinterpret_cast(keys) + + g * (group_slice / pack_factor) * bytes_per_pack; + thread uint8_t raw[pack_factor]; #pragma clang loop unroll(full) - for (int j = 0; j < elem_per_thread; j += pack_factor) { - PackReader::load(ks, raw); + for (int j = 0; j < group_slice; j += pack_factor) { + PackReader::load(ks, raw); #pragma clang loop unroll(full) - for (int t = 0; t < pack_factor; ++t) { - U decoded = dequant.raw(raw[t]); - score += q[j + t] * decoded; - if constexpr (Cfg::has_bias) - bias_acc += q[j + t]; + for (int t = 0; t < pack_factor; ++t) { + U decoded = dequant.raw(raw[t]); + int q_idx = g * group_slice + j + t; + group_score += q[q_idx] * decoded; + if constexpr (Cfg::has_bias) + bias_acc += q[q_idx]; + } + ks += bytes_per_pack; + } + } + + if constexpr (Cfg::has_bias) { + score += fma(scale, group_score, bias * bias_acc); + } else { + score += scale * group_score; } - ks += bytes_per_pack; } - return Cfg::has_bias ? fma(scale, score, bias * bias_acc) : scale * score; + return score; } - template + // ACCUMULATE + template [[clang::always_inline]] static void accumulate( thread U* o, const device uint32_t* values, U factor, - U w_scale, - U bias) { + U exp_score, + const device ScaleT* scales, + [[maybe_unused]] const device ScaleT* biases) { + static_assert( + (elem_per_thread % granularity) == 0, + "elem_per_thread must be divisible by the granularity"); Dequant dequant; - auto vs = reinterpret_cast(values); - thread uint8_t raw[pack_factor]; + + using Slice = GroupSlice; + constexpr int group_slice = Slice::value; + constexpr int num_groups = Slice::num_groups; + constexpr int iters_per_group = Slice::iters_per_group; + +#pragma clang loop unroll(full) + for (int g = 0; g < num_groups; g++) { + U w_scale = exp_score * dequant.scale(scales[g]); + U bias = 0; + if constexpr (Cfg::has_bias) + bias = exp_score * static_cast(biases[g]); + + if constexpr (is_fast_path) { + auto vs = reinterpret_cast(values); +#pragma clang loop unroll(full) + for (int j = 0; j < iters_per_group; j++) { + fast_load_t p = vs[g * iters_per_group + j]; + int base = g * group_slice + 4 * j; + + U v0 = dequant.raw(p & fast_mask); + U v1 = dequant.raw((p >> (bits * 1)) & fast_mask); + U v2 = dequant.raw((p >> (bits * 2)) & fast_mask); + U v3 = dequant.raw((p >> (bits * 3)) & fast_mask); + + if constexpr (Cfg::has_bias) { + o[base + 0] = fma(o[base + 0], factor, fma(w_scale, v0, bias)); + o[base + 1] = fma(o[base + 1], factor, fma(w_scale, v1, bias)); + o[base + 2] = fma(o[base + 2], factor, fma(w_scale, v2, bias)); + o[base + 3] = fma(o[base + 3], factor, fma(w_scale, v3, bias)); + } else { + o[base + 0] = fma(o[base + 0], factor, v0 * w_scale); + o[base + 1] = fma(o[base + 1], factor, v1 * w_scale); + o[base + 2] = fma(o[base + 2], factor, v2 * w_scale); + o[base + 3] = fma(o[base + 3], factor, v3 * w_scale); + } + } + } else { + auto vs = reinterpret_cast(values) + + g * (group_slice / pack_factor) * bytes_per_pack; + thread uint8_t raw[pack_factor]; + #pragma clang loop unroll(full) - for (int j = 0; j < elem_per_thread; j += pack_factor) { - PackReader::load(vs, raw); + for (int j = 0; j < group_slice; j += pack_factor) { + PackReader::load(vs, raw); #pragma clang loop unroll(full) - for (int t = 0; t < pack_factor; ++t) { - U decoded = dequant.raw(raw[t]); - if constexpr (Cfg::has_bias) { - o[j + t] = fma(o[j + t], factor, fma(w_scale, decoded, bias)); - } else { - o[j + t] = fma(o[j + t], factor, decoded * w_scale); + for (int t = 0; t < pack_factor; ++t) { + U decoded = dequant.raw(raw[t]); + int idx = g * group_slice + j + t; + if constexpr (Cfg::has_bias) { + o[idx] = fma(o[idx], factor, fma(w_scale, decoded, bias)); + } else { + o[idx] = fma(o[idx], factor, decoded * w_scale); + } + } + vs += bytes_per_pack; } } - vs += bytes_per_pack; } } }; - -template -struct QuantOps { - template - [[clang::always_inline]] static U - dot(const thread U* q, const device uint32_t* keys, U scale, U bias) { - return QuantFastOps::template dot( - q, keys, scale, bias); - } - template - [[clang::always_inline]] static void accumulate( - thread U* o, - const device uint32_t* values, - U factor, - U w_scale, - U bias) { - QuantFastOps::template accumulate( - o, values, factor, w_scale, bias); - } -}; - -template -struct QuantOps { - template - [[clang::always_inline]] static U - dot(const thread U* q, const device uint32_t* keys, U scale, U bias) { - return QuantFastOps::template dot( - q, keys, scale, bias); - } - template - [[clang::always_inline]] static void accumulate( - thread U* o, - const device uint32_t* values, - U factor, - U w_scale, - U bias) { - QuantFastOps::template accumulate( - o, values, factor, w_scale, bias); - } -}; template using ScaleTypeT = typename QuantConfig::template scale_storage_t; @@ -401,12 +409,13 @@ template // elem_per_thread=D/4 is large enough for all pack_factors (max 8). using Cfg = QuantConfig; + static_assert( + (D % group_size) == 0, "group_size must divide the head dimension"); constexpr int BN = 8; constexpr int BD = 4; constexpr int elem_per_thread = D / BD; typedef float U; - [[maybe_unused]] Dequant dequant; thread U q[elem_per_thread]; thread U o[elem_per_thread] = {0}; @@ -431,9 +440,15 @@ template key_scales += k_group_idx; value_scales += v_group_idx; + const device ScaleTypeT* key_bias_ptr = nullptr; + const device ScaleTypeT* value_bias_ptr = nullptr; if constexpr (Cfg::has_bias) { - key_biases += k_group_idx; - value_biases += v_group_idx; + key_bias_ptr = + reinterpret_cast*>(key_biases) + + k_group_idx; + value_bias_ptr = + reinterpret_cast*>(value_biases) + + v_group_idx; } out += o_offset * blocks * D + block_idx * D + quad_lid * elem_per_thread; @@ -451,13 +466,17 @@ template q_seq_idx * mask_q_seq_stride; } - load_queries(queries, q, static_cast(scale)); - constexpr int stride = BN * D; const int data_step = blocks * stride; const int scale_step = data_step / group_size; const int mask_step = BN * blocks * mask_kv_seq_stride; + // Read the query +#pragma clang loop unroll(full) + for (int i = 0; i < elem_per_thread; i++) { + q[i] = static_cast(scale) * queries[i]; + } + U max_score = Limits::finite_min; U sum_exp_score = 0; @@ -473,16 +492,9 @@ template } if (use_key) { - U key_scale, key_bias = 0; - if constexpr (Cfg::has_bias) { - key_scale = U(key_scales[0]); - key_bias = U(key_biases[0]); - } else { - key_scale = dequant.scale(key_scales[0]); - } - - U score = QuantOps::template dot( - q, key_ptr.ptr(), key_scale, key_bias); + U score = QuantOps:: + template dot, elem_per_thread>( + q, key_ptr.ptr(), key_scales, key_bias_ptr); score = quad_sum(score); if (float_mask) { @@ -497,20 +509,14 @@ template max_score = new_max; sum_exp_score = sum_exp_score * factor + exp_score; - U value_scale, value_bias = 0; - if constexpr (Cfg::has_bias) { - value_scale = U(value_scales[0]); - value_bias = U(value_biases[0]); - } else { - value_scale = dequant.scale(value_scales[0]); - } - - QuantOps::template accumulate( - o, - value_ptr.ptr(), - factor, - exp_score * value_scale, - exp_score * value_bias); + QuantOps:: + template accumulate, elem_per_thread>( + o, + value_ptr.ptr(), + factor, + exp_score, + value_scales, + value_bias_ptr); } // Advance pointers @@ -519,8 +525,8 @@ template key_scales += scale_step; value_scales += scale_step; if constexpr (Cfg::has_bias) { - key_biases += scale_step; - value_biases += scale_step; + key_bias_ptr += scale_step; + value_bias_ptr += scale_step; } if (bool_mask) { bmask += mask_step; diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 23e93da546..80e828c262 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -443,6 +443,7 @@ void sdpa_vector_2pass( char devc = d.get_architecture().back(); int N = k.shape(2); int blocks; + if (devc == 's') { blocks = 64; if (N > 1024 && n_simds > 4) { @@ -474,6 +475,7 @@ void sdpa_vector_2pass( blocks = 32; } } + size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1); size_t k_seq_stride = k.strides()[2]; size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1); From 146dcdcd43213e7234b1e7d5e77f009e3aea18e8 Mon Sep 17 00:00:00 2001 From: Dan Yeh Date: Sat, 7 Feb 2026 19:13:20 +0800 Subject: [PATCH 11/21] refactor: use function_constant to reduce template specializations(binary size) --- benchmarks/python/sdpa_vector_bench.py | 14 +- .../scaled_dot_product_attention.metal | 32 +-- mlx/backend/metal/kernels/sdpa_vector.h | 224 +++++++++++++----- .../metal/scaled_dot_product_attention.cpp | 37 ++- mlx/fast.cpp | 4 +- 5 files changed, 201 insertions(+), 110 deletions(-) diff --git a/benchmarks/python/sdpa_vector_bench.py b/benchmarks/python/sdpa_vector_bench.py index 693816e5d0..4904519510 100644 --- a/benchmarks/python/sdpa_vector_bench.py +++ b/benchmarks/python/sdpa_vector_bench.py @@ -17,10 +17,10 @@ def time_fn(fn, *args, warmup=5, iters=100, **kwargs): return 1e3 * (toc - tic) / iters -def quant_sdpa(q, k, v, bits, mode, loops=1): +def quant_sdpa(q, k, v, bits, mode, group_size=None, loops=1): for _ in range(loops): q = mx.fast.quantized_scaled_dot_product_attention( - q, *k, *v, scale=1.0, mask=None, bits=bits, mode=mode + q, *k, *v, scale=1.0, mask=None, bits=bits, mode=mode, group_size=group_size ) return q @@ -69,8 +69,13 @@ def run_benchmark( # Benchmark each quant mode for mode, bits in modes: - k_quant = mx.quantize(k, bits=bits, mode=mode) - v_quant = mx.quantize(v, bits=bits, mode=mode) + gs = 32 if mode == "affine" else None + k_quant = mx.quantize( + k, bits=bits, mode=mode, **({"group_size": gs} if gs else {}) + ) + v_quant = mx.quantize( + v, bits=bits, mode=mode, **({"group_size": gs} if gs else {}) + ) mx.eval(k_quant, v_quant) ms = time_fn( @@ -80,6 +85,7 @@ def run_benchmark( v_quant, bits, mode, + group_size=gs, loops=loops, warmup=warmup, iters=iters, diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index dd566bd1a4..e71e4395aa 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -43,31 +43,17 @@ instantiate_sdpa_vector_heads(bfloat16_t) instantiate_sdpa_vector_heads(float16_t) // Quantized SDPA vector instantiations -// Uses QuantMode enum for explicit mode selection -#define instantiate_quant_sdpa_vector(type, head_dim, mode, group_size, bits) \ - instantiate_kernel( \ - "quant_sdpa_vector_2pass_1_" #type "_" #head_dim "_" #mode "_" #group_size "_" #bits, \ - quant_sdpa_vector_2pass_1, \ - type, \ - head_dim, \ - QuantMode::mode, \ - group_size, \ - bits) - -#define instantiate_quant_sdpa_vector_affine(type, head_dim) \ - instantiate_quant_sdpa_vector(type, head_dim, Affine, 32, 4) \ - instantiate_quant_sdpa_vector(type, head_dim, Affine, 32, 6) \ - instantiate_quant_sdpa_vector(type, head_dim, Affine, 32, 8) - -#define instantiate_quant_sdpa_vector_all_modes(type, head_dim) \ - instantiate_quant_sdpa_vector(type, head_dim, Mxfp4, 32, 4) \ - instantiate_quant_sdpa_vector(type, head_dim, Nvfp4, 16, 4) \ - instantiate_quant_sdpa_vector(type, head_dim, Mxfp8, 32, 8) \ - instantiate_quant_sdpa_vector_affine(type, head_dim) +// mode/bits/group_size are now function constants (indices 28-30) +#define instantiate_quant_sdpa_vector(type, head_dim) \ + instantiate_kernel( \ + "quant_sdpa_vector_2pass_1_" #type "_" #head_dim, \ + quant_sdpa_vector_2pass_1, \ + type, \ + head_dim) #define instantiate_quant_sdpa_vector_heads(type) \ - instantiate_quant_sdpa_vector_all_modes(type, 64) \ - instantiate_quant_sdpa_vector_all_modes(type, 128) + instantiate_quant_sdpa_vector(type, 64) \ + instantiate_quant_sdpa_vector(type, 128) instantiate_quant_sdpa_vector_heads(float) instantiate_quant_sdpa_vector_heads(bfloat16_t) diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 91d3f90d9a..aa0bb65bee 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -187,6 +187,9 @@ METAL_FUNC void load_queries(const device T* queries, thread U* q, U scale) { } constant bool has_affine_bias [[function_constant(27)]]; +constant int quant_mode_int [[function_constant(28)]]; +constant int quant_bits [[function_constant(29)]]; +constant int quant_group_size [[function_constant(30)]]; template struct GroupSlice { @@ -371,43 +374,44 @@ template using ScaleTypeT = typename QuantConfig::template scale_storage_t; template -[[kernel]] void quant_sdpa_vector_2pass_1( - const device T* queries [[buffer(0)]], - const device uint32_t* keys [[buffer(1)]], - const device ScaleTypeT* key_scales [[buffer(2)]], - const device uint32_t* values [[buffer(3)]], - const device ScaleTypeT* value_scales [[buffer(4)]], - device T* out [[buffer(5)]], - device float* sums [[buffer(6)]], - device float* maxs [[buffer(7)]], - const constant int& gqa_factor [[buffer(8)]], - const constant int& N [[buffer(9)]], - const constant size_t& k_stride [[buffer(10)]], - const constant size_t& v_stride [[buffer(11)]], - const constant size_t& k_group_stride [[buffer(12)]], - const constant size_t& v_group_stride [[buffer(13)]], - const constant float& scale [[buffer(14)]], - const device bool* bmask [[buffer(15), function_constant(bool_mask)]], - const device T* fmask [[buffer(16), function_constant(float_mask)]], - const constant int& mask_kv_seq_stride - [[buffer(17), function_constant(has_mask)]], - const constant int& mask_q_seq_stride - [[buffer(18), function_constant(has_mask)]], - const constant int& mask_head_stride - [[buffer(19), function_constant(has_mask)]], - const device T* key_biases - [[buffer(20), function_constant(has_affine_bias)]], - const device T* value_biases - [[buffer(21), function_constant(has_affine_bias)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 tpg [[threadgroups_per_grid]], - uint simd_lid [[thread_index_in_simdgroup]], - uint quad_gid [[quadgroup_index_in_threadgroup]], - uint quad_lid [[thread_index_in_quadgroup]]) { +METAL_FUNC void quant_sdpa_inner( + const device T* queries, + const device uint32_t* keys, + const device uint8_t* key_scales_raw, + const device uint32_t* values, + const device uint8_t* value_scales_raw, + device T* out, + device float* sums, + device float* maxs, + const constant int& N, + const constant size_t& k_stride, + const constant size_t& v_stride, + const constant size_t& k_group_stride, + const constant size_t& v_group_stride, + const constant float& scale, + const device bool* bmask, + const device T* fmask, + const constant int& mask_kv_seq_stride, + const constant int& mask_q_seq_stride, + const constant int& mask_head_stride, + const device uint8_t* key_biases_raw, + const device uint8_t* value_biases_raw, + uint3 tid, + uint3 tpg, + uint3 tptg, + uint3 tidtg, + uint simd_lid) { // Quadgroup approach: BN=8 quads × BD=4 lanes = 32 threads = 1 simdgroup // Each quad processes one key, lanes split D dimension. // elem_per_thread=D/4 is large enough for all pack_factors (max 8). + // + // GQA: multiple query heads sharing the same KV head are packed into the + // same threadgroup (along with q_seq_len). This lets them share L2 cache + // for KV data. + // Grid: (num_kv_heads, batch, blocks) + // Group: (32, gqa_factor, q_seq_len) using Cfg = QuantConfig; + using ScaleT = ScaleTypeT; static_assert( (D % group_size) == 0, "group_size must divide the head dimension"); @@ -415,54 +419,72 @@ template constexpr int BD = 4; constexpr int elem_per_thread = D / BD; + // Derive quad indices from simd_lid (replaces quad_gid/quad_lid attributes) + const int local_quad_gid = simd_lid / 4; // 0-7 + const int local_quad_lid = simd_lid % 4; // 0-3 + typedef float U; + // Cast raw byte pointers to typed scale pointers + auto key_scales = reinterpret_cast(key_scales_raw); + auto value_scales = reinterpret_cast(value_scales_raw); + thread U q[elem_per_thread]; thread U o[elem_per_thread] = {0}; + // Head/batch from grid + threadgroup position + const int kv_head_idx = tid.x; + const int batch_idx = tid.y; const int block_idx = tid.z; - const int q_batch_head_idx = tid.x; - const int q_seq_idx = tid.y; - const int kv_head_idx = q_batch_head_idx / gqa_factor; - const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx; + const int gqa_factor = tptg.y; + const int q_seq_len = tptg.z; + const int gqa_offset = tidtg.y; + const int q_seq_idx = tidtg.z; + const int num_kv_heads = tpg.x; + const int num_q_heads = num_kv_heads * gqa_factor; + const int q_head_idx = gqa_factor * kv_head_idx + gqa_offset; + const int q_batch_head_idx = batch_idx * num_q_heads + q_head_idx; + const int o_offset = q_batch_head_idx * q_seq_len + q_seq_idx; const int q_offset = - query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset; + query_transposed ? num_q_heads * q_seq_idx + q_batch_head_idx : o_offset; - queries += q_offset * D + quad_lid * elem_per_thread; + queries += q_offset * D + local_quad_lid * elem_per_thread; + const int kv_batch_head_idx = batch_idx * num_kv_heads + kv_head_idx; const int kv_idx = - (block_idx * BN + quad_gid) * D + quad_lid * elem_per_thread; - const int k_group_idx = kv_head_idx * k_group_stride + kv_idx / group_size; - const int v_group_idx = kv_head_idx * v_group_stride + kv_idx / group_size; + (block_idx * BN + local_quad_gid) * D + local_quad_lid * elem_per_thread; + const int k_group_idx = + kv_batch_head_idx * k_group_stride + kv_idx / group_size; + const int v_group_idx = + kv_batch_head_idx * v_group_stride + kv_idx / group_size; - QuantDataPtr key_ptr(keys, k_stride, kv_head_idx, kv_idx); - QuantDataPtr value_ptr(values, v_stride, kv_head_idx, kv_idx); + QuantDataPtr key_ptr(keys, k_stride, kv_batch_head_idx, kv_idx); + QuantDataPtr value_ptr(values, v_stride, kv_batch_head_idx, kv_idx); key_scales += k_group_idx; value_scales += v_group_idx; - const device ScaleTypeT* key_bias_ptr = nullptr; - const device ScaleTypeT* value_bias_ptr = nullptr; + const device ScaleT* key_bias_ptr = nullptr; + const device ScaleT* value_bias_ptr = nullptr; if constexpr (Cfg::has_bias) { key_bias_ptr = - reinterpret_cast*>(key_biases) + - k_group_idx; + reinterpret_cast(key_biases_raw) + k_group_idx; value_bias_ptr = - reinterpret_cast*>(value_biases) + - v_group_idx; + reinterpret_cast(value_biases_raw) + v_group_idx; } - out += o_offset * blocks * D + block_idx * D + quad_lid * elem_per_thread; + out += + o_offset * blocks * D + block_idx * D + local_quad_lid * elem_per_thread; sums += o_offset * blocks + block_idx; maxs += o_offset * blocks + block_idx; if (bool_mask) { bmask += q_batch_head_idx * mask_head_stride + - (block_idx * BN + quad_gid) * mask_kv_seq_stride + + (block_idx * BN + local_quad_gid) * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; } if (float_mask) { fmask += q_batch_head_idx * mask_head_stride + - (block_idx * BN + quad_gid) * mask_kv_seq_stride + + (block_idx * BN + local_quad_gid) * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; } @@ -481,10 +503,10 @@ template U sum_exp_score = 0; // Main loop: each quad processes one key at a time - for (int i = block_idx * BN + quad_gid; i < N; i += blocks * BN) { + for (int i = block_idx * BN + local_quad_gid; i < N; i += blocks * BN) { bool use_key = true; if (do_causal) { - use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); + use_key = i <= (N - q_seq_len + int(q_seq_idx)); } else if (bool_mask) { use_key = bmask[0]; } else if (float_mask) { @@ -493,7 +515,7 @@ template if (use_key) { U score = QuantOps:: - template dot, elem_per_thread>( + template dot( q, key_ptr.ptr(), key_scales, key_bias_ptr); score = quad_sum(score); @@ -510,7 +532,7 @@ template sum_exp_score = sum_exp_score * factor + exp_score; QuantOps:: - template accumulate, elem_per_thread>( + template accumulate( o, value_ptr.ptr(), factor, @@ -536,11 +558,12 @@ template } } - U sg_max = (quad_lid == 0) ? max_score : Limits::finite_min; + U sg_max = (local_quad_lid == 0) ? max_score : Limits::finite_min; U global_max = simd_max(sg_max); - U sg_sum = - (quad_lid == 0) ? sum_exp_score * fast::exp(max_score - global_max) : 0; + U sg_sum = (local_quad_lid == 0) + ? sum_exp_score * fast::exp(max_score - global_max) + : 0; U global_sum = simd_sum(sg_sum); if (simd_lid == 0) { @@ -548,20 +571,95 @@ template maxs[0] = global_max; } - // Output reduction: sum across quads (same quad_lid only) + // Output reduction: sum across quads (same local_quad_lid only) U rescale = fast::exp(max_score - global_max); for (int i = 0; i < elem_per_thread; i++) { U val = o[i] * rescale; val += simd_shuffle_xor(val, 4); // sum quads 0+1, 2+3, 4+5, 6+7 val += simd_shuffle_xor(val, 8); // sum quads 0-3, 4-7 val += simd_shuffle_xor(val, 16); // sum quads 0-7 - // All lanes with same quad_lid now have the full sum; quad_gid=0 writes - if (quad_gid == 0) { + // All lanes with same local_quad_lid now have the full sum; + // local_quad_gid=0 writes + if (local_quad_gid == 0) { out[i] = static_cast(val); } } } +template +[[kernel]] void quant_sdpa_vector_2pass_1( + const device T* queries [[buffer(0)]], + const device uint32_t* keys [[buffer(1)]], + const device uint8_t* key_scales [[buffer(2)]], + const device uint32_t* values [[buffer(3)]], + const device uint8_t* value_scales [[buffer(4)]], + device T* out [[buffer(5)]], + device float* sums [[buffer(6)]], + device float* maxs [[buffer(7)]], + const constant int& N [[buffer(9)]], + const constant size_t& k_stride [[buffer(10)]], + const constant size_t& v_stride [[buffer(11)]], + const constant size_t& k_group_stride [[buffer(12)]], + const constant size_t& v_group_stride [[buffer(13)]], + const constant float& scale [[buffer(14)]], + const device bool* bmask [[buffer(15), function_constant(bool_mask)]], + const device T* fmask [[buffer(16), function_constant(float_mask)]], + const constant int& mask_kv_seq_stride + [[buffer(17), function_constant(has_mask)]], + const constant int& mask_q_seq_stride + [[buffer(18), function_constant(has_mask)]], + const constant int& mask_head_stride + [[buffer(19), function_constant(has_mask)]], + const device uint8_t* key_biases + [[buffer(20), function_constant(has_affine_bias)]], + const device uint8_t* value_biases + [[buffer(21), function_constant(has_affine_bias)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 tpg [[threadgroups_per_grid]], + uint3 tptg [[threads_per_threadgroup]], + uint3 tidtg [[thread_position_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { +#define QUANT_SDPA_DISPATCH(MODE, GS, B) \ + if (quant_mode_int == int(QuantMode::MODE) && quant_group_size == GS && \ + quant_bits == B) { \ + quant_sdpa_inner( \ + queries, \ + keys, \ + key_scales, \ + values, \ + value_scales, \ + out, \ + sums, \ + maxs, \ + N, \ + k_stride, \ + v_stride, \ + k_group_stride, \ + v_group_stride, \ + scale, \ + bmask, \ + fmask, \ + mask_kv_seq_stride, \ + mask_q_seq_stride, \ + mask_head_stride, \ + key_biases, \ + value_biases, \ + tid, \ + tpg, \ + tptg, \ + tidtg, \ + simd_lid); \ + return; \ + } + QUANT_SDPA_DISPATCH(Affine, 32, 4) + QUANT_SDPA_DISPATCH(Affine, 32, 6) + QUANT_SDPA_DISPATCH(Affine, 32, 8) + QUANT_SDPA_DISPATCH(Mxfp4, 32, 4) + QUANT_SDPA_DISPATCH(Nvfp4, 16, 4) + QUANT_SDPA_DISPATCH(Mxfp8, 32, 8) +#undef QUANT_SDPA_DISPATCH +} + template [[kernel]] void sdpa_vector_2pass_1( const device T* queries [[buffer(0)]], diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 80e828c262..e9ba12090e 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -585,16 +585,16 @@ void sdpa_vector_2pass( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } -std::string quant_mode_to_kernel_suffix(QuantizationMode mode) { +int quant_mode_to_int(QuantizationMode mode) { switch (mode) { case QuantizationMode::Affine: - return "Affine"; + return 0; case QuantizationMode::Mxfp4: - return "Mxfp4"; - case QuantizationMode::Nvfp4: - return "Nvfp4"; + return 1; case QuantizationMode::Mxfp8: - return "Mxfp8"; + return 2; + case QuantizationMode::Nvfp4: + return 3; default: throw std::invalid_argument( "[quant_sdpa_vector_2pass] Unsupported quantization mode."); @@ -619,22 +619,15 @@ void quant_sdpa_vector_2pass( const std::optional& mask, QuantizationMode mode) { std::string kname; - kname.reserve(96); + kname.reserve(64); kname += "quant_sdpa_vector_2pass_1_"; kname += get_type_string(q.dtype()); kname += "_"; kname += std::to_string(q.shape(-1)); - kname += "_"; - kname += quant_mode_to_kernel_suffix(mode); - kname += "_"; - kname += std::to_string(group_size); - kname += "_"; - kname += std::to_string(bits); - int gqa_factor = q.shape(1) / k.shape(1); int N = k.shape(2); + int gqa_factor = q.shape(1) / k.shape(1); int n_simds = gqa_factor * q.shape(2); - int B = q.shape(0) * q.shape(1); // TODO: tune block sizes for different devices char devc = d.get_architecture().back(); @@ -679,8 +672,8 @@ void quant_sdpa_vector_2pass( size_t v_group_stride = v_scales.shape(1) == 1 ? v_scales.strides(0) : v_scales.strides(1); - MTL::Size group_dims(32, 1, 1); // 1 simdgroup, like non-quant - MTL::Size grid_dims(B, q.shape(2), blocks); + MTL::Size group_dims(32, gqa_factor, q.shape(2)); + MTL::Size grid_dims(k.shape(1), q.shape(0), blocks); Shape intermediate_shape; intermediate_shape.reserve(out.ndim() + 1); @@ -705,6 +698,7 @@ void quant_sdpa_vector_2pass( bool query_transposed = !q.flags().row_contiguous; bool has_sinks = false; bool has_affine_bias = mode == QuantizationMode::Affine; + int quant_mode_int = quant_mode_to_int(mode); metal::MTLFCList func_consts = { {&has_mask, MTL::DataType::DataTypeBool, 20}, {&query_transposed, MTL::DataType::DataTypeBool, 21}, @@ -714,12 +708,18 @@ void quant_sdpa_vector_2pass( {&has_sinks, MTL::DataType::DataTypeBool, 25}, {&blocks, MTL::DataType::DataTypeInt, 26}, {&has_affine_bias, MTL::DataType::DataTypeBool, 27}, + {&quant_mode_int, MTL::DataType::DataTypeInt, 28}, + {&bits, MTL::DataType::DataTypeInt, 29}, + {&group_size, MTL::DataType::DataTypeInt, 30}, }; std::string hash_name = kname; hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask"; hash_name += query_transposed ? "_qt" : "_qnt"; hash_name += do_causal ? "_c" : "_nc"; hash_name += has_affine_bias ? "_affine_" : "_noaffine_"; + hash_name += std::to_string(quant_mode_int) + "_"; + hash_name += std::to_string(bits) + "_"; + hash_name += std::to_string(group_size) + "_"; hash_name += std::to_string(blocks); auto& compute_encoder = d.get_command_encoder(s.index); @@ -734,7 +734,6 @@ void quant_sdpa_vector_2pass( compute_encoder.set_output_array(intermediate, 5); compute_encoder.set_output_array(sums, 6); compute_encoder.set_output_array(maxs, 7); - compute_encoder.set_bytes(gqa_factor, 8); compute_encoder.set_bytes(N, 9); compute_encoder.set_bytes(k_stride, 10); compute_encoder.set_bytes(v_stride, 11); @@ -782,7 +781,7 @@ void quant_sdpa_vector_2pass( compute_encoder.set_output_array(out, 3); group_dims = MTL::Size(1024, 1, 1); - grid_dims = MTL::Size(B, q.shape(2), 1); + grid_dims = MTL::Size(q.shape(0) * q.shape(1), q.shape(2), 1); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } diff --git a/mlx/fast.cpp b/mlx/fast.cpp index d86862e5f7..f971095c53 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -1147,10 +1147,12 @@ array quantized_scaled_dot_product_attention( bool supported_type = (queries.dtype() == float32) || (queries.dtype() == float16) || (queries.dtype() == bfloat16); + int gqa_factor = queries.shape(1) / keys.shape(1); bool unsupported = detail::in_grad_tracing() || stream.device == Device::cpu || queries.shape(2) > 8 || (queries.shape(2) > keys.shape(2)) || - !(queries.shape(-1) == 64 || queries.shape(-1) == 128) || !supported_type; + !(queries.shape(-1) == 64 || queries.shape(-1) == 128) || + !supported_type || (queries.shape(2) * gqa_factor > 32); if (unsupported) { return fallback(std::move(inputs))[0]; From 4c9287f3c989705e48737f8d5eda1c1064198418 Mon Sep 17 00:00:00 2001 From: Dan Yeh Date: Sat, 7 Feb 2026 21:11:56 +0800 Subject: [PATCH 12/21] tune blocks --- .../metal/scaled_dot_product_attention.cpp | 132 +++++++++--------- mlx/fast.cpp | 45 ++++-- 2 files changed, 102 insertions(+), 75 deletions(-) diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index e9ba12090e..b494fe3483 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -15,6 +15,68 @@ namespace mlx::core::fast { namespace { +// Select block count for vector 2-pass attention kernels. +int select_sdpa_blocks( + char devc, + int N, + int n_simds, + int head_dim, + [[maybe_unused]] bool quantized) { + if (devc == 's') { + int blocks = 64; + if (N > 1024 && n_simds > 4) { + if (N <= 8192) { + blocks = 128; + } else if (N <= 32768) { + blocks = 256; + } else if (N <= 65536) { + blocks = 512; + } else { + blocks = 1024; + } + } + return blocks; + } + + if (devc == 'd') { + int blocks = 128; + if (n_simds <= 2 && N > 8192) { + blocks = 256; + } else if (n_simds >= 6) { + if (N >= 16384 && N < 65536) { + blocks = 512; + } else if (N >= 65536) { + blocks = 1024; + } + } + return blocks; + } + + if (devc == 'g' || devc == 'p') { + if (n_simds <= 1) { + if (N <= 2048) { + return 32; + } else if (N <= 8192) { + return 64; + } else { + return 128; + } + } + if (head_dim >= 128) { + return 32; + } + if (N <= 8192) { + return 32; + } else if (N <= 32768) { + return 64; + } else { + return 128; + } + } + + return (n_simds >= 4) ? 64 : 32; +} + void sdpa_full_self_attention_nax( const Stream& s, metal::Device& d, @@ -442,39 +504,8 @@ void sdpa_vector_2pass( char devc = d.get_architecture().back(); int N = k.shape(2); - int blocks; - - if (devc == 's') { - blocks = 64; - if (N > 1024 && n_simds > 4) { - if (N <= 8192) { - blocks = 128; - } else if (N <= 32768) { - blocks = 256; - } else if (N <= 65536) { - blocks = 512; - } else { - blocks = 1024; - } - } - } else if (devc == 'd') { - blocks = 128; - if (n_simds <= 2 && N > 8192) { - blocks = 256; - } else if (n_simds >= 6) { - if (N >= 16384 && N < 65536) { - blocks = 512; - } else if (N >= 65536) { - blocks = 1024; - } - } - } else { - if (n_simds >= 4) { - blocks = 64; - } else { - blocks = 32; - } - } + int blocks = + select_sdpa_blocks(devc, N, n_simds, q.shape(-1), /*quantized=*/false); size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1); size_t k_seq_stride = k.strides()[2]; @@ -629,40 +660,9 @@ void quant_sdpa_vector_2pass( int gqa_factor = q.shape(1) / k.shape(1); int n_simds = gqa_factor * q.shape(2); - // TODO: tune block sizes for different devices char devc = d.get_architecture().back(); - int blocks; - if (devc == 's') { - blocks = 64; - if (N > 1024 && n_simds > 4) { - if (N <= 8192) { - blocks = 128; - } else if (N <= 32768) { - blocks = 256; - } else if (N <= 65536) { - blocks = 512; - } else { - blocks = 1024; - } - } - } else if (devc == 'd') { - blocks = 128; - if (n_simds <= 2 && N > 8192) { - blocks = 256; - } else if (n_simds >= 6) { - if (N >= 16384 && N < 65536) { - blocks = 512; - } else if (N >= 65536) { - blocks = 1024; - } - } - } else { - if (n_simds >= 4) { - blocks = 64; - } else { - blocks = 32; - } - } + int blocks = + select_sdpa_blocks(devc, N, n_simds, q.shape(-1), /*quantized=*/true); // Head strides for quantized data (in uint32 units) and scales size_t k_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1); diff --git a/mlx/fast.cpp b/mlx/fast.cpp index f971095c53..2532cf04bd 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -943,6 +943,15 @@ array quantized_scaled_dot_product_attention( << final_type << "."; throw std::invalid_argument(msg.str()); } + if (!(final_type == float16 || final_type == bfloat16 || + final_type == float32)) { + std::ostringstream msg; + msg << "[" << tag + << "] queries must be float16, bfloat16, or float32 for quantized " + "attention; received " + << final_type << "."; + throw std::invalid_argument(msg.str()); + } if (keys.dtype() != uint32 || values.dtype() != uint32) { throw std::invalid_argument( "[quantized_scaled_dot_product_attention] Keys and values must be " @@ -1114,6 +1123,30 @@ array quantized_scaled_dot_product_attention( }; auto stream = to_stream(s); + Shape full_mask_shape{ + queries.shape(0), queries.shape(1), queries.shape(2), keys.shape(-2)}; + + auto normalize_mask = [&](const array& raw_mask) { + array m = raw_mask; + switch (m.ndim()) { + case 1: // [K] + m = reshape(m, {1, 1, 1, m.shape(0)}, stream); + break; + case 2: // [B, K] + m = reshape(m, {m.shape(0), 1, 1, m.shape(1)}, stream); + break; + case 3: // [B, L_q, K] + m = reshape(m, {m.shape(0), 1, m.shape(1), m.shape(2)}, stream); + break; + case 4: // [B, H, L_q, K] + break; + default: + throw std::invalid_argument( + "[quantized_scaled_dot_product_attention] Mask rank must be <= 4."); + } + return broadcast_to(m, full_mask_shape, stream); + }; + std::vector inputs = {q, keys, key_scales}; if (is_affine) { inputs.push_back(*key_biases); @@ -1131,14 +1164,9 @@ array quantized_scaled_dot_product_attention( << final_type << "."; throw std::invalid_argument(msg.str()); } - if (!has_bool_mask && mask->dtype() != final_type) { - inputs.push_back(astype(*mask, final_type, stream)); - } else { - inputs.push_back(*mask); - } - auto mask_shape = queries.shape(); - mask_shape.back() = keys.shape(-2); - inputs.back() = broadcast_to(inputs.back(), mask_shape, stream); + array normalized = + has_bool_mask ? *mask : astype(*mask, final_type, stream); + inputs.push_back(normalize_mask(normalized)); } int out_dim = value_head_dim; @@ -1150,7 +1178,6 @@ array quantized_scaled_dot_product_attention( int gqa_factor = queries.shape(1) / keys.shape(1); bool unsupported = detail::in_grad_tracing() || stream.device == Device::cpu || queries.shape(2) > 8 || - (queries.shape(2) > keys.shape(2)) || !(queries.shape(-1) == 64 || queries.shape(-1) == 128) || !supported_type || (queries.shape(2) * gqa_factor > 32); From 2433b68c12b5bfa63f515bbd4bc29035c8224996 Mon Sep 17 00:00:00 2001 From: Dan Yeh Date: Sat, 7 Feb 2026 21:23:50 +0800 Subject: [PATCH 13/21] fix --- mlx/backend/metal/scaled_dot_product_attention.cpp | 1 + python/tests/test_quantized.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index b494fe3483..41eef25a9e 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -779,6 +779,7 @@ void quant_sdpa_vector_2pass( compute_encoder.set_input_array(sums, 1); compute_encoder.set_input_array(maxs, 2); compute_encoder.set_output_array(out, 3); + compute_encoder.set_bytes(blocks, 4); group_dims = MTL::Size(1024, 1, 1); grid_dims = MTL::Size(q.shape(0) * q.shape(1), q.shape(2), 1); diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 9e3859772d..de804e22d4 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -694,7 +694,6 @@ def test_non_multiples(self): self.assertEqual(y_q.shape, y_hat.shape) self.assertLess((y_q - y_hat).abs().max(), 1e-3) -<<<<<<< HEAD def test_qmv_small_non_multiples(self): # Test very small K and N dimensions (e.g., [MxK] x [NxK].T = [MxN]) # Each tuple is (M, K, N) representing input rows, weight cols, weight rows From 3856b0d700eddb233cc16ffa145b03597004cde5 Mon Sep 17 00:00:00 2001 From: Dan Yeh Date: Sat, 7 Feb 2026 22:58:08 +0800 Subject: [PATCH 14/21] cleanup --- benchmarks/python/sdpa_vector_bench.py | 243 ++++++------------ mlx/backend/metal/kernels/fp_quantized_nax.h | 2 - .../metal/kernels/fp_quantized_nax.metal | 2 + mlx/backend/metal/kernels/quantized_utils.h | 2 - .../scaled_dot_product_attention.metal | 2 - mlx/backend/metal/kernels/sdpa_vector.h | 17 +- mlx/fast.cpp | 10 +- 7 files changed, 90 insertions(+), 188 deletions(-) diff --git a/benchmarks/python/sdpa_vector_bench.py b/benchmarks/python/sdpa_vector_bench.py index 4904519510..546bff84c2 100644 --- a/benchmarks/python/sdpa_vector_bench.py +++ b/benchmarks/python/sdpa_vector_bench.py @@ -1,180 +1,95 @@ import argparse -import time +import math import mlx.core as mx +from time_utils import time_fn + +L = 16384 +H = 32 +H_k = H // 4 +D = 128 +V = 128 +dtype = mx.float16 +loops = 10 + + +def upproject(x, w): + if w is None: + return x + else: + return x @ w.T + + +def attention(q, k, v, mask=None, w=None): + def _sdpa(q, k, v): + B, Hq, L, D = q.shape + _, Hk, S, _ = k.shape + _, _, _, V = v.shape + q = q.reshape(B, Hk, Hq // Hk, L, D) + k = k[:, :, None, :, :] + v = v[:, :, None, :, :] + s = q @ k.transpose(0, 1, 2, 4, 3) + if mask is not None: + m = mx.broadcast_to(mask, (B, Hq, L, S)).reshape(B, Hk, Hq // Hk, L, S) + s = mx.where(m, s, mx.finfo(s.dtype).min) + p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype) + o = p @ v + return o.reshape(B, Hq, L, V) + + for i in range(loops): + q = _sdpa(q, k, v) + q = upproject(q, w) + return q -def time_fn(fn, *args, warmup=5, iters=100, **kwargs): - """Time a function, return milliseconds per call.""" - for _ in range(warmup): - mx.eval(fn(*args, **kwargs)) +def sdpa(q, k, v, mask=None, w=None): + for i in range(loops): + q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask) + q = upproject(q, w) + return q - tic = time.perf_counter() - for _ in range(iters): - mx.eval(fn(*args, **kwargs)) - toc = time.perf_counter() - return 1e3 * (toc - tic) / iters +def time_self_attention_primitives(): + mx.random.seed(3) + q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) + k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) + v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype) + w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None + mx.eval(q, k, v, w) + time_fn(attention, q, k, v, w=w) -def quant_sdpa(q, k, v, bits, mode, group_size=None, loops=1): - for _ in range(loops): - q = mx.fast.quantized_scaled_dot_product_attention( - q, *k, *v, scale=1.0, mask=None, bits=bits, mode=mode, group_size=group_size - ) - return q +def time_self_attention_sdpa(): + mx.random.seed(3) + q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) + k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) + v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype) + w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None + mx.eval(q, k, v, w) + time_fn(sdpa, q, k, v, w=w) -def sdpa(q, k, v, loops=1): - for _ in range(loops): - q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None) - return q +def time_self_attention_sdpa_with_mask(): + mx.random.seed(3) + q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) + k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) + v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype) + w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None + mask = mx.full((L,), True) + mask[L // 2 :] = False + mx.eval(q, k, v, mask, w) + + def sdpa_mask(*args): + return sdpa(*args, mask=mask, w=w) + def attention_mask(*args): + return attention(*args, mask=mask, w=w) -def run_benchmark( - seq_lengths, - modes, - H=32, - H_k=8, - D=128, - dtype=mx.float16, - loops=20, - warmup=5, - iters=100, -): - """Run benchmarks across sequence lengths and modes.""" - results = {} - - print(f"\n{'=' * 70}") - print(f"Quant SDPA Benchmark: H={H}, H_k={H_k}, D={D}, GQA={H // H_k}x") - print(f"{'=' * 70}") - - # Header - header = f"{'SeqLen':>8}" - for mode, bits in modes: - header += f" | {mode}({bits}b):ms" - header += " | fp16:ms" - print(header) - print("-" * len(header)) - - for L in seq_lengths: - mx.random.seed(42) - q = mx.random.uniform(shape=(1, H, 1, D), dtype=dtype) - k = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype) - v = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype) - mx.eval(q, k, v) - - row = f"{L:>8}" - results[L] = {} - - # Benchmark each quant mode - for mode, bits in modes: - gs = 32 if mode == "affine" else None - k_quant = mx.quantize( - k, bits=bits, mode=mode, **({"group_size": gs} if gs else {}) - ) - v_quant = mx.quantize( - v, bits=bits, mode=mode, **({"group_size": gs} if gs else {}) - ) - mx.eval(k_quant, v_quant) - - ms = time_fn( - quant_sdpa, - q, - k_quant, - v_quant, - bits, - mode, - group_size=gs, - loops=loops, - warmup=warmup, - iters=iters, - ) - ms_per_call = ms / loops - results[L][(mode, bits)] = ms_per_call - row += f" | {ms_per_call:8.4f}" - - # Benchmark fp16 baseline - ms = time_fn(sdpa, q, k, v, loops=loops, warmup=warmup, iters=iters) - ms_per_call = ms / loops - results[L]["fp16"] = ms_per_call - row += f" | {ms_per_call:8.4f}" - - print(row) - - return results - - -def print_speedup_table(results, modes): - """Print speedup vs fp16 baseline.""" - print(f"\n{'=' * 60}") - print("Speedup vs fp16") - print(f"{'=' * 60}") - - header = f"{'SeqLen':>8}" - for mode, bits in modes: - header += f" | {mode}({bits}b)" - print(header) - print("-" * len(header)) - - for L, data in results.items(): - fp16_ms = data["fp16"] - row = f"{L:>8}" - for mode, bits in modes: - quant_ms = data[(mode, bits)] - speedup = fp16_ms / quant_ms - row += f" | {speedup:5.2f}x" - print(row) - - -def main(): - parser = argparse.ArgumentParser(description="Benchmark Quant SDPA") - parser.add_argument("--heads", type=int, default=32, help="Number of query heads") - parser.add_argument("--kv-heads", type=int, default=8, help="Number of KV heads") - parser.add_argument("--dim", type=int, default=128, help="Head dimension") - parser.add_argument("--loops", type=int, default=20, help="Loops per timing call") - parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations") - parser.add_argument("--iters", type=int, default=100, help="Timing iterations") - parser.add_argument( - "--seq-lengths", - type=int, - nargs="+", - default=[2048, 4096, 8192, 16384, 32768, 65536, 131072], - help="Sequence lengths to test", - ) - parser.add_argument( - "--modes", - type=str, - nargs="+", - default=["mxfp4", "mxfp8"], - help="Quantization modes to test", - ) - args = parser.parse_args() - - # Map mode names to (mode, bits) - mode_map = { - "mxfp4": ("mxfp4", 4), - "mxfp8": ("mxfp8", 8), - "affine4": ("affine", 4), - "affine8": ("affine", 8), - "nvfp4": ("nvfp4", 4), - } - - modes = [mode_map[m] for m in args.modes if m in mode_map] - - results = run_benchmark( - seq_lengths=args.seq_lengths, - modes=modes, - H=args.heads, - H_k=args.kv_heads, - D=args.dim, - loops=args.loops, - warmup=args.warmup, - iters=args.iters, - ) - - print_speedup_table(results, modes) + time_fn(attention_mask, q, k, v) + time_fn(sdpa_mask, q, k, v) if __name__ == "__main__": - main() + time_self_attention_sdpa() + time_self_attention_primitives() + time_self_attention_sdpa_with_mask() diff --git a/mlx/backend/metal/kernels/fp_quantized_nax.h b/mlx/backend/metal/kernels/fp_quantized_nax.h index 6a660eb760..381bc6c7d3 100644 --- a/mlx/backend/metal/kernels/fp_quantized_nax.h +++ b/mlx/backend/metal/kernels/fp_quantized_nax.h @@ -5,8 +5,6 @@ #include "mlx/backend/metal/kernels/fp4.h" #include "mlx/backend/metal/kernels/fp8.h" -#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" -#include "mlx/backend/metal/kernels/steel/utils.h" constant bool align_M [[function_constant(200)]]; constant bool align_N [[function_constant(201)]]; diff --git a/mlx/backend/metal/kernels/fp_quantized_nax.metal b/mlx/backend/metal/kernels/fp_quantized_nax.metal index e96c508f7b..4d65a384d3 100644 --- a/mlx/backend/metal/kernels/fp_quantized_nax.metal +++ b/mlx/backend/metal/kernels/fp_quantized_nax.metal @@ -2,6 +2,8 @@ // clang-format off #include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" +#include "mlx/backend/metal/kernels/quantized_utils.h" #include "mlx/backend/metal/kernels/steel/gemm/nax.h" #include "mlx/backend/metal/kernels/fp_quantized_nax.h" diff --git a/mlx/backend/metal/kernels/quantized_utils.h b/mlx/backend/metal/kernels/quantized_utils.h index f517a68c3f..86aa8b75d0 100644 --- a/mlx/backend/metal/kernels/quantized_utils.h +++ b/mlx/backend/metal/kernels/quantized_utils.h @@ -146,8 +146,6 @@ struct PackReader { // Pointer wrapper for quantized data that handles byte-level addressing // correctly for all bit widths. For non-4-byte-aligned packs (3, 5, 6-bit), -// simple uint32_t pointer arithmetic truncates and causes misalignment. -// This class uses byte-level arithmetic internally to ensure correctness. template class QuantDataPtr { const device uint8_t* byte_ptr_; diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index e71e4395aa..45d7b1a7e4 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -42,8 +42,6 @@ instantiate_sdpa_vector_heads(float) instantiate_sdpa_vector_heads(bfloat16_t) instantiate_sdpa_vector_heads(float16_t) -// Quantized SDPA vector instantiations -// mode/bits/group_size are now function constants (indices 28-30) #define instantiate_quant_sdpa_vector(type, head_dim) \ instantiate_kernel( \ "quant_sdpa_vector_2pass_1_" #type "_" #head_dim, \ diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index aa0bb65bee..5c9f58db06 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -179,13 +179,6 @@ template } } -template -METAL_FUNC void load_queries(const device T* queries, thread U* q, U scale) { - for (int i = 0; i < elem_per_thread; i++) { - q[i] = scale * queries[i]; - } -} - constant bool has_affine_bias [[function_constant(27)]]; constant int quant_mode_int [[function_constant(28)]]; constant int quant_bits [[function_constant(29)]]; @@ -296,7 +289,6 @@ struct QuantOps { return score; } - // ACCUMULATE template [[clang::always_inline]] static void accumulate( thread U* o, @@ -370,11 +362,12 @@ struct QuantOps { } } }; + template using ScaleTypeT = typename QuantConfig::template scale_storage_t; template -METAL_FUNC void quant_sdpa_inner( +METAL_FUNC void quant_sdpa_vector_2pass_1_impl( const device T* queries, const device uint32_t* keys, const device uint8_t* key_scales_raw, @@ -406,8 +399,7 @@ METAL_FUNC void quant_sdpa_inner( // elem_per_thread=D/4 is large enough for all pack_factors (max 8). // // GQA: multiple query heads sharing the same KV head are packed into the - // same threadgroup (along with q_seq_len). This lets them share L2 cache - // for KV data. + // same threadgroup (along with q_seq_len) to share L2 cache for KV data. // Grid: (num_kv_heads, batch, blocks) // Group: (32, gqa_factor, q_seq_len) using Cfg = QuantConfig; @@ -579,7 +571,6 @@ METAL_FUNC void quant_sdpa_inner( val += simd_shuffle_xor(val, 8); // sum quads 0-3, 4-7 val += simd_shuffle_xor(val, 16); // sum quads 0-7 // All lanes with same local_quad_lid now have the full sum; - // local_quad_gid=0 writes if (local_quad_gid == 0) { out[i] = static_cast(val); } @@ -622,7 +613,7 @@ template #define QUANT_SDPA_DISPATCH(MODE, GS, B) \ if (quant_mode_int == int(QuantMode::MODE) && quant_group_size == GS && \ quant_bits == B) { \ - quant_sdpa_inner( \ + quant_sdpa_vector_2pass_1_impl( \ queries, \ keys, \ key_scales, \ diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 2532cf04bd..5d14934843 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -1003,17 +1003,17 @@ array quantized_scaled_dot_product_attention( // Validate scale/bias shapes auto expected_scale_dim = queries.shape(-1) / group_size; - for (const auto& [qdata, sc, bias, name] : + for (const auto& [qdata, scale, bias, name] : {std::tuple{&keys, &key_scales, &key_biases, "key"}, std::tuple{&values, &value_scales, &value_biases, "value"}}) { - if (sc->shape(-1) != expected_scale_dim || - sc->shape(-3) != qdata->shape(-3) || - sc->shape(-2) != qdata->shape(-2)) { + if (scale->shape(-1) != expected_scale_dim || + scale->shape(-3) != qdata->shape(-3) || + scale->shape(-2) != qdata->shape(-2)) { std::ostringstream msg; msg << "[" << tag << "] " << name << " scale shape mismatch."; throw std::invalid_argument(msg.str()); } - if (is_affine && bias->has_value() && (*bias)->shape() != sc->shape()) { + if (is_affine && bias->has_value() && (*bias)->shape() != scale->shape()) { std::ostringstream msg; msg << "[" << tag << "] " << name << " bias shape must match scale shape."; From c550db985cddc07b212cf5c1948a026a21b526fb Mon Sep 17 00:00:00 2001 From: Dan Yeh Date: Sun, 8 Feb 2026 00:16:59 +0800 Subject: [PATCH 15/21] cleanup + refactor --- .../metal/scaled_dot_product_attention.cpp | 23 ++++ mlx/backend/no_gpu/primitives.cpp | 9 ++ mlx/fast.cpp | 87 +++++--------- mlx/fast_primitives.h | 3 + python/tests/test_quantized.py | 109 ++++++++++++++++++ 5 files changed, 176 insertions(+), 55 deletions(-) diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 41eef25a9e..f1eff133dc 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -843,6 +843,29 @@ bool ScaledDotProductAttention::supports_bool_mask() { return true; } +bool QuantizedScaledDotProductAttention::use_fallback( + const array& q, + const array& k, + bool is_training, + Stream s) { + if (is_training || s.device == Device::cpu) { + return true; + } + + bool supported_type = (q.dtype() == float32) || (q.dtype() == float16) || + (q.dtype() == bfloat16); + if (!supported_type) { + return true; + } + + int query_sequence_length = q.shape(2); + int query_head_dim = q.shape(-1); + int gqa_factor = q.shape(1) / k.shape(1); + return query_sequence_length > 8 || + !(query_head_dim == 64 || query_head_dim == 128) || + (query_sequence_length * gqa_factor > 32); +} + void ScaledDotProductAttention::eval_gpu( const std::vector& inputs, std::vector& outputs) { diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index 4819ed2724..085f8331d7 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -40,6 +40,14 @@ bool fast::ScaledDotProductAttention::supports_bool_mask() { return false; } +bool fast::QuantizedScaledDotProductAttention::use_fallback( + const array&, + const array&, + bool, + Stream) { + return true; +} + bool fast::ScaledDotProductAttentionVJP::use_fallback( const array& q, Stream s) { @@ -168,6 +176,7 @@ NO_GPU_USE_FALLBACK(RMSNorm) NO_GPU_MULTI(RMSNormVJP) NO_GPU_USE_FALLBACK(RoPE) NO_GPU_MULTI(ScaledDotProductAttention) +NO_GPU_MULTI(QuantizedScaledDotProductAttention) NO_GPU_MULTI(ScaledDotProductAttentionVJP) NO_GPU_MULTI(ConvertFP8) NO_GPU_MULTI(Quantize) diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 5d14934843..f82c848c4f 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -1,6 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include #include +#include #include "mlx/fast.h" #include "mlx/fast_primitives.h" @@ -610,6 +611,23 @@ bool RoPE::is_equivalent(const Primitive& other) const { forward_ == a_other.forward_); } +std::pair prepare_sdpa_array_mask( + const array& mask, + Dtype out_type, + const Shape& full_mask_shape, + std::string_view tag, + Stream s) { + bool has_bool_mask = mask.dtype() == bool_; + if (!has_bool_mask && promote_types(mask.dtype(), out_type) != out_type) { + std::ostringstream msg; + msg << "[" << tag << "] Mask type must promote to output type " << out_type + << "."; + throw std::invalid_argument(msg.str()); + } + auto prepared_mask = has_bool_mask ? mask : astype(mask, out_type, s); + return {broadcast_to(prepared_mask, full_mask_shape, s), has_bool_mask}; +} + /** Computes: O = softmax(Q @ K.T) @ V **/ array scaled_dot_product_attention( const array& queries, @@ -654,7 +672,6 @@ array scaled_dot_product_attention( } else if (mask_arr) { has_mask = true; has_arr_mask = true; - has_bool_mask = mask_arr->dtype() == bool_; } if (has_arr_mask && mask_arr->ndim() > 4) { @@ -792,20 +809,16 @@ array scaled_dot_product_attention( auto stream = to_stream(s); std::vector inputs = {q, k, v}; if (has_arr_mask) { - // Check type - has_bool_mask = mask_arr->dtype() == bool_; - if (promote_types(mask_arr->dtype(), final_type) != final_type) { - std::ostringstream msg; - msg << "[scaled_dot_product_attention] Mask type must promote to output type " - << final_type << "."; - throw std::invalid_argument(msg.str()); - } else if (!has_bool_mask) { - mask_arr = astype(*mask_arr, final_type, stream); - } - // Broadcast mask auto mask_shape = queries.shape(); mask_shape.back() = keys.shape(-2); - inputs.push_back(broadcast_to(*mask_arr, mask_shape, stream)); + auto [prepared_mask, prepared_bool_mask] = prepare_sdpa_array_mask( + *mask_arr, + final_type, + mask_shape, + "scaled_dot_product_attention", + stream); + has_bool_mask = prepared_bool_mask; + inputs.push_back(std::move(prepared_mask)); } if (has_sinks) { if (promote_types(sinks->dtype(), final_type) != final_type) { @@ -1023,7 +1036,6 @@ array quantized_scaled_dot_product_attention( // Validate mask bool needs_mask = mask.has_value(); - bool has_bool_mask = needs_mask && mask->dtype() == bool_; if (needs_mask && mask->ndim() > 4) { std::ostringstream msg; msg << "[" << tag << "] Mask with shape " << mask->shape() @@ -1117,7 +1129,7 @@ array quantized_scaled_dot_product_attention( mode, s); if (n_repeats > 1) { - out = reshape(out, {out.shape(0), n_q_heads, out.shape(2), -1}, s); + out = flatten(out, 1, 2, s); } return std::vector{out}; }; @@ -1126,27 +1138,6 @@ array quantized_scaled_dot_product_attention( Shape full_mask_shape{ queries.shape(0), queries.shape(1), queries.shape(2), keys.shape(-2)}; - auto normalize_mask = [&](const array& raw_mask) { - array m = raw_mask; - switch (m.ndim()) { - case 1: // [K] - m = reshape(m, {1, 1, 1, m.shape(0)}, stream); - break; - case 2: // [B, K] - m = reshape(m, {m.shape(0), 1, 1, m.shape(1)}, stream); - break; - case 3: // [B, L_q, K] - m = reshape(m, {m.shape(0), 1, m.shape(1), m.shape(2)}, stream); - break; - case 4: // [B, H, L_q, K] - break; - default: - throw std::invalid_argument( - "[quantized_scaled_dot_product_attention] Mask rank must be <= 4."); - } - return broadcast_to(m, full_mask_shape, stream); - }; - std::vector inputs = {q, keys, key_scales}; if (is_affine) { inputs.push_back(*key_biases); @@ -1157,31 +1148,17 @@ array quantized_scaled_dot_product_attention( inputs.push_back(*value_biases); } if (needs_mask) { - if (promote_types(mask->dtype(), final_type) != final_type && - mask->dtype() != bool_) { - std::ostringstream msg; - msg << "[quantized_scaled_dot_product_attention] Mask type must promote to output type " - << final_type << "."; - throw std::invalid_argument(msg.str()); - } - array normalized = - has_bool_mask ? *mask : astype(*mask, final_type, stream); - inputs.push_back(normalize_mask(normalized)); + auto prepared_mask = prepare_sdpa_array_mask( + *mask, final_type, full_mask_shape, tag, stream); + inputs.push_back(std::move(prepared_mask.first)); } int out_dim = value_head_dim; Shape out_shape{ queries.shape(0), queries.shape(1), queries.shape(2), out_dim}; - bool supported_type = (queries.dtype() == float32) || - (queries.dtype() == float16) || (queries.dtype() == bfloat16); - int gqa_factor = queries.shape(1) / keys.shape(1); - bool unsupported = detail::in_grad_tracing() || - stream.device == Device::cpu || queries.shape(2) > 8 || - !(queries.shape(-1) == 64 || queries.shape(-1) == 128) || - !supported_type || (queries.shape(2) * gqa_factor > 32); - - if (unsupported) { + if (QuantizedScaledDotProductAttention::use_fallback( + q, keys, detail::in_grad_tracing(), stream)) { return fallback(std::move(inputs))[0]; } diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 4ef110c855..18ef172348 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -284,6 +284,9 @@ class QuantizedScaledDotProductAttention : public Custom { void eval_gpu(const std::vector& inputs, std::vector& outputs) override; + static bool + use_fallback(const array& q, const array& k, bool is_training, Stream s); + bool is_equivalent(const Primitive& other) const override; DEFINE_NAME(QuantizedScaledDotProductAttention); diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index de804e22d4..8969f34dcb 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -832,6 +832,115 @@ def test_quantized_sdpa_affine(self): tol = 2e-2 self.assertLess((out - ref).abs().max(), tol) + def test_quantized_sdpa_masked(self): + if mx.default_device() == mx.cpu: + self.skipTest("Quantized fast attention is only available on GPU.") + + mx.random.seed(0) + B, Hq, Hkv = 1, 2, 1 + Lk, D = 640, 128 + + for mode in ["mxfp4", "mxfp8", "nvfp4"]: + bits = 8 if mode == "mxfp8" else 4 + for Lq in [4, 9]: + q = 0.1 * mx.random.normal(shape=(B, Hq, Lq, D)) + k = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + v = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + + bool_mask = mx.random.uniform(shape=(Lq, Lk)) > 0.2 + additive_mask = mx.where( + bool_mask, + mx.zeros((Lq, Lk), dtype=mx.float32), + mx.full((Lq, Lk), -1e9, dtype=mx.float32), + ) + + k_q, k_scales = mx.quantize(k, mode=mode) + v_q, v_scales = mx.quantize(v, mode=mode) + + for mask_name, mask in { + "bool": bool_mask, + "additive": additive_mask, + }.items(): + with self.subTest(mode=mode, bits=bits, Lq=Lq, mask=mask_name): + ref = mx.fast.scaled_dot_product_attention( + q, k, v, scale=1.0, mask=mask + ) + out = mx.fast.quantized_scaled_dot_product_attention( + q, + k_q, + k_scales, + v_q, + v_scales, + scale=1.0, + mode=mode, + bits=bits, + mask=mask, + ) + + self.assertEqual(out.shape, ref.shape) + tol = 5e-2 if bits == 4 else 2e-2 + self.assertLess((out - ref).abs().max(), tol) + + def test_quantized_sdpa_affine_masked(self): + if mx.default_device() == mx.cpu: + self.skipTest("Quantized fast attention is only available on GPU.") + + mx.random.seed(0) + B, Hq, Hkv = 1, 2, 1 + Lk, D = 640, 128 + + for bits in [4, 6, 8]: + for Lq in [4, 9]: + q = 0.1 * mx.random.normal(shape=(B, Hq, Lq, D)) + k = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + v = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + + bool_mask = mx.random.uniform(shape=(Lq, Lk)) > 0.2 + additive_mask = mx.where( + bool_mask, + mx.zeros((Lq, Lk), dtype=mx.float32), + mx.full((Lq, Lk), -1e9, dtype=mx.float32), + ) + + k_q, k_scales, k_biases = mx.quantize( + k, group_size=32, bits=bits, mode="affine" + ) + v_q, v_scales, v_biases = mx.quantize( + v, group_size=32, bits=bits, mode="affine" + ) + + for mask_name, mask in { + "bool": bool_mask, + "additive": additive_mask, + }.items(): + with self.subTest(bits=bits, Lq=Lq, mask=mask_name): + ref = mx.fast.scaled_dot_product_attention( + q, k, v, scale=1.0, mask=mask + ) + out = mx.fast.quantized_scaled_dot_product_attention( + q, + k_q, + k_scales, + k_biases, + v_q, + v_scales, + v_biases, + scale=1.0, + mode="affine", + group_size=32, + bits=bits, + mask=mask, + ) + + self.assertEqual(out.shape, ref.shape) + if bits == 6: + tol = 1e-1 + elif bits == 4: + tol = 5e-2 + else: + tol = 2e-2 + self.assertLess((out - ref).abs().max(), tol) + def test_gather_qmm(self): def quantize(w, transpose=True, group_size=None, bits=None, mode="affine"): if mode == "affine": From 8a764fdd3b415dc74427fa03d390479ea0e936b7 Mon Sep 17 00:00:00 2001 From: Dan Yeh Date: Sun, 8 Feb 2026 00:26:48 +0800 Subject: [PATCH 16/21] enable causal --- .../metal/scaled_dot_product_attention.cpp | 4 +- mlx/fast.cpp | 52 ++++++-- mlx/fast.h | 1 + mlx/fast_primitives.h | 11 +- python/src/fast.cpp | 10 +- python/tests/test_quantized.py | 115 ++++++++++++++++++ 6 files changed, 173 insertions(+), 20 deletions(-) diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index f1eff133dc..db82a01762 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -1093,7 +1093,7 @@ void QuantizedScaledDotProductAttention::eval_gpu( } std::optional mask = std::nullopt; - if (needs_mask_) { + if (has_arr_mask_) { auto mask_copy_unless = [&q](const array& arr) { auto& strides = arr.strides(); auto& shape = arr.shape(); @@ -1123,7 +1123,7 @@ void QuantizedScaledDotProductAttention::eval_gpu( scale_, group_size_, bits_, - /* do_causal = */ false, + do_causal_, mask, mode_); diff --git a/mlx/fast.cpp b/mlx/fast.cpp index f82c848c4f..ebe08e4d00 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -888,6 +888,7 @@ array quantized_scaled_dot_product_attention( const std::optional group_size_ /* = std::nullopt */, const std::optional bits_ /* = std::nullopt */, const std::string& mode /* = "mxfp4" */, + bool causal /* = false */, StreamOrDevice s /* = {} */) { constexpr const char* tag = "quantized_scaled_dot_product_attention"; @@ -1035,8 +1036,14 @@ array quantized_scaled_dot_product_attention( } // Validate mask - bool needs_mask = mask.has_value(); - if (needs_mask && mask->ndim() > 4) { + bool do_causal = causal; + bool has_arr_mask = mask.has_value(); + if (do_causal && has_arr_mask) { + throw std::invalid_argument( + "[quantized_scaled_dot_product_attention] Received both causal=true " + "and an array mask. Please provide only one mask type."); + } + if (has_arr_mask && mask->ndim() > 4) { std::ostringstream msg; msg << "[" << tag << "] Mask with shape " << mask->shape() << " expected to have at most rank 4."; @@ -1050,7 +1057,8 @@ array quantized_scaled_dot_product_attention( auto fallback = [scale, n_q_heads, n_kv_heads, - needs_mask, + do_causal, + has_arr_mask, is_affine, group_size, bits, @@ -1073,8 +1081,8 @@ array quantized_scaled_dot_product_attention( v_biases = inputs[idx++]; } - std::optional mask = - needs_mask ? std::optional{inputs[idx]} : std::nullopt; + std::optional arr_mask = + has_arr_mask ? std::optional{inputs[idx]} : std::nullopt; if (n_repeats > 1) { q = reshape(q, {q.shape(0), n_kv_heads, n_repeats, q.shape(2), -1}, s); @@ -1100,8 +1108,21 @@ array quantized_scaled_dot_product_attention( bits, mode, s); - if (mask) { - auto m = *mask; + if (has_arr_mask || do_causal) { + auto make_or_fetch_mask = [&]() { + if (do_causal) { + int kL = k.shape(-2); + int qL = q.shape(-2); + int offset = kL - qL; + auto q_idx = arange(offset, qL + offset, s); + auto k_idx = arange(0, kL, s); + q_idx = expand_dims(q_idx, 1, s); + k_idx = expand_dims(k_idx, 0, s); + return greater_equal(q_idx, k_idx, s); + } + return *arr_mask; + }; + auto m = make_or_fetch_mask(); if (n_repeats > 1 && m.ndim() >= 3) { if (m.shape(-3) == 1) { m = expand_dims(m, -3, s); @@ -1147,7 +1168,7 @@ array quantized_scaled_dot_product_attention( if (is_affine) { inputs.push_back(*value_biases); } - if (needs_mask) { + if (has_arr_mask) { auto prepared_mask = prepare_sdpa_array_mask( *mask, final_type, full_mask_shape, tag, stream); inputs.push_back(std::move(prepared_mask.first)); @@ -1163,7 +1184,14 @@ array quantized_scaled_dot_product_attention( } auto primitive = std::make_shared( - stream, fallback, scale, needs_mask, group_size, bits, qmode); + stream, + fallback, + scale, + has_arr_mask, + do_causal, + group_size, + bits, + qmode); return array(std::move(out_shape), final_type, primitive, std::move(inputs)); } @@ -1225,9 +1253,9 @@ bool QuantizedScaledDotProductAttention::is_equivalent( const Primitive& other) const { const QuantizedScaledDotProductAttention& a_other = static_cast(other); - return scale_ == a_other.scale_ && needs_mask_ == a_other.needs_mask_ && - group_size_ == a_other.group_size_ && bits_ == a_other.bits_ && - mode_ == a_other.mode_; + return scale_ == a_other.scale_ && has_arr_mask_ == a_other.has_arr_mask_ && + do_causal_ == a_other.do_causal_ && group_size_ == a_other.group_size_ && + bits_ == a_other.bits_ && mode_ == a_other.mode_; } bool ScaledDotProductAttentionVJP::is_equivalent(const Primitive& other) const { diff --git a/mlx/fast.h b/mlx/fast.h index 6d82a4208a..cbafd12973 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -68,6 +68,7 @@ MLX_API array quantized_scaled_dot_product_attention( std::optional group_size = std::nullopt, std::optional bits = std::nullopt, const std::string& mode = "mxfp4", + bool causal = false, StreamOrDevice s = {}); using TemplateArg = std::variant; diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 18ef172348..9517fbc620 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -266,13 +266,15 @@ class QuantizedScaledDotProductAttention : public Custom { Stream stream, std::function(std::vector)> fallback, float scale, - bool needs_mask, + bool has_arr_mask, + bool do_causal, int group_size, int bits, QuantizationMode mode) : Custom(stream, std::move(fallback)), scale_(scale), - needs_mask_(needs_mask), + has_arr_mask_(has_arr_mask), + do_causal_(do_causal), group_size_(group_size), bits_(bits), mode_(mode) {} @@ -293,12 +295,13 @@ class QuantizedScaledDotProductAttention : public Custom { DEFINE_INPUT_OUTPUT_SHAPE() auto state() const { return std::make_tuple( - nullptr, scale_, needs_mask_, group_size_, bits_, mode_); + nullptr, scale_, has_arr_mask_, do_causal_, group_size_, bits_, mode_); } private: float scale_; - bool needs_mask_; + bool has_arr_mask_; + bool do_causal_; int group_size_; int bits_; QuantizationMode mode_; diff --git a/python/src/fast.cpp b/python/src/fast.cpp index d30ee2cfd8..c9be77c7ba 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -310,6 +310,7 @@ void init_fast(nb::module_& parent_module) { std::optional group_size, std::optional bits, const std::string& mode, + bool causal, mx::StreamOrDevice s) { return mx::fast::quantized_scaled_dot_product_attention( q, @@ -324,6 +325,7 @@ void init_fast(nb::module_& parent_module) { group_size, bits, mode, + causal, s); }, "q"_a, @@ -339,9 +341,10 @@ void init_fast(nb::module_& parent_module) { "group_size"_a = nb::none(), "bits"_a = nb::none(), "mode"_a = "mxfp4", + "causal"_a = false, "stream"_a = nb::none(), nb::sig( - "def quantized_scaled_dot_product_attention(q: array, k: array, k_scales: array, v: array, v_scales: array, *, k_biases: Optional[array] = None, v_biases: Optional[array] = None, scale: float, mask: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = \"mxfp4\", stream: Union[None, Stream, Device] = None) -> array")); + "def quantized_scaled_dot_product_attention(q: array, k: array, k_scales: array, v: array, v_scales: array, *, k_biases: Optional[array] = None, v_biases: Optional[array] = None, scale: float, mask: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = \"mxfp4\", causal: bool = False, stream: Union[None, Stream, Device] = None) -> array")); m.def( "quantized_scaled_dot_product_attention", @@ -359,9 +362,10 @@ void init_fast(nb::module_& parent_module) { "group_size"_a = nb::none(), "bits"_a = nb::none(), "mode"_a = "mxfp4", + "causal"_a = false, "stream"_a = nb::none(), nb::sig( - "def quantized_scaled_dot_product_attention(q: array, k: array, k_scales: array, k_biases: Optional[array] = None, v: array, v_scales: array, v_biases: Optional[array] = None, *, scale: float, mask: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = \"mxfp4\", stream: Union[None, Stream, Device] = None) -> array"), + "def quantized_scaled_dot_product_attention(q: array, k: array, k_scales: array, k_biases: Optional[array] = None, v: array, v_scales: array, v_biases: Optional[array] = None, *, scale: float, mask: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = \"mxfp4\", causal: bool = False, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( A fast implementation of multi-head attention where the keys and values are quantized. @@ -380,6 +384,8 @@ void init_fast(nb::module_& parent_module) { group_size (int, optional): The group size used in the KV quantization. Defaults follow the quantization ``mode``. bits (int, optional): The bits used in the KV quantization. Defaults follow the quantization ``mode``. mode (str, optional): The quantization mode: ``"mxfp4"``, ``"mxfp8"``, ``"nvfp4"``, or ``"affine"``. + causal (bool, optional): Whether to apply lower-right aligned causal masking. + Cannot be used together with ``mask``. Returns: array: The output array. )pbdoc"); diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 8969f34dcb..81cafd1695 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -941,6 +941,121 @@ def test_quantized_sdpa_affine_masked(self): tol = 2e-2 self.assertLess((out - ref).abs().max(), tol) + def test_quantized_sdpa_causal(self): + if mx.default_device() == mx.cpu: + self.skipTest("Quantized fast attention is only available on GPU.") + + mx.random.seed(0) + B, Hq, Hkv = 1, 2, 1 + Lk, D = 640, 128 + + for mode in ["mxfp4", "mxfp8", "nvfp4"]: + bits = 8 if mode == "mxfp8" else 4 + for Lq in [4, 9]: + with self.subTest(mode=mode, bits=bits, Lq=Lq): + q = 0.1 * mx.random.normal(shape=(B, Hq, Lq, D)) + k = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + v = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + + k_q, k_scales = mx.quantize(k, mode=mode) + v_q, v_scales = mx.quantize(v, mode=mode) + + ref = mx.fast.scaled_dot_product_attention( + q, k, v, scale=1.0, mask="causal" + ) + out = mx.fast.quantized_scaled_dot_product_attention( + q, + k_q, + k_scales, + v_q, + v_scales, + scale=1.0, + mode=mode, + bits=bits, + causal=True, + ) + + self.assertEqual(out.shape, ref.shape) + tol = 5e-2 if bits == 4 else 2e-2 + self.assertLess((out - ref).abs().max(), tol) + + def test_quantized_sdpa_affine_causal(self): + if mx.default_device() == mx.cpu: + self.skipTest("Quantized fast attention is only available on GPU.") + + mx.random.seed(0) + B, Hq, Hkv = 1, 2, 1 + Lk, D = 640, 128 + + for bits in [4, 6, 8]: + for Lq in [4, 9]: + with self.subTest(bits=bits, Lq=Lq): + q = 0.1 * mx.random.normal(shape=(B, Hq, Lq, D)) + k = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + v = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + + k_q, k_scales, k_biases = mx.quantize( + k, group_size=32, bits=bits, mode="affine" + ) + v_q, v_scales, v_biases = mx.quantize( + v, group_size=32, bits=bits, mode="affine" + ) + + ref = mx.fast.scaled_dot_product_attention( + q, k, v, scale=1.0, mask="causal" + ) + out = mx.fast.quantized_scaled_dot_product_attention( + q, + k_q, + k_scales, + k_biases, + v_q, + v_scales, + v_biases, + scale=1.0, + mode="affine", + group_size=32, + bits=bits, + causal=True, + ) + + self.assertEqual(out.shape, ref.shape) + if bits == 6: + tol = 1e-1 + elif bits == 4: + tol = 5e-2 + else: + tol = 2e-2 + self.assertLess((out - ref).abs().max(), tol) + + def test_quantized_sdpa_causal_with_array_mask_error(self): + if mx.default_device() == mx.cpu: + self.skipTest("Quantized fast attention is only available on GPU.") + + B, Hq, Hkv = 1, 2, 1 + Lq, Lk, D = 4, 640, 128 + q = mx.random.normal(shape=(B, Hq, Lq, D)) + k = mx.random.normal(shape=(B, Hkv, Lk, D)) + v = mx.random.normal(shape=(B, Hkv, Lk, D)) + mask = mx.ones(shape=(Lq, Lk), dtype=mx.bool_) + + k_q, k_scales = mx.quantize(k, mode="mxfp4") + v_q, v_scales = mx.quantize(v, mode="mxfp4") + + with self.assertRaises(ValueError): + mx.fast.quantized_scaled_dot_product_attention( + q, + k_q, + k_scales, + v_q, + v_scales, + scale=1.0, + mode="mxfp4", + bits=4, + mask=mask, + causal=True, + ) + def test_gather_qmm(self): def quantize(w, transpose=True, group_size=None, bits=None, mode="affine"): if mode == "affine": From 4ce3d5d9067a246b1a491bf216154463d1cd8667 Mon Sep 17 00:00:00 2001 From: Dan Yeh Date: Sun, 8 Feb 2026 00:45:09 +0800 Subject: [PATCH 17/21] support sinks --- mlx/backend/metal/kernels/sdpa_vector.h | 7 + .../metal/scaled_dot_product_attention.cpp | 21 ++- mlx/fast.cpp | 49 ++++++- mlx/fast.h | 1 + mlx/fast_primitives.h | 12 +- python/src/fast.cpp | 9 +- python/tests/test_quantized.py | 136 ++++++++++++++++++ 7 files changed, 224 insertions(+), 11 deletions(-) diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 5c9f58db06..d6ca2764fa 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -389,6 +389,7 @@ METAL_FUNC void quant_sdpa_vector_2pass_1_impl( const constant int& mask_head_stride, const device uint8_t* key_biases_raw, const device uint8_t* value_biases_raw, + const device T* sinks, uint3 tid, uint3 tpg, uint3 tptg, @@ -493,6 +494,10 @@ METAL_FUNC void quant_sdpa_vector_2pass_1_impl( U max_score = Limits::finite_min; U sum_exp_score = 0; + if (has_sinks && block_idx == 0 && local_quad_gid == 0) { + max_score = static_cast(sinks[q_head_idx]); + sum_exp_score = 1; + } // Main loop: each quad processes one key at a time for (int i = block_idx * BN + local_quad_gid; i < N; i += blocks * BN) { @@ -605,6 +610,7 @@ template [[buffer(20), function_constant(has_affine_bias)]], const device uint8_t* value_biases [[buffer(21), function_constant(has_affine_bias)]], + const device T* sinks [[buffer(22), function_constant(has_sinks)]], uint3 tid [[threadgroup_position_in_grid]], uint3 tpg [[threadgroups_per_grid]], uint3 tptg [[threads_per_threadgroup]], @@ -635,6 +641,7 @@ template mask_head_stride, \ key_biases, \ value_biases, \ + sinks, \ tid, \ tpg, \ tptg, \ diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index db82a01762..8b984cb78f 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -648,6 +648,7 @@ void quant_sdpa_vector_2pass( int bits, bool do_causal, const std::optional& mask, + const std::optional& sinks, QuantizationMode mode) { std::string kname; kname.reserve(64); @@ -696,7 +697,7 @@ void quant_sdpa_vector_2pass( bool bool_mask = has_mask && (*mask).dtype() == bool_; bool float_mask = has_mask && !bool_mask; bool query_transposed = !q.flags().row_contiguous; - bool has_sinks = false; + bool has_sinks = sinks.has_value(); bool has_affine_bias = mode == QuantizationMode::Affine; int quant_mode_int = quant_mode_to_int(mode); metal::MTLFCList func_consts = { @@ -716,6 +717,7 @@ void quant_sdpa_vector_2pass( hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask"; hash_name += query_transposed ? "_qt" : "_qnt"; hash_name += do_causal ? "_c" : "_nc"; + hash_name += has_sinks ? "_s" : "_ns"; hash_name += has_affine_bias ? "_affine_" : "_noaffine_"; hash_name += std::to_string(quant_mode_int) + "_"; hash_name += std::to_string(bits) + "_"; @@ -757,6 +759,9 @@ void quant_sdpa_vector_2pass( compute_encoder.set_input_array(*k_biases, 20); compute_encoder.set_input_array(*v_biases, 21); } + if (has_sinks) { + compute_encoder.set_input_array(*sinks, 22); + } compute_encoder.dispatch_threadgroups(grid_dims, group_dims); @@ -1021,7 +1026,7 @@ void QuantizedScaledDotProductAttention::eval_gpu( // Inputs layout: // [q, k, k_scales, k_biases (if affine), v, v_scales, v_biases (if affine), - // mask (if present)] + // mask (if present), sinks (if present)] auto& q_pre = inputs[0]; auto& k_pre = inputs[1]; auto& k_scales_pre = inputs[2]; @@ -1052,6 +1057,10 @@ void QuantizedScaledDotProductAttention::eval_gpu( } }; + auto is_matrix_contiguous = [](const array& arr) { + return arr.strides(-1) == 1; + }; + auto q_copy_unless = [](const array& arr) { if (arr.flags().row_contiguous) { return true; @@ -1100,7 +1109,12 @@ void QuantizedScaledDotProductAttention::eval_gpu( return arr.flags().row_contiguous || q.shape(0) == 1 || q.shape(1) == 1 || (strides[0] == strides[1] * shape[1]); }; - mask = copy_unless(mask_copy_unless, inputs.back()); + mask = copy_unless(mask_copy_unless, inputs[idx++]); + } + + std::optional sinks = std::nullopt; + if (has_sinks_) { + sinks = copy_unless(is_matrix_contiguous, inputs[idx++]); } if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) { @@ -1125,6 +1139,7 @@ void QuantizedScaledDotProductAttention::eval_gpu( bits_, do_causal_, mask, + sinks, mode_); d.add_temporaries(std::move(copies), s.index); diff --git a/mlx/fast.cpp b/mlx/fast.cpp index ebe08e4d00..aa3f88fd36 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -885,6 +885,7 @@ array quantized_scaled_dot_product_attention( const std::optional& value_biases, const float scale, const std::optional& mask /* = std::nullopt */, + const std::optional& sinks /* = std::nullopt */, const std::optional group_size_ /* = std::nullopt */, const std::optional bits_ /* = std::nullopt */, const std::string& mode /* = "mxfp4" */, @@ -1038,6 +1039,7 @@ array quantized_scaled_dot_product_attention( // Validate mask bool do_causal = causal; bool has_arr_mask = mask.has_value(); + bool has_sinks = sinks.has_value(); if (do_causal && has_arr_mask) { throw std::invalid_argument( "[quantized_scaled_dot_product_attention] Received both causal=true " @@ -1053,12 +1055,13 @@ array quantized_scaled_dot_product_attention( auto q = astype(queries, final_type, s); // Inputs layout: // [q, k, k_scales, k_biases (if affine), v, v_scales, v_biases (if affine), - // mask (if present)] + // mask (if present), sinks (if present)] auto fallback = [scale, n_q_heads, n_kv_heads, do_causal, has_arr_mask, + has_sinks, is_affine, group_size, bits, @@ -1082,10 +1085,12 @@ array quantized_scaled_dot_product_attention( } std::optional arr_mask = - has_arr_mask ? std::optional{inputs[idx]} : std::nullopt; + has_arr_mask ? std::optional{inputs[idx++]} : std::nullopt; + std::optional sinks_opt = + has_sinks ? std::optional{inputs[idx++]} : std::nullopt; if (n_repeats > 1) { - q = reshape(q, {q.shape(0), n_kv_heads, n_repeats, q.shape(2), -1}, s); + q = unflatten(q, 1, {n_kv_heads, n_repeats}, s); k = expand_dims(k, 2, s); k_scales = expand_dims(k_scales, 2, s); if (k_biases) { @@ -1138,7 +1143,24 @@ array quantized_scaled_dot_product_attention( } } + if (has_sinks) { + auto sinks = *sinks_opt; + // scores has shape B N_q N_k L_q L_k + sinks = expand_dims(sinks, {0, 2, 3}, s); + if (scores.ndim() == 5) { + sinks = unflatten(sinks, 1, {n_kv_heads, n_repeats}, s); + } + auto bsx_shape = scores.shape(); + bsx_shape.back() = 1; + scores = concatenate({broadcast_to(sinks, bsx_shape, s), scores}, -1, s); + } scores = softmax(scores, std::vector{-1}, true, s); + if (has_sinks) { + auto start = Shape(scores.ndim(), 0); + start.back() = 1; + auto stop = scores.shape(); + scores = slice(scores, std::move(start), std::move(stop), s); + } auto out = quantized_matmul( scores, v, @@ -1173,6 +1195,21 @@ array quantized_scaled_dot_product_attention( *mask, final_type, full_mask_shape, tag, stream); inputs.push_back(std::move(prepared_mask.first)); } + if (has_sinks) { + if (promote_types(sinks->dtype(), final_type) != final_type) { + std::ostringstream msg; + msg << "[" << tag << "] Type of sinks must promote to output type " + << final_type << "."; + throw std::invalid_argument(msg.str()); + } + if (sinks->ndim() != 1 || sinks->shape(0) != n_q_heads) { + std::ostringstream msg; + msg << "[" << tag << "] Received invalid shape for sinks " + << sinks->shape() << "."; + throw std::invalid_argument(msg.str()); + } + inputs.push_back(astype(*sinks, final_type, stream)); + } int out_dim = value_head_dim; Shape out_shape{ @@ -1188,6 +1225,7 @@ array quantized_scaled_dot_product_attention( fallback, scale, has_arr_mask, + has_sinks, do_causal, group_size, bits, @@ -1254,8 +1292,9 @@ bool QuantizedScaledDotProductAttention::is_equivalent( const QuantizedScaledDotProductAttention& a_other = static_cast(other); return scale_ == a_other.scale_ && has_arr_mask_ == a_other.has_arr_mask_ && - do_causal_ == a_other.do_causal_ && group_size_ == a_other.group_size_ && - bits_ == a_other.bits_ && mode_ == a_other.mode_; + has_sinks_ == a_other.has_sinks_ && do_causal_ == a_other.do_causal_ && + group_size_ == a_other.group_size_ && bits_ == a_other.bits_ && + mode_ == a_other.mode_; } bool ScaledDotProductAttentionVJP::is_equivalent(const Primitive& other) const { diff --git a/mlx/fast.h b/mlx/fast.h index cbafd12973..b176e52e03 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -65,6 +65,7 @@ MLX_API array quantized_scaled_dot_product_attention( const std::optional& value_biases, const float scale, const std::optional& mask = std::nullopt, + const std::optional& sinks = std::nullopt, std::optional group_size = std::nullopt, std::optional bits = std::nullopt, const std::string& mode = "mxfp4", diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 9517fbc620..961a6a139e 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -267,6 +267,7 @@ class QuantizedScaledDotProductAttention : public Custom { std::function(std::vector)> fallback, float scale, bool has_arr_mask, + bool has_sinks, bool do_causal, int group_size, int bits, @@ -274,6 +275,7 @@ class QuantizedScaledDotProductAttention : public Custom { : Custom(stream, std::move(fallback)), scale_(scale), has_arr_mask_(has_arr_mask), + has_sinks_(has_sinks), do_causal_(do_causal), group_size_(group_size), bits_(bits), @@ -295,12 +297,20 @@ class QuantizedScaledDotProductAttention : public Custom { DEFINE_INPUT_OUTPUT_SHAPE() auto state() const { return std::make_tuple( - nullptr, scale_, has_arr_mask_, do_causal_, group_size_, bits_, mode_); + nullptr, + scale_, + has_arr_mask_, + has_sinks_, + do_causal_, + group_size_, + bits_, + mode_); } private: float scale_; bool has_arr_mask_; + bool has_sinks_; bool do_causal_; int group_size_; int bits_; diff --git a/python/src/fast.cpp b/python/src/fast.cpp index c9be77c7ba..c982f6b3e6 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -307,6 +307,7 @@ void init_fast(nb::module_& parent_module) { const std::optional& v_biases, const float scale, const std::optional& mask, + const std::optional& sinks, std::optional group_size, std::optional bits, const std::string& mode, @@ -322,6 +323,7 @@ void init_fast(nb::module_& parent_module) { v_biases, scale, mask, + sinks, group_size, bits, mode, @@ -338,13 +340,14 @@ void init_fast(nb::module_& parent_module) { "v_biases"_a = nb::none(), "scale"_a, "mask"_a = nb::none(), + "sinks"_a = nb::none(), "group_size"_a = nb::none(), "bits"_a = nb::none(), "mode"_a = "mxfp4", "causal"_a = false, "stream"_a = nb::none(), nb::sig( - "def quantized_scaled_dot_product_attention(q: array, k: array, k_scales: array, v: array, v_scales: array, *, k_biases: Optional[array] = None, v_biases: Optional[array] = None, scale: float, mask: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = \"mxfp4\", causal: bool = False, stream: Union[None, Stream, Device] = None) -> array")); + "def quantized_scaled_dot_product_attention(q: array, k: array, k_scales: array, v: array, v_scales: array, *, k_biases: Optional[array] = None, v_biases: Optional[array] = None, scale: float, mask: Optional[array] = None, sinks: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = \"mxfp4\", causal: bool = False, stream: Union[None, Stream, Device] = None) -> array")); m.def( "quantized_scaled_dot_product_attention", @@ -359,13 +362,14 @@ void init_fast(nb::module_& parent_module) { nb::kw_only(), "scale"_a, "mask"_a = nb::none(), + "sinks"_a = nb::none(), "group_size"_a = nb::none(), "bits"_a = nb::none(), "mode"_a = "mxfp4", "causal"_a = false, "stream"_a = nb::none(), nb::sig( - "def quantized_scaled_dot_product_attention(q: array, k: array, k_scales: array, k_biases: Optional[array] = None, v: array, v_scales: array, v_biases: Optional[array] = None, *, scale: float, mask: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = \"mxfp4\", causal: bool = False, stream: Union[None, Stream, Device] = None) -> array"), + "def quantized_scaled_dot_product_attention(q: array, k: array, k_scales: array, k_biases: Optional[array] = None, v: array, v_scales: array, v_biases: Optional[array] = None, *, scale: float, mask: Optional[array] = None, sinks: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = \"mxfp4\", causal: bool = False, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( A fast implementation of multi-head attention where the keys and values are quantized. @@ -381,6 +385,7 @@ void init_fast(nb::module_& parent_module) { v_biases (array or None): Biases for the affine-quantized values array. Required for affine mode, None for fp modes. scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``) mask (array, optional): An additive or boolean mask to apply to the query-key scores. + sinks (array, optional): An optional array of attention sinks with shape ``[N_q]``. group_size (int, optional): The group size used in the KV quantization. Defaults follow the quantization ``mode``. bits (int, optional): The bits used in the KV quantization. Defaults follow the quantization ``mode``. mode (str, optional): The quantization mode: ``"mxfp4"``, ``"mxfp8"``, ``"nvfp4"``, or ``"affine"``. diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 81cafd1695..b680a90a6b 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -941,6 +941,142 @@ def test_quantized_sdpa_affine_masked(self): tol = 2e-2 self.assertLess((out - ref).abs().max(), tol) + def test_quantized_sdpa_sinks(self): + if mx.default_device() == mx.cpu: + self.skipTest("Quantized fast attention is only available on GPU.") + + mx.random.seed(0) + B, Hq, Hkv = 1, 2, 1 + Lk, D = 640, 128 + sinks = mx.array([0.7, -0.4], dtype=mx.float32) + + for mode in ["mxfp4", "mxfp8", "nvfp4"]: + bits = 8 if mode == "mxfp8" else 4 + for Lq in [4, 9]: + with self.subTest(mode=mode, bits=bits, Lq=Lq): + q = 0.1 * mx.random.normal(shape=(B, Hq, Lq, D)) + k = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + v = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + + k_q, k_scales = mx.quantize(k, mode=mode) + v_q, v_scales = mx.quantize(v, mode=mode) + + ref = mx.fast.scaled_dot_product_attention( + q, k, v, scale=1.0, sinks=sinks + ) + out = mx.fast.quantized_scaled_dot_product_attention( + q, + k_q, + k_scales, + v_q, + v_scales, + scale=1.0, + mode=mode, + bits=bits, + sinks=sinks, + ) + + self.assertEqual(out.shape, ref.shape) + tol = 5e-2 if bits == 4 else 2e-2 + self.assertLess((out - ref).abs().max(), tol) + + def test_quantized_sdpa_masked_with_sinks(self): + if mx.default_device() == mx.cpu: + self.skipTest("Quantized fast attention is only available on GPU.") + + mx.random.seed(0) + B, Hq, Hkv = 1, 2, 1 + Lk, D = 640, 128 + sinks = mx.array([0.5, -0.3], dtype=mx.float32) + + for Lq in [4, 9]: + q = 0.1 * mx.random.normal(shape=(B, Hq, Lq, D)) + k = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + v = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + + bool_mask = mx.random.uniform(shape=(Lq, Lk)) > 0.2 + additive_mask = mx.where( + bool_mask, + mx.zeros((Lq, Lk), dtype=mx.float32), + mx.full((Lq, Lk), -1e9, dtype=mx.float32), + ) + + mode = "mxfp4" + bits = 4 + k_q, k_scales = mx.quantize(k, mode=mode) + v_q, v_scales = mx.quantize(v, mode=mode) + + for mask_name, mask in { + "bool": bool_mask, + "additive": additive_mask, + }.items(): + with self.subTest(Lq=Lq, mask=mask_name): + ref = mx.fast.scaled_dot_product_attention( + q, k, v, scale=1.0, mask=mask, sinks=sinks + ) + out = mx.fast.quantized_scaled_dot_product_attention( + q, + k_q, + k_scales, + v_q, + v_scales, + scale=1.0, + mode=mode, + bits=bits, + mask=mask, + sinks=sinks, + ) + + self.assertEqual(out.shape, ref.shape) + self.assertLess((out - ref).abs().max(), 5e-2) + + def test_quantized_sdpa_affine_masked_with_sinks(self): + if mx.default_device() == mx.cpu: + self.skipTest("Quantized fast attention is only available on GPU.") + + mx.random.seed(0) + B, Hq, Hkv = 1, 2, 1 + Lk, D = 640, 128 + bits = 4 + sinks = mx.array([0.2, -0.1], dtype=mx.float32) + + for Lq in [4, 9]: + with self.subTest(Lq=Lq): + q = 0.1 * mx.random.normal(shape=(B, Hq, Lq, D)) + k = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + v = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + + bool_mask = mx.random.uniform(shape=(Lq, Lk)) > 0.2 + + k_q, k_scales, k_biases = mx.quantize( + k, group_size=32, bits=bits, mode="affine" + ) + v_q, v_scales, v_biases = mx.quantize( + v, group_size=32, bits=bits, mode="affine" + ) + + ref = mx.fast.scaled_dot_product_attention( + q, k, v, scale=1.0, mask=bool_mask, sinks=sinks + ) + out = mx.fast.quantized_scaled_dot_product_attention( + q, + k_q, + k_scales, + k_biases, + v_q, + v_scales, + v_biases, + scale=1.0, + mode="affine", + group_size=32, + bits=bits, + mask=bool_mask, + sinks=sinks, + ) + + self.assertEqual(out.shape, ref.shape) + self.assertLess((out - ref).abs().max(), 5e-2) + def test_quantized_sdpa_causal(self): if mx.default_device() == mx.cpu: self.skipTest("Quantized fast attention is only available on GPU.") From 0732449f0fe117dcd34957d92d871b20f2de5736 Mon Sep 17 00:00:00 2001 From: Dan Yeh Date: Sun, 8 Feb 2026 01:04:41 +0800 Subject: [PATCH 18/21] cleanup --- .../kernels/scaled_dot_product_attention.metal | 17 +++++++++-------- .../metal/scaled_dot_product_attention.cpp | 14 +++++++++++--- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index 45d7b1a7e4..4de30fccdd 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -42,16 +42,17 @@ instantiate_sdpa_vector_heads(float) instantiate_sdpa_vector_heads(bfloat16_t) instantiate_sdpa_vector_heads(float16_t) -#define instantiate_quant_sdpa_vector(type, head_dim) \ - instantiate_kernel( \ - "quant_sdpa_vector_2pass_1_" #type "_" #head_dim, \ - quant_sdpa_vector_2pass_1, \ - type, \ - head_dim) +#define instantiate_quant_sdpa_vector(type, qk_dim, value_dim) \ + instantiate_kernel( \ + "quant_sdpa_vector_2pass_1_" #type "_" #qk_dim "_" #value_dim, \ + quant_sdpa_vector_2pass_1, \ + type, \ + qk_dim) #define instantiate_quant_sdpa_vector_heads(type) \ - instantiate_quant_sdpa_vector(type, 64) \ - instantiate_quant_sdpa_vector(type, 128) + instantiate_quant_sdpa_vector(type, 64, 64) \ + instantiate_quant_sdpa_vector(type, 128, 128) \ + instantiate_quant_sdpa_vector(type, 256, 256) instantiate_quant_sdpa_vector_heads(float) instantiate_quant_sdpa_vector_heads(bfloat16_t) diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 8b984cb78f..ee1cbb6285 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -656,6 +656,8 @@ void quant_sdpa_vector_2pass( kname += get_type_string(q.dtype()); kname += "_"; kname += std::to_string(q.shape(-1)); + kname += "_"; + kname += std::to_string(q.shape(-1)); int N = k.shape(2); int gqa_factor = q.shape(1) / k.shape(1); @@ -726,6 +728,7 @@ void quant_sdpa_vector_2pass( auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname, hash_name, func_consts); + check_kernel_threadgroup_size(kernel, group_dims, hash_name); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(q, 0); @@ -770,7 +773,7 @@ void quant_sdpa_vector_2pass( kname += "sdpa_vector_2pass_2_"; kname += get_type_string(q.dtype()); kname += "_"; - kname += std::to_string(q.shape(-1)); + kname += std::to_string(out.shape(-1)); func_consts = { {&blocks, MTL::DataType::DataTypeInt, 26}, @@ -788,6 +791,7 @@ void quant_sdpa_vector_2pass( group_dims = MTL::Size(1024, 1, 1); grid_dims = MTL::Size(q.shape(0) * q.shape(1), q.shape(2), 1); + check_kernel_threadgroup_size(kernel, group_dims, kname); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } @@ -864,10 +868,13 @@ bool QuantizedScaledDotProductAttention::use_fallback( } int query_sequence_length = q.shape(2); + int key_sequence_length = k.shape(2); int query_head_dim = q.shape(-1); int gqa_factor = q.shape(1) / k.shape(1); return query_sequence_length > 8 || - !(query_head_dim == 64 || query_head_dim == 128) || + query_sequence_length > key_sequence_length || + !(query_head_dim == 64 || query_head_dim == 128 || + query_head_dim == 256) || (query_sequence_length * gqa_factor > 32); } @@ -1123,6 +1130,7 @@ void QuantizedScaledDotProductAttention::eval_gpu( o.set_data(allocator::malloc(o.nbytes())); } + bool do_causal = do_causal_ && q.shape(2) > 1; quant_sdpa_vector_2pass( s, d, @@ -1137,7 +1145,7 @@ void QuantizedScaledDotProductAttention::eval_gpu( scale_, group_size_, bits_, - do_causal_, + do_causal, mask, sinks, mode_); From 21388b1142a401983c53ab64264ba31bfc631a3e Mon Sep 17 00:00:00 2001 From: Dan Yeh Date: Sun, 8 Feb 2026 11:19:42 +0800 Subject: [PATCH 19/21] cleanup --- mlx/backend/metal/scaled_dot_product_attention.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index ee1cbb6285..2f3f941954 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -657,7 +657,7 @@ void quant_sdpa_vector_2pass( kname += "_"; kname += std::to_string(q.shape(-1)); kname += "_"; - kname += std::to_string(q.shape(-1)); + kname += std::to_string(v.shape(-1)); int N = k.shape(2); int gqa_factor = q.shape(1) / k.shape(1); From 0c58fff181104fac7daa9a318231bfdb4dd900bb Mon Sep 17 00:00:00 2001 From: Dan Yeh Date: Thu, 16 Apr 2026 23:10:15 +0200 Subject: [PATCH 20/21] improvements --- .../scaled_dot_product_attention.metal | 6 ++- mlx/backend/metal/kernels/sdpa_vector.h | 21 +++++---- .../metal/scaled_dot_product_attention.cpp | 14 +++--- python/tests/test_quantized.py | 47 ++++++++++--------- 4 files changed, 46 insertions(+), 42 deletions(-) diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index 4de30fccdd..9a628c3851 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -36,7 +36,8 @@ using namespace metal; instantiate_sdpa_vector_aggregation(type, 64) \ instantiate_sdpa_vector_aggregation(type, 96) \ instantiate_sdpa_vector_aggregation(type, 128) \ - instantiate_sdpa_vector_aggregation(type, 256) + instantiate_sdpa_vector_aggregation(type, 256) \ + instantiate_sdpa_vector_aggregation(type, 512) instantiate_sdpa_vector_heads(float) instantiate_sdpa_vector_heads(bfloat16_t) @@ -52,7 +53,8 @@ instantiate_sdpa_vector_heads(float16_t) #define instantiate_quant_sdpa_vector_heads(type) \ instantiate_quant_sdpa_vector(type, 64, 64) \ instantiate_quant_sdpa_vector(type, 128, 128) \ - instantiate_quant_sdpa_vector(type, 256, 256) + instantiate_quant_sdpa_vector(type, 256, 256) \ + instantiate_quant_sdpa_vector(type, 512, 512) instantiate_quant_sdpa_vector_heads(float) instantiate_quant_sdpa_vector_heads(bfloat16_t) diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index d6ca2764fa..bf7b0bafd6 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -408,13 +408,12 @@ METAL_FUNC void quant_sdpa_vector_2pass_1_impl( static_assert( (D % group_size) == 0, "group_size must divide the head dimension"); - constexpr int BN = 8; - constexpr int BD = 4; + constexpr int BD = (D > 256) ? 8 : 4; + constexpr int BN = 32 / BD; constexpr int elem_per_thread = D / BD; - // Derive quad indices from simd_lid (replaces quad_gid/quad_lid attributes) - const int local_quad_gid = simd_lid / 4; // 0-7 - const int local_quad_lid = simd_lid % 4; // 0-3 + const int local_quad_gid = simd_lid / BD; + const int local_quad_lid = simd_lid % BD; typedef float U; @@ -515,6 +514,9 @@ METAL_FUNC void quant_sdpa_vector_2pass_1_impl( template dot( q, key_ptr.ptr(), key_scales, key_bias_ptr); score = quad_sum(score); + for (int s = 4; s < BD; s <<= 1) { + score += simd_shuffle_xor(score, s); + } if (float_mask) { score += static_cast(fmask[0]); @@ -568,14 +570,13 @@ METAL_FUNC void quant_sdpa_vector_2pass_1_impl( maxs[0] = global_max; } - // Output reduction: sum across quads (same local_quad_lid only) + // Output reduction: sum across groups (same local_quad_lid only) U rescale = fast::exp(max_score - global_max); for (int i = 0; i < elem_per_thread; i++) { U val = o[i] * rescale; - val += simd_shuffle_xor(val, 4); // sum quads 0+1, 2+3, 4+5, 6+7 - val += simd_shuffle_xor(val, 8); // sum quads 0-3, 4-7 - val += simd_shuffle_xor(val, 16); // sum quads 0-7 - // All lanes with same local_quad_lid now have the full sum; + for (int s = BD; s < 32; s <<= 1) { + val += simd_shuffle_xor(val, s); + } if (local_quad_gid == 0) { out[i] = static_cast(val); } diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 2f3f941954..fae6188ed5 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -657,7 +657,7 @@ void quant_sdpa_vector_2pass( kname += "_"; kname += std::to_string(q.shape(-1)); kname += "_"; - kname += std::to_string(v.shape(-1)); + kname += std::to_string(q.shape(-1)); int N = k.shape(2); int gqa_factor = q.shape(1) / k.shape(1); @@ -691,9 +691,6 @@ void quant_sdpa_vector_2pass( intermediate.set_data(allocator::malloc(intermediate.nbytes())); sums.set_data(allocator::malloc(sums.nbytes())); maxs.set_data(allocator::malloc(maxs.nbytes())); - d.add_temporary(intermediate, s.index); - d.add_temporary(sums, s.index); - d.add_temporary(maxs, s.index); bool has_mask = mask.has_value(); bool bool_mask = has_mask && (*mask).dtype() == bool_; @@ -726,7 +723,10 @@ void quant_sdpa_vector_2pass( hash_name += std::to_string(group_size) + "_"; hash_name += std::to_string(blocks); - auto& compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = metal::get_command_encoder(s); + compute_encoder.add_temporary(intermediate); + compute_encoder.add_temporary(sums); + compute_encoder.add_temporary(maxs); auto kernel = d.get_kernel(kname, hash_name, func_consts); check_kernel_threadgroup_size(kernel, group_dims, hash_name); compute_encoder.set_compute_pipeline_state(kernel); @@ -874,7 +874,7 @@ bool QuantizedScaledDotProductAttention::use_fallback( return query_sequence_length > 8 || query_sequence_length > key_sequence_length || !(query_head_dim == 64 || query_head_dim == 128 || - query_head_dim == 256) || + query_head_dim == 256 || query_head_dim == 512) || (query_sequence_length * gqa_factor > 32); } @@ -1150,7 +1150,7 @@ void QuantizedScaledDotProductAttention::eval_gpu( sinks, mode_); - d.add_temporaries(std::move(copies), s.index); + metal::get_command_encoder(s).add_temporaries(std::move(copies)); } bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) { diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index b680a90a6b..8c46e7e9e9 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -759,33 +759,34 @@ def test_quantized_sdpa(self): mx.random.seed(0) B, Hq, Hkv = 1, 2, 1 - Lq, Lk, D = 4, 640, 128 + Lq, Lk = 4, 640 - for mode in ["mxfp4", "mxfp8", "nvfp4"]: - with self.subTest(mode=mode): - bits = 8 if mode == "mxfp8" else 4 - q = 0.1 * mx.random.normal(shape=(B, Hq, Lq, D)) - k = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) - v = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + for D in [128, 512]: + for mode in ["mxfp4", "mxfp8", "nvfp4"]: + with self.subTest(D=D, mode=mode): + bits = 8 if mode == "mxfp8" else 4 + q = 0.1 * mx.random.normal(shape=(B, Hq, Lq, D)) + k = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + v = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) - k_q, k_scales = mx.quantize(k, mode=mode) - v_q, v_scales = mx.quantize(v, mode=mode) + k_q, k_scales = mx.quantize(k, mode=mode) + v_q, v_scales = mx.quantize(v, mode=mode) - ref = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0) - out = mx.fast.quantized_scaled_dot_product_attention( - q, - k_q, - k_scales, - v_q, - v_scales, - scale=1.0, - mode=mode, - bits=bits, - ) + ref = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0) + out = mx.fast.quantized_scaled_dot_product_attention( + q, + k_q, + k_scales, + v_q, + v_scales, + scale=1.0, + mode=mode, + bits=bits, + ) - self.assertEqual(out.shape, ref.shape) - tol = 5e-2 if bits == 4 else 2e-2 - self.assertLess((out - ref).abs().max(), tol) + self.assertEqual(out.shape, ref.shape) + tol = 5e-2 if bits == 4 else 2e-2 + self.assertLess((out - ref).abs().max(), tol) def test_quantized_sdpa_affine(self): if mx.default_device() == mx.cpu: From e1c923e5ea9c4c65aa16c03d709013e9319c6fec Mon Sep 17 00:00:00 2001 From: dogukanveziroglu Date: Tue, 14 Apr 2026 12:58:14 +0300 Subject: [PATCH 21/21] Support group_size=64 for affine quantized SDPA Add Affine dispatch entries for group_size=64 at bits={4,6,8} and relax the validation in quantized_scaled_dot_product_attention. This matches the default produced by mx.quantize(mode="affine") and the kv_group_size=64 default used by mlx-lm, so users following the MLX/mlx-lm conventions no longer hit an error when using fused quantized attention. Benchmarks (M4, B=1 H=32 D=128 Lq=1, affine 4-bit): Context gs=32 fused gs=64 fused speedup 32K 50 us 41 us +22% 64K 95 us 82 us +16% 128K 176 us 152 us +16% gs=64 is faster at long context because it has half the scale/bias memory traffic. Costs: mlx.metallib: 128,161,428 -> 128,233,236 bytes (+0.056%) libmlx.dylib: unchanged Existing 10 test_quantized_sdpa* tests continue to pass (54 subtests). --- mlx/backend/metal/kernels/sdpa_vector.h | 3 +++ mlx/fast.cpp | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index bf7b0bafd6..a823915123 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -653,6 +653,9 @@ template QUANT_SDPA_DISPATCH(Affine, 32, 4) QUANT_SDPA_DISPATCH(Affine, 32, 6) QUANT_SDPA_DISPATCH(Affine, 32, 8) + QUANT_SDPA_DISPATCH(Affine, 64, 4) + QUANT_SDPA_DISPATCH(Affine, 64, 6) + QUANT_SDPA_DISPATCH(Affine, 64, 8) QUANT_SDPA_DISPATCH(Mxfp4, 32, 4) QUANT_SDPA_DISPATCH(Nvfp4, 16, 4) QUANT_SDPA_DISPATCH(Mxfp8, 32, 8) diff --git a/mlx/fast.cpp b/mlx/fast.cpp index aa3f88fd36..bddc569b55 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -901,9 +901,9 @@ array quantized_scaled_dot_product_attention( // Validate mode-specific group_size and bits if (is_affine) { - if (group_size != 32) { + if (group_size != 32 && group_size != 64) { std::ostringstream msg; - msg << "[" << tag << "] Affine mode supports group_size 32 " + msg << "[" << tag << "] Affine mode supports group_size 32 or 64 " << "but received " << group_size << "."; throw std::invalid_argument(msg.str()); }