diff --git a/.gitignore b/.gitignore index 1daaa46d12..3a1ff6a3a0 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,7 @@ parts/ sdist/ var/ venv/ +.venv/ wheels/ share/python-wheels/ *.egg-info/ diff --git a/mlx/backend/metal/kernels/fp_quantized.h b/mlx/backend/metal/kernels/fp_quantized.h index cc9b68ade8..e6684b5880 100644 --- a/mlx/backend/metal/kernels/fp_quantized.h +++ b/mlx/backend/metal/kernels/fp_quantized.h @@ -5,6 +5,8 @@ #include "mlx/backend/metal/kernels/fp4.h" #include "mlx/backend/metal/kernels/fp8.h" +#include "mlx/backend/metal/kernels/steel/gemm/loader.h" +#include "mlx/backend/metal/kernels/steel/gemm/mma.h" constant bool align_M [[function_constant(200)]]; constant bool align_N [[function_constant(201)]]; diff --git a/mlx/backend/metal/kernels/quantized_utils.h b/mlx/backend/metal/kernels/quantized_utils.h index 38253f8fe9..9484391e48 100644 --- a/mlx/backend/metal/kernels/quantized_utils.h +++ b/mlx/backend/metal/kernels/quantized_utils.h @@ -1,8 +1,267 @@ // Copyright © 2023-2024 Apple Inc. +#pragma once + #include #include +#include "mlx/backend/metal/kernels/fp4.h" +#include "mlx/backend/metal/kernels/fp8.h" + +enum class QuantMode { Affine, Mxfp4, Mxfp8, Nvfp4, TurboQuant3, TurboQuant4 }; + +template +struct DecodeValue { + [[clang::always_inline]] OutT operator()(uint8_t v) const { + return OutT(*(thread EncodedT*)(&v)); + } +}; + +// Specialization for Affine (plain integer cast) +template +struct DecodeValue { + [[clang::always_inline]] OutT operator()(uint8_t v) const { + return OutT(v); + } +}; + +template +struct QuantConfig; + +template <> +struct QuantConfig { + static constant constexpr bool has_bias = true; + + using value_type = void; + using scale_type = void; + + template + using scale_storage_t = T; +}; + +template <> +struct QuantConfig { + static constant constexpr bool has_bias = false; + + using value_type = fp4_e2m1; + using scale_type = fp8_e8m0; + + template + using scale_storage_t = uint8_t; +}; + +template <> +struct QuantConfig { + static constant constexpr bool has_bias = false; + + using value_type = fp4_e2m1; + using scale_type = fp8_e4m3; + + template + using scale_storage_t = uint8_t; +}; + +template <> +struct QuantConfig { + static constant constexpr bool has_bias = false; + + using value_type = fp8_e4m3; + using scale_type = fp8_e8m0; + + template + using scale_storage_t = uint8_t; +}; + +// TurboQuant: codebook-based quantization with per-vector float scales. +// Keys/values are packed bit indices; scales are per-vector L2 norms / sqrt(D). +// Codebooks are Lloyd-Max optimal for N(0,1) (distribution of rotated, +// norm-normalized key coordinates scaled by sqrt(D)). +template <> +struct QuantConfig { + static constant constexpr bool has_bias = false; + + using value_type = void; + using scale_type = void; + + template + using scale_storage_t = T; +}; + +template <> +struct QuantConfig { + static constant constexpr bool has_bias = false; + + using value_type = void; + using scale_type = void; + + template + using scale_storage_t = T; +}; + +// N(0,1) Lloyd-Max 3-bit codebook (8 reconstruction levels). +// Boundaries: 0, ±0.332, ±0.776, ±1.399 (midpoints of adjacent centroids). +constant float turbo3_codebook[8] = { + -1.7481f, + -1.0498f, + -0.5012f, + -0.1624f, + 0.1624f, + 0.5012f, + 1.0498f, + 1.7481f}; + +// N(0,1) equal-probability 4-bit codebook (16 reconstruction levels). +// Computed as E[X | X in (b_i, b_{i+1})] for N(0,1) with equiprobable bins. +constant float turbo4_codebook[16] = { + -1.9672f, + -1.3305f, + -1.0130f, + -0.7811f, + -0.5714f, + -0.4053f, + -0.2382f, + -0.0784f, + 0.0784f, + 0.2382f, + 0.4053f, + 0.5714f, + 0.7811f, + 1.0130f, + 1.3305f, + 1.9672f}; + +template +struct Dequant { + using Cfg = QuantConfig; + + [[clang::always_inline]] T raw(uint8_t v) const { + return DecodeValue{}(v); + } + + [[clang::always_inline]] T scale( + typename Cfg::template scale_storage_t s) const { + if constexpr (metal::is_same_v) { + return s; + } else { + return DecodeValue{}(s); + } + } + + [[clang::always_inline]] T operator()(uint8_t v, T s, T bias) const { + if constexpr (Cfg::has_bias) { + return fma(s, raw(v), bias); + } else { + return s * raw(v); + } + } +}; + +template +struct Dequant { + [[clang::always_inline]] T raw(uint8_t v) const { + return T(turbo3_codebook[v & 7u]); + } + [[clang::always_inline]] T scale(T s) const { + return s; + } + [[clang::always_inline]] T operator()(uint8_t v, T s, T) const { + return s * raw(v); + } +}; + +template +struct Dequant { + [[clang::always_inline]] T raw(uint8_t v) const { + return T(turbo4_codebook[v & 15u]); + } + [[clang::always_inline]] T scale(T s) const { + return s; + } + [[clang::always_inline]] T operator()(uint8_t v, T s, T) const { + return s * raw(v); + } +}; + +// Pack metadata and unpackers for arbitrary bit-widths (wsize fixed at 32 bits) +template +struct PackInfo { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "PackInfo only supports bits in {2,3,4,5,6,8}"); + + static constant constexpr int pack_factor = + (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : 32 / bits); + static constant constexpr int bytes_per_pack = + ((bits & (bits - 1)) == 0) ? 4 : (bits == 5 ? 5 : 3); +}; + +template +struct PackReader { + static constant constexpr int pack_factor = PackInfo::pack_factor; + static constant constexpr int bytes_per_pack = PackInfo::bytes_per_pack; + static constant constexpr uint64_t mask = (uint64_t(1) << bits) - 1; + + [[gnu::always_inline]] static void load( + const device uint8_t* p, + thread uint8_t (&out)[pack_factor]) { + uint64_t packed = load_packed(p); +#pragma clang loop unroll(full) + for (int i = 0; i < pack_factor; ++i) { + out[i] = static_cast((packed >> (bits * i)) & mask); + } + } + + private: + [[gnu::always_inline]] static uint64_t load_packed(const device uint8_t* p) { + if constexpr (bytes_per_pack == 4) { + return static_cast( + *(reinterpret_cast(p))); + } else { + uint64_t packed = 0; +#pragma clang loop unroll(full) + for (int i = 0; i < bytes_per_pack; ++i) { + packed |= static_cast(p[i]) << (8 * i); + } + return packed; + } + } +}; + +// Pointer wrapper for quantized data that handles byte-level addressing +// correctly for all bit widths. For non-4-byte-aligned packs (3, 5, 6-bit), +template +class QuantDataPtr { + const device uint8_t* byte_ptr_; + + public: + static constant constexpr int pack_factor = PackInfo::pack_factor; + static constant constexpr int bytes_per_pack = PackInfo::bytes_per_pack; + + // Initialize from base pointer, head stride (in uint32 units), head index, + // and element index + [[clang::always_inline]] QuantDataPtr( + const device uint32_t* base, + size_t head_stride, + int head_idx, + int elem_idx) { + int packed_idx = elem_idx / pack_factor; + byte_ptr_ = reinterpret_cast(base) + + head_idx * head_stride * 4 + // head_stride is in uint32 units + packed_idx * bytes_per_pack; + } + + // Advance by number of elements + [[clang::always_inline]] void advance(int num_elements) { + byte_ptr_ += num_elements * bits / 8; + } + + // Get raw pointer for passing to dot/accumulate functions + [[clang::always_inline]] const device uint32_t* ptr() const { + return reinterpret_cast(byte_ptr_); + } +}; + template METAL_FUNC void gemm_loop_aligned( threadgroup T* As, diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index c668d9d8c5..9a628c3851 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -36,9 +36,27 @@ using namespace metal; instantiate_sdpa_vector_aggregation(type, 64) \ instantiate_sdpa_vector_aggregation(type, 96) \ instantiate_sdpa_vector_aggregation(type, 128) \ - instantiate_sdpa_vector_aggregation(type, 256) + instantiate_sdpa_vector_aggregation(type, 256) \ + instantiate_sdpa_vector_aggregation(type, 512) instantiate_sdpa_vector_heads(float) instantiate_sdpa_vector_heads(bfloat16_t) instantiate_sdpa_vector_heads(float16_t) + +#define instantiate_quant_sdpa_vector(type, qk_dim, value_dim) \ + instantiate_kernel( \ + "quant_sdpa_vector_2pass_1_" #type "_" #qk_dim "_" #value_dim, \ + quant_sdpa_vector_2pass_1, \ + type, \ + qk_dim) + +#define instantiate_quant_sdpa_vector_heads(type) \ + instantiate_quant_sdpa_vector(type, 64, 64) \ + instantiate_quant_sdpa_vector(type, 128, 128) \ + instantiate_quant_sdpa_vector(type, 256, 256) \ + instantiate_quant_sdpa_vector(type, 512, 512) + +instantiate_quant_sdpa_vector_heads(float) +instantiate_quant_sdpa_vector_heads(bfloat16_t) +instantiate_quant_sdpa_vector_heads(float16_t) // clang-format on diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 1eec72be31..5a6d3050ab 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -2,6 +2,9 @@ #include +#include "mlx/backend/metal/kernels/fp_quantized.h" +#include "mlx/backend/metal/kernels/quantized_utils.h" + using namespace metal; constant bool has_mask [[function_constant(20)]]; @@ -176,6 +179,505 @@ template } } +constant bool has_affine_bias [[function_constant(27)]]; +constant int quant_mode_int [[function_constant(28)]]; +constant int quant_bits [[function_constant(29)]]; +constant int quant_group_size [[function_constant(30)]]; + +template +struct GroupSlice { + enum : int { + value = (elem_per_thread < group_size) ? elem_per_thread : group_size, + num_groups = elem_per_thread / value, + iters_per_group = value / granularity + }; + static_assert( + (value % granularity) == 0, + "group slice must be divisible by granularity"); + static_assert( + (elem_per_thread % value) == 0, + "elem_per_thread must be divisible by group slice"); +}; + +template +struct QuantOps { + using Cfg = QuantConfig; + static constant constexpr bool is_fast_path = (bits == 4 || bits == 8); + static constant constexpr int pack_factor = PackInfo::pack_factor; + static constant constexpr int bytes_per_pack = PackInfo::bytes_per_pack; + static constant constexpr int granularity = is_fast_path ? 4 : pack_factor; + using fast_load_t = metal::conditional_t; + static constant constexpr uint32_t fast_mask = (1u << bits) - 1; + static_assert( + bits == 3 || bits == 4 || bits == 6 || bits == 8, + "unsupported quant bits"); + static_assert( + !is_fast_path || (group_size % 4) == 0, + "group_size must be divisible by 4 for 4/8-bit fast path"); + + template + [[clang::always_inline]] static U dot( + const thread U* q, + const device uint32_t* keys, + const device ScaleT* scales, + [[maybe_unused]] const device ScaleT* biases) { + static_assert( + (elem_per_thread % granularity) == 0, + "elem_per_thread must be divisible by the granularity"); + Dequant dequant; + U score = 0; + + using Slice = GroupSlice; + constexpr int group_slice = Slice::value; + constexpr int num_groups = Slice::num_groups; + constexpr int iters_per_group = Slice::iters_per_group; + +#pragma clang loop unroll(full) + for (int g = 0; g < num_groups; g++) { + U scale = dequant.scale(scales[g]); + U bias = 0; + if constexpr (Cfg::has_bias) + bias = static_cast(biases[g]); + + U group_score = 0; + U bias_acc = 0; + + if constexpr (is_fast_path) { + auto ks = reinterpret_cast(keys); +#pragma clang loop unroll(full) + for (int j = 0; j < iters_per_group; j++) { + fast_load_t p = ks[g * iters_per_group + j]; + int base = g * group_slice + 4 * j; + + U v0 = dequant.raw(p & fast_mask); + U v1 = dequant.raw((p >> (bits * 1)) & fast_mask); + U v2 = dequant.raw((p >> (bits * 2)) & fast_mask); + U v3 = dequant.raw((p >> (bits * 3)) & fast_mask); + + group_score += q[base + 0] * v0; + group_score += q[base + 1] * v1; + group_score += q[base + 2] * v2; + group_score += q[base + 3] * v3; + if constexpr (Cfg::has_bias) { + bias_acc += q[base + 0] + q[base + 1] + q[base + 2] + q[base + 3]; + } + } + } else { + auto ks = reinterpret_cast(keys) + + g * (group_slice / pack_factor) * bytes_per_pack; + thread uint8_t raw[pack_factor]; + +#pragma clang loop unroll(full) + for (int j = 0; j < group_slice; j += pack_factor) { + PackReader::load(ks, raw); +#pragma clang loop unroll(full) + for (int t = 0; t < pack_factor; ++t) { + U decoded = dequant.raw(raw[t]); + int q_idx = g * group_slice + j + t; + group_score += q[q_idx] * decoded; + if constexpr (Cfg::has_bias) + bias_acc += q[q_idx]; + } + ks += bytes_per_pack; + } + } + + if constexpr (Cfg::has_bias) { + score += fma(scale, group_score, bias * bias_acc); + } else { + score += scale * group_score; + } + } + return score; + } + + template + [[clang::always_inline]] static void accumulate( + thread U* o, + const device uint32_t* values, + U factor, + U exp_score, + const device ScaleT* scales, + [[maybe_unused]] const device ScaleT* biases) { + static_assert( + (elem_per_thread % granularity) == 0, + "elem_per_thread must be divisible by the granularity"); + Dequant dequant; + + using Slice = GroupSlice; + constexpr int group_slice = Slice::value; + constexpr int num_groups = Slice::num_groups; + constexpr int iters_per_group = Slice::iters_per_group; + +#pragma clang loop unroll(full) + for (int g = 0; g < num_groups; g++) { + U w_scale = exp_score * dequant.scale(scales[g]); + U bias = 0; + if constexpr (Cfg::has_bias) + bias = exp_score * static_cast(biases[g]); + + if constexpr (is_fast_path) { + auto vs = reinterpret_cast(values); +#pragma clang loop unroll(full) + for (int j = 0; j < iters_per_group; j++) { + fast_load_t p = vs[g * iters_per_group + j]; + int base = g * group_slice + 4 * j; + + U v0 = dequant.raw(p & fast_mask); + U v1 = dequant.raw((p >> (bits * 1)) & fast_mask); + U v2 = dequant.raw((p >> (bits * 2)) & fast_mask); + U v3 = dequant.raw((p >> (bits * 3)) & fast_mask); + + if constexpr (Cfg::has_bias) { + o[base + 0] = fma(o[base + 0], factor, fma(w_scale, v0, bias)); + o[base + 1] = fma(o[base + 1], factor, fma(w_scale, v1, bias)); + o[base + 2] = fma(o[base + 2], factor, fma(w_scale, v2, bias)); + o[base + 3] = fma(o[base + 3], factor, fma(w_scale, v3, bias)); + } else { + o[base + 0] = fma(o[base + 0], factor, v0 * w_scale); + o[base + 1] = fma(o[base + 1], factor, v1 * w_scale); + o[base + 2] = fma(o[base + 2], factor, v2 * w_scale); + o[base + 3] = fma(o[base + 3], factor, v3 * w_scale); + } + } + } else { + auto vs = reinterpret_cast(values) + + g * (group_slice / pack_factor) * bytes_per_pack; + thread uint8_t raw[pack_factor]; + +#pragma clang loop unroll(full) + for (int j = 0; j < group_slice; j += pack_factor) { + PackReader::load(vs, raw); +#pragma clang loop unroll(full) + for (int t = 0; t < pack_factor; ++t) { + U decoded = dequant.raw(raw[t]); + int idx = g * group_slice + j + t; + if constexpr (Cfg::has_bias) { + o[idx] = fma(o[idx], factor, fma(w_scale, decoded, bias)); + } else { + o[idx] = fma(o[idx], factor, decoded * w_scale); + } + } + vs += bytes_per_pack; + } + } + } + } +}; + +template +using ScaleTypeT = typename QuantConfig::template scale_storage_t; + +template +METAL_FUNC void quant_sdpa_vector_2pass_1_impl( + const device T* queries, + const device uint32_t* keys, + const device uint8_t* key_scales_raw, + const device uint32_t* values, + const device uint8_t* value_scales_raw, + device T* out, + device float* sums, + device float* maxs, + const constant int& N, + const constant size_t& k_stride, + const constant size_t& v_stride, + const constant size_t& k_group_stride, + const constant size_t& v_group_stride, + const constant float& scale, + const device bool* bmask, + const device T* fmask, + const constant int& mask_kv_seq_stride, + const constant int& mask_q_seq_stride, + const constant int& mask_head_stride, + const device uint8_t* key_biases_raw, + const device uint8_t* value_biases_raw, + const device T* sinks, + uint3 tid, + uint3 tpg, + uint3 tptg, + uint3 tidtg, + uint simd_lid) { + // Quadgroup approach: BN=8 quads × BD=4 lanes = 32 threads = 1 simdgroup + // Each quad processes one key, lanes split D dimension. + // elem_per_thread=D/4 is large enough for all pack_factors (max 8). + // + // GQA: multiple query heads sharing the same KV head are packed into the + // same threadgroup (along with q_seq_len) to share L2 cache for KV data. + // Grid: (num_kv_heads, batch, blocks) + // Group: (32, gqa_factor, q_seq_len) + using Cfg = QuantConfig; + using ScaleT = ScaleTypeT; + + static_assert( + (D % group_size) == 0, "group_size must divide the head dimension"); + constexpr int BD = (D > 256) ? 8 : 4; + constexpr int BN = 32 / BD; + constexpr int elem_per_thread = D / BD; + + const int local_quad_gid = simd_lid / BD; + const int local_quad_lid = simd_lid % BD; + + typedef float U; + + // Cast raw byte pointers to typed scale pointers + auto key_scales = reinterpret_cast(key_scales_raw); + auto value_scales = reinterpret_cast(value_scales_raw); + + thread U q[elem_per_thread]; + thread U o[elem_per_thread] = {0}; + + // Head/batch from grid + threadgroup position + const int kv_head_idx = tid.x; + const int batch_idx = tid.y; + const int block_idx = tid.z; + const int gqa_factor = tptg.y; + const int q_seq_len = tptg.z; + const int gqa_offset = tidtg.y; + const int q_seq_idx = tidtg.z; + const int num_kv_heads = tpg.x; + const int num_q_heads = num_kv_heads * gqa_factor; + const int q_head_idx = gqa_factor * kv_head_idx + gqa_offset; + const int q_batch_head_idx = batch_idx * num_q_heads + q_head_idx; + const int o_offset = q_batch_head_idx * q_seq_len + q_seq_idx; + const int q_offset = + query_transposed ? num_q_heads * q_seq_idx + q_batch_head_idx : o_offset; + + queries += q_offset * D + local_quad_lid * elem_per_thread; + + const int kv_batch_head_idx = batch_idx * num_kv_heads + kv_head_idx; + const int kv_idx = + (block_idx * BN + local_quad_gid) * D + local_quad_lid * elem_per_thread; + const int k_group_idx = + kv_batch_head_idx * k_group_stride + kv_idx / group_size; + const int v_group_idx = + kv_batch_head_idx * v_group_stride + kv_idx / group_size; + + QuantDataPtr key_ptr(keys, k_stride, kv_batch_head_idx, kv_idx); + QuantDataPtr value_ptr(values, v_stride, kv_batch_head_idx, kv_idx); + + key_scales += k_group_idx; + value_scales += v_group_idx; + const device ScaleT* key_bias_ptr = nullptr; + const device ScaleT* value_bias_ptr = nullptr; + if constexpr (Cfg::has_bias) { + key_bias_ptr = + reinterpret_cast(key_biases_raw) + k_group_idx; + value_bias_ptr = + reinterpret_cast(value_biases_raw) + v_group_idx; + } + + out += + o_offset * blocks * D + block_idx * D + local_quad_lid * elem_per_thread; + sums += o_offset * blocks + block_idx; + maxs += o_offset * blocks + block_idx; + + if (bool_mask) { + bmask += q_batch_head_idx * mask_head_stride + + (block_idx * BN + local_quad_gid) * mask_kv_seq_stride + + q_seq_idx * mask_q_seq_stride; + } + if (float_mask) { + fmask += q_batch_head_idx * mask_head_stride + + (block_idx * BN + local_quad_gid) * mask_kv_seq_stride + + q_seq_idx * mask_q_seq_stride; + } + + constexpr int stride = BN * D; + const int data_step = blocks * stride; + const int scale_step = data_step / group_size; + const int mask_step = BN * blocks * mask_kv_seq_stride; + + // Read the query +#pragma clang loop unroll(full) + for (int i = 0; i < elem_per_thread; i++) { + q[i] = static_cast(scale) * queries[i]; + } + + U max_score = Limits::finite_min; + U sum_exp_score = 0; + if (has_sinks && block_idx == 0 && local_quad_gid == 0) { + max_score = static_cast(sinks[q_head_idx]); + sum_exp_score = 1; + } + + // Main loop: each quad processes one key at a time + for (int i = block_idx * BN + local_quad_gid; i < N; i += blocks * BN) { + bool use_key = true; + if (do_causal) { + use_key = i <= (N - q_seq_len + int(q_seq_idx)); + } else if (bool_mask) { + use_key = bmask[0]; + } else if (float_mask) { + use_key = (fmask[0] >= Limits::finite_min); + } + + if (use_key) { + U score = QuantOps:: + template dot( + q, key_ptr.ptr(), key_scales, key_bias_ptr); + score = quad_sum(score); + for (int s = 4; s < BD; s <<= 1) { + score += simd_shuffle_xor(score, s); + } + + if (float_mask) { + score += static_cast(fmask[0]); + } + + // Online softmax update + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + + QuantOps:: + template accumulate( + o, + value_ptr.ptr(), + factor, + exp_score, + value_scales, + value_bias_ptr); + } + + // Advance pointers + key_ptr.advance(data_step); + value_ptr.advance(data_step); + key_scales += scale_step; + value_scales += scale_step; + if constexpr (Cfg::has_bias) { + key_bias_ptr += scale_step; + value_bias_ptr += scale_step; + } + if (bool_mask) { + bmask += mask_step; + } + if (float_mask) { + fmask += mask_step; + } + } + + U sg_max = (local_quad_lid == 0) ? max_score : Limits::finite_min; + U global_max = simd_max(sg_max); + + U sg_sum = (local_quad_lid == 0) + ? sum_exp_score * fast::exp(max_score - global_max) + : 0; + U global_sum = simd_sum(sg_sum); + + if (simd_lid == 0) { + sums[0] = global_sum; + maxs[0] = global_max; + } + + // Output reduction: sum across groups (same local_quad_lid only) + U rescale = fast::exp(max_score - global_max); + for (int i = 0; i < elem_per_thread; i++) { + U val = o[i] * rescale; + for (int s = BD; s < 32; s <<= 1) { + val += simd_shuffle_xor(val, s); + } + if (local_quad_gid == 0) { + out[i] = static_cast(val); + } + } +} + +template +[[kernel]] void quant_sdpa_vector_2pass_1( + const device T* queries [[buffer(0)]], + const device uint32_t* keys [[buffer(1)]], + const device uint8_t* key_scales [[buffer(2)]], + const device uint32_t* values [[buffer(3)]], + const device uint8_t* value_scales [[buffer(4)]], + device T* out [[buffer(5)]], + device float* sums [[buffer(6)]], + device float* maxs [[buffer(7)]], + const constant int& N [[buffer(9)]], + const constant size_t& k_stride [[buffer(10)]], + const constant size_t& v_stride [[buffer(11)]], + const constant size_t& k_group_stride [[buffer(12)]], + const constant size_t& v_group_stride [[buffer(13)]], + const constant float& scale [[buffer(14)]], + const device bool* bmask [[buffer(15), function_constant(bool_mask)]], + const device T* fmask [[buffer(16), function_constant(float_mask)]], + const constant int& mask_kv_seq_stride + [[buffer(17), function_constant(has_mask)]], + const constant int& mask_q_seq_stride + [[buffer(18), function_constant(has_mask)]], + const constant int& mask_head_stride + [[buffer(19), function_constant(has_mask)]], + const device uint8_t* key_biases + [[buffer(20), function_constant(has_affine_bias)]], + const device uint8_t* value_biases + [[buffer(21), function_constant(has_affine_bias)]], + const device T* sinks [[buffer(22), function_constant(has_sinks)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 tpg [[threadgroups_per_grid]], + uint3 tptg [[threads_per_threadgroup]], + uint3 tidtg [[thread_position_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { +#define QUANT_SDPA_DISPATCH(MODE, GS, B) \ + if (quant_mode_int == int(QuantMode::MODE) && quant_group_size == GS && \ + quant_bits == B) { \ + quant_sdpa_vector_2pass_1_impl( \ + queries, \ + keys, \ + key_scales, \ + values, \ + value_scales, \ + out, \ + sums, \ + maxs, \ + N, \ + k_stride, \ + v_stride, \ + k_group_stride, \ + v_group_stride, \ + scale, \ + bmask, \ + fmask, \ + mask_kv_seq_stride, \ + mask_q_seq_stride, \ + mask_head_stride, \ + key_biases, \ + value_biases, \ + sinks, \ + tid, \ + tpg, \ + tptg, \ + tidtg, \ + simd_lid); \ + return; \ + } + QUANT_SDPA_DISPATCH(Affine, 32, 4) + QUANT_SDPA_DISPATCH(Affine, 32, 6) + QUANT_SDPA_DISPATCH(Affine, 32, 8) + QUANT_SDPA_DISPATCH(Affine, 64, 4) + QUANT_SDPA_DISPATCH(Affine, 64, 6) + QUANT_SDPA_DISPATCH(Affine, 64, 8) + QUANT_SDPA_DISPATCH(Mxfp4, 32, 4) + QUANT_SDPA_DISPATCH(Nvfp4, 16, 4) + QUANT_SDPA_DISPATCH(Mxfp8, 32, 8) + // TurboQuant requires group_size == head_dim (one norm per vector). + // Gate on D to avoid the (D % group_size) == 0 static_assert failing. + if constexpr (D >= 64) { + QUANT_SDPA_DISPATCH(TurboQuant3, 64, 3) + QUANT_SDPA_DISPATCH(TurboQuant4, 64, 4) + } + if constexpr (D >= 128) { + QUANT_SDPA_DISPATCH(TurboQuant3, 128, 3) + QUANT_SDPA_DISPATCH(TurboQuant4, 128, 4) + } + if constexpr (D >= 256) { + QUANT_SDPA_DISPATCH(TurboQuant3, 256, 3) + QUANT_SDPA_DISPATCH(TurboQuant4, 256, 4) + } +#undef QUANT_SDPA_DISPATCH +} + template [[kernel]] void sdpa_vector_2pass_1( const device T* queries [[buffer(0)]], diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index c79cd51ff0..b79a501342 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -15,6 +15,68 @@ namespace mlx::core::fast { namespace { +// Select block count for vector 2-pass attention kernels. +int select_sdpa_blocks( + char devc, + int N, + int n_simds, + int head_dim, + [[maybe_unused]] bool quantized) { + if (devc == 's') { + int blocks = 64; + if (N > 1024 && n_simds > 4) { + if (N <= 8192) { + blocks = 128; + } else if (N <= 32768) { + blocks = 256; + } else if (N <= 65536) { + blocks = 512; + } else { + blocks = 1024; + } + } + return blocks; + } + + if (devc == 'd') { + int blocks = 128; + if (n_simds <= 2 && N > 8192) { + blocks = 256; + } else if (n_simds >= 6) { + if (N >= 16384 && N < 65536) { + blocks = 512; + } else if (N >= 65536) { + blocks = 1024; + } + } + return blocks; + } + + if (devc == 'g' || devc == 'p') { + if (n_simds <= 1) { + if (N <= 2048) { + return 32; + } else if (N <= 8192) { + return 64; + } else { + return 128; + } + } + if (head_dim >= 128) { + return 32; + } + if (N <= 8192) { + return 32; + } else if (N <= 32768) { + return 64; + } else { + return 128; + } + } + + return (n_simds >= 4) ? 64 : 32; +} + void sdpa_full_self_attention_nax( const Stream& s, metal::Device& d, @@ -442,38 +504,9 @@ void sdpa_vector_2pass( char devc = d.get_architecture().back(); int N = k.shape(2); - int blocks; - if (devc == 's') { - blocks = 64; - if (N > 1024 && n_simds > 4) { - if (N <= 8192) { - blocks = 128; - } else if (N <= 32768) { - blocks = 256; - } else if (N <= 65536) { - blocks = 512; - } else { - blocks = 1024; - } - } - } else if (devc == 'd') { - blocks = 128; - if (n_simds <= 2 && N > 8192) { - blocks = 256; - } else if (n_simds >= 6) { - if (N >= 16384 && N < 65536) { - blocks = 512; - } else if (N >= 65536) { - blocks = 1024; - } - } - } else { - if (n_simds >= 4) { - blocks = 64; - } else { - blocks = 32; - } - } + int blocks = + select_sdpa_blocks(devc, N, n_simds, q.shape(-1), /*quantized=*/false); + size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1); size_t k_seq_stride = k.strides()[2]; size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1); @@ -583,6 +616,189 @@ void sdpa_vector_2pass( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } +int quant_mode_to_int(QuantizationMode mode) { + switch (mode) { + case QuantizationMode::Affine: + return 0; + case QuantizationMode::Mxfp4: + return 1; + case QuantizationMode::Mxfp8: + return 2; + case QuantizationMode::Nvfp4: + return 3; + case QuantizationMode::TurboQuant3: + return 4; + case QuantizationMode::TurboQuant4: + return 5; + default: + throw std::invalid_argument( + "[quant_sdpa_vector_2pass] Unsupported quantization mode."); + } +} + +void quant_sdpa_vector_2pass( + const Stream& s, + metal::Device& d, + const array& q, + const array& k, + const array& k_scales, + const std::optional& k_biases, + const array& v, + const array& v_scales, + const std::optional& v_biases, + array& out, + float scale, + int group_size, + int bits, + bool do_causal, + const std::optional& mask, + const std::optional& sinks, + QuantizationMode mode) { + std::string kname; + kname.reserve(64); + kname += "quant_sdpa_vector_2pass_1_"; + kname += get_type_string(q.dtype()); + kname += "_"; + kname += std::to_string(q.shape(-1)); + kname += "_"; + kname += std::to_string(q.shape(-1)); + + int N = k.shape(2); + int gqa_factor = q.shape(1) / k.shape(1); + int n_simds = gqa_factor * q.shape(2); + + char devc = d.get_architecture().back(); + int blocks = + select_sdpa_blocks(devc, N, n_simds, q.shape(-1), /*quantized=*/true); + + // Head strides for quantized data (in uint32 units) and scales + size_t k_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1); + size_t v_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1); + size_t k_group_stride = + k_scales.shape(1) == 1 ? k_scales.strides(0) : k_scales.strides(1); + size_t v_group_stride = + v_scales.shape(1) == 1 ? v_scales.strides(0) : v_scales.strides(1); + + MTL::Size group_dims(32, gqa_factor, q.shape(2)); + MTL::Size grid_dims(k.shape(1), q.shape(0), blocks); + + Shape intermediate_shape; + intermediate_shape.reserve(out.ndim() + 1); + intermediate_shape.insert( + intermediate_shape.end(), out.shape().begin(), out.shape().end() - 1); + intermediate_shape.push_back(blocks); + intermediate_shape.push_back(out.shape().back()); + array intermediate(intermediate_shape, q.dtype(), nullptr, {}); + intermediate_shape.pop_back(); + array sums(intermediate_shape, float32, nullptr, {}); + array maxs(std::move(intermediate_shape), float32, nullptr, {}); + intermediate.set_data(allocator::malloc(intermediate.nbytes())); + sums.set_data(allocator::malloc(sums.nbytes())); + maxs.set_data(allocator::malloc(maxs.nbytes())); + + bool has_mask = mask.has_value(); + bool bool_mask = has_mask && (*mask).dtype() == bool_; + bool float_mask = has_mask && !bool_mask; + bool query_transposed = !q.flags().row_contiguous; + bool has_sinks = sinks.has_value(); + bool has_affine_bias = mode == QuantizationMode::Affine; + int quant_mode_int = quant_mode_to_int(mode); + metal::MTLFCList func_consts = { + {&has_mask, MTL::DataType::DataTypeBool, 20}, + {&query_transposed, MTL::DataType::DataTypeBool, 21}, + {&do_causal, MTL::DataType::DataTypeBool, 22}, + {&bool_mask, MTL::DataType::DataTypeBool, 23}, + {&float_mask, MTL::DataType::DataTypeBool, 24}, + {&has_sinks, MTL::DataType::DataTypeBool, 25}, + {&blocks, MTL::DataType::DataTypeInt, 26}, + {&has_affine_bias, MTL::DataType::DataTypeBool, 27}, + {&quant_mode_int, MTL::DataType::DataTypeInt, 28}, + {&bits, MTL::DataType::DataTypeInt, 29}, + {&group_size, MTL::DataType::DataTypeInt, 30}, + }; + std::string hash_name = kname; + hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask"; + hash_name += query_transposed ? "_qt" : "_qnt"; + hash_name += do_causal ? "_c" : "_nc"; + hash_name += has_sinks ? "_s" : "_ns"; + hash_name += has_affine_bias ? "_affine_" : "_noaffine_"; + hash_name += std::to_string(quant_mode_int) + "_"; + hash_name += std::to_string(bits) + "_"; + hash_name += std::to_string(group_size) + "_"; + hash_name += std::to_string(blocks); + + auto& compute_encoder = metal::get_command_encoder(s); + compute_encoder.add_temporary(intermediate); + compute_encoder.add_temporary(sums); + compute_encoder.add_temporary(maxs); + auto kernel = d.get_kernel(kname, hash_name, func_consts); + check_kernel_threadgroup_size(kernel, group_dims, hash_name); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(q, 0); + compute_encoder.set_input_array(k, 1); + compute_encoder.set_input_array(k_scales, 2); + compute_encoder.set_input_array(v, 3); + compute_encoder.set_input_array(v_scales, 4); + compute_encoder.set_output_array(intermediate, 5); + compute_encoder.set_output_array(sums, 6); + compute_encoder.set_output_array(maxs, 7); + compute_encoder.set_bytes(N, 9); + compute_encoder.set_bytes(k_stride, 10); + compute_encoder.set_bytes(v_stride, 11); + compute_encoder.set_bytes(k_group_stride, 12); + compute_encoder.set_bytes(v_group_stride, 13); + compute_encoder.set_bytes(scale, 14); + + if (has_mask) { + auto& m = *mask; + compute_encoder.set_input_array(m, 15 + float_mask); + int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0; + int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0; + int32_t head_stride = + m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0); + compute_encoder.set_bytes(kv_seq_stride, 17); + compute_encoder.set_bytes(q_seq_stride, 18); + compute_encoder.set_bytes(head_stride, 19); + } + + if (has_affine_bias) { + compute_encoder.set_input_array(*k_biases, 20); + compute_encoder.set_input_array(*v_biases, 21); + } + if (has_sinks) { + compute_encoder.set_input_array(*sinks, 22); + } + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + + // Second pass kernel + kname.clear(); + kname += "sdpa_vector_2pass_2_"; + kname += get_type_string(q.dtype()); + kname += "_"; + kname += std::to_string(out.shape(-1)); + + func_consts = { + {&blocks, MTL::DataType::DataTypeInt, 26}, + }; + hash_name = kname + "_" + std::to_string(blocks); + + kernel = d.get_kernel(kname, hash_name, func_consts); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(intermediate, 0); + compute_encoder.set_input_array(sums, 1); + compute_encoder.set_input_array(maxs, 2); + compute_encoder.set_output_array(out, 3); + compute_encoder.set_bytes(blocks, 4); + + group_dims = MTL::Size(1024, 1, 1); + grid_dims = MTL::Size(q.shape(0) * q.shape(1), q.shape(2), 1); + check_kernel_threadgroup_size(kernel, group_dims, kname); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + } // namespace bool ScaledDotProductAttention::use_fallback( @@ -640,6 +856,32 @@ bool ScaledDotProductAttention::supports_bool_mask() { return true; } +bool QuantizedScaledDotProductAttention::use_fallback( + const array& q, + const array& k, + bool is_training, + Stream s) { + if (is_training || s.device == Device::cpu) { + return true; + } + + bool supported_type = (q.dtype() == float32) || (q.dtype() == float16) || + (q.dtype() == bfloat16); + if (!supported_type) { + return true; + } + + int query_sequence_length = q.shape(2); + int key_sequence_length = k.shape(2); + int query_head_dim = q.shape(-1); + int gqa_factor = q.shape(1) / k.shape(1); + return query_sequence_length > 8 || + query_sequence_length > key_sequence_length || + !(query_head_dim == 64 || query_head_dim == 128 || + query_head_dim == 256 || query_head_dim == 512) || + (query_sequence_length * gqa_factor > 32); +} + void ScaledDotProductAttention::eval_gpu( const std::vector& inputs, std::vector& outputs) { @@ -785,6 +1027,136 @@ void ScaledDotProductAttention::eval_gpu( metal::get_command_encoder(s).add_temporaries(std::move(copies)); } +void QuantizedScaledDotProductAttention::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + auto& s = stream(); + auto& d = metal::device(s.device); + + bool is_affine = mode_ == QuantizationMode::Affine; + + // Inputs layout: + // [q, k, k_scales, k_biases (if affine), v, v_scales, v_biases (if affine), + // mask (if present), sinks (if present)] + auto& q_pre = inputs[0]; + auto& k_pre = inputs[1]; + auto& k_scales_pre = inputs[2]; + int idx = 3; + const array* k_biases_pre = nullptr; + if (is_affine) { + k_biases_pre = &inputs[idx++]; + } + auto& v_pre = inputs[idx++]; + auto& v_scales_pre = inputs[idx++]; + const array* v_biases_pre = nullptr; + if (is_affine) { + v_biases_pre = &inputs[idx++]; + } + auto& o = outputs[0]; + + std::vector copies; + copies.reserve(inputs.size()); + + auto copy_unless = [&copies, &s]( + auto predicate, const array& arr) -> const array& { + if (!predicate(arr)) { + array arr_copy = contiguous_copy_gpu(arr, s); + copies.push_back(std::move(arr_copy)); + return copies.back(); + } else { + return arr; + } + }; + + auto is_matrix_contiguous = [](const array& arr) { + return arr.strides(-1) == 1; + }; + + auto q_copy_unless = [](const array& arr) { + if (arr.flags().row_contiguous) { + return true; + } + auto& strides = arr.strides(); + auto& shape = arr.shape(); + if (shape[0] == 1 || shape[1] == 1) { + auto bidx = shape[0] == 1 ? 1 : 0; + return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) && + (strides[bidx] == shape[3]); + } + return false; + }; + + auto kv_copy_unless = [](const array& arr) { + auto& strides = arr.strides(); + auto& shape = arr.shape(); + if (strides.back() != 1) { + return false; + } + if (shape[0] == 1 || shape[1] == 1) { + return true; + } + return (strides[0] == strides[1] * shape[1]); + }; + + const auto& q = copy_unless(q_copy_unless, q_pre); + const auto& k = copy_unless(kv_copy_unless, k_pre); + const auto& k_scales = copy_unless(kv_copy_unless, k_scales_pre); + std::optional k_biases = std::nullopt; + if (is_affine) { + k_biases = copy_unless(kv_copy_unless, *k_biases_pre); + } + const auto& v = copy_unless(kv_copy_unless, v_pre); + const auto& v_scales = copy_unless(kv_copy_unless, v_scales_pre); + std::optional v_biases = std::nullopt; + if (is_affine) { + v_biases = copy_unless(kv_copy_unless, *v_biases_pre); + } + + std::optional mask = std::nullopt; + if (has_arr_mask_) { + auto mask_copy_unless = [&q](const array& arr) { + auto& strides = arr.strides(); + auto& shape = arr.shape(); + return arr.flags().row_contiguous || q.shape(0) == 1 || q.shape(1) == 1 || + (strides[0] == strides[1] * shape[1]); + }; + mask = copy_unless(mask_copy_unless, inputs[idx++]); + } + + std::optional sinks = std::nullopt; + if (has_sinks_) { + sinks = copy_unless(is_matrix_contiguous, inputs[idx++]); + } + + if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) { + o.copy_shared_buffer(q); + } else { + o.set_data(allocator::malloc(o.nbytes())); + } + + bool do_causal = do_causal_ && q.shape(2) > 1; + quant_sdpa_vector_2pass( + s, + d, + q, + k, + k_scales, + k_biases, + v, + v_scales, + v_biases, + o, + scale_, + group_size_, + bits_, + do_causal, + mask, + sinks, + mode_); + + metal::get_command_encoder(s).add_temporaries(std::move(copies)); +} + bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) { return true; } diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index 4819ed2724..085f8331d7 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -40,6 +40,14 @@ bool fast::ScaledDotProductAttention::supports_bool_mask() { return false; } +bool fast::QuantizedScaledDotProductAttention::use_fallback( + const array&, + const array&, + bool, + Stream) { + return true; +} + bool fast::ScaledDotProductAttentionVJP::use_fallback( const array& q, Stream s) { @@ -168,6 +176,7 @@ NO_GPU_USE_FALLBACK(RMSNorm) NO_GPU_MULTI(RMSNormVJP) NO_GPU_USE_FALLBACK(RoPE) NO_GPU_MULTI(ScaledDotProductAttention) +NO_GPU_MULTI(QuantizedScaledDotProductAttention) NO_GPU_MULTI(ScaledDotProductAttentionVJP) NO_GPU_MULTI(ConvertFP8) NO_GPU_MULTI(Quantize) diff --git a/mlx/fast.cpp b/mlx/fast.cpp index a668fe9abd..2b4caa7499 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -1,10 +1,12 @@ // Copyright © 2023-2024 Apple Inc. #include #include +#include #include "mlx/fast.h" #include "mlx/fast_primitives.h" #include "mlx/ops.h" +#include "mlx/primitives.h" #include "mlx/transforms.h" #include "mlx/transforms_impl.h" @@ -609,6 +611,23 @@ bool RoPE::is_equivalent(const Primitive& other) const { forward_ == a_other.forward_); } +std::pair prepare_sdpa_array_mask( + const array& mask, + Dtype out_type, + const Shape& full_mask_shape, + std::string_view tag, + Stream s) { + bool has_bool_mask = mask.dtype() == bool_; + if (!has_bool_mask && promote_types(mask.dtype(), out_type) != out_type) { + std::ostringstream msg; + msg << "[" << tag << "] Mask type must promote to output type " << out_type + << "."; + throw std::invalid_argument(msg.str()); + } + auto prepared_mask = has_bool_mask ? mask : astype(mask, out_type, s); + return {broadcast_to(prepared_mask, full_mask_shape, s), has_bool_mask}; +} + /** Computes: O = softmax(Q @ K.T) @ V **/ array scaled_dot_product_attention( const array& queries, @@ -653,7 +672,6 @@ array scaled_dot_product_attention( } else if (mask_arr) { has_mask = true; has_arr_mask = true; - has_bool_mask = mask_arr->dtype() == bool_; } if (has_arr_mask && mask_arr->ndim() > 4) { @@ -791,20 +809,16 @@ array scaled_dot_product_attention( auto stream = to_stream(s); std::vector inputs = {q, k, v}; if (has_arr_mask) { - // Check type - has_bool_mask = mask_arr->dtype() == bool_; - if (promote_types(mask_arr->dtype(), final_type) != final_type) { - std::ostringstream msg; - msg << "[scaled_dot_product_attention] Mask type must promote to output type " - << final_type << "."; - throw std::invalid_argument(msg.str()); - } else if (!has_bool_mask) { - mask_arr = astype(*mask_arr, final_type, stream); - } - // Broadcast mask auto mask_shape = queries.shape(); mask_shape.back() = keys.shape(-2); - inputs.push_back(broadcast_to(*mask_arr, mask_shape, stream)); + auto [prepared_mask, prepared_bool_mask] = prepare_sdpa_array_mask( + *mask_arr, + final_type, + mask_shape, + "scaled_dot_product_attention", + stream); + has_bool_mask = prepared_bool_mask; + inputs.push_back(std::move(prepared_mask)); } if (has_sinks) { if (promote_types(sinks->dtype(), final_type) != final_type) { @@ -861,6 +875,419 @@ array scaled_dot_product_attention( return fallback(std::move(inputs))[0]; } +array quantized_scaled_dot_product_attention( + const array& queries, + const array& keys, + const array& key_scales, + const std::optional& key_biases, + const array& values, + const array& value_scales, + const std::optional& value_biases, + const float scale, + const std::optional& mask /* = std::nullopt */, + const std::optional& sinks /* = std::nullopt */, + const std::optional group_size_ /* = std::nullopt */, + const std::optional bits_ /* = std::nullopt */, + const std::string& mode /* = "mxfp4" */, + bool causal /* = false */, + StreamOrDevice s /* = {} */) { + constexpr const char* tag = "quantized_scaled_dot_product_attention"; + + // Parse mode and get parameters + auto qmode = string_to_quantization_mode(mode, tag); + bool is_affine = qmode == QuantizationMode::Affine; + bool is_turbo = qmode == QuantizationMode::TurboQuant3 || + qmode == QuantizationMode::TurboQuant4; + auto [group_size, bits] = + quantization_params_from_mode(qmode, group_size_, bits_); + + // Validate mode-specific group_size and bits + if (is_affine) { + if (group_size != 32 && group_size != 64) { + std::ostringstream msg; + msg << "[" << tag << "] Affine mode supports group_size 32 or 64 " + << "but received " << group_size << "."; + throw std::invalid_argument(msg.str()); + } + if (bits != 4 && bits != 6 && bits != 8) { + std::ostringstream msg; + msg << "[" << tag + << "] Affine mode supports bits 4, 6, or 8 but received " << bits + << "."; + throw std::invalid_argument(msg.str()); + } + if (!key_biases.has_value() || !value_biases.has_value()) { + throw std::invalid_argument( + "[quantized_scaled_dot_product_attention] Affine mode requires " + "key_biases and value_biases."); + } + } else if (is_turbo) { + int expected_bits = (qmode == QuantizationMode::TurboQuant3) ? 3 : 4; + if (bits != expected_bits) { + std::ostringstream msg; + msg << "[" << tag << "] Mode '" << mode << "' requires bits " + << expected_bits << "."; + throw std::invalid_argument(msg.str()); + } + if (key_biases.has_value() || value_biases.has_value()) { + throw std::invalid_argument( + "[quantized_scaled_dot_product_attention] Biases are not supported " + "for TurboQuant modes."); + } + } else { + // FP modes have fixed params - verify if user overrode them incorrectly + auto [expected_gs, expected_bits] = + quantization_params_from_mode(qmode, std::nullopt, std::nullopt); + if (group_size != expected_gs || bits != expected_bits) { + std::ostringstream msg; + msg << "[" << tag << "] Mode '" << mode << "' requires group_size " + << expected_gs << " and bits " << expected_bits << "."; + throw std::invalid_argument(msg.str()); + } + if (key_biases.has_value() || value_biases.has_value()) { + throw std::invalid_argument( + "[quantized_scaled_dot_product_attention] Biases should only be " + "provided for affine mode."); + } + } + + // Validate rank 4 for all inputs + for (const auto& t : {queries, keys, key_scales, values, value_scales}) { + if (t.ndim() != 4) { + std::ostringstream msg; + msg << "[" << tag << "] input with shape " << t.shape() + << " expected to be rank 4."; + throw std::invalid_argument(msg.str()); + } + } + if (is_affine && (key_biases->ndim() != 4 || value_biases->ndim() != 4)) { + throw std::invalid_argument( + "[quantized_scaled_dot_product_attention] Biases must be rank 4."); + } + + // Validate dtypes + auto final_type = queries.dtype(); + if (!issubdtype(final_type, floating)) { + std::ostringstream msg; + msg << "[" << tag << "] queries must be floating type but got " + << final_type << "."; + throw std::invalid_argument(msg.str()); + } + if (!(final_type == float16 || final_type == bfloat16 || + final_type == float32)) { + std::ostringstream msg; + msg << "[" << tag + << "] queries must be float16, bfloat16, or float32 for quantized " + "attention; received " + << final_type << "."; + throw std::invalid_argument(msg.str()); + } + if (keys.dtype() != uint32 || values.dtype() != uint32) { + throw std::invalid_argument( + "[quantized_scaled_dot_product_attention] Keys and values must be " + "uint32."); + } + if (!is_affine && !is_turbo && + (key_scales.dtype() != uint8 || value_scales.dtype() != uint8)) { + throw std::invalid_argument( + "[quantized_scaled_dot_product_attention] Scales must be uint8 for fp " + "quantization."); + } + if (is_turbo) { + auto check_turbo_scale_dtype = [&](const array& s, const char* name) { + if (s.dtype() != float16 && s.dtype() != bfloat16 && + s.dtype() != float32) { + std::ostringstream msg; + msg << "[" << tag << "] TurboQuant " << name + << " scales must be float16, bfloat16, or float32."; + throw std::invalid_argument(msg.str()); + } + }; + check_turbo_scale_dtype(key_scales, "key"); + check_turbo_scale_dtype(value_scales, "value"); + } + + // Compute and validate dimensions + auto key_head_dim = (keys.shape(-1) * 32) / bits; + auto value_head_dim = (values.shape(-1) * 32) / bits; + auto n_q_heads = queries.shape(-3); + auto n_kv_heads = keys.shape(-3); + + if (queries.shape(0) != keys.shape(0) || + queries.shape(0) != values.shape(0)) { + throw std::invalid_argument( + "[quantized_scaled_dot_product_attention] Batch dimensions must match."); + } + if (n_q_heads % n_kv_heads != 0) { + std::ostringstream msg; + msg << "[" << tag << "] n_heads must be a multiple of n_kv_heads, found " + << n_q_heads << " vs " << n_kv_heads << "."; + throw std::invalid_argument(msg.str()); + } + if (keys.shape(-3) != values.shape(-3)) { + throw std::invalid_argument( + "[quantized_scaled_dot_product_attention] Keys and values must have " + "matching n_kv_heads."); + } + if (queries.shape(-1) != key_head_dim || + queries.shape(-1) != value_head_dim) { + std::ostringstream msg; + msg << "[" << tag << "] Query head dim " << queries.shape(-1) + << " must match key (" << key_head_dim << ") and value (" + << value_head_dim << ")."; + throw std::invalid_argument(msg.str()); + } + if (queries.shape(-1) % group_size != 0) { + std::ostringstream msg; + msg << "[" << tag << "] Head dim " << queries.shape(-1) + << " must be divisible by group_size " << group_size << "."; + throw std::invalid_argument(msg.str()); + } + if (is_turbo) { + if (group_size != queries.shape(-1)) { + std::ostringstream msg; + msg << "[" << tag << "] TurboQuant requires group_size == head_dim," + << " got group_size=" << group_size + << " head_dim=" << queries.shape(-1) << "."; + throw std::invalid_argument(msg.str()); + } + auto head_dim = queries.shape(-1); + if (head_dim != 64 && head_dim != 128 && head_dim != 256) { + std::ostringstream msg; + msg << "[" << tag << "] TurboQuant only supports head_dim in {64, 128, 256}," + << " got " << head_dim << "."; + throw std::invalid_argument(msg.str()); + } + } + + // Validate scale/bias shapes + auto expected_scale_dim = queries.shape(-1) / group_size; + for (const auto& [qdata, scale, bias, name] : + {std::tuple{&keys, &key_scales, &key_biases, "key"}, + std::tuple{&values, &value_scales, &value_biases, "value"}}) { + if (scale->shape(-1) != expected_scale_dim || + scale->shape(-3) != qdata->shape(-3) || + scale->shape(-2) != qdata->shape(-2)) { + std::ostringstream msg; + msg << "[" << tag << "] " << name << " scale shape mismatch."; + throw std::invalid_argument(msg.str()); + } + if (is_affine && bias->has_value() && (*bias)->shape() != scale->shape()) { + std::ostringstream msg; + msg << "[" << tag << "] " << name + << " bias shape must match scale shape."; + throw std::invalid_argument(msg.str()); + } + } + + // Validate mask + bool do_causal = causal; + bool has_arr_mask = mask.has_value(); + bool has_sinks = sinks.has_value(); + if (do_causal && has_arr_mask) { + throw std::invalid_argument( + "[quantized_scaled_dot_product_attention] Received both causal=true " + "and an array mask. Please provide only one mask type."); + } + if (has_arr_mask && mask->ndim() > 4) { + std::ostringstream msg; + msg << "[" << tag << "] Mask with shape " << mask->shape() + << " expected to have at most rank 4."; + throw std::invalid_argument(msg.str()); + } + + auto q = astype(queries, final_type, s); + // Inputs layout: + // [q, k, k_scales, k_biases (if affine), v, v_scales, v_biases (if affine), + // mask (if present), sinks (if present)] + auto fallback = [scale, + n_q_heads, + n_kv_heads, + do_causal, + has_arr_mask, + has_sinks, + is_affine, + is_turbo, + group_size, + bits, + mode, + s](const std::vector& inputs) { + if (is_turbo) { + throw std::runtime_error( + "[quantized_scaled_dot_product_attention] TurboQuant mode requires " + "a Metal GPU device."); + } + auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s); + int n_repeats = n_q_heads / n_kv_heads; + + auto k = inputs[1]; + auto k_scales = inputs[2]; + std::optional k_biases = std::nullopt; + int idx = 3; + if (is_affine) { + k_biases = inputs[idx++]; + } + auto v = inputs[idx++]; + auto v_scales = inputs[idx++]; + std::optional v_biases = std::nullopt; + if (is_affine) { + v_biases = inputs[idx++]; + } + + std::optional arr_mask = + has_arr_mask ? std::optional{inputs[idx++]} : std::nullopt; + std::optional sinks_opt = + has_sinks ? std::optional{inputs[idx++]} : std::nullopt; + + if (n_repeats > 1) { + q = unflatten(q, 1, {n_kv_heads, n_repeats}, s); + k = expand_dims(k, 2, s); + k_scales = expand_dims(k_scales, 2, s); + if (k_biases) { + k_biases = expand_dims(*k_biases, 2, s); + } + v = expand_dims(v, 2, s); + v_scales = expand_dims(v_scales, 2, s); + if (v_biases) { + v_biases = expand_dims(*v_biases, 2, s); + } + } + + auto scores = quantized_matmul( + q, + k, + k_scales, + k_biases, + /*transpose=*/true, + group_size, + bits, + mode, + s); + if (has_arr_mask || do_causal) { + auto make_or_fetch_mask = [&]() { + if (do_causal) { + int kL = k.shape(-2); + int qL = q.shape(-2); + int offset = kL - qL; + auto q_idx = arange(offset, qL + offset, s); + auto k_idx = arange(0, kL, s); + q_idx = expand_dims(q_idx, 1, s); + k_idx = expand_dims(k_idx, 0, s); + return greater_equal(q_idx, k_idx, s); + } + return *arr_mask; + }; + auto m = make_or_fetch_mask(); + if (n_repeats > 1 && m.ndim() >= 3) { + if (m.shape(-3) == 1) { + m = expand_dims(m, -3, s); + } else { + m = unflatten(m, -3, {n_kv_heads, n_repeats}, s); + } + } + if (m.dtype() == bool_) { + scores = where( + m, scores, array(finfo(scores.dtype()).min, scores.dtype()), s); + } else { + scores = add(scores, m, s); + } + } + + if (has_sinks) { + auto sinks = *sinks_opt; + // scores has shape B N_q N_k L_q L_k + sinks = expand_dims(sinks, {0, 2, 3}, s); + if (scores.ndim() == 5) { + sinks = unflatten(sinks, 1, {n_kv_heads, n_repeats}, s); + } + auto bsx_shape = scores.shape(); + bsx_shape.back() = 1; + scores = concatenate({broadcast_to(sinks, bsx_shape, s), scores}, -1, s); + } + scores = softmax(scores, std::vector{-1}, true, s); + if (has_sinks) { + auto start = Shape(scores.ndim(), 0); + start.back() = 1; + auto stop = scores.shape(); + scores = slice(scores, std::move(start), std::move(stop), s); + } + auto out = quantized_matmul( + scores, + v, + v_scales, + v_biases, + /*transpose=*/false, + group_size, + bits, + mode, + s); + if (n_repeats > 1) { + out = flatten(out, 1, 2, s); + } + return std::vector{out}; + }; + + auto stream = to_stream(s); + Shape full_mask_shape{ + queries.shape(0), queries.shape(1), queries.shape(2), keys.shape(-2)}; + + // For TurboQuant, scales are per-vector norms stored as floats (not uint8). + // Cast them to final_type so the Metal kernel sees the expected dtype. + auto ks = is_turbo ? astype(key_scales, final_type, stream) : key_scales; + auto vs = is_turbo ? astype(value_scales, final_type, stream) : value_scales; + + std::vector inputs = {q, keys, ks}; + if (is_affine) { + inputs.push_back(*key_biases); + } + inputs.push_back(values); + inputs.push_back(vs); + if (is_affine) { + inputs.push_back(*value_biases); + } + if (has_arr_mask) { + auto prepared_mask = prepare_sdpa_array_mask( + *mask, final_type, full_mask_shape, tag, stream); + inputs.push_back(std::move(prepared_mask.first)); + } + if (has_sinks) { + if (promote_types(sinks->dtype(), final_type) != final_type) { + std::ostringstream msg; + msg << "[" << tag << "] Type of sinks must promote to output type " + << final_type << "."; + throw std::invalid_argument(msg.str()); + } + if (sinks->ndim() != 1 || sinks->shape(0) != n_q_heads) { + std::ostringstream msg; + msg << "[" << tag << "] Received invalid shape for sinks " + << sinks->shape() << "."; + throw std::invalid_argument(msg.str()); + } + inputs.push_back(astype(*sinks, final_type, stream)); + } + + int out_dim = value_head_dim; + Shape out_shape{ + queries.shape(0), queries.shape(1), queries.shape(2), out_dim}; + + if (QuantizedScaledDotProductAttention::use_fallback( + q, keys, detail::in_grad_tracing(), stream)) { + return fallback(std::move(inputs))[0]; + } + + auto primitive = std::make_shared( + stream, + fallback, + scale, + has_arr_mask, + has_sinks, + do_causal, + group_size, + bits, + qmode); + return array(std::move(out_shape), final_type, primitive, std::move(inputs)); +} + std::vector ScaledDotProductAttention::vjp( const std::vector& primals, const std::vector& cotangents, @@ -915,6 +1342,16 @@ bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const { output_logsumexp_ == a_other.output_logsumexp_; } +bool QuantizedScaledDotProductAttention::is_equivalent( + const Primitive& other) const { + const QuantizedScaledDotProductAttention& a_other = + static_cast(other); + return scale_ == a_other.scale_ && has_arr_mask_ == a_other.has_arr_mask_ && + has_sinks_ == a_other.has_sinks_ && do_causal_ == a_other.do_causal_ && + group_size_ == a_other.group_size_ && bits_ == a_other.bits_ && + mode_ == a_other.mode_; +} + bool ScaledDotProductAttentionVJP::is_equivalent(const Primitive& other) const { const ScaledDotProductAttentionVJP& a_other = static_cast(other); diff --git a/mlx/fast.h b/mlx/fast.h index 1183aba8fe..b176e52e03 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -54,6 +54,24 @@ MLX_API array scaled_dot_product_attention( const std::optional& sinks = {}, StreamOrDevice s = {}); +/** Computes: `O = softmax(Q @ K.T) @ V` where K and V are quantized. **/ +MLX_API array quantized_scaled_dot_product_attention( + const array& queries, + const array& keys, + const array& key_scales, + const std::optional& key_biases, + const array& values, + const array& value_scales, + const std::optional& value_biases, + const float scale, + const std::optional& mask = std::nullopt, + const std::optional& sinks = std::nullopt, + std::optional group_size = std::nullopt, + std::optional bits = std::nullopt, + const std::string& mode = "mxfp4", + bool causal = false, + StreamOrDevice s = {}); + using TemplateArg = std::variant; using ScalarArg = std::variant; diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 4434830875..961a6a139e 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -260,6 +260,63 @@ class ScaledDotProductAttention : public Custom { bool output_logsumexp_; }; +class QuantizedScaledDotProductAttention : public Custom { + public: + QuantizedScaledDotProductAttention( + Stream stream, + std::function(std::vector)> fallback, + float scale, + bool has_arr_mask, + bool has_sinks, + bool do_causal, + int group_size, + int bits, + QuantizationMode mode) + : Custom(stream, std::move(fallback)), + scale_(scale), + has_arr_mask_(has_arr_mask), + has_sinks_(has_sinks), + do_causal_(do_causal), + group_size_(group_size), + bits_(bits), + mode_(mode) {} + + void eval_cpu(const std::vector&, std::vector&) override { + throw std::runtime_error("NYI"); + } + + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + static bool + use_fallback(const array& q, const array& k, bool is_training, Stream s); + + bool is_equivalent(const Primitive& other) const override; + + DEFINE_NAME(QuantizedScaledDotProductAttention); + DEFINE_INPUT_OUTPUT_SHAPE() + auto state() const { + return std::make_tuple( + nullptr, + scale_, + has_arr_mask_, + has_sinks_, + do_causal_, + group_size_, + bits_, + mode_); + } + + private: + float scale_; + bool has_arr_mask_; + bool has_sinks_; + bool do_causal_; + int group_size_; + int bits_; + QuantizationMode mode_; +}; + class ScaledDotProductAttentionVJP : public Custom { public: ScaledDotProductAttentionVJP( diff --git a/mlx/ops.cpp b/mlx/ops.cpp index defcc2f6e0..e4639923c5 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4365,35 +4365,6 @@ array conv_general( {in, wt}); } -std::pair quantization_params_from_mode( - QuantizationMode mode, - std::optional group_size_, - std::optional bits_) { - int default_group_size; - int default_bits; - switch (mode) { - case QuantizationMode::Affine: - default_group_size = 64; - default_bits = 4; - break; - case QuantizationMode::Nvfp4: - default_group_size = 16; - default_bits = 4; - break; - case QuantizationMode::Mxfp4: - default_group_size = 32; - default_bits = 4; - break; - case QuantizationMode::Mxfp8: - default_group_size = 32; - default_bits = 8; - break; - } - return { - group_size_.has_value() ? *group_size_ : default_group_size, - bits_.has_value() ? *bits_ : default_bits}; -} - std::pair validate_mode_with_type( std::string_view tag, const array& scales, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index f3acec574b..1ad17bf41c 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3425,9 +3425,13 @@ std::string quantization_mode_to_string(QuantizationMode mode) { case QuantizationMode::Mxfp8: return "mxfp8"; case QuantizationMode::Nvfp4: - default: return "nvfp4"; + case QuantizationMode::TurboQuant3: + return "turbo3"; + case QuantizationMode::TurboQuant4: + return "turbo4"; } + throw std::runtime_error("Unknown quantization mode"); } QuantizationMode string_to_quantization_mode( @@ -3441,6 +3445,10 @@ QuantizationMode string_to_quantization_mode( return QuantizationMode::Mxfp8; } else if (mode == "nvfp4") { return QuantizationMode::Nvfp4; + } else if (mode == "turbo3") { + return QuantizationMode::TurboQuant3; + } else if (mode == "turbo4") { + return QuantizationMode::TurboQuant4; } std::string msg; if (!tag.empty()) { @@ -3450,6 +3458,42 @@ QuantizationMode string_to_quantization_mode( throw std::invalid_argument(msg); } +std::pair quantization_params_from_mode( + QuantizationMode mode, + std::optional group_size_, + std::optional bits_) { + int default_group_size; + int default_bits; + switch (mode) { + case QuantizationMode::Affine: + default_group_size = 64; + default_bits = 4; + break; + case QuantizationMode::Nvfp4: + default_group_size = 16; + default_bits = 4; + break; + case QuantizationMode::Mxfp4: + default_group_size = 32; + default_bits = 4; + break; + case QuantizationMode::Mxfp8: + default_group_size = 32; + default_bits = 8; + break; + case QuantizationMode::TurboQuant3: + default_group_size = 64; + default_bits = 3; + break; + case QuantizationMode::TurboQuant4: + default_group_size = 64; + default_bits = 4; + break; + } + return { + group_size_.value_or(default_group_size), bits_.value_or(default_bits)}; +} + std::pair, std::vector> QuantizedMatmul::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/mlx/primitives.h b/mlx/primitives.h index 75fb978dce..d858b42807 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -152,13 +152,27 @@ class MLX_API UnaryPrimitive : public Primitive { UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete; }; -enum class QuantizationMode { Affine, Mxfp4, Mxfp8, Nvfp4 }; +enum class QuantizationMode { + Affine, + Mxfp4, + Mxfp8, + Nvfp4, + TurboQuant3, + TurboQuant4 +}; std::string quantization_mode_to_string(QuantizationMode mode); QuantizationMode string_to_quantization_mode( const std::string& mode, std::string_view error_tag = ""); +// Returns (group_size, bits) for a given quantization mode. +// Uses provided values if given, otherwise returns mode defaults. +std::pair quantization_params_from_mode( + QuantizationMode mode, + std::optional group_size = std::nullopt, + std::optional bits = std::nullopt); + class Abs : public UnaryPrimitive { public: explicit Abs(Stream stream) : UnaryPrimitive(stream) {} diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 1a43d89d9b..01ad481832 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -296,6 +296,105 @@ void init_fast(nb::module_& parent_module) { out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask="causal") )pbdoc"); + m.def( + "quantized_scaled_dot_product_attention", + [](const mx::array& q, + const mx::array& k, + const mx::array& k_scales, + const mx::array& v, + const mx::array& v_scales, + const std::optional& k_biases, + const std::optional& v_biases, + const float scale, + const std::optional& mask, + const std::optional& sinks, + std::optional group_size, + std::optional bits, + const std::string& mode, + bool causal, + mx::StreamOrDevice s) { + return mx::fast::quantized_scaled_dot_product_attention( + q, + k, + k_scales, + k_biases, + v, + v_scales, + v_biases, + scale, + mask, + sinks, + group_size, + bits, + mode, + causal, + s); + }, + "q"_a, + "k"_a, + "k_scales"_a, + "v"_a, + "v_scales"_a, + nb::kw_only(), + "k_biases"_a = nb::none(), + "v_biases"_a = nb::none(), + "scale"_a, + "mask"_a = nb::none(), + "sinks"_a = nb::none(), + "group_size"_a = nb::none(), + "bits"_a = nb::none(), + "mode"_a = "mxfp4", + "causal"_a = false, + "stream"_a = nb::none(), + nb::sig( + "def quantized_scaled_dot_product_attention(q: array, k: array, k_scales: array, v: array, v_scales: array, *, k_biases: Optional[array] = None, v_biases: Optional[array] = None, scale: float, mask: Optional[array] = None, sinks: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = \"mxfp4\", causal: bool = False, stream: Union[None, Stream, Device] = None) -> array")); + + m.def( + "quantized_scaled_dot_product_attention", + &mx::fast::quantized_scaled_dot_product_attention, + "q"_a, + "k"_a, + "k_scales"_a, + "k_biases"_a = nb::none(), + "v"_a, + "v_scales"_a, + "v_biases"_a = nb::none(), + nb::kw_only(), + "scale"_a, + "mask"_a = nb::none(), + "sinks"_a = nb::none(), + "group_size"_a = nb::none(), + "bits"_a = nb::none(), + "mode"_a = "mxfp4", + "causal"_a = false, + "stream"_a = nb::none(), + nb::sig( + "def quantized_scaled_dot_product_attention(q: array, k: array, k_scales: array, k_biases: Optional[array] = None, v: array, v_scales: array, v_biases: Optional[array] = None, *, scale: float, mask: Optional[array] = None, sinks: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = \"mxfp4\", causal: bool = False, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + A fast implementation of multi-head attention where the keys and values are quantized. + + see :func:`scaled_dot_product_attention` for more details. + + Args: + q (array): Input query array. + k (array): Input keys array. + k_scales (array): Scales for the quantized keys array. + k_biases (array or None): Biases for the affine-quantized keys array. Required for affine mode, None for fp modes. + v (array): Input values array. + v_scales (array): Scales for the quantized values array. + v_biases (array or None): Biases for the affine-quantized values array. Required for affine mode, None for fp modes. + scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``) + mask (array, optional): An additive or boolean mask to apply to the query-key scores. + sinks (array, optional): An optional array of attention sinks with shape ``[N_q]``. + group_size (int, optional): The group size used in the KV quantization. Defaults follow the quantization ``mode``. + bits (int, optional): The bits used in the KV quantization. Defaults follow the quantization ``mode``. + mode (str, optional): The quantization mode: ``"mxfp4"``, ``"mxfp8"``, ``"nvfp4"``, ``"affine"``, ``"turbo3"``, or ``"turbo4"``. TurboQuant modes use a WHT rotation + Lloyd-Max codebook and require ``group_size == head_dim``; ``k_biases`` and ``v_biases`` must be ``None``. + causal (bool, optional): Whether to apply lower-right aligned causal masking. + Cannot be used together with ``mask``. + Returns: + array: The output array. + )pbdoc"); + m.def( "metal_kernel", [](const std::string& name, diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index a7472e9920..3fa0579336 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -753,6 +753,446 @@ def test_qmv_small_non_multiples(self): self.assertEqual(y_q.shape, y_hat.shape) self.assertLess((y_q - y_hat).abs().max(), 1e-3) + def test_quantized_sdpa(self): + if mx.default_device() == mx.cpu: + self.skipTest("Quantized fast attention is only available on GPU.") + + mx.random.seed(0) + B, Hq, Hkv = 1, 2, 1 + Lq, Lk = 4, 640 + + for D in [128, 512]: + for mode in ["mxfp4", "mxfp8", "nvfp4"]: + with self.subTest(D=D, mode=mode): + bits = 8 if mode == "mxfp8" else 4 + q = 0.1 * mx.random.normal(shape=(B, Hq, Lq, D)) + k = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + v = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + + k_q, k_scales = mx.quantize(k, mode=mode) + v_q, v_scales = mx.quantize(v, mode=mode) + + ref = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0) + out = mx.fast.quantized_scaled_dot_product_attention( + q, + k_q, + k_scales, + v_q, + v_scales, + scale=1.0, + mode=mode, + bits=bits, + ) + + self.assertEqual(out.shape, ref.shape) + tol = 5e-2 if bits == 4 else 2e-2 + self.assertLess((out - ref).abs().max(), tol) + + def test_quantized_sdpa_affine(self): + if mx.default_device() == mx.cpu: + self.skipTest("Quantized fast attention is only available on GPU.") + + mx.random.seed(0) + B, Hq, Hkv = 1, 2, 1 + Lq, Lk, D = 4, 640, 128 + + for bits in [4, 6, 8]: + with self.subTest(bits=bits): + q = 0.1 * mx.random.normal(shape=(B, Hq, Lq, D)) + k = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + v = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + + k_q, k_scales, k_biases = mx.quantize( + k, group_size=32, bits=bits, mode="affine" + ) + v_q, v_scales, v_biases = mx.quantize( + v, group_size=32, bits=bits, mode="affine" + ) + + ref = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0) + out = mx.fast.quantized_scaled_dot_product_attention( + q, + k_q, + k_scales, + k_biases, + v_q, + v_scales, + v_biases, + scale=1.0, + mode="affine", + group_size=32, + bits=bits, + ) + + self.assertEqual(out.shape, ref.shape) + if bits == 6: + tol = 1e-1 + elif bits == 4: + tol = 5e-2 + else: + tol = 2e-2 + self.assertLess((out - ref).abs().max(), tol) + + def test_quantized_sdpa_masked(self): + if mx.default_device() == mx.cpu: + self.skipTest("Quantized fast attention is only available on GPU.") + + mx.random.seed(0) + B, Hq, Hkv = 1, 2, 1 + Lk, D = 640, 128 + + for mode in ["mxfp4", "mxfp8", "nvfp4"]: + bits = 8 if mode == "mxfp8" else 4 + for Lq in [4, 9]: + q = 0.1 * mx.random.normal(shape=(B, Hq, Lq, D)) + k = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + v = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + + bool_mask = mx.random.uniform(shape=(Lq, Lk)) > 0.2 + additive_mask = mx.where( + bool_mask, + mx.zeros((Lq, Lk), dtype=mx.float32), + mx.full((Lq, Lk), -1e9, dtype=mx.float32), + ) + + k_q, k_scales = mx.quantize(k, mode=mode) + v_q, v_scales = mx.quantize(v, mode=mode) + + for mask_name, mask in { + "bool": bool_mask, + "additive": additive_mask, + }.items(): + with self.subTest(mode=mode, bits=bits, Lq=Lq, mask=mask_name): + ref = mx.fast.scaled_dot_product_attention( + q, k, v, scale=1.0, mask=mask + ) + out = mx.fast.quantized_scaled_dot_product_attention( + q, + k_q, + k_scales, + v_q, + v_scales, + scale=1.0, + mode=mode, + bits=bits, + mask=mask, + ) + + self.assertEqual(out.shape, ref.shape) + tol = 5e-2 if bits == 4 else 2e-2 + self.assertLess((out - ref).abs().max(), tol) + + def test_quantized_sdpa_affine_masked(self): + if mx.default_device() == mx.cpu: + self.skipTest("Quantized fast attention is only available on GPU.") + + mx.random.seed(0) + B, Hq, Hkv = 1, 2, 1 + Lk, D = 640, 128 + + for bits in [4, 6, 8]: + for Lq in [4, 9]: + q = 0.1 * mx.random.normal(shape=(B, Hq, Lq, D)) + k = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + v = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + + bool_mask = mx.random.uniform(shape=(Lq, Lk)) > 0.2 + additive_mask = mx.where( + bool_mask, + mx.zeros((Lq, Lk), dtype=mx.float32), + mx.full((Lq, Lk), -1e9, dtype=mx.float32), + ) + + k_q, k_scales, k_biases = mx.quantize( + k, group_size=32, bits=bits, mode="affine" + ) + v_q, v_scales, v_biases = mx.quantize( + v, group_size=32, bits=bits, mode="affine" + ) + + for mask_name, mask in { + "bool": bool_mask, + "additive": additive_mask, + }.items(): + with self.subTest(bits=bits, Lq=Lq, mask=mask_name): + ref = mx.fast.scaled_dot_product_attention( + q, k, v, scale=1.0, mask=mask + ) + out = mx.fast.quantized_scaled_dot_product_attention( + q, + k_q, + k_scales, + k_biases, + v_q, + v_scales, + v_biases, + scale=1.0, + mode="affine", + group_size=32, + bits=bits, + mask=mask, + ) + + self.assertEqual(out.shape, ref.shape) + if bits == 6: + tol = 1e-1 + elif bits == 4: + tol = 5e-2 + else: + tol = 2e-2 + self.assertLess((out - ref).abs().max(), tol) + + def test_quantized_sdpa_sinks(self): + if mx.default_device() == mx.cpu: + self.skipTest("Quantized fast attention is only available on GPU.") + + mx.random.seed(0) + B, Hq, Hkv = 1, 2, 1 + Lk, D = 640, 128 + sinks = mx.array([0.7, -0.4], dtype=mx.float32) + + for mode in ["mxfp4", "mxfp8", "nvfp4"]: + bits = 8 if mode == "mxfp8" else 4 + for Lq in [4, 9]: + with self.subTest(mode=mode, bits=bits, Lq=Lq): + q = 0.1 * mx.random.normal(shape=(B, Hq, Lq, D)) + k = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + v = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + + k_q, k_scales = mx.quantize(k, mode=mode) + v_q, v_scales = mx.quantize(v, mode=mode) + + ref = mx.fast.scaled_dot_product_attention( + q, k, v, scale=1.0, sinks=sinks + ) + out = mx.fast.quantized_scaled_dot_product_attention( + q, + k_q, + k_scales, + v_q, + v_scales, + scale=1.0, + mode=mode, + bits=bits, + sinks=sinks, + ) + + self.assertEqual(out.shape, ref.shape) + tol = 5e-2 if bits == 4 else 2e-2 + self.assertLess((out - ref).abs().max(), tol) + + def test_quantized_sdpa_masked_with_sinks(self): + if mx.default_device() == mx.cpu: + self.skipTest("Quantized fast attention is only available on GPU.") + + mx.random.seed(0) + B, Hq, Hkv = 1, 2, 1 + Lk, D = 640, 128 + sinks = mx.array([0.5, -0.3], dtype=mx.float32) + + for Lq in [4, 9]: + q = 0.1 * mx.random.normal(shape=(B, Hq, Lq, D)) + k = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + v = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + + bool_mask = mx.random.uniform(shape=(Lq, Lk)) > 0.2 + additive_mask = mx.where( + bool_mask, + mx.zeros((Lq, Lk), dtype=mx.float32), + mx.full((Lq, Lk), -1e9, dtype=mx.float32), + ) + + mode = "mxfp4" + bits = 4 + k_q, k_scales = mx.quantize(k, mode=mode) + v_q, v_scales = mx.quantize(v, mode=mode) + + for mask_name, mask in { + "bool": bool_mask, + "additive": additive_mask, + }.items(): + with self.subTest(Lq=Lq, mask=mask_name): + ref = mx.fast.scaled_dot_product_attention( + q, k, v, scale=1.0, mask=mask, sinks=sinks + ) + out = mx.fast.quantized_scaled_dot_product_attention( + q, + k_q, + k_scales, + v_q, + v_scales, + scale=1.0, + mode=mode, + bits=bits, + mask=mask, + sinks=sinks, + ) + + self.assertEqual(out.shape, ref.shape) + self.assertLess((out - ref).abs().max(), 5e-2) + + def test_quantized_sdpa_affine_masked_with_sinks(self): + if mx.default_device() == mx.cpu: + self.skipTest("Quantized fast attention is only available on GPU.") + + mx.random.seed(0) + B, Hq, Hkv = 1, 2, 1 + Lk, D = 640, 128 + bits = 4 + sinks = mx.array([0.2, -0.1], dtype=mx.float32) + + for Lq in [4, 9]: + with self.subTest(Lq=Lq): + q = 0.1 * mx.random.normal(shape=(B, Hq, Lq, D)) + k = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + v = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + + bool_mask = mx.random.uniform(shape=(Lq, Lk)) > 0.2 + + k_q, k_scales, k_biases = mx.quantize( + k, group_size=32, bits=bits, mode="affine" + ) + v_q, v_scales, v_biases = mx.quantize( + v, group_size=32, bits=bits, mode="affine" + ) + + ref = mx.fast.scaled_dot_product_attention( + q, k, v, scale=1.0, mask=bool_mask, sinks=sinks + ) + out = mx.fast.quantized_scaled_dot_product_attention( + q, + k_q, + k_scales, + k_biases, + v_q, + v_scales, + v_biases, + scale=1.0, + mode="affine", + group_size=32, + bits=bits, + mask=bool_mask, + sinks=sinks, + ) + + self.assertEqual(out.shape, ref.shape) + self.assertLess((out - ref).abs().max(), 5e-2) + + def test_quantized_sdpa_causal(self): + if mx.default_device() == mx.cpu: + self.skipTest("Quantized fast attention is only available on GPU.") + + mx.random.seed(0) + B, Hq, Hkv = 1, 2, 1 + Lk, D = 640, 128 + + for mode in ["mxfp4", "mxfp8", "nvfp4"]: + bits = 8 if mode == "mxfp8" else 4 + for Lq in [4, 9]: + with self.subTest(mode=mode, bits=bits, Lq=Lq): + q = 0.1 * mx.random.normal(shape=(B, Hq, Lq, D)) + k = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + v = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + + k_q, k_scales = mx.quantize(k, mode=mode) + v_q, v_scales = mx.quantize(v, mode=mode) + + ref = mx.fast.scaled_dot_product_attention( + q, k, v, scale=1.0, mask="causal" + ) + out = mx.fast.quantized_scaled_dot_product_attention( + q, + k_q, + k_scales, + v_q, + v_scales, + scale=1.0, + mode=mode, + bits=bits, + causal=True, + ) + + self.assertEqual(out.shape, ref.shape) + tol = 5e-2 if bits == 4 else 2e-2 + self.assertLess((out - ref).abs().max(), tol) + + def test_quantized_sdpa_affine_causal(self): + if mx.default_device() == mx.cpu: + self.skipTest("Quantized fast attention is only available on GPU.") + + mx.random.seed(0) + B, Hq, Hkv = 1, 2, 1 + Lk, D = 640, 128 + + for bits in [4, 6, 8]: + for Lq in [4, 9]: + with self.subTest(bits=bits, Lq=Lq): + q = 0.1 * mx.random.normal(shape=(B, Hq, Lq, D)) + k = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + v = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + + k_q, k_scales, k_biases = mx.quantize( + k, group_size=32, bits=bits, mode="affine" + ) + v_q, v_scales, v_biases = mx.quantize( + v, group_size=32, bits=bits, mode="affine" + ) + + ref = mx.fast.scaled_dot_product_attention( + q, k, v, scale=1.0, mask="causal" + ) + out = mx.fast.quantized_scaled_dot_product_attention( + q, + k_q, + k_scales, + k_biases, + v_q, + v_scales, + v_biases, + scale=1.0, + mode="affine", + group_size=32, + bits=bits, + causal=True, + ) + + self.assertEqual(out.shape, ref.shape) + if bits == 6: + tol = 1e-1 + elif bits == 4: + tol = 5e-2 + else: + tol = 2e-2 + self.assertLess((out - ref).abs().max(), tol) + + def test_quantized_sdpa_causal_with_array_mask_error(self): + if mx.default_device() == mx.cpu: + self.skipTest("Quantized fast attention is only available on GPU.") + + B, Hq, Hkv = 1, 2, 1 + Lq, Lk, D = 4, 640, 128 + q = mx.random.normal(shape=(B, Hq, Lq, D)) + k = mx.random.normal(shape=(B, Hkv, Lk, D)) + v = mx.random.normal(shape=(B, Hkv, Lk, D)) + mask = mx.ones(shape=(Lq, Lk), dtype=mx.bool_) + + k_q, k_scales = mx.quantize(k, mode="mxfp4") + v_q, v_scales = mx.quantize(v, mode="mxfp4") + + with self.assertRaises(ValueError): + mx.fast.quantized_scaled_dot_product_attention( + q, + k_q, + k_scales, + v_q, + v_scales, + scale=1.0, + mode="mxfp4", + bits=4, + mask=mask, + causal=True, + ) + def test_gather_qmm(self): def quantize(w, transpose=True, group_size=None, bits=None, mode="affine"): if mode == "affine": @@ -1168,6 +1608,193 @@ def test_quantize_strided(self): expected = mx.dequantize(w_q, mx.contiguous(scales), mode=mode) self.assertTrue(mx.allclose(w_hat, expected)) + def test_quantized_sdpa_turbo(self): + """TurboQuant 3-bit and 4-bit SDPA: pack indices, run kernel, compare + against standard SDPA using the dequantized K and V. + + Tests all supported head_dim values (64, 128, 256) and batch sizes + to exercise each Metal template instantiation independently. + """ + if mx.default_device() == mx.cpu: + self.skipTest("TurboQuant SDPA requires a Metal GPU.") + + import numpy as np + + CB3 = np.array( + [-1.7481, -1.0498, -0.5012, -0.1624, 0.1624, 0.5012, 1.0498, 1.7481], + dtype=np.float32, + ) + CB4 = np.array( + [ + -1.9672, -1.3305, -1.0130, -0.7811, + -0.5714, -0.4053, -0.2382, -0.0784, + 0.0784, 0.2382, 0.4053, 0.5714, + 0.7811, 1.0130, 1.3305, 1.9672, + ], + dtype=np.float32, + ) + + def pack_3bit(indices_np): + """Pack 3-bit indices to uint32 using PackReader<3> byte layout. + 8 indices → 3 bytes (24 bits), little-endian, stored as uint32 words. + """ + *batch, D = indices_np.shape + assert D % 8 == 0 + n_packs = D // 8 + flat = indices_np.reshape(-1, D).astype(np.int64) + n_tok = flat.shape[0] + buf = np.zeros((n_tok, n_packs * 3), dtype=np.uint8) + for t in range(n_tok): + for p in range(n_packs): + ix = flat[t, p * 8 : p * 8 + 8] + v = ( + int(ix[0]) + | (int(ix[1]) << 3) + | (int(ix[2]) << 6) + | (int(ix[3]) << 9) + | (int(ix[4]) << 12) + | (int(ix[5]) << 15) + | (int(ix[6]) << 18) + | (int(ix[7]) << 21) + ) + buf[t, p * 3] = v & 0xFF + buf[t, p * 3 + 1] = (v >> 8) & 0xFF + buf[t, p * 3 + 2] = (v >> 16) & 0xFF + n_u32 = D * 3 // 32 + u32 = np.frombuffer(buf.tobytes(), dtype="