Skip to content

feat(bonsai-q1): packed engine scaffold with upstream MLX guard#142

Merged
panbanda merged 2 commits into
panbanda:mainfrom
dusterbloom:dusterbloom/bonsai-q1-fp16
May 6, 2026
Merged

feat(bonsai-q1): packed engine scaffold with upstream MLX guard#142
panbanda merged 2 commits into
panbanda:mainfrom
dusterbloom:dusterbloom/bonsai-q1-fp16

Conversation

@dusterbloom

@dusterbloom dusterbloom commented May 4, 2026

Copy link
Copy Markdown
Contributor

Summary

Adds the Rust-side Bonsai-Q1 packed 1-bit Qwen3-shaped engine scaffold without changing the workspace MLX dependency away from upstream oxideai/mlx-rs.

Bonsai-Q1 configs are detected as:

  • model_type == "qwen3"
  • quantization.bits == 1
  • quantization.group_size == 128

Because the pinned upstream MLX revision does not yet provide bits=1 affine quantization kernels, higgs-engine now rejects detected Bonsai-Q1 checkpoints with an explicit unsupported-model error instead of routing them into the normal Qwen3 transformer path or requiring a contributor fork.

What lands in this PR

  • bonsai_q1::PackedQ1Linear, BonsaiQ1Engine, and BonsaiQ1Gpu Rust-side implementation
  • Shared yarn helpers with dtype-safe scalar handling
  • AnyModel::BonsaiQ1 dispatch/cache plumbing for future runtime enablement
  • SteppingKeyValueCache::key_value_arrays_mut for Updatable state borrowing
  • Bonsai-Q1 detection in model_loader with an upstream-MLX guard
  • docs/BONSAI_Q1.md documenting why runtime routing is held back

What does not land

  • No switch to dusterbloom/mlx-rs / dusterbloom/mlx-c / dusterbloom/mlx
  • No production Bonsai-Q1 routing until bits=1 affine support lands upstream
  • No E2E throughput claim for this guarded PR; earlier Bonsai decode numbers depended on the forked MLX stack

Test plan

  • cargo fmt --all -- --check
  • cargo test -p higgs-engine --lib bonsai -- --test-threads=1
  • cargo check -p higgs-models -p higgs-engine
  • git diff --check
  • cargo clippy --all-targets --all-features

Upstream MLX verification

Verified against upstream oxideai/mlx-rs rev f4aa309c79b6be35255ca7d34157dfc10d9ed4c9, not dusterbloom/mlx-rs. A repo scan found no dusterbloom/mlx* dependency references in Cargo.toml, Cargo.lock, crates, or docs; the only MLX git dependency references point at oxideai/mlx-rs.

This guarded PR verifies that the scaffold builds/tests cleanly and that Bonsai-Q1 checkpoints fail early with an explicit unsupported-model error on upstream MLX. Runtime Bonsai-Q1 inference is intentionally not claimed here because upstream MLX does not yet expose the required bits=1 affine kernels.

Notes

Summary by CodeRabbit

Release Notes

  • New Features

    • Added support for Bonsai-Q1 1-bit quantized models
    • Implemented YaRN positional encoding for enhanced model accuracy
    • Improved key-value cache operations for better memory efficiency
  • Documentation

    • Added Bonsai-Q1 model support documentation

@coderabbitai

coderabbitai Bot commented May 4, 2026

Copy link
Copy Markdown
Contributor
📝 Walkthrough

Walkthrough

Introduces Bonsai-Q1 1-bit quantization support across the Higgs engine and models crates, including CPU and GPU representations, forward and decode paths, YaRN RoPE utilities, KV cache mutation support, and runtime detection to reject unsupported configurations.

Changes

Bonsai-Q1 Architecture Integration

