diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 3573b2640..4d10ee0f7 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -207,6 +207,18 @@ def setup_arg_parser(): type=int, default=DEFAULT_QUANTIZED_KV_START, ) + parser.add_argument( + "--kv-cache-type", + type=str, + help="KV cache type. Use 'turbo3' or 'turbo4' for TurboQuant compression.", + default=None, + ) + parser.add_argument( + "--turbo-fp16-layers", + type=int, + help="Number of attention layers at each end to keep in float16 when using TurboQuant.", + default=0, + ) parser.add_argument( "--draft-model", type=str, @@ -317,6 +329,8 @@ def generate_step( kv_bits: Optional[int] = None, kv_group_size: int = 64, quantized_kv_start: int = 0, + kv_cache_type: Optional[str] = None, + turbo_fp16_layers: int = 0, prompt_progress_callback: Optional[Callable[[int, int], None]] = None, input_embeddings: Optional[mx.array] = None, ) -> Generator[Tuple[mx.array, mx.array], None, None]: @@ -369,10 +383,18 @@ def generate_step( # Create the KV cache for generation if prompt_cache is None: - prompt_cache = cache.make_prompt_cache( - model, - max_kv_size=max_kv_size, - ) + if kv_cache_type in ("turbo3", "turbo4"): + turbo_bits = int(kv_cache_type[-1]) + prompt_cache = cache.make_turbo_cache( + model, + bits=turbo_bits, + fp16_layers=turbo_fp16_layers, + ) + else: + prompt_cache = cache.make_prompt_cache( + model, + max_kv_size=max_kv_size, + ) prompt_progress_callback = prompt_progress_callback or (lambda *_: None) @@ -2081,6 +2103,8 @@ def main(): kv_bits=args.kv_bits, kv_group_size=args.kv_group_size, quantized_kv_start=args.quantized_kv_start, + kv_cache_type=args.kv_cache_type, + turbo_fp16_layers=args.turbo_fp16_layers, draft_model=draft_model, num_draft_tokens=args.num_draft_tokens, ) diff --git a/mlx_lm/models/base.py b/mlx_lm/models/base.py index d7c3efb28..43b01c093 100644 --- a/mlx_lm/models/base.py +++ b/mlx_lm/models/base.py @@ -105,6 +105,59 @@ def quantized_scaled_dot_product_attention( return out +def _turbo_scaled_dot_product_attention( + queries: mx.array, + q_keys: tuple[mx.array, mx.array], + q_values: tuple[mx.array, mx.array], + scale: float, + mask: Optional[mx.array], + mode: str, + group_size: int, +) -> mx.array: + """SDPA for TurboQuant-compressed keys/values. + + Rotates queries with a normalized WHT before calling the fused kernel. + The kernel handles causal masking, GQA fan-out, online softmax, and + the full attention reduction in one Metal pass. + """ + import math + + D = queries.shape[-1] + inv_sqrt_d = 1.0 / math.sqrt(D) + + # WHT in float32 to match encoding precision (bfloat16 butterfly accumulates + # enough error to shift softmax peaks on large-scale models). + try: + from .turbo_metal import is_available, wht_rotate_metal + if is_available(): + q_rot = wht_rotate_metal(queries, scale=inv_sqrt_d) + else: + raise ImportError + except (ImportError, Exception): + from .turbo_cache import _hadamard_transform + q_rot = _hadamard_transform(queries.astype(mx.float32), scale=inv_sqrt_d) + q_rot = q_rot.astype(queries.dtype) + + k_packed, k_scales = q_keys + v_packed, v_scales = q_values + + causal = mask == "causal" if isinstance(mask, str) else False + arr_mask = None if (mask is None or isinstance(mask, str)) else mask + + return mx.fast.quantized_scaled_dot_product_attention( + q_rot, + k_packed, + k_scales, + v_packed, + v_scales, + scale=scale, + mask=arr_mask, + mode=mode, + group_size=group_size, + causal=causal, + ) + + def scaled_dot_product_attention( queries, keys, @@ -114,7 +167,22 @@ def scaled_dot_product_attention( mask: Optional[mx.array], sinks: Optional[mx.array] = None, ) -> mx.array: - if hasattr(cache, "bits"): + from .turbo_cache import TurboQuantKVCache + + if isinstance(cache, TurboQuantKVCache) and isinstance(keys, tuple): + if sinks is not None: + raise ValueError("TurboQuant SDPA does not support attention sinks.") + return _turbo_scaled_dot_product_attention( + queries, + keys, + values, + scale=scale, + mask=mask, + mode=cache.mode, + group_size=cache.group_size, + ) + elif hasattr(cache, "bits") and not isinstance(cache, TurboQuantKVCache): + # Standard QuantizedKVCache (affine/mxfp4/…) if sinks is not None: raise ValueError("Quantized SDPA does not support attention sinks.") return quantized_scaled_dot_product_attention( diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index b84c9d650..6007542fa 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -40,6 +40,54 @@ def make_prompt_cache( return [KVCache() for _ in range(num_layers)] +def make_turbo_cache( + model: nn.Module, + bits: int = 3, + fp16_layers: int = 0, +) -> List[Any]: + """Build a per-layer cache list using TurboQuant for full-attention layers. + + Layers that use a linear-attention state (e.g. DeltaNet / SSM) are left + with their default cache type from the model's own ``make_cache()``. + Standard full-attention layers receive :class:`TurboQuantKVCache`. + + The first and last ``fp16_layers`` attention layers keep a plain + :class:`KVCache` (float16) to preserve quality at sequence boundaries. + + Args: + model: The language model (must expose ``.layers``). + bits: TurboQuant bit-width, 3 or 4. Default: ``3``. + fp16_layers: Number of attention layers at each end to keep in + float16. Default: ``0`` (all attention layers compressed). + + Returns: + A list of cache objects, one per model layer. + """ + from .turbo_cache import TurboQuantKVCache + + # Start from the model's own cache (handles SSM/DeltaNet layers correctly) + base_caches = make_prompt_cache(model) + + # Identify which positions are plain KVCache (full-attention layers) + attn_indices = [ + i for i, c in enumerate(base_caches) if isinstance(c, KVCache) + ] + + # Determine the turbo range (skip first and last fp16_layers). + # `-fp16_layers or None` handles fp16_layers=0: -0 == 0 which truncates + # the slice to empty, so we use None (meaning "to end") instead. + turbo_indices = set(attn_indices[fp16_layers : -fp16_layers or None]) + + caches = [] + for i, c in enumerate(base_caches): + if i in turbo_indices: + caches.append(TurboQuantKVCache(bits=bits)) + else: + caches.append(c) + + return caches + + def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}): """ Save a pre-computed prompt cache to a file. diff --git a/mlx_lm/models/turbo_cache.py b/mlx_lm/models/turbo_cache.py new file mode 100644 index 000000000..7d632865d --- /dev/null +++ b/mlx_lm/models/turbo_cache.py @@ -0,0 +1,419 @@ +# Copyright © 2024 Apple Inc. +"""TurboQuant KV cache for mlx-lm. + +Integrates TurboQuant (arXiv 2504.19874) via the generic +mx.fast.quantized_scaled_dot_product_attention API (modes "turbo3" / "turbo4"). + +Encoding scheme (per key/value vector of dimension D): + 1. k_rot = WHT(k) / sqrt(D) [normalized Walsh-Hadamard transform] + 2. norm = ||k_rot|| + 3. k_scaled = k_rot / norm * sqrt(D) [→ N(0,1) for Lloyd-Max codebook] + 4. idx = argmin |k_scaled - codebook| + 5. key_scale = norm / sqrt(D) [stored alongside packed indices] + +SDPA at generation time (Lq == 1): + q_rot = WHT(q) / sqrt(D) + mx.fast.quantized_scaled_dot_product_attention( + q_rot, k_packed, k_scales, v_packed, v_scales, + scale=1/sqrt(D), mode="turbo3", group_size=D) + +The math: scale * key_scale * dot(q_rot, codebook[idx]) + = (1/√D) * (norm/√D) * dot(q_rot, codebook[idx]) + ≈ (1/√D) * dot(q_rot, k_rot) [WHT orthogonality: dot = q·k] +""" + +import math +from typing import Optional + +import mlx.core as mx + +from .cache import _BaseCache + +# N(0,1) Lloyd-Max 3-bit codebook (8 levels) +_CB3 = mx.array( + [-1.7481, -1.0498, -0.5012, -0.1624, 0.1624, 0.5012, 1.0498, 1.7481], + dtype=mx.float32, +) + +# N(0,1) equal-probability 4-bit codebook (16 levels) +_CB4 = mx.array( + [ + -1.9672, + -1.3305, + -1.0130, + -0.7811, + -0.5714, + -0.4053, + -0.2382, + -0.0784, + 0.0784, + 0.2382, + 0.4053, + 0.5714, + 0.7811, + 1.0130, + 1.3305, + 1.9672, + ], + dtype=mx.float32, +) + +_CODEBOOKS: dict[int, mx.array] = {3: _CB3, 4: _CB4} + +# Supported head dimensions (must have a QUANT_SDPA_DISPATCH entry in MLX core) +_SUPPORTED_DIMS: frozenset[int] = frozenset({64, 128, 256}) + +def _hadamard_transform(x: mx.array, scale: float = 1.0) -> mx.array: + """Normalized WHT. Dispatches to the Metal kernel when available. + + Falls back to an iterative butterfly in MLX ops on non-Metal devices. + D must be a power of 2. Result dtype matches x (Metal path) or float32 + (fallback path — caller must cast if needed). + """ + try: + from .turbo_metal import is_available, wht_rotate_metal + if is_available(): + return wht_rotate_metal(x, scale) + except Exception: + pass + + # Fallback: iterative butterfly in MLX ops + *batch, D = x.shape + assert D & (D - 1) == 0, f"head_dim {D} must be a power of 2 for WHT" + h = x.astype(mx.float32) + stride = 1 + while stride < D: + h = h.reshape(*batch, D // (2 * stride), 2, stride) + a, b = h[..., 0, :], h[..., 1, :] + h = mx.concatenate([a + b, a - b], axis=-1).reshape(*batch, D) + stride *= 2 + if scale != 1.0: + h = h * scale + return h + +def _turbo_encode_python( + x: mx.array, + codebook: mx.array, + bits: int, + rotate: bool = True, +) -> tuple[mx.array, mx.array]: + """Pure-Python fallback encode (CPU numpy packing). Used on non-Metal.""" + *batch, D = x.shape + inv_sqrt_d = 1.0 / math.sqrt(D) + + x_f32 = x.astype(mx.float32) + x_preq = _hadamard_transform(x_f32, scale=inv_sqrt_d) if rotate else x_f32 + + norm = mx.sqrt(mx.sum(x_preq * x_preq, axis=-1, keepdims=True)) + x_scaled = x_preq / (norm + 1e-8) * math.sqrt(D) + + diff = x_scaled[..., None] - codebook + indices = mx.argmin(diff * diff, axis=-1).astype(mx.uint32) + + packed = _pack_indices(indices, bits) + scales = (norm * inv_sqrt_d).astype(x.dtype) + return packed, scales + +def _turbo_encode( + x: mx.array, + codebook: mx.array, + bits: int, + rotate: bool = True, +) -> tuple[mx.array, mx.array]: + """Encode a float16/bf16 K or V tensor to TurboQuant packed format. + + Args: + x: Input tensor, shape [B, H, L, D]. + codebook: Lloyd-Max codebook of shape [2**bits]. + bits: Quantization bit-width (3 or 4). + rotate: If True (default, use for K), apply a normalized WHT before + quantization to make the distribution Gaussian, matching the + Lloyd-Max codebook assumption. Set to False for V so that the + kernel output is in the original V space without any inverse + transform on the caller side. + + Returns: + packed: uint32 array of shape [B, H, L, D*bits//32]. + scales: float16/bfloat16 array of shape [B, H, L, 1]. + """ + try: + from .turbo_metal import is_available, turbo_encode_metal + if is_available(): + return turbo_encode_metal(x, codebook, bits, rotate) + except Exception: + pass + return _turbo_encode_python(x, codebook, bits, rotate) + +def _pack_indices(indices: mx.array, bits: int) -> mx.array: + """Pack bit-width indices into uint32 words using PackReader layout. + + For bits=3: 8 indices → 3 bytes → D*3//32 uint32 words per vector. + For bits=4: 8 indices → 1 uint32 word (lower nibble first) → D//8 words. + + Uses numpy for reliable bit manipulation, then converts back to MLX. + The packing happens once per generation step (O(Lk × D)), not in the + hot inner loop, so numpy overhead is acceptable. + """ + import numpy as np + + # Materialise indices to numpy (this triggers MLX evaluation) + indices_np = np.array(indices, copy=False).astype(np.int64) + *batch, D = indices_np.shape + + if bits == 4: + # 8 × 4-bit per uint32, lower nibble first + n_u32 = D // 8 + flat = indices_np.reshape(-1, D) + n_tok = flat.shape[0] + u32 = np.zeros((n_tok, n_u32), dtype=np.uint32) + for i in range(8): + u32 |= (flat[:, i::8].astype(np.uint32) & 0xF) << (4 * i) + return mx.array(u32.reshape(*batch, n_u32)) + + elif bits == 3: + # 8 × 3-bit per pack → 3 bytes; D*3//32 uint32 words per vector. + # Vectorised: reshape into [n_tok, n_packs, 8], shift each column by + # its bit offset, sum → packed_24 [n_tok, n_packs] with 24 bits each. + n_packs = D // 8 + n_u32 = D * 3 // 32 + flat = indices_np.reshape(-1, D) + n_tok = flat.shape[0] + + shifts = np.array([0, 3, 6, 9, 12, 15, 18, 21], dtype=np.int64) + idx_r = flat.reshape(n_tok, n_packs, 8).astype(np.int64) + packed_24 = np.sum(idx_r << shifts, axis=-1) # [n_tok, n_packs] + + b0 = (packed_24 & 0xFF).astype(np.uint8) + b1 = ((packed_24 >> 8) & 0xFF).astype(np.uint8) + b2 = ((packed_24 >> 16) & 0xFF).astype(np.uint8) + bytes_arr = np.stack([b0, b1, b2], axis=-1).reshape(n_tok, n_packs * 3) + + u32 = np.frombuffer(bytes_arr.tobytes(), dtype=" 1 or unsupported head_dim): stores float16, + returns float16 so the standard SDPA path is used. + - **Generation** (Lq == 1, head_dim in {64, 128, 256}): compresses new + tokens and returns (packed, scales) tuples so + ``mx.fast.quantized_scaled_dot_product_attention`` is used. + + On the first generation step after prefill, all cached float16 tokens + are re-compressed to TurboQuant format (one-time amortised cost). + + Attributes: + bits: Quantization bit-width (3 or 4). + mode: Mode string for mx.fast.quantized_scaled_dot_product_attention. + group_size: Equals head_dim (set on first call). + """ + + step = 256 # allocation step size + + def __init__(self, bits: int = 3) -> None: + if bits not in (3, 4): + raise ValueError(f"TurboQuantKVCache: bits must be 3 or 4, got {bits}") + self.bits = bits + self.mode = f"turbo{bits}" + + # Float16 buffer (prefill phase) + self._keys_f16: Optional[mx.array] = None + self._values_f16: Optional[mx.array] = None + + # Compressed buffer (generation phase): (packed_uint32, float16_scales) + self._keys_tq: Optional[tuple[mx.array, mx.array]] = None + self._values_tq: Optional[tuple[mx.array, mx.array]] = None + + self.offset: int = 0 + self.group_size: Optional[int] = None # set to head_dim on first call + self._in_generation: bool = False + + def _codebook(self) -> mx.array: + return _CODEBOOKS[self.bits] + + def _compress( + self, keys: mx.array, values: mx.array + ) -> tuple[ + tuple[mx.array, mx.array], + tuple[mx.array, mx.array], + ]: + cb = self._codebook() + k_packed, k_scales = _turbo_encode(keys, cb, self.bits, rotate=True) + v_packed, v_scales = _turbo_encode(values, cb, self.bits, rotate=False) + return (k_packed, k_scales), (v_packed, v_scales) + + def _init_tq_buffers( + self, + B: int, + H: int, + D: int, + dtype: mx.Dtype, + n_steps: int, + ) -> None: + """Allocate pre-sized TurboQuant buffers.""" + n_u32 = D * self.bits // 32 + shape = (B, H, n_steps) + self._keys_tq = ( + mx.zeros((*shape, n_u32), dtype=mx.uint32), + mx.zeros((*shape, 1), dtype=dtype), + ) + self._values_tq = ( + mx.zeros((*shape, n_u32), dtype=mx.uint32), + mx.zeros((*shape, 1), dtype=dtype), + ) + + def _expand_tq_buffers(self) -> None: + """Grow TurboQuant buffers using doubling (amortised O(1) reallocations).""" + B, H, cur_len, n_u32 = self._keys_tq[0].shape + n_extra = max(self.step, cur_len) # double: grow by at least current size + pad_d = mx.zeros((B, H, n_extra, n_u32), dtype=mx.uint32) + pad_s = mx.zeros((B, H, n_extra, 1), dtype=self._keys_tq[1].dtype) + + self._keys_tq = ( + mx.concatenate([self._keys_tq[0], pad_d], axis=-2), + mx.concatenate([self._keys_tq[1], pad_s], axis=-2), + ) + self._values_tq = ( + mx.concatenate([self._values_tq[0], pad_d], axis=-2), + mx.concatenate([self._values_tq[1], pad_s], axis=-2), + ) + + def _transition_to_generation(self, D: int) -> None: + """Re-compress all float16 prefill tokens to TurboQuant format.""" + k_f = self._keys_f16[..., : self.offset, :] + v_f = self._values_f16[..., : self.offset, :] + B, H, L, _ = k_f.shape + (k_packed, k_scales), (v_packed, v_scales) = self._compress(k_f, v_f) + + # Allocate with extra room for generation + n_alloc = (L + self.step - 1) // self.step * self.step + self.step + self._init_tq_buffers(B, H, D, k_f.dtype, n_alloc) + + self._keys_tq[0][..., :L, :] = k_packed + self._keys_tq[1][..., :L, :] = k_scales + self._values_tq[0][..., :L, :] = v_packed + self._values_tq[1][..., :L, :] = v_scales + + # Free float16 buffers + self._keys_f16 = None + self._values_f16 = None + self._in_generation = True + + def update_and_fetch( + self, keys: mx.array, values: mx.array + ) -> tuple: + """Append keys/values to the cache and return all cached data. + + During prefill (num_steps > 1) or unsupported head_dim: + Returns (keys_f16, values_f16). + During generation (num_steps == 1, head_dim in {64, 128, 256}): + Returns ((k_packed, k_scales), (v_packed, v_scales)). + """ + B, H, num_steps, D = keys.shape + prev = self.offset + self.offset += num_steps + + use_turbo = num_steps == 1 and D in _SUPPORTED_DIMS + + if not use_turbo: + new_size = (prev + num_steps + self.step - 1) // self.step * self.step + + if self._keys_f16 is None: + self._keys_f16 = mx.zeros((B, H, new_size, D), dtype=keys.dtype) + self._values_f16 = mx.zeros((B, H, new_size, D), dtype=values.dtype) + elif prev + num_steps > self._keys_f16.shape[-2]: + pad_k = mx.zeros( + (B, H, new_size - self._keys_f16.shape[-2], D), dtype=keys.dtype + ) + pad_v = mx.zeros( + (B, H, new_size - self._values_f16.shape[-2], D), + dtype=values.dtype, + ) + self._keys_f16 = mx.concatenate([self._keys_f16, pad_k], axis=-2) + self._values_f16 = mx.concatenate([self._values_f16, pad_v], axis=-2) + + self._keys_f16[..., prev : self.offset, :] = keys + self._values_f16[..., prev : self.offset, :] = values + + return ( + self._keys_f16[..., : self.offset, :], + self._values_f16[..., : self.offset, :], + ) + + self.group_size = D + + if not self._in_generation: + # First generation step: re-compress any stored prefill tokens + if self._keys_f16 is not None: + self._transition_to_generation(D) + else: + # No prefill: jump straight to generation + self._in_generation = True + + # Grow buffers if needed + if self._keys_tq is None: + self._init_tq_buffers(B, H, D, keys.dtype, self.step) + elif prev + num_steps > self._keys_tq[0].shape[-2]: + self._expand_tq_buffers() + + # Compress the new token + (k_packed, k_scales), (v_packed, v_scales) = self._compress(keys, values) + + self._keys_tq[0][..., prev : self.offset, :] = k_packed + self._keys_tq[1][..., prev : self.offset, :] = k_scales + self._values_tq[0][..., prev : self.offset, :] = v_packed + self._values_tq[1][..., prev : self.offset, :] = v_scales + + return ( + ( + self._keys_tq[0][..., : self.offset, :], + self._keys_tq[1][..., : self.offset, :], + ), + ( + self._values_tq[0][..., : self.offset, :], + self._values_tq[1][..., : self.offset, :], + ), + ) + + def is_trimmable(self) -> bool: + return False + + def empty(self) -> bool: + return self.offset == 0 + + def size(self) -> int: + return self.offset + + @property + def nbytes(self) -> int: + total = 0 + if self._keys_f16 is not None: + total += self._keys_f16.nbytes + self._values_f16.nbytes + if self._keys_tq is not None: + for arr in (*self._keys_tq, *self._values_tq): + total += arr.nbytes + return total + + @property + def state(self): + # Serialisation not supported; included for interface compliance. + return [] + + @state.setter + def state(self, v): + if v: + raise ValueError("TurboQuantKVCache does not support state loading.") + + @property + def meta_state(self): + return str(self.bits) + + @meta_state.setter + def meta_state(self, v): + self.bits = int(v) + self.mode = f"turbo{self.bits}" diff --git a/mlx_lm/models/turbo_metal.py b/mlx_lm/models/turbo_metal.py new file mode 100644 index 000000000..4d689c3ff --- /dev/null +++ b/mlx_lm/models/turbo_metal.py @@ -0,0 +1,228 @@ +# Copyright © 2024 Apple Inc. +"""Fused Metal kernels for TurboQuant KV-cache operations. + +Two kernels, same pattern: one thread per token, all intermediate +values in float32 registers, single GPU dispatch. + + turbo_encode_metal — WHT (optional) → L2 norm → codebook → bit pack + wht_rotate_metal — WHT → scale (used to pre-rotate Q before SDPA) +""" + +import math + +import mlx.core as mx + + +_ENCODE_HEADER = r""" +template +inline void turbo_encode_impl( + device const T* x_in, + device uint32_t* packed_out, + device T* scale_out, + device const float* codebook, + uint tok) +{ + constexpr int N_LEVELS = 1 << BITS; + constexpr int N_U32 = D * BITS / 32; + const float SQRT_D = metal::sqrt(float(D)); + const float INV_SQRT_D = 1.0f / SQRT_D; + + device const T* src = x_in + tok * D; + + float buf[D]; + for (int i = 0; i < D; i++) buf[i] = float(src[i]); + + if constexpr (ROTATE) { + for (int stride = 1; stride < D; stride <<= 1) { + for (int i = 0; i < D; i += stride * 2) { + for (int j = 0; j < stride; j++) { + float a = buf[i + j]; + float b = buf[i + j + stride]; + buf[i + j] = a + b; + buf[i + j + stride] = a - b; + } + } + } + for (int i = 0; i < D; i++) buf[i] *= INV_SQRT_D; + } + + float norm2 = 0.0f; + for (int i = 0; i < D; i++) norm2 += buf[i] * buf[i]; + float norm = metal::sqrt(norm2); + + scale_out[tok] = T(norm * INV_SQRT_D); + + float inv_n_s = SQRT_D / (norm + 1e-8f); + + device uchar* out_bytes = (device uchar*)(packed_out + tok * N_U32); + + for (int p = 0; p < D / 8; p++) { + uint val = 0u; + for (int i = 0; i < 8; i++) { + float xs = buf[p * 8 + i] * inv_n_s; + int best = 0; + float bd = (xs - codebook[0]) * (xs - codebook[0]); + for (int k = 1; k < N_LEVELS; k++) { + float d = xs - codebook[k]; + float dd = d * d; + if (dd < bd) { bd = dd; best = k; } + } + val |= (uint(best) & uint(N_LEVELS - 1)) << uint(BITS * i); + } + if constexpr (BITS == 3) { + out_bytes[p * 3] = uchar(val & 0xFFu); + out_bytes[p * 3 + 1] = uchar((val >> 8u) & 0xFFu); + out_bytes[p * 3 + 2] = uchar((val >> 16u) & 0xFFu); + } else { + packed_out[tok * N_U32 + p] = val; + } + } +} +""" + +_ENCODE_SOURCE = r""" +uint tok = thread_position_in_grid.x; +if (tok >= uint(N_tokens[0])) return; +turbo_encode_impl(x, packed, scales, codebook, tok); +""" + +# WHT butterfly in float32 registers; output cast back to T. +# Same precision as the Python _hadamard_transform fallback. +_WHT_HEADER = r""" +template +inline void wht_rotate_impl( + device const T* x_in, + device T* x_out, + float scale, + uint tok) +{ + float buf[D]; + for (int i = 0; i < D; i++) buf[i] = float(x_in[tok * D + i]); + + for (int stride = 1; stride < D; stride <<= 1) { + for (int i = 0; i < D; i += stride * 2) { + for (int j = 0; j < stride; j++) { + float a = buf[i + j]; + float b = buf[i + j + stride]; + buf[i + j] = a + b; + buf[i + j + stride] = a - b; + } + } + } + + for (int i = 0; i < D; i++) x_out[tok * D + i] = T(buf[i] * scale); +} +""" + +_WHT_SOURCE = r""" +uint tok = thread_position_in_grid.x; +if (tok >= uint(N_tokens[0])) return; +wht_rotate_impl(x, y, scale[0], tok); +""" + +_kernels: dict[str, mx.fast.metal_kernel] = {} + + +def _get_encode_kernel() -> mx.fast.metal_kernel: + if "encode" not in _kernels: + _kernels["encode"] = mx.fast.metal_kernel( + name="turbo_encode", + input_names=["x", "codebook", "N_tokens"], + output_names=["packed", "scales"], + header=_ENCODE_HEADER, + source=_ENCODE_SOURCE, + ) + return _kernels["encode"] + + +def _get_wht_kernel() -> mx.fast.metal_kernel: + if "wht" not in _kernels: + _kernels["wht"] = mx.fast.metal_kernel( + name="wht_rotate", + input_names=["x", "scale", "N_tokens"], + output_names=["y"], + header=_WHT_HEADER, + source=_WHT_SOURCE, + ) + return _kernels["wht"] + + +def turbo_encode_metal( + x: mx.array, + codebook: mx.array, + bits: int, + rotate: bool, +) -> tuple[mx.array, mx.array]: + """Encode x to TurboQuant packed format via a single GPU dispatch. + + Args: + x: [*batch, D] float16 or bfloat16. + codebook: [2**bits] float32. + bits: 3 or 4. + rotate: Apply WHT before quantization (True for K, False for V). + + Returns: + packed: [*batch, D*bits//32] uint32. + scales: [*batch, 1] same dtype as x. + """ + *batch, D = x.shape + N_tokens = math.prod(batch) if batch else 1 + n_u32 = D * bits // 32 + # Threadgroup size: keep register pressure below spilling threshold. + # Each thread holds float buf[D]; larger D → fewer threads per group. + # Formula: max(32, 8192 // D) gives {64:128, 128:64, 256:32}. + tg_size = min(N_tokens, max(32, 8192 // D)) + + packed_flat, scales_flat = _get_encode_kernel()( + inputs=[ + x.reshape(N_tokens, D), + codebook.astype(mx.float32), + mx.array([N_tokens], dtype=mx.int32), + ], + template=[("T", x.dtype), ("D", D), ("BITS", bits), ("ROTATE", rotate)], + output_shapes=[(N_tokens, n_u32), (N_tokens, 1)], + output_dtypes=[mx.uint32, x.dtype], + grid=(N_tokens, 1, 1), + threadgroup=(tg_size, 1, 1), + stream=mx.gpu, + ) + return packed_flat.reshape(*batch, n_u32), scales_flat.reshape(*batch, 1) + + +def wht_rotate_metal(x: mx.array, scale: float) -> mx.array: + """Apply a normalized WHT via a single GPU dispatch. + + Runs the butterfly in float32 registers and casts back to x.dtype. + Replaces the per-stage MLX op chain in _hadamard_transform — useful for + models with many attention layers (40+ WHT calls per generation step). + + Args: + x: [*batch, D], any float dtype. + scale: Multiplied into every output element (typically 1/sqrt(D)). + + Returns: + [*batch, D] same dtype as x. + """ + *batch, D = x.shape + N_tokens = math.prod(batch) if batch else 1 + tg_size = min(N_tokens, max(32, 8192 // D)) + + (y_flat,) = _get_wht_kernel()( + inputs=[ + x.reshape(N_tokens, D), + mx.array([scale], dtype=mx.float32), + mx.array([N_tokens], dtype=mx.int32), + ], + template=[("T", x.dtype), ("D", D)], + output_shapes=[(N_tokens, D)], + output_dtypes=[x.dtype], + grid=(N_tokens, 1, 1), + threadgroup=(tg_size, 1, 1), + stream=mx.gpu, + ) + return y_flat.reshape(*batch, D) + + +def is_available() -> bool: + """Return True if Metal kernels can be used.""" + return mx.metal.is_available() diff --git a/tests/test_turbo_cache.py b/tests/test_turbo_cache.py new file mode 100644 index 000000000..864aa8e61 --- /dev/null +++ b/tests/test_turbo_cache.py @@ -0,0 +1,382 @@ +# Copyright © 2024 Apple Inc. +"""Integration tests for TurboQuantKVCache. + +Run with: + python -m pytest tests/test_turbo_cache.py -v + +These tests do NOT require a GPU: shape/API checks run on CPU. +The numerical accuracy test is skipped on CPU (TurboQuant SDPA needs Metal). +""" + +import math +import unittest + +import mlx.core as mx +import mlx.nn as nn + +from mlx_lm.models.turbo_cache import ( + TurboQuantKVCache, + _hadamard_transform, + _turbo_encode, +) +from mlx_lm.models.cache import make_turbo_cache, KVCache + + +# ── WHT correctness ─────────────────────────────────────────────────────────── + + +class TestHadamardTransform(unittest.TestCase): + def test_orthogonality_d64(self): + """WHT * WHT^T = D * I (before normalization).""" + D = 64 + x = mx.random.normal(shape=(1, 1, 1, D)) + # Apply twice without normalization: should get D * x + h = _hadamard_transform(x, scale=1.0) + hh = _hadamard_transform(h, scale=1.0) + mx.eval(hh) + self.assertTrue(mx.allclose(hh, x * D, atol=1e-3).item()) + + def test_normalized_dot_product(self): + """dot(WHT(a)/√D, WHT(b)/√D) == dot(a, b) (isometry).""" + D = 128 + a = mx.random.normal(shape=(D,)) + b = mx.random.normal(shape=(D,)) + inv_d = 1.0 / math.sqrt(D) + a_rot = _hadamard_transform(a, scale=inv_d) + b_rot = _hadamard_transform(b, scale=inv_d) + dot_rot = float(mx.sum(a_rot * b_rot).item()) + dot_orig = float(mx.sum(a * b).item()) + mx.eval(a_rot, b_rot) + self.assertAlmostEqual(dot_rot, dot_orig, places=2) + + +# ── Encoding shape checks ───────────────────────────────────────────────────── + + +class TestTurboEncode(unittest.TestCase): + def _check_encode(self, D: int, bits: int): + from mlx_lm.models.turbo_cache import _CODEBOOKS + + B, H, L = 1, 2, 8 + x = mx.random.normal(shape=(B, H, L, D)) + cb = _CODEBOOKS[bits] + packed, scales = _turbo_encode(x, cb, bits) + mx.eval(packed, scales) + + expected_u32 = D * bits // 32 + self.assertEqual(packed.shape, (B, H, L, expected_u32)) + self.assertEqual(packed.dtype, mx.uint32) + self.assertEqual(scales.shape, (B, H, L, 1)) + + def test_encode_3bit_d64(self): + self._check_encode(64, 3) + + def test_encode_3bit_d128(self): + self._check_encode(128, 3) + + def test_encode_4bit_d64(self): + self._check_encode(64, 4) + + def test_encode_4bit_d128(self): + self._check_encode(128, 4) + + +# ── TurboQuantKVCache API ───────────────────────────────────────────────────── + + +class TestTurboQuantKVCache(unittest.TestCase): + def _make_kv(self, B, H, L, D): + k = 0.1 * mx.random.normal(shape=(B, H, L, D)) + v = 0.1 * mx.random.normal(shape=(B, H, L, D)) + return k, v + + def test_prefill_returns_float(self): + """During prefill (L > 1), update_and_fetch returns float16.""" + c = TurboQuantKVCache(bits=3) + k, v = self._make_kv(1, 2, 16, 64) + rk, rv = c.update_and_fetch(k, v) + mx.eval(rk, rv) + self.assertIsInstance(rk, mx.array) + self.assertIsInstance(rv, mx.array) + self.assertEqual(rk.shape[-2], 16) + + def test_generation_returns_tuple(self): + """During generation (L == 1, D in {64,128}), returns (packed, scales).""" + c = TurboQuantKVCache(bits=3) + k, v = self._make_kv(1, 2, 1, 64) + rk, rv = c.update_and_fetch(k, v) + mx.eval(rk[0], rk[1], rv[0], rv[1]) + self.assertIsInstance(rk, tuple) + self.assertEqual(len(rk), 2) # (packed, scales) + self.assertEqual(rk[0].dtype, mx.uint32) + + def test_prefill_then_generation(self): + """Prefill then generation: transition compresses float16 tokens.""" + c = TurboQuantKVCache(bits=3) + B, H, D = 1, 2, 64 + + # Prefill 32 tokens + kp, vp = self._make_kv(B, H, 32, D) + c.update_and_fetch(kp, vp) + self.assertFalse(c._in_generation) + + # Generation step: triggers transition + kg, vg = self._make_kv(B, H, 1, D) + rk, rv = c.update_and_fetch(kg, vg) + mx.eval(rk[0], rk[1]) + self.assertTrue(c._in_generation) + # Cache now holds 33 tokens + self.assertEqual(rk[0].shape[-2], 33) + self.assertEqual(c.offset, 33) + + def test_unsupported_dim_falls_back(self): + """head_dim not in {64,128} returns float (graceful fallback).""" + c = TurboQuantKVCache(bits=3) + k, v = self._make_kv(1, 2, 1, 32) # D=32 unsupported + rk, rv = c.update_and_fetch(k, v) + mx.eval(rk, rv) + # Should return float, not tuple + self.assertIsInstance(rk, mx.array) + + def test_offset_tracking(self): + c = TurboQuantKVCache(bits=4) + for _ in range(5): + k, v = self._make_kv(1, 1, 1, 64) + c.update_and_fetch(k, v) + self.assertEqual(c.offset, 5) + + +# ── make_turbo_cache ────────────────────────────────────────────────────────── + + +class TestMakeTurboCache(unittest.TestCase): + def _make_mock_model(self, n_attn: int, n_linear: int): + """Simple mock with mixed attention and linear-attention layers.""" + + class AttnLayer(nn.Module): + is_linear = False + + class LinearLayer(nn.Module): + is_linear = True + + class MockModel(nn.Module): + def __init__(self): + super().__init__() + # Interleave: 3 linear + 1 attn pattern + self.layers = [] + for _ in range(n_attn): + for _ in range(n_linear): + self.layers.append(LinearLayer()) + self.layers.append(AttnLayer()) + + def make_cache(self): + from mlx_lm.models.cache import ArraysCache + + return [ + ArraysCache(size=2) if l.is_linear else KVCache() + for l in self.layers + ] + + return MockModel() + + def test_turbo_cache_replaces_kvcache(self): + model = self._make_mock_model(n_attn=4, n_linear=3) + caches = make_turbo_cache(model, bits=3, fp16_layers=0) + turbo_count = sum(isinstance(c, TurboQuantKVCache) for c in caches) + kvcache_count = sum(isinstance(c, KVCache) for c in caches) + self.assertEqual(turbo_count, 4) + self.assertEqual(kvcache_count, 0) + + def test_fp16_layers_kept(self): + """First and last fp16_layers attention layers remain as KVCache.""" + model = self._make_mock_model(n_attn=6, n_linear=3) + caches = make_turbo_cache(model, bits=3, fp16_layers=1) + turbo_count = sum(isinstance(c, TurboQuantKVCache) for c in caches) + kvcache_count = sum(isinstance(c, KVCache) for c in caches) + self.assertEqual(kvcache_count, 2) # first + last kept + self.assertEqual(turbo_count, 4) + + def test_fp16_layers_exact_boundary(self): + """When fp16_layers*2 == n_attn, no layers remain as turbo.""" + # 4 attention layers, fp16_layers=2 → first 2 + last 2 = all 4 are kept fp16 + model = self._make_mock_model(n_attn=4, n_linear=0) + caches = make_turbo_cache(model, bits=3, fp16_layers=2) + turbo_count = sum(isinstance(c, TurboQuantKVCache) for c in caches) + kvcache_count = sum(isinstance(c, KVCache) for c in caches) + self.assertEqual(turbo_count, 0) + self.assertEqual(kvcache_count, 4) + + def test_fp16_layers_zero(self): + """fp16_layers=0 means all attention layers are compressed (no fp16 kept).""" + model = self._make_mock_model(n_attn=4, n_linear=0) + caches = make_turbo_cache(model, bits=3, fp16_layers=0) + turbo_count = sum(isinstance(c, TurboQuantKVCache) for c in caches) + self.assertEqual(turbo_count, 4) + + +# ── Numerical accuracy (GPU only) ───────────────────────────────────────────── + + +@unittest.skipUnless(mx.metal.is_available(), "TurboQuant SDPA requires Metal GPU") +class TestTurboNumericalAccuracy(unittest.TestCase): + """Compare TurboQuant SDPA output against float16 reference. + + The tolerance is deliberately loose (3e-2) since we are testing a lossy + compression scheme, not bit-exact arithmetic. + """ + + def _run_turbo_sdpa(self, B, Hq, Hkv, Lk, D, bits): + import math + + from mlx_lm.models.turbo_cache import _CODEBOOKS, _hadamard_transform, _turbo_encode + + mx.random.seed(0) + Lq = 1 + q = 0.1 * mx.random.normal(shape=(B, Hq, Lq, D)) + k = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + v = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)) + scale = 1.0 / math.sqrt(D) + + cb = _CODEBOOKS[bits] + # K is WHT-rotated (for score computation); V is NOT rotated + # so the kernel output is directly in the original V space. + k_packed, k_scales = _turbo_encode(k, cb, bits, rotate=True) + v_packed, v_scales = _turbo_encode(v, cb, bits, rotate=False) + + # Q is pre-rotated in Python (float32 precision) before the kernel. + # The kernel receives Q_rot and K_rot (from encoding); the score + # dot(Q_rot, K_rot) = dot(Q, K) by WHT isometry. + q_rot = _hadamard_transform(q.astype(mx.float32), scale=1.0 / math.sqrt(D)) + q_rot = q_rot.astype(q.dtype) + + out_turbo = mx.fast.quantized_scaled_dot_product_attention( + q_rot, + k_packed, + k_scales, + v_packed, + v_scales, + scale=scale, + mode=f"turbo{bits}", + group_size=D, + ) + mx.eval(out_turbo) + + # Reference: float SDPA with WHT-rotated Q/K and original V. + k_rot = _hadamard_transform(k.astype(mx.float32), scale=1.0 / math.sqrt(D)) + k_rot = k_rot.astype(k.dtype) + ref = mx.fast.scaled_dot_product_attention(q_rot, k_rot, v, scale=scale) + mx.eval(ref) + + return out_turbo, ref + + def test_accuracy_3bit_d64(self): + out, ref = self._run_turbo_sdpa(1, 4, 1, 64, 64, 3) + err = float(mx.abs(out - ref).max().item()) + self.assertLess(err, 3e-2, f"3-bit D=64 max error {err:.4f} > 3e-2") + + def test_accuracy_3bit_d128(self): + # GQA=4: matches Qwen3-14B / Qwen3.6-35B-A3B config + out, ref = self._run_turbo_sdpa(1, 8, 2, 128, 128, 3) + err = float(mx.abs(out - ref).max().item()) + self.assertLess(err, 3e-2, f"3-bit D=128 max error {err:.4f} > 3e-2") + + def test_accuracy_3bit_d256_gqa6(self): + # GQA=6: matches Qwen3.6-27B (24Q/4KV heads, head_dim=256) + out, ref = self._run_turbo_sdpa(1, 24, 4, 64, 256, 3) + err = float(mx.abs(out - ref).max().item()) + self.assertLess(err, 3e-2, f"3-bit D=256 max error {err:.4f} > 3e-2") + + def test_accuracy_4bit_d128(self): + out, ref = self._run_turbo_sdpa(1, 4, 1, 128, 128, 4) + err = float(mx.abs(out - ref).max().item()) + self.assertLess(err, 2e-2, f"4-bit D=128 max error {err:.4f} > 2e-2") + + def test_output_shape(self): + out, ref = self._run_turbo_sdpa(1, 4, 1, 64, 64, 3) + self.assertEqual(out.shape, ref.shape) + + +@unittest.skipUnless(mx.metal.is_available(), "TurboQuant SDPA requires Metal GPU") +class TestQwen36_27B(unittest.TestCase): + """End-to-end smoke test matching Qwen3.6-27B's exact attention config. + + Architecture: head_dim=256, 24Q/4KV heads (GQA=6), 64 layers. + This test does NOT require the actual weights — it exercises the Metal + kernel with the same tensor shapes and dtypes used during generation. + """ + + MODEL_ID = "mlx-community/Qwen3.6-27B-4bit" + + @classmethod + def setUpClass(cls): + import os + + if not os.path.exists( + os.path.expanduser( + "~/.cache/huggingface/hub/models--mlx-community--Qwen3.6-27B-4bit" + ) + ): + raise unittest.SkipTest("Qwen3.6-27B-4bit not downloaded") + + def test_turbo3_qwen36_27b_shapes(self): + """Kernel correctness at Qwen3.6-27B attention geometry.""" + import math + + from mlx_lm.models.turbo_cache import ( + _CODEBOOKS, + _hadamard_transform, + _turbo_encode, + ) + + B, Hq, Hkv, Lk, D, bits = 1, 24, 4, 128, 256, 3 + scale = 1.0 / math.sqrt(D) + cb = _CODEBOOKS[bits] + mx.random.seed(0) + + q = 0.1 * mx.random.normal(shape=(B, Hq, 1, D)).astype(mx.bfloat16) + k = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)).astype(mx.bfloat16) + v = 0.1 * mx.random.normal(shape=(B, Hkv, Lk, D)).astype(mx.bfloat16) + + k_packed, k_scales = _turbo_encode(k, cb, bits, rotate=True) + v_packed, v_scales = _turbo_encode(v, cb, bits, rotate=False) + + q_rot = _hadamard_transform(q.astype(mx.float32), scale=1 / math.sqrt(D)) + q_rot = q_rot.astype(q.dtype) + k_rot = _hadamard_transform(k.astype(mx.float32), scale=1 / math.sqrt(D)) + k_rot = k_rot.astype(k.dtype) + + ref = mx.fast.scaled_dot_product_attention(q_rot, k_rot, v, scale=scale) + out = mx.fast.quantized_scaled_dot_product_attention( + q_rot, k_packed, k_scales, v_packed, v_scales, + scale=scale, mode="turbo3", group_size=D, + ) + mx.eval(out, ref) + + self.assertEqual(out.shape, (B, Hq, 1, D)) + err = float(mx.abs(out - ref).max().item()) + self.assertLess(err, 5e-2, f"D=256 gqa=6 max error {err:.4f} > 5e-2") + + def test_turbo3_qwen36_27b_generation(self): + """Full generation loop with TurboQuantKVCache on Qwen3.6-27B.""" + from mlx_lm import load + from mlx_lm.generate import generate_step + + model, tokenizer = load(self.MODEL_ID) + mx.eval(model.parameters()) + + tokens = tokenizer.encode("Hello", return_tensors="mlx")[0] + out_tokens = [] + for tok, _ in generate_step( + tokens, model, max_tokens=10, kv_cache_type="turbo3" + ): + t = tok if isinstance(tok, int) else int(tok.item()) + out_tokens.append(t) + mx.eval(tok) + + self.assertEqual(len(out_tokens), 10) + text = tokenizer.decode(out_tokens) + self.assertGreater(len(text.strip()), 0) + + +if __name__ == "__main__": + unittest.main()