feat(qwen35): Qwen3.6 MoE MTP speculative decode + checkpoint memory-safety fix#183
feat(qwen35): Qwen3.6 MoE MTP speculative decode + checkpoint memory-safety fix#183dusterbloom wants to merge 4 commits into
Conversation
Qwen3.6-A3B ships its MTP layer as a full MoE decoder layer (router gate +
256 stacked experts + shared expert/gate) with a quantized fc, both bundled
and as standalone drafter sidecars (e.g. mlx-community Qwen3.6-35B-A3B-MTP-
4bit). The existing MtpHead/DenseMtpHead are dense-MLP only, so these
checkpoints could not speculate.
- MoeMtpHead / MoeMtpTransformerLayer: full attention + SparseMoeBlock,
quantized fc (these sidecars ship fc.{weight,scales,biases} triples).
Constructed at the checkpoint's uniform quantization; the main model's
gate_quantization override is deliberately NOT applied (sidecar router
gates are quantized at the default width).
- Layout detection: MoE-structured MTP keys classify as MoeQuantized and
enable the head (use_moe_mtp). Truly unprefixed sidecar keys (fc.weight,
layers.0....) are mtp.-prefixed at detection and load time
(normalize_sidecar_mtp_key), so mlx-community drafters work as
mtp.safetensors drop-ins.
- Loading: mtp.* -> moe_mtp.* param remap through both the fused and direct
loaders; no dense rmsnorm adjustment for MoE targets.
- Forward: MoE branches in mtp_step_hidden and mtp_advance_many; has_mtp /
make_mtp_cache cover the new head.
Measured on Qwen3.6-35B-A3B-4bit + MTP drafter (M-series, kv-bits 4):
short/structured output 60-62 tok/s at 100% accept (vs 40 tok/s baseline,
+50%); long prose 37-41 tok/s at 60-71% accept (breakeven). Outputs verified
exact (drafts go through the verify path).
Tests: MoE layout classification, mtp->moe_mtp key remap, sidecar key
normalization (aux-only, idempotent).
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
… bits) Mixed-precision checkpoints (e.g. mlx-community OptiQ quants) assign different bit-widths per GDN projection on sensitive layers. Two loader gaps turned those into hard load failures: - The mixed-bit detector only compared the in_proj_a/in_proj_b pair; a mismatched in_proj_qkv/in_proj_z pair slipped through to the fused loader, which then failed concatenating packed shapes like (8192,256) vs (4096,512). Check both fusion pairs. - In separate-GDN mode the fused in_proj_qkvz/in_proj_ba QLinears are still constructed (as unused placeholders — the forward dispatches on use_separate_projections), and the direct loader's completeness check flagged them as missing weights, rejecting every checkpoint that *requires* separate projections. Exempt the unused fused placeholders. Note: fully running OptiQ-style quants also needs per-projection quantization plumbed into every QLinear (their overrides span attention, shared experts, etc.) — this fix makes the GDN layer detection/loading correct and turns the failure mode from a crash into a clean report. Test: mixed in_proj_qkv/in_proj_z bits force separate GDN projections. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…tive-decode double-free Speculative-decode checkpoints shared MLX buffers with the live cache. The backbone rollback did `*cache = base_cache` (shallow clone) and the MTP head cache was shallow-cloned too; the live cache's in-place `slice_update` during verify then let MLX donate/free a buffer the checkpoint still held, double- freeing on drop — the `malloc: pointer being freed was not allocated` abort that crashed MTP decode at ~44 tokens. Backbone: roll KV layers back by offset (`trim_by`), never clone — no buffer aliasing. Only hybrid SSM/recurrent state (can't be offset-trimmed) still clone-restores via `AnyCache::deep_clone`. MTP head + hybrid clones: deep-copy via `deep_clone_mtp_cache` / `SteppingKeyValueCache::deep_clone`, which allocate fresh buffers. deep_clone itself was unsafe: `Array::deep_clone` copies straight from the buffer pointer (valid only once evaluated), but the cache stores lazy `slice_update` results — cloning read an unmaterialized pointer and segfaulted. `eval_deep_clone` forces eval first. Tests: 2 deep_clone unit tests (lazy-pointer + live-update independence); higgs-models lib suite 371 passed. Soak: Qwen3.6-35B-A3B MTP server, 5x200-tok requests = 432 mtp_cycle iterations / ~1300 cache deep-clones, accept 62-69%, zero aborts (RED crashed ~44). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
|
Warning Review limit reached
More reviews will be available in 1 hour. Learn how PR review limits work. Your organization has run out of usage credits. Purchase more credits in the billing tab to continue. ⌛ How to resolve this issue?After more reviews become available, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans include higher PR review limits than trial, open-source, and free plans. In all cases, reviews become available again over time. During sustained high-volume PR review activity, CodeRabbit may temporarily slow when the next review becomes available. Please see our Fair Usage Limits Policy for further information. ℹ️ Review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (2)
📝 WalkthroughWalkthroughThis PR implements safe speculative decoding with deep-cloned KV cache checkpoints and adds MoE-structured MTP head support. KV cache deep-cloning infrastructure materializes lazy arrays into independent buffers, enabling safe snapshot/restore semantics during speculative verification. MTP engine cycles now use this infrastructure instead of shallow cloning. Additionally, MoE MTP heads are now detected in checkpoints, loaded alongside existing variants, and executed with corresponding cache management and weight remapping. ChangesSafe speculative decoding with deep-cloned cache checkpoints and MoE MTP support
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
crates/higgs-models/src/qwen3_next.rs (1)
4527-4529:⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift
load_qwen3_next_modelstill skips themtp.*→moe_mtp.*remap.After
MoeQuantizeddetection, this path constructsmoe_mtp, then immediately calls the generic safetensor loader under the assumption that checkpoint keys already match model params exactly. They do not: the new MoE sidecars still live under themtp.*namespace, and the only remap helper (moe_mtp_param_key) is wired into the qwen3.5 direct/fused loaders further down in this file. That means plain Qwen3Next/Qwen3.6-A3B loads will miss the MoE MTP weights or fail the completeness check.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@crates/higgs-models/src/qwen3_next.rs` around lines 4527 - 4529, The load_qwen3_next_model path detects MoeQuantized and constructs the moe_mtp sidecar but then calls crate::load_safetensors_weights assuming keys match; however safetensor checkpoints still use the mtp.* namespace so the moe MTP params are never remapped or loaded. Update load_qwen3_next_model to apply the same remapping used elsewhere: use the existing moe_mtp_param_key helper (or equivalent remap function) to translate incoming safetensor keys from mtp.* → moe_mtp.* before calling crate::load_safetensors_weights (or pass a key-remap callback into that loader) so the MoE MTP weights are picked up and the completeness check succeeds.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@crates/higgs-models/src/qwen3_next.rs`:
- Around line 4961-4970: The function normalize_sidecar_mtp_key is
over-prefixing namespaced keys (e.g. "language_model.mtp.layers..."); change the
condition so you only add "mtp." for auxiliary files when the key is truly
un-namespaced—i.e., it does not already start with "mtp." and does not contain
any '.' namespace separator. Update normalize_sidecar_mtp_key to check is_aux &&
!key.starts_with("mtp.") && !key.contains('.') before returning
format!("mtp.{key}"), so qwen35_checkpoint_param_key can still recognize and
strip/remap legitimately namespaced keys.
---
Outside diff comments:
In `@crates/higgs-models/src/qwen3_next.rs`:
- Around line 4527-4529: The load_qwen3_next_model path detects MoeQuantized and
constructs the moe_mtp sidecar but then calls crate::load_safetensors_weights
assuming keys match; however safetensor checkpoints still use the mtp.*
namespace so the moe MTP params are never remapped or loaded. Update
load_qwen3_next_model to apply the same remapping used elsewhere: use the
existing moe_mtp_param_key helper (or equivalent remap function) to translate
incoming safetensor keys from mtp.* → moe_mtp.* before calling
crate::load_safetensors_weights (or pass a key-remap callback into that loader)
so the MoE MTP weights are picked up and the completeness check succeeds.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: c6a38926-c3b4-41e5-8efd-f9f189956210
📒 Files selected for processing (4)
crates/higgs-engine/src/mtp.rscrates/higgs-models/src/cache.rscrates/higgs-models/src/lib.rscrates/higgs-models/src/qwen3_next.rs
CodeRabbit review:
- normalize_sidecar_mtp_key: only prefix truly un-namespaced sidecar keys.
Gate on !is_mtp_key() instead of !starts_with("mtp.") so already-namespaced
keys (e.g. language_model.mtp.*) aren't mangled into unmatchable
mtp.language_model.mtp.*. Extends the existing unit test to cover it.
- load_qwen3_next_model: route through a new MTP-aware load_qwen3_next_weights
instead of the plain loader. maybe_disable_mtp_without_checkpoint_weights can
select the dense/MoE head (params dense_mtp.* / moe_mtp.*) while the checkpoint
ships the head under mtp.*; the plain loader did no remap and silently left the
draft head uninitialized. Backbone keys still match directly, so behaviour is
unchanged for the common Quantized layout.
Lint CI was failing at the fmt step, masking clippy -Dwarnings errors that the
feature commits introduced. Fixed so the full Lint job passes:
- cargo fmt (AUXILIARY_SAFETENSORS_FILES wrap, in_proj filter closure)
- clippy::shadow_reuse on the three sidecar-key loaders (file convention)
- backtick bare `MoE` in MoE-MTP doc comments (doc_markdown)
- LayerCache wildcard -> explicit Arrays(_) arm (match_wildcard_for_single_variants)
- allow large_enum_variant on AnyModel (singleton dispatch handle; boxing every
variant would add forward-path indirection for no real benefit)
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
What
Adds Qwen3.6-A3B (MoE) multi-token-prediction speculative decode, and fixes a memory-safety bug in the speculative-decode checkpoint path.
Three commits:
feat(qwen35): MoE-structured MTP head— loads themtp.safetensorssidecar as aMoeMtpHeadand wires it into the speculative-decode loop for Qwen3.6-A3B-style checkpoints.fix(qwen35): survive mixed-bit GDN quants— handles OptiQ-style per-projection bit widths so mixed-bit GDN checkpoints load instead of erroring.fix(mtp): eval-before-deep_clone + trim-based rollback— the memory-safety fix (below).The memory-safety bug
Speculative-decode checkpoints shared MLX buffers with the live cache. The backbone rollback did
*cache = base_cache(a shallowclone()), and the MTP head cache was shallow-cloned too. The live cache's in-placeslice_updateduring verify then let MLX donate/free a buffer the checkpoint still referenced, double-freeing on drop — themalloc: pointer being freed was not allocatedabort that crashed MTP decode at ~44 tokens.Fix:
trim_by), never clone — no buffer aliasing. Only hybrid SSM/recurrent state (can't be offset-trimmed) still clone-restores, viaAnyCache::deep_clone.deep_clone_mtp_cache/SteppingKeyValueCache::deep_clone, which allocate fresh buffers.deep_cloneitself was unsafe:Array::deep_clonecopies straight from the buffer pointer (valid only once evaluated), but the cache stores lazyslice_updateresults — cloning read an unmaterialized pointer and segfaulted.eval_deep_cloneforces eval first.Testing
cargo test -p higgs-models --lib— 356 passed, 0 failed.deep_clone_preserves_contents_and_offset(faithful copy) anddeep_clone_checkpoint_survives_live_in_place_update(independence under a live in-place update). Both fail (segfault) without the fix.mtp_cycleiterations / ~1300 cache deep-clones, accept-rate 62–69%, zero aborts. Pre-fix this aborted at ~44 tokens.Summary by CodeRabbit
New Features
Bug Fixes
Refactor
Tests