Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions mlx_lm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,18 @@ def setup_arg_parser():
type=int,
default=DEFAULT_QUANTIZED_KV_START,
)
parser.add_argument(
"--kv-cache-type",
type=str,
help="KV cache type. Use 'turbo3' or 'turbo4' for TurboQuant compression.",
default=None,
)
parser.add_argument(
"--turbo-fp16-layers",
type=int,
help="Number of attention layers at each end to keep in float16 when using TurboQuant.",
default=0,
)
parser.add_argument(
"--draft-model",
type=str,
Expand Down Expand Up @@ -317,6 +329,8 @@ def generate_step(
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
kv_cache_type: Optional[str] = None,
turbo_fp16_layers: int = 0,
prompt_progress_callback: Optional[Callable[[int, int], None]] = None,
input_embeddings: Optional[mx.array] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
Expand Down Expand Up @@ -369,10 +383,18 @@ def generate_step(

# Create the KV cache for generation
if prompt_cache is None:
prompt_cache = cache.make_prompt_cache(
model,
max_kv_size=max_kv_size,
)
if kv_cache_type in ("turbo3", "turbo4"):
turbo_bits = int(kv_cache_type[-1])
prompt_cache = cache.make_turbo_cache(
model,
bits=turbo_bits,
fp16_layers=turbo_fp16_layers,
)
else:
prompt_cache = cache.make_prompt_cache(
model,
max_kv_size=max_kv_size,
)

prompt_progress_callback = prompt_progress_callback or (lambda *_: None)

Expand Down Expand Up @@ -2081,6 +2103,8 @@ def main():
kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size,
quantized_kv_start=args.quantized_kv_start,
kv_cache_type=args.kv_cache_type,
turbo_fp16_layers=args.turbo_fp16_layers,
draft_model=draft_model,
num_draft_tokens=args.num_draft_tokens,
)
Expand Down
70 changes: 69 additions & 1 deletion mlx_lm/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,59 @@ def quantized_scaled_dot_product_attention(
return out


def _turbo_scaled_dot_product_attention(
queries: mx.array,
q_keys: tuple[mx.array, mx.array],
q_values: tuple[mx.array, mx.array],
scale: float,
mask: Optional[mx.array],
mode: str,
group_size: int,
) -> mx.array:
"""SDPA for TurboQuant-compressed keys/values.

Rotates queries with a normalized WHT before calling the fused kernel.
The kernel handles causal masking, GQA fan-out, online softmax, and
the full attention reduction in one Metal pass.
"""
import math

D = queries.shape[-1]
inv_sqrt_d = 1.0 / math.sqrt(D)

# WHT in float32 to match encoding precision (bfloat16 butterfly accumulates
# enough error to shift softmax peaks on large-scale models).
try:
from .turbo_metal import is_available, wht_rotate_metal
if is_available():
q_rot = wht_rotate_metal(queries, scale=inv_sqrt_d)
else:
raise ImportError
except (ImportError, Exception):
from .turbo_cache import _hadamard_transform
q_rot = _hadamard_transform(queries.astype(mx.float32), scale=inv_sqrt_d)
q_rot = q_rot.astype(queries.dtype)

k_packed, k_scales = q_keys
v_packed, v_scales = q_values

causal = mask == "causal" if isinstance(mask, str) else False
arr_mask = None if (mask is None or isinstance(mask, str)) else mask

return mx.fast.quantized_scaled_dot_product_attention(
q_rot,
k_packed,
k_scales,
v_packed,
v_scales,
scale=scale,
mask=arr_mask,
mode=mode,
group_size=group_size,
causal=causal,
)


def scaled_dot_product_attention(
queries,
keys,
Expand All @@ -114,7 +167,22 @@ def scaled_dot_product_attention(
mask: Optional[mx.array],
sinks: Optional[mx.array] = None,
) -> mx.array:
if hasattr(cache, "bits"):
from .turbo_cache import TurboQuantKVCache

if isinstance(cache, TurboQuantKVCache) and isinstance(keys, tuple):
if sinks is not None:
raise ValueError("TurboQuant SDPA does not support attention sinks.")
return _turbo_scaled_dot_product_attention(
queries,
keys,
values,
scale=scale,
mask=mask,
mode=cache.mode,
group_size=cache.group_size,
)
elif hasattr(cache, "bits") and not isinstance(cache, TurboQuantKVCache):
# Standard QuantizedKVCache (affine/mxfp4/…)
if sinks is not None:
raise ValueError("Quantized SDPA does not support attention sinks.")
return quantized_scaled_dot_product_attention(
Expand Down
48 changes: 48 additions & 0 deletions mlx_lm/models/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,54 @@ def make_prompt_cache(
return [KVCache() for _ in range(num_layers)]


def make_turbo_cache(
model: nn.Module,
bits: int = 3,
fp16_layers: int = 0,
) -> List[Any]:
"""Build a per-layer cache list using TurboQuant for full-attention layers.

Layers that use a linear-attention state (e.g. DeltaNet / SSM) are left
with their default cache type from the model's own ``make_cache()``.
Standard full-attention layers receive :class:`TurboQuantKVCache`.

The first and last ``fp16_layers`` attention layers keep a plain
:class:`KVCache` (float16) to preserve quality at sequence boundaries.

Args:
model: The language model (must expose ``.layers``).
bits: TurboQuant bit-width, 3 or 4. Default: ``3``.
fp16_layers: Number of attention layers at each end to keep in
float16. Default: ``0`` (all attention layers compressed).

Returns:
A list of cache objects, one per model layer.
"""
from .turbo_cache import TurboQuantKVCache

# Start from the model's own cache (handles SSM/DeltaNet layers correctly)
base_caches = make_prompt_cache(model)

# Identify which positions are plain KVCache (full-attention layers)
attn_indices = [
i for i, c in enumerate(base_caches) if isinstance(c, KVCache)
]

# Determine the turbo range (skip first and last fp16_layers).
# `-fp16_layers or None` handles fp16_layers=0: -0 == 0 which truncates
# the slice to empty, so we use None (meaning "to end") instead.
turbo_indices = set(attn_indices[fp16_layers : -fp16_layers or None])

caches = []
for i, c in enumerate(base_caches):
if i in turbo_indices:
caches.append(TurboQuantKVCache(bits=bits))
else:
caches.append(c)

return caches


def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):
"""
Save a pre-computed prompt cache to a file.
Expand Down
Loading