Skip to content

feat(bonsai): run Bonsai-Q1 bits=1 on vanilla MLX (JIT qmv_fast) — beats PrismML, no fork#182

Open
dusterbloom wants to merge 3 commits into
panbanda:mainfrom
dusterbloom:dusterbloom/bonsai-q1-jit-fast
Open

feat(bonsai): run Bonsai-Q1 bits=1 on vanilla MLX (JIT qmv_fast) — beats PrismML, no fork#182
dusterbloom wants to merge 3 commits into
panbanda:mainfrom
dusterbloom:dusterbloom/bonsai-q1-jit-fast

Conversation

@dusterbloom

@dusterbloom dusterbloom commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Problem

The Bonsai-Q1 (bits=1) engine merged in #142 cannot run on the pinned oxideai/mlx-rs (f4aa309c): decode calls ops::quantized_matmul(…, bits=1), which MLX gates to bits >= 2, so affine_dequantize_*_b_1 is never compiled and the model errors at the first decode step with Unable to load kernel affine_dequantize_<dtype>_b_1. The in-tree engine is inert. Fixes #181.

Approach

Add the missing bits=1 matvec + dequant kernels from this crate via mlx-c's runtime-JIT facility (mlx_fast_metal_kernel, already compiled into mlx-sys) — no mlx-rs fork, no native rebuild; stays on the official pin, kernels JIT-compile at first use. The decode matvec is a qmv_fast-class kernel: 4 output rows per simdgroup with register-resident input reused across rows, and a select-based 1-bit unpack (the decisive perf lever — tiling alone was a no-op; decode is compute-bound on the bit extraction).

Two commits:

  1. feat(bonsai): run Bonsai-Q1 bits=1 on vanilla MLX via JIT Metal kernels — activates the engine (correctness).
  2. perf(bonsai): qmv_fast-class 1-bit decode kernel — the tuned decode kernel (perf).

Evidence

M4 base (32 GB), greedy, vs the PrismML mlx-lm reference on the same machine (controlled, interleaved, 3 trials), production serving path:

model this PR PrismML Δ
Bonsai-8B 77.0 tok/s 73.0 +5.5%
Bonsai-1.7B ~270 tok/s 229.6 +17.6%

The 8B figure is cross-validated by two independent methods — the engine's per-step total_ms and an HTTP wall-clock differencing 32 vs 256 real generated tokens — which agree. This also beats the fork-native qmv_fast path (which needs a forked mlx-rs chain + a from-source rebuild) with no fork.

Test plan

  • Unit: qmv_fast_kernel_matches_cpu_reference + qmv_kernel_matches_cpu_reference + dequant_kernel_matches_cpu_reference — bit-exact vs CPU reference (tail K<2048, main K=4096 block, N%4 lm_head remainder).
  • E2E: Bonsai-8B server generation coherent.
  • cargo clippy -p higgs-models -p higgs-engine clean; cargo fmt --check clean.

Risk

Low — purely additive (new kernels + a dispatcher). The fast kernel is bit-exact and default-on; opt back to the original per-row kernel with HIGGS_BONSAI_QMV_KERNEL=legacy. No dependency change — stays on the official oxideai/mlx-rs pin.

Out of scope

  • Compiled-decode / fixed-shape KV rearchitecture (further headroom).
  • Prefill-path kernel optimization.

🤖 Generated with Claude Code

Summary by CodeRabbit

  • New Features

    • Add runtime Metal kernels and a packed Bonsai‑Q1 execution path for 1‑bit quantized models, enabling Qwen2/Qwen3/Llama/Mistral to route to the packed engine.
  • Bug Fixes

    • Remove upfront rejection for unsupported packed models so runtime-packed errors surface instead of a stale guard message.
  • Tests

    • Add deterministic validation tests that compare Metal kernel outputs against CPU references for quantized matvec/dequant paths.