Layer / File(s) Summary
Dependencies & Module Foundation
crates/higgs-models/Cargo.toml, crates/higgs-models/src/lib.rs
Added half = "2.4" dependency for half-precision support. Declared public modules bonsai_q1 and yarn; extended AnyModel enum with BonsaiQ1(BonsaiQ1Gpu) variant.
YaRN RoPE Utilities
crates/higgs-models/src/yarn.rs
New module providing YaRN-scaled rope helpers: yarn_get_mscale, compute_yarn_freqs, apply_yarn_rope (crate-public); private correction dimension/range helpers for dynamic rope behavior with precomputation support and tests.
Bonsai-Q1 Core Engine
crates/higgs-models/src/bonsai_q1.rs
Complete quantization engine with: public loader load_bonsai_q1; CPU-side structs (PackedQ1Linear, BonsaiQ1LayerWeights, BonsaiQ1Config, BonsaiQ1Engine); GPU mirroring (BonsaiQ1GpuLinear, BonsaiQ1GpuLayer, BonsaiQ1Gpu); forward/logits paths with profiling; stateful decode via BonsaiQ1DecodeState and Updatable trait; helpers for safetensors loading and data conversion.
KV Cache Support
crates/higgs-models/src/cache.rs
Added key_value_arrays_mut method to SteppingKeyValueCache for simultaneous mutable references to internal key and value arrays.
Model Dispatch & Wiring
crates/higgs-models/src/lib.rs
Routed BonsaiQ1 through forward (KV cache path), forward_hidden (error), batched forward (error), and image_size (non-text); updated cache geometry and construction for KV cache layout; changed hidden_size from const fn to normal method.
Runtime Safety & Detection
crates/higgs-engine/src/model_loader.rs
Added is_bonsai_q1 helper to inspect config.json for 1-bit Bonsai-Q1 checkpoints; inserted pre-check in model loading to reject Bonsai-Q1 configs when runtime support unavailable; included unit tests for detection and error handling.
Documentation
docs/BONSAI_Q1.md
New documentation describing model metadata (model_type, quantization parameters, dependencies, and code location).

Sequence Diagram

sequenceDiagram
    participant Loader as Model Loader
    participant Config as Config Parser
    participant Engine as Bonsai-Q1 Engine
    participant GPU as GPU Runtime
    participant Cache as KV Cache

    Loader->>Config: Check is_bonsai_q1
    Config-->>Loader: Reject if unsupported
    Loader->>Engine: load_bonsai_q1(dir)
    Engine->>Engine: Parse config, load safetensors
    Engine->>Engine: Build CPU weights, embeddings
    Engine-->>Loader: BonsaiQ1Engine (CPU)
    Loader->>GPU: to_gpu() convert to GPU
    GPU->>GPU: Dequant weights, prep f16
    GPU->>GPU: Precompute YARN freqs
    GPU-->>Loader: BonsaiQ1Gpu (GPU-ready)
    Loader->>Cache: Create KV cache
    Cache-->>Loader: SteppingKeyValueCache
    Loader->>GPU: forward(inputs, cache)
    GPU->>GPU: QKV projection, RoPE, attention
    GPU->>Cache: Update via key_value_arrays_mut
    GPU-->>Loader: Logits
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

🐰 A quantum leap for model dreams,
Bonsai wisdom in 1-bit streams!
YaRN-spun ropes and caches deep,
GPU mirrors secrets keep—
Quantized joy, precise and fleet! 🚀

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat(bonsai-q1): packed engine scaffold with upstream MLX guard' accurately captures the main change: adding Bonsai-Q1 engine support with MLX fork dependencies and configuration validation.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@panbanda panbanda force-pushed the dusterbloom/bonsai-q1-fp16 branch from fad54d3 to 4af2603 Compare May 6, 2026 13:30
@panbanda panbanda changed the title feat(bonsai-q1): packed 1.25-bpw engine with fp16 attention path feat(bonsai-q1): packed engine scaffold with upstream MLX guard May 6, 2026

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (3)
crates/higgs-models/src/bonsai_q1.rs (1)

484-497: ⚡ Quick win

Suspicious fallback: rope_original_max_seq.unwrap_or(self.config.hidden).

When rope_yarn_factor > 1.0 but original_max_position_embeddings is missing from rope_scaling, this falls back to hidden_size — a hidden dimension is the wrong type to substitute for a position count. For Bonsai-8B (hidden=4096) it happens to land on a numerically plausible value, but any other architecture would silently produce incorrect YaRN correction-range frequencies with no warning.

Bonsai checkpoints always set this field, so this is latent today, but it's a footgun for future YaRN-scaled checkpoints. Either error out, or use a documented numeric default (e.g. 4096) with a comment explaining the choice.

♻️ Proposed fix — fail loudly on missing field
 let (yarn_freqs, yarn_mscale) = match self.config.rope_yarn_factor {
     Some(factor) if factor > 1.0 => {
-        let orig = i32::try_from(
-            self.config
-                .rope_original_max_seq
-                .unwrap_or(self.config.hidden),
-        )
-        .map_err(|_| Exception::custom("orig_max_seq overflows i32"))?;
+        let orig_seq = self.config.rope_original_max_seq.ok_or_else(|| {
+            Exception::custom(
+                "rope_yarn_factor > 1.0 requires rope_scaling.original_max_position_embeddings",
+            )
+        })?;
+        let orig = i32::try_from(orig_seq)
+            .map_err(|_| Exception::custom("orig_max_seq overflows i32"))?;
         let factor_f = factor as f32;
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@crates/higgs-models/src/bonsai_q1.rs` around lines 484 - 497, The current
YaRN branch uses rope_original_max_seq.unwrap_or(self.config.hidden), which
substitutes a hidden-size for a position count; change this to fail loudly or
provide a documented numeric default: when self.config.rope_yarn_factor > 1.0,
check whether self.config.rope_original_max_seq is Some and return an
Err(Exception::custom(...)) if missing (or explicitly use a constant like 4096
with a comment), then pass that validated orig into
compute_yarn_freqs(head_dim_i, base, factor_f, orig, 32.0, 1.0) and keep
yarn_get_mscale(factor_f, 1.0) unchanged; update the code around
rope_yarn_factor, rope_original_max_seq, compute_yarn_freqs and yarn_get_mscale
accordingly so we don’t silently use self.config.hidden as a position count.
crates/higgs-engine/src/model_loader.rs (1)

279-304: ⚡ Quick win

Add a negative test for bits != 1 (e.g., regular Q4 Qwen3 checkpoints).

The most important false-positive guard for is_bonsai_q1 is a Qwen3 checkpoint with a quantization block where bits != 1 (the common Q4 case). The current tests cover wrong model_type and wrong group_size, but not wrong bits. A regression that makes is_bonsai_q1 accept bits=4 would silently start rejecting all standard quantized Qwen3 models with the "MLX bits=1" error.

♻️ Proposed additional test case
         assert!(!is_bonsai_q1(wrong_group_dir.path()).unwrap());
+
+        let (q4_dir, _q4_result) = config_from_raw(
+            r#"{
+                "model_type": "qwen3",
+                "quantization": {"bits": 4, "group_size": 128}
+            }"#,
+        );
+        assert!(
+            !is_bonsai_q1(q4_dir.path()).unwrap(),
+            "regular Q4 Qwen3 must not be misclassified as Bonsai-Q1"
+        );
     }
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@crates/higgs-engine/src/model_loader.rs` around lines 279 - 304, Add a
negative test case to the existing
is_bonsai_q1_requires_qwen3_model_type_and_group_size test (or as a new test)
that constructs a qwen3 config where quantization.bits != 1 (e.g., bits: 4) and
asserts that is_bonsai_q1(path) returns false; update the test harness using
config_from_raw and reuse the pattern with a variable like (qwen4_dir,
_qwen4_result) to call is_bonsai_q1 and
assert!(!is_bonsai_q1(qwen4_dir.path()).unwrap()) so the function is guarded
against accepting non-1 bits Qwen3 checkpoints.
crates/higgs-models/src/lib.rs (1)

220-222: 💤 Low value

Silent mask-drop is a documented trap — consider warning on non-None.

The inline comment is good, but callers passing a non-causal mask (e.g., a packing mask for batched-prefill or a prefix-cache-shifted mask) will get silently incorrect attention. Since BonsaiQ1 is gated behind the upstream-MLX-fork requirement and only used through the engine layer today, this is unlikely to bite, but a tracing::warn! (rate-limited) when mask.is_some() would surface the misuse without breaking the API.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@crates/higgs-models/src/lib.rs` around lines 220 - 222, In the match arm
handling (Self::BonsaiQ1(m), AnyCache::KV(c)) where you call m.forward(inputs,
c), add a conditional check for the incoming mask (e.g., if mask.is_some()) and
emit a rate-limited tracing::warn! that the provided mask will be ignored
because BonsaiQ1 builds its own causal mask; keep behavior unchanged but log the
misuse to help callers. Locate the match arm in lib.rs (the BonsaiQ1 /
AnyCache::KV branch) and insert the warning before invoking m.forward, using a
rate limiter or tracing's throttling helpers to avoid log spam.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@crates/higgs-models/src/lib.rs`:
- Around line 243-245: The BonsaiQ1 match arm currently returns an error
preventing forward_chunked from calling self.forward_hidden on intermediate
chunks; change that arm to delegate to the existing final-norm computation in
forward_trunk_free so chunked prefill works for long prompts—i.e. in the match
arm matching (Self::BonsaiQ1(_), AnyCache::KV(_)) call self.forward_trunk_free
with the same arguments used by forward_hidden (ensuring you apply the same
gpu.final_norm / rms_norm path) and return its result instead of
Err(Exception::custom(...)); this keeps BonsaiQ1 behavior consistent with
forward_trunk_free and allows forward_chunked to progress across intermediate
chunks.

In `@crates/higgs-models/src/yarn.rs`:
- Around line 14-18: Replace the i16 round-trip casts that silently saturate
values with direct as f32 casts: in yarn_find_correction_dim replace
f32::from(i16::try_from(dim).unwrap_or(i16::MAX)) and
f32::from(i16::try_from(max_pos).unwrap_or(i16::MAX)) with dim as f32 and
max_pos as f32 respectively; apply the same pattern inside compute_yarn_freqs
for dim_f, the "2 * i" frequency computation, and any low/high/idx index-to-f32
conversions (use i as f32, idx as f32, etc.) to avoid truncation, and make the
analogous fixes in deepseek_v2.rs's compute_yarn_freqs; the file-level
cast_precision_loss lint is already allowed so use plain as f32 casts.

---

Nitpick comments:
In `@crates/higgs-engine/src/model_loader.rs`:
- Around line 279-304: Add a negative test case to the existing
is_bonsai_q1_requires_qwen3_model_type_and_group_size test (or as a new test)
that constructs a qwen3 config where quantization.bits != 1 (e.g., bits: 4) and
asserts that is_bonsai_q1(path) returns false; update the test harness using
config_from_raw and reuse the pattern with a variable like (qwen4_dir,
_qwen4_result) to call is_bonsai_q1 and
assert!(!is_bonsai_q1(qwen4_dir.path()).unwrap()) so the function is guarded
against accepting non-1 bits Qwen3 checkpoints.

In `@crates/higgs-models/src/bonsai_q1.rs`:
- Around line 484-497: The current YaRN branch uses
rope_original_max_seq.unwrap_or(self.config.hidden), which substitutes a
hidden-size for a position count; change this to fail loudly or provide a
documented numeric default: when self.config.rope_yarn_factor > 1.0, check
whether self.config.rope_original_max_seq is Some and return an
Err(Exception::custom(...)) if missing (or explicitly use a constant like 4096
with a comment), then pass that validated orig into
compute_yarn_freqs(head_dim_i, base, factor_f, orig, 32.0, 1.0) and keep
yarn_get_mscale(factor_f, 1.0) unchanged; update the code around
rope_yarn_factor, rope_original_max_seq, compute_yarn_freqs and yarn_get_mscale
accordingly so we don’t silently use self.config.hidden as a position count.

In `@crates/higgs-models/src/lib.rs`:
- Around line 220-222: In the match arm handling (Self::BonsaiQ1(m),
AnyCache::KV(c)) where you call m.forward(inputs, c), add a conditional check
for the incoming mask (e.g., if mask.is_some()) and emit a rate-limited
tracing::warn! that the provided mask will be ignored because BonsaiQ1 builds
its own causal mask; keep behavior unchanged but log the misuse to help callers.
Locate the match arm in lib.rs (the BonsaiQ1 / AnyCache::KV branch) and insert
the warning before invoking m.forward, using a rate limiter or tracing's
throttling helpers to avoid log spam.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: fabe0441-75ac-47ab-848d-778845ef88ed

📥 Commits

Reviewing files that changed from the base of the PR and between 60d7cb4 and 4af2603.

⛔ Files ignored due to path filters (1)
  • Cargo.lock is excluded by !**/*.lock
📒 Files selected for processing (7)
  • crates/higgs-engine/src/model_loader.rs
  • crates/higgs-models/Cargo.toml
  • crates/higgs-models/src/bonsai_q1.rs
  • crates/higgs-models/src/cache.rs
  • crates/higgs-models/src/lib.rs
  • crates/higgs-models/src/yarn.rs
  • docs/BONSAI_Q1.md

Comment thread crates/higgs-models/src/lib.rs Outdated
Comment thread crates/higgs-models/src/yarn.rs
@panbanda

panbanda commented May 6, 2026

Copy link
Copy Markdown
Owner

Hope you don't mind me jumping in here-- if you prefer i can make a sister pr next time but was able to get this working on my machine without changing upstream, would love if you tried it out?

@panbanda panbanda merged commit fe43aab into panbanda:main May 6, 2026
6 checks passed
@github-actions github-actions Bot mentioned this pull request May 6, 2026
dusterbloom added a commit to dusterbloom/higgs that referenced this pull request May 20, 2026
After rebasing PR-6a onto origin/main, the new `AnyModel::BonsaiQ1`
variant landed via panbanda#142 (merged) makes the explicit-enumeration arms
in `embed_token_ids` and `forward_all_logits_from_hidden` non-exhaustive
(E0004 x2).

Add `Self::BonsaiQ1(_)` to the existing union of non-Qwen3Next variants
that return the same "only implemented for Qwen3Next" Err — preserving
the DRY pattern (one error message per dispatcher across all 8 non-tap
variants) and the exhaustive-enumeration invariant (no `_ =>` catch-all,
so future variants will trip the same compiler check).

No runtime behaviour change: `BonsaiQ1` already cannot reach a DFlash
spec-decode flow (no tap API plumbing exists for it), so this just
formalises the rejection at the dispatcher boundary with a clear error.

Tests:
- cargo check -p higgs-models: green (was 2x E0004)
- cargo clippy -p {higgs-models, higgs-engine, higgs} --tests -- -Dwarnings: clean
- cargo test: 1030/1030 across three crates (no regression)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants