Add olmo hybrid#1315
Open
cmurray1105 wants to merge 4 commits into
Open
Conversation
First MLX implementation of allenai/OLMo-Hybrid-7B — a 32-layer model
alternating 3:1 GatedDeltaNet (GDN) / full-Attention layers.
GDN is a sequential recurrence with fixed O(1) memory cost at any context
length, combining the delta rule (Widrow & Hoff 1960) with Mamba-style
learned forget gates (Yang et al., arxiv 2412.06464).
Key implementation notes:
- GatedDeltaNet: full sequential recurrence with chunked mx.eval() (CHUNK=32)
to cap peak memory regardless of sequence length
- ShortConv: causal depthwise conv1d via mx.conv1d with groups=C; carries
real prior-token context during decode (avoids zero-pad corruption)
- Heterogeneous cache: GDNCache (recurrent S matrix + conv context) for GDN
layers, KVCache for attention layers; make_cache() returns the right type
- Two correctness bugs fixed vs naive port:
QK-norm applied to full (H*D) projection before reshape, not per-head
rope_theta = 500_000 (not 10_000 -- OLMo2 large-context default)
HF model: allenai/OLMo-Hybrid-7B (Apache 2.0)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds MLX support for allenai/OLMo-Hybrid-7B — a 32-layer model combining GatedDeltaNet (GDN) recurrent layers with standard full attention in a 3:1 pattern (24 GDN + 8 Attention).
GDN is a fixed-memory alternative to the KV cache: instead of growing linearly with context length, it maintains a constant-size state matrix S regardless of sequence length. This makes 65k token contexts practical on consumer hardware.
Implementation notes:
GatedDeltaNet: sequential delta-rule recurrence with chunkedmx.eval()(every 32 tokens) to prevent graph explosion at long contextsShortConv: causal depthwise conv1d viamx.conv1dwithgroups=C; real prior-token context passed during decode to avoid zero-pad corruption after the first generated tokenGDNCache(fixed-size recurrent state + conv context) for GDN layers,KVCachefor attention layers —make_cache()returns the right type per layer(H×D)projection before reshape — not per-head(D,)— weight shape in the checkpoint is(H*D,)rope_theta = 500_000(OLMo2 large-context default, not10_000)Tested on M4 Max 36GB: 16.3 tok/s prefill, 4.9 tok/s decode (float16). Decode speed reflects GDN's sequential recurrence — a natural follow-up is a parallel prefix scan implementation for faster prefill.