-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Open
Description
Motivation
MLX’s current scaled_dot_product_attention is optimized but not fully IO-aware like FlashAttention. Users have reported slower inference compared to engines with FlashAttention (e.g., llama.cpp with flash-attention).
There is already community work on PagedAttention kernels that show substantial throughput improvements on Metal (e.g., ~77% on Qwen 30B 4-bit)
Request
- Integrate FlashAttention-style or PagedAttention kernels directly into the MLX backend.
- Expose an API that allows the transformer implementation to select Flash/Paged attention if available.
- Ensure compatibility with
- Causal masking
- Quantized keys/values (q4/q6/q8)
- KV cache usage in decode
Benefits
- Significant speedups for long contexts, closing the gap with other optimized engines.
- Better memory throughput and scaling for larger models.
Notes
I’m happy to help with integration details and benchmarking.
Metadata
Metadata
Assignees
Labels
No labels