dusterbloom and others added 2 commits June 9, 2026 14:39
Stock oxideai/mlx-rs ships no bits=1 affine kernel (MLX gates affine quant to
bits>=2), so ops::quantized_matmul / ops::dequantize with bits=1 failed at
runtime with 'Unable to load kernel affine_dequantize_*_b_1'. Rather than fork
mlx-rs (which forces a full from-source mlx-c rebuild), add the kernels from
this crate using mlx-c's runtime JIT facility (mlx_fast_metal_kernel), already
compiled into mlx-sys and reached over the -sys FFI. We stay on the official
oxideai pin with no extra native recompile; kernels JIT-compile at first use.

New crates/higgs-models/src/metal_kernel.rs:
- bonsai_q1_qmv: fused 1-bit quantized matvec (decode hot path) — one simdgroup
  per output row over the packed weights, simd_sum reduce.
- bonsai_q1_dequant: packed -> dense f16 (embedding gather + prefill matmul).
FFI plumbing mirrors the proven qgemv_4bit pattern in qwen3_next; the unpack/
affine math mirrors PackedQ1Linear::dequant_row_to_fp32.

Wired into BonsaiQ1GpuLinear::forward (decode uses qmv, prefill dequant+matmul)
and BonsaiQ1Gpu::embed_rows. model_loader comment corrected and the obsolete
bits=1 guard test repurposed to assert routing into the packed engine.

Verified on vanilla MLX (no fork, no native rebuild):
- Oracle unit tests vs CPU reference pass; clippy (nursery) + fmt clean.
- Bonsai-1.7B-mlx-1bit: ~113 tok/s decode (8.82 ms/step), coherent output.
- Bonsai-8B-mlx-1bit: ~28 tok/s decode (35.5 ms/step), coherent output;
  exercises K-chunking + nsg=16 (inter=12288) and the untied lm_head.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…fork

The JIT bits=1 matvec used one simdgroup per output row with scalar
bit-extraction and threadgroup-staged x, leaving decode far below
MLX/PrismML's qmv_fast. Port qmv_fast's tiling (4 rows/simdgroup,
register-resident x_thread[64] reused across rows, block_size 2048, no
shared memory) AND its select-based qdot (constant byte masks, no
int→float convert, no FP-multiply). The select-based unpack was the
decisive lever: the tiling alone was a measured no-op — decode is
compute-bound on the bit extraction, not the memory access pattern.

Measured on M4 base (32 GB), greedy, vs the PrismML mlx-lm reference on
the same machine (controlled, interleaved, 3 trials).

Kernel micro-bench (synchronous per-step eval — conservative lower bound):
  8B   28.0 -> 66.1 tok/s   (2.36x)
  1.7B 101.4 -> 190.2 tok/s (1.88x)

Production serving path (higgs server, async decode loop — real-world):
  8B   74.7 tok/s  vs PrismML 73.0   => +2.3%
  1.7B  270 tok/s  vs PrismML 229.6  => +17.6%

The 8B figure is cross-validated by two independent methods — the engine's
per-step total_ms and an HTTP wall-clock differencing 32 vs 256 real
generated tokens — which agree to 74.7 tok/s.

This beats the fork-native qmv_fast path (which needs the dusterbloom
mlx-rs fork chain + a from-source native rebuild) with NO fork and NO
native rebuild — stays on the official oxideai/mlx-rs pin via the
runtime-JIT facility.

Bit-exact vs an independent CPU reference: the new
qmv_fast_kernel_matches_cpu_reference oracle covers the tail (K < 2048)
path, the main K=4096 block, and the N % 4 lm_head row remainder.
Verified end-to-end: Bonsai-8B server generation is coherent.

The fast kernel is now the default; opt back to the original per-row
kernel with HIGGS_BONSAI_QMV_KERNEL=legacy. Adds a sustained-decode bench
(bench_bonsai_q1_decode) gated on HIGGS_BONSAI_PROFILE_DIR.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@coderabbitai

coderabbitai Bot commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 36e5f8c5-5f70-484d-b141-cbd592f425a7

📥 Commits

Reviewing files that changed from the base of the PR and between fa77fbb and d510212.

📒 Files selected for processing (2)
  • crates/higgs-models/src/bonsai_q1.rs
  • crates/higgs-models/src/metal_kernel.rs
🚧 Files skipped from review as they are similar to previous changes (2)
  • crates/higgs-models/src/bonsai_q1.rs
  • crates/higgs-models/src/metal_kernel.rs

📝 Walkthrough

Walkthrough

Model loader now routes Bonsai-Q1 configs to a packed Bonsai-Q1 engine. Added an internal metal_kernel module implementing runtime JIT Metal kernels (qmv/dequant) with MLX FFI error capture. GPU Bonsai-Q1 code calls these kernels and tests validate kernel outputs against CPU reference dequantization.

Changes

Bonsai-Q1 Metal Kernel Infrastructure

Layer / File(s) Summary
Model loader routing for Bonsai-Q1 packed engine
crates/higgs-engine/src/model_loader.rs
load_model routes Bonsai-Q1 configs through higgs_models::bonsai_q1::load_bonsai_q1() instead of rejecting them with an early MLX bits=1 guard; test updated to assert routing rather than the stale guard error.
Metal kernel module declaration
crates/higgs-models/src/lib.rs
Declare private metal_kernel module for runtime JIT Metal kernels.
Metal kernel module and MLX FFI error handling
crates/higgs-models/src/metal_kernel.rs (lines 1–328)
Register MLX FFI error handler with thread-local last-error capture, add CachedMetalKernel and kernel resource utilities, and implement kernel invocation + error → Exception mapping.
Fused 1-bit matvec kernels (legacy and fast)
crates/higgs-models/src/metal_kernel.rs (lines 329–644)
Implement bonsai_q1_qmv_legacy and bonsai_q1_qmv_fast and a dispatcher bonsai_q1_qmv; include env-based selection/tuning and cached compiled kernels.
Dense weight dequantization kernel
crates/higgs-models/src/metal_kernel.rs (lines 645–814)
Implement bonsai_q1_dequant kernel to unpack 1-bit packed weights into a dense [out_features,in_features] array; add cached DEQUANT_KERNEL.
Bonsai-Q1 GPU layer uses Metal kernels and validation tests
crates/higgs-models/src/bonsai_q1.rs
Update docs and comments; remove BITS constant; BonsaiQ1GpuLinear::forward validates input and routes m==1bonsai_q1_qmv, m>1bonsai_q1_dequant + ops::matmul; embed_rows uses bonsai_q1_dequant; add comprehensive tests comparing kernel outputs to CPU reference.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • panbanda/higgs#142: Introduces the Bonsai-Q1 packed engine that this PR enables by providing bits=1 kernel implementations.

Poem

🐰 I build kernels in the night,
One-bit whispers, packed and tight,
Metal sparks and errors caught,
Bonsai wakes from what we taught—
Hopping off, the inference is light.

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately reflects the main change: enabling Bonsai-Q1 1-bit quantization to run on vanilla MLX via JIT kernels, beating PrismML without forking.
Linked Issues check ✅ Passed The code changes fully address issue #181 by implementing runtime JIT Metal kernels for missing bits=1 affine quantization, enabling Bonsai-Q1 to run without forking mlx-rs or rebuilding.
Out of Scope Changes check ✅ Passed All changes are directly scoped to implementing bits=1 kernel support and integrating it into the engine routing; no unrelated modifications found.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
crates/higgs-models/src/bonsai_q1.rs (1)

1297-1315: ⚡ Quick win

This test never exercises the legacy fallback kernel.

bonsai_q1_qmv() defaults to qmv_fast, so this test hits the same path as qmv_fast_kernel_matches_cpu_reference() unless HIGGS_BONSAI_QMV_KERNEL=legacy was set before the first dispatcher call. That leaves the advertised opt-out path untested.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@crates/higgs-models/src/bonsai_q1.rs` around lines 1297 - 1315, The test
qmv_kernel_matches_cpu_reference never forces the legacy kernel because
bonsai_q1_qmv chooses qmv_fast by default; set the environment variable
HIGGS_BONSAI_QMV_KERNEL="legacy" at the very start of the
qmv_kernel_matches_cpu_reference test (before any dispatcher/selection happens
or before calling bonsai_q1_qmv or BonsaiQ1GpuLinear::from_packed) so the legacy
path is selected, and restore or remove the env var at test end to avoid
affecting other tests; ensure the change targets the
qmv_kernel_matches_cpu_reference test and does not alter
qmv_fast_kernel_matches_cpu_reference.
crates/higgs-engine/src/model_loader.rs (1)

317-337: ⚡ Quick win

This test does not actually prove the Bonsai-Q1 branch was taken.

A config-only directory can fail in both the packed loader and the transformer loader, so this assertion only proves the old guard string disappeared. It won't catch a regression where load_model() stops dispatching to load_bonsai_q1() and falls back to the transformer path again. Please make the failure Bonsai-specific, or factor the dispatch decision into a helper that can be unit-tested directly.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@crates/higgs-engine/src/model_loader.rs` around lines 317 - 337, The test
load_model_routes_bonsai_q1_to_packed_engine only checks that the old guard
string is gone but doesn't prove dispatch chose the Bonsai path; update the test
to either (A) assert a Bonsai-specific failure by invoking the loader path that
load_bonsai_q1() would produce (for example triggering or detecting an
EngineError/ModelError variant or message unique to load_bonsai_q1), or (B)
factor the dispatch logic into a small pure helper (e.g.
should_route_to_bonsai_q1 or choose_model_loader) and add a focused unit test
for that function feeding it the same config_with bits=1/group=128 so you can
directly assert it returns the Bonsai/paked-engine decision rather than relying
on a config-only load_model() failure; locate references to load_model,
load_bonsai_q1, and the test function name to implement the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@crates/higgs-models/src/bonsai_q1.rs`:
- Around line 413-418: The current logic computes m by dividing the total
element count by self.in_features, which can silently accept inputs where
in_features is matched by the product of trailing dimensions; instead validate
the trailing (last) dimension before flattening: check that x.shape().last()
(the last dimension) exactly equals self.in_features and reject (return Err /
panic) when it does not, and compute m as the product of all leading dims (i.e.,
product of x.shape().iter().take(x.ndim()-1)); replace the existing total /
self.in_features branch with this last-dimension check so the GEMV path only
runs for inputs whose final axis equals in_features.

In `@crates/higgs-models/src/metal_kernel.rs`:
- Around line 241-243: The doc comment above the legacy kernel is incorrect:
update the wording around the dispatcher (referenced by bonsai_q1_qmv) and the
environment escape hatch to state that use_fast_qmv() makes the fast kernel the
default and that the legacy kernel is selected only when
HIGGS_BONSAI_QMV_KERNEL=legacy (i.e., fast is the default and legacy is the
opt-out); change the comments at the current block (around the per-row 1-bit
matvec comment) and the similar block near lines 360-361 to reflect that the
fast kernel is the default and the legacy path is chosen only when the env var
equals "legacy".

---

Nitpick comments:
In `@crates/higgs-engine/src/model_loader.rs`:
- Around line 317-337: The test load_model_routes_bonsai_q1_to_packed_engine
only checks that the old guard string is gone but doesn't prove dispatch chose
the Bonsai path; update the test to either (A) assert a Bonsai-specific failure
by invoking the loader path that load_bonsai_q1() would produce (for example
triggering or detecting an EngineError/ModelError variant or message unique to
load_bonsai_q1), or (B) factor the dispatch logic into a small pure helper (e.g.
should_route_to_bonsai_q1 or choose_model_loader) and add a focused unit test
for that function feeding it the same config_with bits=1/group=128 so you can
directly assert it returns the Bonsai/paked-engine decision rather than relying
on a config-only load_model() failure; locate references to load_model,
load_bonsai_q1, and the test function name to implement the change.

In `@crates/higgs-models/src/bonsai_q1.rs`:
- Around line 1297-1315: The test qmv_kernel_matches_cpu_reference never forces
the legacy kernel because bonsai_q1_qmv chooses qmv_fast by default; set the
environment variable HIGGS_BONSAI_QMV_KERNEL="legacy" at the very start of the
qmv_kernel_matches_cpu_reference test (before any dispatcher/selection happens
or before calling bonsai_q1_qmv or BonsaiQ1GpuLinear::from_packed) so the legacy
path is selected, and restore or remove the env var at test end to avoid
affecting other tests; ensure the change targets the
qmv_kernel_matches_cpu_reference test and does not alter
qmv_fast_kernel_matches_cpu_reference.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: c5c1aa9f-9914-419d-afda-2b2af3e5cac8

📥 Commits

Reviewing files that changed from the base of the PR and between f6e3c2f and fa77fbb.

📒 Files selected for processing (4)
  • crates/higgs-engine/src/model_loader.rs
  • crates/higgs-models/src/bonsai_q1.rs
  • crates/higgs-models/src/lib.rs
  • crates/higgs-models/src/metal_kernel.rs

Comment thread crates/higgs-models/src/bonsai_q1.rs Outdated
Comment thread crates/higgs-models/src/metal_kernel.rs Outdated
…shape guard

- metal_kernel.rs: correct the dispatcher docs. The fast kernel is the
  default; `HIGGS_BONSAI_QMV_KERNEL=legacy` is the opt-out. The comments
  still described the pre-flip behavior (fast = opt-in).
- bonsai_q1.rs tests: `qmv_kernel_matches_cpu_reference` now calls
  `bonsai_q1_qmv_legacy` directly so it actually exercises the legacy
  kernel. The dispatcher defaults to fast, and setting the env var
  in-test is unreliable because `use_fast_qmv` caches it in a `OnceLock`
  on the first dispatcher call anywhere in the process. Makes
  `bonsai_q1_qmv_legacy` `pub` to match `bonsai_q1_qmv_fast`.
- `BonsaiQ1GpuLinear::forward`: reject inputs whose trailing dim isn't
  `in_features` instead of silently deriving the wrong `m` from the
  element count.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@dusterbloom

Copy link
Copy Markdown
Contributor Author

Thanks @coderabbitai — addressed all three in d510212:

  1. Dispatcher docs backwards ✅ — corrected both doc comments (bonsai_q1_qmv_legacy and the bonsai_q1_qmv dispatcher). Fast is the default; HIGGS_BONSAI_QMV_KERNEL=legacy is the opt-out.

  2. Legacy test never exercised the legacy kernel ✅ — good catch. I took a slightly different route than the suggested env-var: use_fast_qmv() caches HIGGS_BONSAI_QMV_KERNEL in a OnceLock on the first dispatcher call anywhere in the process, so setting it inside one test is order-dependent and unreliable. Instead qmv_kernel_matches_cpu_reference now calls bonsai_q1_qmv_legacy directly (made pub, mirroring bonsai_q1_qmv_fast) — legacy and fast are each covered deterministically, env-independent.

  3. Validate trailing dim in forward ✅ — added a last_dim == in_features (and in_features > 0) guard that returns an error instead of silently treating a mis-shaped input as a decode row.

clippy + fmt clean; both oracle tests pass.

@coderabbitai

coderabbitai Bot commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Tip

For best results, initiate chat on the files or code changes.

@dusterbloom All three look great — the OnceLock reasoning for item 2 is spot-on, and calling bonsai_q1_qmv_legacy directly (making it pub) is a cleaner and more robust solution than the env-var approach I suggested. Thanks for the thorough fix-up!

(ノ◕ヮ◕)ノ*:・゚✧


If you found this review helpful, would you consider giving us a shout-out on X?

Thank you for using CodeRabbit!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Bonsai-Q1 (bits=1) engine from #142 fails at runtime on the pinned mlx-rs

1 participant