Skip to content

Add olmo hybrid#1315

Open
cmurray1105 wants to merge 4 commits into
ml-explore:mainfrom
cmurray1105:add-olmo-hybrid
Open

Add olmo hybrid#1315
cmurray1105 wants to merge 4 commits into
ml-explore:mainfrom
cmurray1105:add-olmo-hybrid

Conversation

@cmurray1105
Copy link
Copy Markdown

Adds MLX support for allenai/OLMo-Hybrid-7B — a 32-layer model combining GatedDeltaNet (GDN) recurrent layers with standard full attention in a 3:1 pattern (24 GDN + 8 Attention).

GDN is a fixed-memory alternative to the KV cache: instead of growing linearly with context length, it maintains a constant-size state matrix S regardless of sequence length. This makes 65k token contexts practical on consumer hardware.

Implementation notes:

  • GatedDeltaNet: sequential delta-rule recurrence with chunked mx.eval() (every 32 tokens) to prevent graph explosion at long contexts
  • ShortConv: causal depthwise conv1d via mx.conv1d with groups=C; real prior-token context passed during decode to avoid zero-pad corruption after the first generated token
  • Heterogeneous cache: GDNCache (fixed-size recurrent state + conv context) for GDN layers, KVCache for attention layers — make_cache() returns the right type per layer
  • Two correctness bugs fixed vs naive port:
    • QK-norm applied to full (H×D) projection before reshape — not per-head (D,) — weight shape in the checkpoint is (H*D,)
    • rope_theta = 500_000 (OLMo2 large-context default, not 10_000)

Tested on M4 Max 36GB: 16.3 tok/s prefill, 4.9 tok/s decode (float16). Decode speed reflects GDN's sequential recurrence — a natural follow-up is a parallel prefix scan implementation for faster prefill.

First MLX implementation of allenai/OLMo-Hybrid-7B — a 32-layer model
alternating 3:1 GatedDeltaNet (GDN) / full-Attention layers.

GDN is a sequential recurrence with fixed O(1) memory cost at any context
length, combining the delta rule (Widrow & Hoff 1960) with Mamba-style
learned forget gates (Yang et al., arxiv 2412.06464).

Key implementation notes:
- GatedDeltaNet: full sequential recurrence with chunked mx.eval() (CHUNK=32)
  to cap peak memory regardless of sequence length
- ShortConv: causal depthwise conv1d via mx.conv1d with groups=C; carries
  real prior-token context during decode (avoids zero-pad corruption)
- Heterogeneous cache: GDNCache (recurrent S matrix + conv context) for GDN
  layers, KVCache for attention layers; make_cache() returns the right type
- Two correctness bugs fixed vs naive port:
    QK-norm applied to full (H*D) projection before reshape, not per-head
    rope_theta = 500_000 (not 10_000 -- OLMo2 large-context default)

HF model: allenai/OLMo-Hybrid-7B (Apache 2.0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants