feat(bonsai-q1): packed engine scaffold with upstream MLX guard#142
Conversation
📝 WalkthroughWalkthroughIntroduces Bonsai-Q1 1-bit quantization support across the Higgs engine and models crates, including CPU and GPU representations, forward and decode paths, YaRN RoPE utilities, KV cache mutation support, and runtime detection to reject unsupported configurations. ChangesBonsai-Q1 Architecture Integration
Sequence DiagramsequenceDiagram
participant Loader as Model Loader
participant Config as Config Parser
participant Engine as Bonsai-Q1 Engine
participant GPU as GPU Runtime
participant Cache as KV Cache
Loader->>Config: Check is_bonsai_q1
Config-->>Loader: Reject if unsupported
Loader->>Engine: load_bonsai_q1(dir)
Engine->>Engine: Parse config, load safetensors
Engine->>Engine: Build CPU weights, embeddings
Engine-->>Loader: BonsaiQ1Engine (CPU)
Loader->>GPU: to_gpu() convert to GPU
GPU->>GPU: Dequant weights, prep f16
GPU->>GPU: Precompute YARN freqs
GPU-->>Loader: BonsaiQ1Gpu (GPU-ready)
Loader->>Cache: Create KV cache
Cache-->>Loader: SteppingKeyValueCache
Loader->>GPU: forward(inputs, cache)
GPU->>GPU: QKV projection, RoPE, attention
GPU->>Cache: Update via key_value_arrays_mut
GPU-->>Loader: Logits
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
fad54d3 to
4af2603
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (3)
crates/higgs-models/src/bonsai_q1.rs (1)
484-497: ⚡ Quick winSuspicious fallback:
rope_original_max_seq.unwrap_or(self.config.hidden).When
rope_yarn_factor > 1.0butoriginal_max_position_embeddingsis missing fromrope_scaling, this falls back tohidden_size— a hidden dimension is the wrong type to substitute for a position count. For Bonsai-8B (hidden=4096) it happens to land on a numerically plausible value, but any other architecture would silently produce incorrect YaRN correction-range frequencies with no warning.Bonsai checkpoints always set this field, so this is latent today, but it's a footgun for future YaRN-scaled checkpoints. Either error out, or use a documented numeric default (e.g. 4096) with a comment explaining the choice.
♻️ Proposed fix — fail loudly on missing field
let (yarn_freqs, yarn_mscale) = match self.config.rope_yarn_factor { Some(factor) if factor > 1.0 => { - let orig = i32::try_from( - self.config - .rope_original_max_seq - .unwrap_or(self.config.hidden), - ) - .map_err(|_| Exception::custom("orig_max_seq overflows i32"))?; + let orig_seq = self.config.rope_original_max_seq.ok_or_else(|| { + Exception::custom( + "rope_yarn_factor > 1.0 requires rope_scaling.original_max_position_embeddings", + ) + })?; + let orig = i32::try_from(orig_seq) + .map_err(|_| Exception::custom("orig_max_seq overflows i32"))?; let factor_f = factor as f32;🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@crates/higgs-models/src/bonsai_q1.rs` around lines 484 - 497, The current YaRN branch uses rope_original_max_seq.unwrap_or(self.config.hidden), which substitutes a hidden-size for a position count; change this to fail loudly or provide a documented numeric default: when self.config.rope_yarn_factor > 1.0, check whether self.config.rope_original_max_seq is Some and return an Err(Exception::custom(...)) if missing (or explicitly use a constant like 4096 with a comment), then pass that validated orig into compute_yarn_freqs(head_dim_i, base, factor_f, orig, 32.0, 1.0) and keep yarn_get_mscale(factor_f, 1.0) unchanged; update the code around rope_yarn_factor, rope_original_max_seq, compute_yarn_freqs and yarn_get_mscale accordingly so we don’t silently use self.config.hidden as a position count.crates/higgs-engine/src/model_loader.rs (1)
279-304: ⚡ Quick winAdd a negative test for
bits != 1(e.g., regular Q4 Qwen3 checkpoints).The most important false-positive guard for
is_bonsai_q1is a Qwen3 checkpoint with aquantizationblock wherebits != 1(the common Q4 case). The current tests cover wrongmodel_typeand wronggroup_size, but not wrongbits. A regression that makesis_bonsai_q1accept bits=4 would silently start rejecting all standard quantized Qwen3 models with the "MLX bits=1" error.♻️ Proposed additional test case
assert!(!is_bonsai_q1(wrong_group_dir.path()).unwrap()); + + let (q4_dir, _q4_result) = config_from_raw( + r#"{ + "model_type": "qwen3", + "quantization": {"bits": 4, "group_size": 128} + }"#, + ); + assert!( + !is_bonsai_q1(q4_dir.path()).unwrap(), + "regular Q4 Qwen3 must not be misclassified as Bonsai-Q1" + ); }🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@crates/higgs-engine/src/model_loader.rs` around lines 279 - 304, Add a negative test case to the existing is_bonsai_q1_requires_qwen3_model_type_and_group_size test (or as a new test) that constructs a qwen3 config where quantization.bits != 1 (e.g., bits: 4) and asserts that is_bonsai_q1(path) returns false; update the test harness using config_from_raw and reuse the pattern with a variable like (qwen4_dir, _qwen4_result) to call is_bonsai_q1 and assert!(!is_bonsai_q1(qwen4_dir.path()).unwrap()) so the function is guarded against accepting non-1 bits Qwen3 checkpoints.crates/higgs-models/src/lib.rs (1)
220-222: 💤 Low valueSilent mask-drop is a documented trap — consider warning on non-None.
The inline comment is good, but callers passing a non-causal mask (e.g., a packing mask for batched-prefill or a prefix-cache-shifted mask) will get silently incorrect attention. Since BonsaiQ1 is gated behind the upstream-MLX-fork requirement and only used through the engine layer today, this is unlikely to bite, but a
tracing::warn!(rate-limited) whenmask.is_some()would surface the misuse without breaking the API.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@crates/higgs-models/src/lib.rs` around lines 220 - 222, In the match arm handling (Self::BonsaiQ1(m), AnyCache::KV(c)) where you call m.forward(inputs, c), add a conditional check for the incoming mask (e.g., if mask.is_some()) and emit a rate-limited tracing::warn! that the provided mask will be ignored because BonsaiQ1 builds its own causal mask; keep behavior unchanged but log the misuse to help callers. Locate the match arm in lib.rs (the BonsaiQ1 / AnyCache::KV branch) and insert the warning before invoking m.forward, using a rate limiter or tracing's throttling helpers to avoid log spam.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@crates/higgs-models/src/lib.rs`:
- Around line 243-245: The BonsaiQ1 match arm currently returns an error
preventing forward_chunked from calling self.forward_hidden on intermediate
chunks; change that arm to delegate to the existing final-norm computation in
forward_trunk_free so chunked prefill works for long prompts—i.e. in the match
arm matching (Self::BonsaiQ1(_), AnyCache::KV(_)) call self.forward_trunk_free
with the same arguments used by forward_hidden (ensuring you apply the same
gpu.final_norm / rms_norm path) and return its result instead of
Err(Exception::custom(...)); this keeps BonsaiQ1 behavior consistent with
forward_trunk_free and allows forward_chunked to progress across intermediate
chunks.
In `@crates/higgs-models/src/yarn.rs`:
- Around line 14-18: Replace the i16 round-trip casts that silently saturate
values with direct as f32 casts: in yarn_find_correction_dim replace
f32::from(i16::try_from(dim).unwrap_or(i16::MAX)) and
f32::from(i16::try_from(max_pos).unwrap_or(i16::MAX)) with dim as f32 and
max_pos as f32 respectively; apply the same pattern inside compute_yarn_freqs
for dim_f, the "2 * i" frequency computation, and any low/high/idx index-to-f32
conversions (use i as f32, idx as f32, etc.) to avoid truncation, and make the
analogous fixes in deepseek_v2.rs's compute_yarn_freqs; the file-level
cast_precision_loss lint is already allowed so use plain as f32 casts.
---
Nitpick comments:
In `@crates/higgs-engine/src/model_loader.rs`:
- Around line 279-304: Add a negative test case to the existing
is_bonsai_q1_requires_qwen3_model_type_and_group_size test (or as a new test)
that constructs a qwen3 config where quantization.bits != 1 (e.g., bits: 4) and
asserts that is_bonsai_q1(path) returns false; update the test harness using
config_from_raw and reuse the pattern with a variable like (qwen4_dir,
_qwen4_result) to call is_bonsai_q1 and
assert!(!is_bonsai_q1(qwen4_dir.path()).unwrap()) so the function is guarded
against accepting non-1 bits Qwen3 checkpoints.
In `@crates/higgs-models/src/bonsai_q1.rs`:
- Around line 484-497: The current YaRN branch uses
rope_original_max_seq.unwrap_or(self.config.hidden), which substitutes a
hidden-size for a position count; change this to fail loudly or provide a
documented numeric default: when self.config.rope_yarn_factor > 1.0, check
whether self.config.rope_original_max_seq is Some and return an
Err(Exception::custom(...)) if missing (or explicitly use a constant like 4096
with a comment), then pass that validated orig into
compute_yarn_freqs(head_dim_i, base, factor_f, orig, 32.0, 1.0) and keep
yarn_get_mscale(factor_f, 1.0) unchanged; update the code around
rope_yarn_factor, rope_original_max_seq, compute_yarn_freqs and yarn_get_mscale
accordingly so we don’t silently use self.config.hidden as a position count.
In `@crates/higgs-models/src/lib.rs`:
- Around line 220-222: In the match arm handling (Self::BonsaiQ1(m),
AnyCache::KV(c)) where you call m.forward(inputs, c), add a conditional check
for the incoming mask (e.g., if mask.is_some()) and emit a rate-limited
tracing::warn! that the provided mask will be ignored because BonsaiQ1 builds
its own causal mask; keep behavior unchanged but log the misuse to help callers.
Locate the match arm in lib.rs (the BonsaiQ1 / AnyCache::KV branch) and insert
the warning before invoking m.forward, using a rate limiter or tracing's
throttling helpers to avoid log spam.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: fabe0441-75ac-47ab-848d-778845ef88ed
⛔ Files ignored due to path filters (1)
Cargo.lockis excluded by!**/*.lock
📒 Files selected for processing (7)
crates/higgs-engine/src/model_loader.rscrates/higgs-models/Cargo.tomlcrates/higgs-models/src/bonsai_q1.rscrates/higgs-models/src/cache.rscrates/higgs-models/src/lib.rscrates/higgs-models/src/yarn.rsdocs/BONSAI_Q1.md
|
Hope you don't mind me jumping in here-- if you prefer i can make a sister pr next time but was able to get this working on my machine without changing upstream, would love if you tried it out? |
After rebasing PR-6a onto origin/main, the new `AnyModel::BonsaiQ1` variant landed via panbanda#142 (merged) makes the explicit-enumeration arms in `embed_token_ids` and `forward_all_logits_from_hidden` non-exhaustive (E0004 x2). Add `Self::BonsaiQ1(_)` to the existing union of non-Qwen3Next variants that return the same "only implemented for Qwen3Next" Err — preserving the DRY pattern (one error message per dispatcher across all 8 non-tap variants) and the exhaustive-enumeration invariant (no `_ =>` catch-all, so future variants will trip the same compiler check). No runtime behaviour change: `BonsaiQ1` already cannot reach a DFlash spec-decode flow (no tap API plumbing exists for it), so this just formalises the rejection at the dispatcher boundary with a clear error. Tests: - cargo check -p higgs-models: green (was 2x E0004) - cargo clippy -p {higgs-models, higgs-engine, higgs} --tests -- -Dwarnings: clean - cargo test: 1030/1030 across three crates (no regression)
Summary
Adds the Rust-side Bonsai-Q1 packed 1-bit Qwen3-shaped engine scaffold without changing the workspace MLX dependency away from upstream
oxideai/mlx-rs.Bonsai-Q1 configs are detected as:
model_type == "qwen3"quantization.bits == 1quantization.group_size == 128Because the pinned upstream MLX revision does not yet provide bits=1 affine quantization kernels,
higgs-enginenow rejects detected Bonsai-Q1 checkpoints with an explicit unsupported-model error instead of routing them into the normal Qwen3 transformer path or requiring a contributor fork.What lands in this PR
bonsai_q1::PackedQ1Linear,BonsaiQ1Engine, andBonsaiQ1GpuRust-side implementationyarnhelpers with dtype-safe scalar handlingAnyModel::BonsaiQ1dispatch/cache plumbing for future runtime enablementSteppingKeyValueCache::key_value_arrays_mutforUpdatablestate borrowingmodel_loaderwith an upstream-MLX guarddocs/BONSAI_Q1.mddocumenting why runtime routing is held backWhat does not land
dusterbloom/mlx-rs/dusterbloom/mlx-c/dusterbloom/mlxTest plan
cargo fmt --all -- --checkcargo test -p higgs-engine --lib bonsai -- --test-threads=1cargo check -p higgs-models -p higgs-enginegit diff --checkcargo clippy --all-targets --all-featuresUpstream MLX verification
Verified against upstream
oxideai/mlx-rsrevf4aa309c79b6be35255ca7d34157dfc10d9ed4c9, notdusterbloom/mlx-rs. A repo scan found nodusterbloom/mlx*dependency references inCargo.toml,Cargo.lock,crates, ordocs; the only MLX git dependency references point atoxideai/mlx-rs.This guarded PR verifies that the scaffold builds/tests cleanly and that Bonsai-Q1 checkpoints fail early with an explicit unsupported-model error on upstream MLX. Runtime Bonsai-Q1 inference is intentionally not claimed here because upstream MLX does not yet expose the required bits=1 affine kernels.
Notes
mainafter perf(models): opt-in fused MoE gate+up — 3→2 expert matmuls per layer #141/feat(cache): AnyCache::trim_by dispatcher for spec-decode rollback #143/feat(qwen3_next): mixed-bit Qwen3.5 GDN BA loading fallback #148.main, so the history no longer contains a fork-dependency switch commit.Summary by CodeRabbit
Release Notes
New Features
Documentation