feat(bonsai): run Bonsai-Q1 bits=1 on vanilla MLX (JIT qmv_fast) — beats PrismML, no fork#182
feat(bonsai): run Bonsai-Q1 bits=1 on vanilla MLX (JIT qmv_fast) — beats PrismML, no fork#182dusterbloom wants to merge 3 commits into
Conversation
Stock oxideai/mlx-rs ships no bits=1 affine kernel (MLX gates affine quant to bits>=2), so ops::quantized_matmul / ops::dequantize with bits=1 failed at runtime with 'Unable to load kernel affine_dequantize_*_b_1'. Rather than fork mlx-rs (which forces a full from-source mlx-c rebuild), add the kernels from this crate using mlx-c's runtime JIT facility (mlx_fast_metal_kernel), already compiled into mlx-sys and reached over the -sys FFI. We stay on the official oxideai pin with no extra native recompile; kernels JIT-compile at first use. New crates/higgs-models/src/metal_kernel.rs: - bonsai_q1_qmv: fused 1-bit quantized matvec (decode hot path) — one simdgroup per output row over the packed weights, simd_sum reduce. - bonsai_q1_dequant: packed -> dense f16 (embedding gather + prefill matmul). FFI plumbing mirrors the proven qgemv_4bit pattern in qwen3_next; the unpack/ affine math mirrors PackedQ1Linear::dequant_row_to_fp32. Wired into BonsaiQ1GpuLinear::forward (decode uses qmv, prefill dequant+matmul) and BonsaiQ1Gpu::embed_rows. model_loader comment corrected and the obsolete bits=1 guard test repurposed to assert routing into the packed engine. Verified on vanilla MLX (no fork, no native rebuild): - Oracle unit tests vs CPU reference pass; clippy (nursery) + fmt clean. - Bonsai-1.7B-mlx-1bit: ~113 tok/s decode (8.82 ms/step), coherent output. - Bonsai-8B-mlx-1bit: ~28 tok/s decode (35.5 ms/step), coherent output; exercises K-chunking + nsg=16 (inter=12288) and the untied lm_head. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…fork The JIT bits=1 matvec used one simdgroup per output row with scalar bit-extraction and threadgroup-staged x, leaving decode far below MLX/PrismML's qmv_fast. Port qmv_fast's tiling (4 rows/simdgroup, register-resident x_thread[64] reused across rows, block_size 2048, no shared memory) AND its select-based qdot (constant byte masks, no int→float convert, no FP-multiply). The select-based unpack was the decisive lever: the tiling alone was a measured no-op — decode is compute-bound on the bit extraction, not the memory access pattern. Measured on M4 base (32 GB), greedy, vs the PrismML mlx-lm reference on the same machine (controlled, interleaved, 3 trials). Kernel micro-bench (synchronous per-step eval — conservative lower bound): 8B 28.0 -> 66.1 tok/s (2.36x) 1.7B 101.4 -> 190.2 tok/s (1.88x) Production serving path (higgs server, async decode loop — real-world): 8B 74.7 tok/s vs PrismML 73.0 => +2.3% 1.7B 270 tok/s vs PrismML 229.6 => +17.6% The 8B figure is cross-validated by two independent methods — the engine's per-step total_ms and an HTTP wall-clock differencing 32 vs 256 real generated tokens — which agree to 74.7 tok/s. This beats the fork-native qmv_fast path (which needs the dusterbloom mlx-rs fork chain + a from-source native rebuild) with NO fork and NO native rebuild — stays on the official oxideai/mlx-rs pin via the runtime-JIT facility. Bit-exact vs an independent CPU reference: the new qmv_fast_kernel_matches_cpu_reference oracle covers the tail (K < 2048) path, the main K=4096 block, and the N % 4 lm_head row remainder. Verified end-to-end: Bonsai-8B server generation is coherent. The fast kernel is now the default; opt back to the original per-row kernel with HIGGS_BONSAI_QMV_KERNEL=legacy. Adds a sustained-decode bench (bench_bonsai_q1_decode) gated on HIGGS_BONSAI_PROFILE_DIR. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (2)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughModel loader now routes Bonsai-Q1 configs to a packed Bonsai-Q1 engine. Added an internal ChangesBonsai-Q1 Metal Kernel Infrastructure
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
crates/higgs-models/src/bonsai_q1.rs (1)
1297-1315: ⚡ Quick winThis test never exercises the legacy fallback kernel.
bonsai_q1_qmv()defaults toqmv_fast, so this test hits the same path asqmv_fast_kernel_matches_cpu_reference()unlessHIGGS_BONSAI_QMV_KERNEL=legacywas set before the first dispatcher call. That leaves the advertised opt-out path untested.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@crates/higgs-models/src/bonsai_q1.rs` around lines 1297 - 1315, The test qmv_kernel_matches_cpu_reference never forces the legacy kernel because bonsai_q1_qmv chooses qmv_fast by default; set the environment variable HIGGS_BONSAI_QMV_KERNEL="legacy" at the very start of the qmv_kernel_matches_cpu_reference test (before any dispatcher/selection happens or before calling bonsai_q1_qmv or BonsaiQ1GpuLinear::from_packed) so the legacy path is selected, and restore or remove the env var at test end to avoid affecting other tests; ensure the change targets the qmv_kernel_matches_cpu_reference test and does not alter qmv_fast_kernel_matches_cpu_reference.crates/higgs-engine/src/model_loader.rs (1)
317-337: ⚡ Quick winThis test does not actually prove the Bonsai-Q1 branch was taken.
A config-only directory can fail in both the packed loader and the transformer loader, so this assertion only proves the old guard string disappeared. It won't catch a regression where
load_model()stops dispatching toload_bonsai_q1()and falls back to the transformer path again. Please make the failure Bonsai-specific, or factor the dispatch decision into a helper that can be unit-tested directly.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@crates/higgs-engine/src/model_loader.rs` around lines 317 - 337, The test load_model_routes_bonsai_q1_to_packed_engine only checks that the old guard string is gone but doesn't prove dispatch chose the Bonsai path; update the test to either (A) assert a Bonsai-specific failure by invoking the loader path that load_bonsai_q1() would produce (for example triggering or detecting an EngineError/ModelError variant or message unique to load_bonsai_q1), or (B) factor the dispatch logic into a small pure helper (e.g. should_route_to_bonsai_q1 or choose_model_loader) and add a focused unit test for that function feeding it the same config_with bits=1/group=128 so you can directly assert it returns the Bonsai/paked-engine decision rather than relying on a config-only load_model() failure; locate references to load_model, load_bonsai_q1, and the test function name to implement the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@crates/higgs-models/src/bonsai_q1.rs`:
- Around line 413-418: The current logic computes m by dividing the total
element count by self.in_features, which can silently accept inputs where
in_features is matched by the product of trailing dimensions; instead validate
the trailing (last) dimension before flattening: check that x.shape().last()
(the last dimension) exactly equals self.in_features and reject (return Err /
panic) when it does not, and compute m as the product of all leading dims (i.e.,
product of x.shape().iter().take(x.ndim()-1)); replace the existing total /
self.in_features branch with this last-dimension check so the GEMV path only
runs for inputs whose final axis equals in_features.
In `@crates/higgs-models/src/metal_kernel.rs`:
- Around line 241-243: The doc comment above the legacy kernel is incorrect:
update the wording around the dispatcher (referenced by bonsai_q1_qmv) and the
environment escape hatch to state that use_fast_qmv() makes the fast kernel the
default and that the legacy kernel is selected only when
HIGGS_BONSAI_QMV_KERNEL=legacy (i.e., fast is the default and legacy is the
opt-out); change the comments at the current block (around the per-row 1-bit
matvec comment) and the similar block near lines 360-361 to reflect that the
fast kernel is the default and the legacy path is chosen only when the env var
equals "legacy".
---
Nitpick comments:
In `@crates/higgs-engine/src/model_loader.rs`:
- Around line 317-337: The test load_model_routes_bonsai_q1_to_packed_engine
only checks that the old guard string is gone but doesn't prove dispatch chose
the Bonsai path; update the test to either (A) assert a Bonsai-specific failure
by invoking the loader path that load_bonsai_q1() would produce (for example
triggering or detecting an EngineError/ModelError variant or message unique to
load_bonsai_q1), or (B) factor the dispatch logic into a small pure helper (e.g.
should_route_to_bonsai_q1 or choose_model_loader) and add a focused unit test
for that function feeding it the same config_with bits=1/group=128 so you can
directly assert it returns the Bonsai/paked-engine decision rather than relying
on a config-only load_model() failure; locate references to load_model,
load_bonsai_q1, and the test function name to implement the change.
In `@crates/higgs-models/src/bonsai_q1.rs`:
- Around line 1297-1315: The test qmv_kernel_matches_cpu_reference never forces
the legacy kernel because bonsai_q1_qmv chooses qmv_fast by default; set the
environment variable HIGGS_BONSAI_QMV_KERNEL="legacy" at the very start of the
qmv_kernel_matches_cpu_reference test (before any dispatcher/selection happens
or before calling bonsai_q1_qmv or BonsaiQ1GpuLinear::from_packed) so the legacy
path is selected, and restore or remove the env var at test end to avoid
affecting other tests; ensure the change targets the
qmv_kernel_matches_cpu_reference test and does not alter
qmv_fast_kernel_matches_cpu_reference.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: c5c1aa9f-9914-419d-afda-2b2af3e5cac8
📒 Files selected for processing (4)
crates/higgs-engine/src/model_loader.rscrates/higgs-models/src/bonsai_q1.rscrates/higgs-models/src/lib.rscrates/higgs-models/src/metal_kernel.rs
…shape guard - metal_kernel.rs: correct the dispatcher docs. The fast kernel is the default; `HIGGS_BONSAI_QMV_KERNEL=legacy` is the opt-out. The comments still described the pre-flip behavior (fast = opt-in). - bonsai_q1.rs tests: `qmv_kernel_matches_cpu_reference` now calls `bonsai_q1_qmv_legacy` directly so it actually exercises the legacy kernel. The dispatcher defaults to fast, and setting the env var in-test is unreliable because `use_fast_qmv` caches it in a `OnceLock` on the first dispatcher call anywhere in the process. Makes `bonsai_q1_qmv_legacy` `pub` to match `bonsai_q1_qmv_fast`. - `BonsaiQ1GpuLinear::forward`: reject inputs whose trailing dim isn't `in_features` instead of silently deriving the wrong `m` from the element count. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
|
Thanks @coderabbitai — addressed all three in d510212:
clippy + fmt clean; both oracle tests pass. |
|
Tip For best results, initiate chat on the files or code changes.
(ノ◕ヮ◕)ノ*:・゚✧ If you found this review helpful, would you consider giving us a shout-out on X? Thank you for using CodeRabbit! |
Problem
The Bonsai-Q1 (
bits=1) engine merged in #142 cannot run on the pinnedoxideai/mlx-rs(f4aa309c): decode callsops::quantized_matmul(…, bits=1), which MLX gates tobits >= 2, soaffine_dequantize_*_b_1is never compiled and the model errors at the first decode step withUnable to load kernel affine_dequantize_<dtype>_b_1. The in-tree engine is inert. Fixes #181.Approach
Add the missing
bits=1matvec + dequant kernels from this crate via mlx-c's runtime-JIT facility (mlx_fast_metal_kernel, already compiled intomlx-sys) — no mlx-rs fork, no native rebuild; stays on the official pin, kernels JIT-compile at first use. The decode matvec is aqmv_fast-class kernel: 4 output rows per simdgroup with register-resident input reused across rows, and aselect-based 1-bit unpack (the decisive perf lever — tiling alone was a no-op; decode is compute-bound on the bit extraction).Two commits:
feat(bonsai): run Bonsai-Q1 bits=1 on vanilla MLX via JIT Metal kernels— activates the engine (correctness).perf(bonsai): qmv_fast-class 1-bit decode kernel— the tuned decode kernel (perf).Evidence
M4 base (32 GB), greedy, vs the PrismML
mlx-lmreference on the same machine (controlled, interleaved, 3 trials), production serving path:The 8B figure is cross-validated by two independent methods — the engine's per-step
total_msand an HTTP wall-clock differencing 32 vs 256 real generated tokens — which agree. This also beats the fork-nativeqmv_fastpath (which needs a forked mlx-rs chain + a from-source rebuild) with no fork.Test plan
qmv_fast_kernel_matches_cpu_reference+qmv_kernel_matches_cpu_reference+dequant_kernel_matches_cpu_reference— bit-exact vs CPU reference (tail K<2048, main K=4096 block, N%4 lm_head remainder).cargo clippy -p higgs-models -p higgs-engineclean;cargo fmt --checkclean.Risk
Low — purely additive (new kernels + a dispatcher). The fast kernel is bit-exact and default-on; opt back to the original per-row kernel with
HIGGS_BONSAI_QMV_KERNEL=legacy. No dependency change — stays on the officialoxideai/mlx-rspin.Out of scope
🤖 Generated with Claude Code
Summary by CodeRabbit
New Features
Bug Fixes
Tests