feat(bonsai): run Bonsai-Q1 (bits=1) on vanilla MLX via JIT Metal kernels + projection fusion#21
Open
dusterbloom wants to merge 44 commits into
Open
feat(bonsai): run Bonsai-Q1 (bits=1) on vanilla MLX via JIT Metal kernels + projection fusion#21dusterbloom wants to merge 44 commits into
dusterbloom wants to merge 44 commits into
Conversation
fix(deps): update rust crate toml to v1
…l-action-digest chore(deps): update taiki-e/install-action digest to cca35ed
…file chore(deps): update rust crate tokio to v1.52.2
…-lockfile chore(deps): update rust crate tower-http to v0.6.9
…anbanda#143) Adds AnyCache::trim_by to roll back KV layers for speculative decode while leaving hybrid Arrays state untouched.\n\nCI: https://github.com/panbanda/higgs/actions/runs/25312580791
…#148) * feat(qwen3_next): mixed-bit Qwen3.5 GDN BA loading fallback Adds a fallback path for loading Qwen3.5 models with mixed-bit GDN projection weights (some layers q4, some q8 — common in unsloth's dynamic-quant variants). The default fused-projection loader fuses `in_proj_a` + `in_proj_b` into a single matmul; mixed-bit weights have incompatible shapes and the fusion fails. Behaviour: 1. Detect via `is_mixed_bit_gdn_ba_fusion_error` — matches a `ModelError::ShapeMismatch` whose message contains both `in_proj_ba` and `requires separate GDN projections`. 2. On detection, retry the load with `args.use_separate_gdn_projections = true`, taking the `load_qwen3_5_moe_weights_direct` path. Forward dispatches go from 2 to 4 GDN ops per layer — slightly slower but correct. 3. Forced separate (via `args.use_separate_gdn_projections` config or `HIGGS_SEPARATE_GDN_PROJ` env var) skips the fused attempt entirely. Also adds: * `qwen3_5_quantization_config` — parses `{group_size, bits}` from the per-layer `quantization` map in `config.json`. * `qwen3_5_mixed_ba_quantization_layers` — scans for the layers where `in_proj_a` and `in_proj_b` differ in bits or group_size. * `can_concatenate_axis0` — guard used inside `load_qwen3_5_moe_weights_fused` to emit the diagnostic `ShapeMismatch` error rather than panicking on the concat. * `load_qwen3_5_model_with_gdn_fallback` — private helper called by both `load_qwen3_5_model` (dense) and `load_qwen3_5_moe_model` (MoE), unifying the fallback path. Adaptations from feat/magic-canvas → origin/main: * The dense `load_qwen3_5_model` previously only honoured the env var; now it honours `args.use_separate_gdn_projections` too, matching the MoE path. Strict improvement: the config flag is set only by the env var or by mixed-bit detection. * No `unwrap()`, no `as` casts (use `i32::try_from`); match arms enumerate variants. No file-level allows added. Verification on origin/main (rustc 1.95.0): * `cargo check -p higgs-models` — clean * `cargo clippy --all-targets --all-features -- -D warnings` — clean * `cargo fmt --check` — clean * `cargo test -p higgs-models --lib` — 333/333 pass (3 new) Source: feat/magic-canvas commit `061e500c`. Direct cherry-pick had 5 conflict regions because origin/main has evolved the load functions independently; this is a manual surgical port that preserves origin/main's structure while adding the fallback behaviour. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix(qwen3_next): preserve explicit GDN projection config --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-authored-by: Jonathan Reyes <me@jonathanreyes.com>
…panbanda#141) * perf(models): fused MoE gate+up — 3→2 expert matmuls per layer SwitchMlpWeights::forward_gather_fused() lazy-concatenates gate+up weights on first call, then dispatches a single gather_qmm instead of two separate calls. FfnBlock::forward() now routes through the fused path instead of forward_gather_global_sort(). Measured on 35B-A3B-3bit M4 base: - S=1 decode: 27ms → 17ms (−37%) - S=16 verify: 253ms → 112ms (−56%) - MoE/layer at K=1: 0.47ms (down from ~0.68ms) Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com> * style: cargo fmt qwen3_next.rs Reflow `let fw/fs/fb = ops::concatenate_axis(..)` from broken-line indentation back onto single lines so `cargo fmt --all -- --check` passes in CI. * fix(clippy): backtick MoE/gather_qmm doc + safe top_k u32 cast Two errors flagged by `-D clippy::doc-markdown` and `-D clippy::cast-sign-loss`/`-D clippy::as-conversions`: - Backtick `MoE` and `gather_qmm` in the `fused_gate_up` doc comment. - Replace `top_k as u32` with the same `u32::try_from(top_k).map_err(...)` pattern already used by `forward_gather_global_sort`. * fix(qwen3_next): gate MoE gate-up fusion behind opt-in --------- Co-authored-by: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: Jonathan Reyes <me@jonathanreyes.com>
…anda#142) * feat(bonsai_q1): add upstream-guarded packed engine * fix(bonsai-q1): address review feedback --------- Co-authored-by: Jonathan Reyes <me@jonathanreyes.com>
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
* fix: default qwen3.6 to non-thinking mode * fix: address qwen thinking review feedback
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
* Add configurable MTP draft depth * Address MTP PR feedback * Support dense MTP sidecar checkpoints * Fix dense MTP formatting * Add architecture-neutral speculative drafts * Fix speculative decode lint * Sanitize local benchmark paths * Optimize MTP speculative decoding * Harden audit and benchmark metadata * Address speculative review blockers * Resolve remaining model review threads * Enable CodeRabbit approval workflow
Optimize MTP follow-up modes
…rib-2.x-lockfile
…l-action-digest
feat(mtp): release speculative decoding optimizations
…s--main chore: release main
Stock oxideai/mlx-rs ships no bits=1 affine kernel (MLX gates affine quant to bits>=2), so ops::quantized_matmul / ops::dequantize with bits=1 failed at runtime with 'Unable to load kernel affine_dequantize_*_b_1'. Rather than fork mlx-rs (which forces a full from-source mlx-c rebuild), add the kernels from this crate using mlx-c's runtime JIT facility (mlx_fast_metal_kernel), already compiled into mlx-sys and reached over the -sys FFI. We stay on the official oxideai pin with no extra native recompile; kernels JIT-compile at first use. New crates/higgs-models/src/metal_kernel.rs: - bonsai_q1_qmv: fused 1-bit quantized matvec (decode hot path) — one simdgroup per output row over the packed weights, simd_sum reduce. - bonsai_q1_dequant: packed -> dense f16 (embedding gather + prefill matmul). FFI plumbing mirrors the proven qgemv_4bit pattern in qwen3_next; the unpack/ affine math mirrors PackedQ1Linear::dequant_row_to_fp32. Wired into BonsaiQ1GpuLinear::forward (decode uses qmv, prefill dequant+matmul) and BonsaiQ1Gpu::embed_rows. model_loader comment corrected and the obsolete bits=1 guard test repurposed to assert routing into the packed engine. Verified on vanilla MLX (no fork, no native rebuild): - Oracle unit tests vs CPU reference pass; clippy (nursery) + fmt clean. - Bonsai-1.7B-mlx-1bit: ~113 tok/s decode (8.82 ms/step), coherent output. - Bonsai-8B-mlx-1bit: ~28 tok/s decode (35.5 ms/step), coherent output; exercises K-chunking + nsg=16 (inter=12288) and the untied lm_head. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Concatenate q/k/v (and gate/up) packed weights row-wise at load so each layer issues one fused matvec dispatch instead of three (and two), then splits the output by slicing. Cuts per-token graph-build cost (forward_ms ~19% lower). GPU eval is unchanged — it is memory/glue-bound, not projection-dispatch-bound — so the end-to-end win is small (~1%); the main value is fewer ops for the upcoming compiled-decode path. Output stays coherent on 1.7B and 8B; oracle kernel tests still pass. Also added an #[ignore] decode-section profiling test (forward_profiled) used to localize the bottleneck. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
Runs Bonsai-Q1 (1.25-bpw,
bits=1affine) checkpoints on the stockoxideai/mlx-rspin — no fork, no forked metallib, no extra native rebuild — by adding the missingbits=1kernels from our own crate at runtime.Why
Stock MLX gates affine quant to
bits >= 2, soops::quantized_matmul/ops::dequantizewithbits=1fail at runtime (Unable to load kernel affine_dequantize_*_b_1). The only prior way to run Bonsai-Q1 was a forked mlx-rs that adds the_b_1kernels, forcing a full from-source mlx-c recompile.How
mlx-c already ships the runtime JIT facility (
mlx_fast_metal_kernel_*), compiled intomlx-sysand reachable over the-sysFFI. We use it (same proven pattern asqwen3_next'sqgemv_4bit) to JIT-compile twobits=1Metal kernels fromhiggs-models:bonsai_q1_qmv— fused 1-bit quantized matvec (decode hot path): one simdgroup per output row over the packed weights,simd_sumreduce.bonsai_q1_dequant— packed → dense f16 (embedding gather + prefill matmul).Wired into
BonsaiQ1GpuLinear::forward(decode uses the fused matvec; prefill dequantizes to dense thenmatmul) andembed_rows. The unpack/affine math mirrors the in-repo CPU referencePackedQ1Linear::dequant_row_to_fp32.Second commit fuses q/k/v and gate/up into single matvec dispatches (one fused matvec instead of three / two per layer; output split by slicing).
Results (M4, vanilla MLX, no fork)
Coherent output on both; matches the historical-best architecture (the fp16-attention fix already in
main), now fork-free.Testing
cargo test -p higgs-models bonsai_q1::tests).model_loaderrouting test updated (Bonsai-Q1 now routes to the packed engine instead of being rejected).clippy(nursery) +fmtclean; end-to-end coherence + tok/s verified by serving both models.Scope notes
compile_with_state) path was prototyped and rejected as a measured negative result (~55% slower: threading the whole model as compile state + fixed-slab masked attention exceeds the per-token graph-rebuild it removes) — not included.🤖 Generated with Claude Code