Skip to content

feat(bonsai): run Bonsai-Q1 (bits=1) on vanilla MLX via JIT Metal kernels + projection fusion#21

Open
dusterbloom wants to merge 44 commits into
mainfrom
dusterbloom/bonsai-enable
Open

feat(bonsai): run Bonsai-Q1 (bits=1) on vanilla MLX via JIT Metal kernels + projection fusion#21
dusterbloom wants to merge 44 commits into
mainfrom
dusterbloom/bonsai-enable

Conversation

@dusterbloom

Copy link
Copy Markdown
Owner

What

Runs Bonsai-Q1 (1.25-bpw, bits=1 affine) checkpoints on the stock oxideai/mlx-rs pin — no fork, no forked metallib, no extra native rebuild — by adding the missing bits=1 kernels from our own crate at runtime.

Why

Stock MLX gates affine quant to bits >= 2, so ops::quantized_matmul / ops::dequantize with bits=1 fail at runtime (Unable to load kernel affine_dequantize_*_b_1). The only prior way to run Bonsai-Q1 was a forked mlx-rs that adds the _b_1 kernels, forcing a full from-source mlx-c recompile.

How

mlx-c already ships the runtime JIT facility (mlx_fast_metal_kernel_*), compiled into mlx-sys and reachable over the -sys FFI. We use it (same proven pattern as qwen3_next's qgemv_4bit) to JIT-compile two bits=1 Metal kernels from higgs-models:

  • bonsai_q1_qmv — fused 1-bit quantized matvec (decode hot path): one simdgroup per output row over the packed weights, simd_sum reduce.
  • bonsai_q1_dequant — packed → dense f16 (embedding gather + prefill matmul).

Wired into BonsaiQ1GpuLinear::forward (decode uses the fused matvec; prefill dequantizes to dense then matmul) and embed_rows. The unpack/affine math mirrors the in-repo CPU reference PackedQ1Linear::dequant_row_to_fp32.

Second commit fuses q/k/v and gate/up into single matvec dispatches (one fused matvec instead of three / two per layer; output split by slicing).

Results (M4, vanilla MLX, no fork)

Model Decode Residency
Bonsai-1.7B-mlx-1bit ~115 tok/s 256 MB
Bonsai-8B-mlx-1bit ~29 tok/s 1.22 GB

Coherent output on both; matches the historical-best architecture (the fp16-attention fix already in main), now fork-free.

Testing

  • New oracle unit tests validate both kernels against the CPU reference within f16 tolerance (cargo test -p higgs-models bonsai_q1::tests).
  • model_loader routing test updated (Bonsai-Q1 now routes to the packed engine instead of being rejected).
  • clippy (nursery) + fmt clean; end-to-end coherence + tok/s verified by serving both models.

Scope notes

  • Build stays fast (no mlx-sys recompile — kernels JIT at runtime).
  • A compiled-decode (compile_with_state) path was prototyped and rejected as a measured negative result (~55% slower: threading the whole model as compile state + fixed-slab masked attention exceeds the per-token graph-rebuild it removes) — not included.

🤖 Generated with Claude Code

renovate Bot and others added 30 commits April 29, 2026 03:05
fix(deps): update rust crate toml to v1
…l-action-digest

chore(deps): update taiki-e/install-action digest to cca35ed
…file

chore(deps): update rust crate tokio to v1.52.2
…-lockfile

chore(deps): update rust crate tower-http to v0.6.9
…anbanda#143)

Adds AnyCache::trim_by to roll back KV layers for speculative decode while leaving hybrid Arrays state untouched.\n\nCI: https://github.com/panbanda/higgs/actions/runs/25312580791
…#148)

* feat(qwen3_next): mixed-bit Qwen3.5 GDN BA loading fallback

Adds a fallback path for loading Qwen3.5 models with mixed-bit GDN
projection weights (some layers q4, some q8 — common in unsloth's
dynamic-quant variants). The default fused-projection loader fuses
`in_proj_a` + `in_proj_b` into a single matmul; mixed-bit weights
have incompatible shapes and the fusion fails.

Behaviour:

  1. Detect via `is_mixed_bit_gdn_ba_fusion_error` — matches a
     `ModelError::ShapeMismatch` whose message contains both
     `in_proj_ba` and `requires separate GDN projections`.

  2. On detection, retry the load with
     `args.use_separate_gdn_projections = true`, taking the
     `load_qwen3_5_moe_weights_direct` path. Forward dispatches go from
     2 to 4 GDN ops per layer — slightly slower but correct.

  3. Forced separate (via `args.use_separate_gdn_projections` config or
     `HIGGS_SEPARATE_GDN_PROJ` env var) skips the fused attempt
     entirely.

Also adds:

  * `qwen3_5_quantization_config` — parses `{group_size, bits}` from
    the per-layer `quantization` map in `config.json`.
  * `qwen3_5_mixed_ba_quantization_layers` — scans for the layers
    where `in_proj_a` and `in_proj_b` differ in bits or group_size.
  * `can_concatenate_axis0` — guard used inside
    `load_qwen3_5_moe_weights_fused` to emit the diagnostic
    `ShapeMismatch` error rather than panicking on the concat.
  * `load_qwen3_5_model_with_gdn_fallback` — private helper called by
    both `load_qwen3_5_model` (dense) and `load_qwen3_5_moe_model`
    (MoE), unifying the fallback path.

Adaptations from feat/magic-canvas → origin/main:

  * The dense `load_qwen3_5_model` previously only honoured the env
    var; now it honours `args.use_separate_gdn_projections` too,
    matching the MoE path. Strict improvement: the config flag is set
    only by the env var or by mixed-bit detection.

  * No `unwrap()`, no `as` casts (use `i32::try_from`); match arms
    enumerate variants. No file-level allows added.

Verification on origin/main (rustc 1.95.0):

  * `cargo check -p higgs-models` — clean
  * `cargo clippy --all-targets --all-features -- -D warnings` — clean
  * `cargo fmt --check` — clean
  * `cargo test -p higgs-models --lib` — 333/333 pass (3 new)

Source: feat/magic-canvas commit `061e500c`. Direct cherry-pick had 5
conflict regions because origin/main has evolved the load functions
independently; this is a manual surgical port that preserves
origin/main's structure while adding the fallback behaviour.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix(qwen3_next): preserve explicit GDN projection config

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: Jonathan Reyes <me@jonathanreyes.com>
…panbanda#141)

* perf(models): fused MoE gate+up — 3→2 expert matmuls per layer

SwitchMlpWeights::forward_gather_fused() lazy-concatenates gate+up
weights on first call, then dispatches a single gather_qmm instead
of two separate calls. FfnBlock::forward() now routes through the
fused path instead of forward_gather_global_sort().

Measured on 35B-A3B-3bit M4 base:
- S=1 decode:   27ms → 17ms  (−37%)
- S=16 verify: 253ms → 112ms (−56%)
- MoE/layer at K=1: 0.47ms (down from ~0.68ms)

Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>

* style: cargo fmt qwen3_next.rs

Reflow `let fw/fs/fb = ops::concatenate_axis(..)` from broken-line
indentation back onto single lines so `cargo fmt --all -- --check`
passes in CI.

* fix(clippy): backtick MoE/gather_qmm doc + safe top_k u32 cast

Two errors flagged by `-D clippy::doc-markdown` and
`-D clippy::cast-sign-loss`/`-D clippy::as-conversions`:

- Backtick `MoE` and `gather_qmm` in the `fused_gate_up` doc comment.
- Replace `top_k as u32` with the same `u32::try_from(top_k).map_err(...)`
  pattern already used by `forward_gather_global_sort`.

* fix(qwen3_next): gate MoE gate-up fusion behind opt-in

---------

Co-authored-by: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Jonathan Reyes <me@jonathanreyes.com>
…anda#142)

* feat(bonsai_q1): add upstream-guarded packed engine

* fix(bonsai-q1): address review feedback

---------

Co-authored-by: Jonathan Reyes <me@jonathanreyes.com>
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
* fix: default qwen3.6 to non-thinking mode

* fix: address qwen thinking review feedback
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
* Add configurable MTP draft depth

* Address MTP PR feedback

* Support dense MTP sidecar checkpoints

* Fix dense MTP formatting

* Add architecture-neutral speculative drafts

* Fix speculative decode lint

* Sanitize local benchmark paths

* Optimize MTP speculative decoding

* Harden audit and benchmark metadata

* Address speculative review blockers

* Resolve remaining model review threads

* Enable CodeRabbit approval workflow
panbanda and others added 14 commits May 29, 2026 17:25
feat(mtp): release speculative decoding optimizations
Stock oxideai/mlx-rs ships no bits=1 affine kernel (MLX gates affine quant to
bits>=2), so ops::quantized_matmul / ops::dequantize with bits=1 failed at
runtime with 'Unable to load kernel affine_dequantize_*_b_1'. Rather than fork
mlx-rs (which forces a full from-source mlx-c rebuild), add the kernels from
this crate using mlx-c's runtime JIT facility (mlx_fast_metal_kernel), already
compiled into mlx-sys and reached over the -sys FFI. We stay on the official
oxideai pin with no extra native recompile; kernels JIT-compile at first use.

New crates/higgs-models/src/metal_kernel.rs:
- bonsai_q1_qmv: fused 1-bit quantized matvec (decode hot path) — one simdgroup
  per output row over the packed weights, simd_sum reduce.
- bonsai_q1_dequant: packed -> dense f16 (embedding gather + prefill matmul).
FFI plumbing mirrors the proven qgemv_4bit pattern in qwen3_next; the unpack/
affine math mirrors PackedQ1Linear::dequant_row_to_fp32.

Wired into BonsaiQ1GpuLinear::forward (decode uses qmv, prefill dequant+matmul)
and BonsaiQ1Gpu::embed_rows. model_loader comment corrected and the obsolete
bits=1 guard test repurposed to assert routing into the packed engine.

Verified on vanilla MLX (no fork, no native rebuild):
- Oracle unit tests vs CPU reference pass; clippy (nursery) + fmt clean.
- Bonsai-1.7B-mlx-1bit: ~113 tok/s decode (8.82 ms/step), coherent output.
- Bonsai-8B-mlx-1bit: ~28 tok/s decode (35.5 ms/step), coherent output;
  exercises K-chunking + nsg=16 (inter=12288) and the untied lm_head.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Concatenate q/k/v (and gate/up) packed weights row-wise at load so each layer
issues one fused matvec dispatch instead of three (and two), then splits the
output by slicing. Cuts per-token graph-build cost (forward_ms ~19% lower).
GPU eval is unchanged — it is memory/glue-bound, not projection-dispatch-bound —
so the end-to-end win is small (~1%); the main value is fewer ops for the
upcoming compiled-decode path. Output stays coherent on 1.7B and 8B; oracle
kernel tests still pass.

Also added an #[ignore] decode-section profiling test (forward_profiled) used to
localize the bottleneck.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants