Skip to content

dusterbloom: feat(dflash): DFlash drafter foundation — module + tap APIs + dispatch#17

Open
dusterbloom wants to merge 24 commits into
mainfrom
dusterbloom/dflash-baseline
Open

dusterbloom: feat(dflash): DFlash drafter foundation — module + tap APIs + dispatch#17
dusterbloom wants to merge 24 commits into
mainfrom
dusterbloom/dflash-baseline

Conversation

@dusterbloom

Copy link
Copy Markdown
Owner

Summary

Lands the DFlash block-diffusion speculative-decoding foundation ported from feat/magic-canvas. This is PR-6a of the magic-canvas split — the module + model surgery + dispatch surface; the engine-level draft-verify loop (SimpleEngine::generate_dflash_inner) is deferred to PR-6b because it needs end-to-end runtime verification against a real DFlash drafter checkpoint that this PR can't exercise from CI.

What's in this PR

Three commits, all on a clean branch off origin/main:

Commit Adds Net lines
ad7edea1 crates/higgs-models/src/dflash.rs — the 0.5B drafter (config, dual-stream attention, GDN-state save/restore, KV-only rollback, drafter loader) +561
62a73522 forward_with_taps, forward_with_taps_stateless, forward_with_taps_tape, replay_tape_rollback, embed_token_ids, forward_all_logits_from_hidden, project_logits on Qwen3NextCausalLM; forward_stateless, forward_with_tape, replay_from_tape on GatedDeltaNet; pub struct GdnLayerTape; Metal kernel FFI (tape_replay_kernel_ffi, gated_delta_kernel_ffi_with_tape, gated_delta_kernel_ffi_stateless) +1657
1c3c0d37 AnyModel::{forward_with_taps, forward_with_taps_tape, embed_token_ids, forward_all_logits_from_hidden} dispatchers; AnyCache::{as_hybrid, as_hybrid_mut}; dflash::accept_prefix (16-line greedy spec-decode acceptance + 5 unit tests); engine::model_loader::load_dflash_drafter +193

Total: ~2.4K lines net new, no public-API regressions on origin/main.

What's NOT in this PR (deferred)

  • SimpleEngine::generate_dflash_inner — the draft-verify loop. The feat/magic-canvas glue assumed a cpu_engine + ANE executor that we're stripping (defer to PR-8 ANE work) and a struct shape that has since evolved on origin/main. Best done with a Carnice-9B + 0.5B drafter checkpoint loaded so we can verify the verify-path correctness, not just the compile.
  • dflash_cpu.rs (CPU BLAS drafter) — depends on 7 BLAS helpers from diffusion.rs (~9970 lines, mostly unrelated). Defer to PR-6b together with the engine glue.
  • dflash_ane.rs (ANE-accelerated drafter) — feature-gated; PR-8 territory.
  • DFlash test suite (~3.8K lines on feat/magic-canvas) — depends on the engine glue being live. Will follow.

Adaptations from feat/magic-canvasorigin/main

  • SteppingKeyValueCache::rollback(i32) was renamed trim_by(usize) on origin/main (PR feat(cache): AnyCache::trim_by dispatcher for spec-decode rollback panbanda/higgs#143). Call sites in dflash.rs and replay_tape_rollback converted with unsigned_abs().try_into().unwrap_or(usize::MAX).
  • Qwen3NextCausalLM's lm_head_ane, dense_lm_head, ane_handle, ane_kernels fields don't exist on this branch — ANE-feature paths in project_logits and forward_with_tape stripped to plain Metal/MLX. Fields ported in PR-8.
  • FFI error handler uses thread_local! RefCell<Option<String>> (matching this branch's existing FFI pattern) rather than feat/magic-canvas's Mutex<Option<String>>.

Senior-Rust hygiene

  • No file-level blanket allows added. Origin/main's pre-existing #![allow(clippy::items_after_test_module)] preserved.
  • Function-scoped #[allow(...)] only on four genuinely-numerical kernel functions (forward_stateless, forward_with_tape, forward_with_taps_tape, replay_tape_rollback), each with a one-line documented justification (Metal-kernel dispatch, tensor-shape indices, hot-path casts).
  • unwrap_used never allowed — refactored to ? propagation throughout.
  • All non-Qwen3Next match arms in AnyModel enumerate variants explicitly (no _ => catch-alls).
  • clippy::type_complexity resolved with pub type TapsTapeOutput.

Test plan

  • cargo check -p higgs-models — clean
  • cargo clippy --all-targets --all-features -- -D warnings — clean
  • cargo fmt --check — clean
  • cargo test -p higgs-models --lib — 335/335 pass (5 new accept_prefix tests)
  • cargo test -p higgs-engine --lib — 228/228 pass
  • cargo test -p higgs --lib -- --test-threads=1 — 449/449 pass
  • Engine-level: defer to PR-6b with real drafter checkpoint

Context

This is part of the feat/magic-canvas PR split. PRs already shipped against panbanda upstream:

🤖 Generated with Claude Code

panbanda and others added 20 commits May 6, 2026 05:21
fix(deps): update rust crate toml to v1
…l-action-digest

chore(deps): update taiki-e/install-action digest to cca35ed
…file

chore(deps): update rust crate tokio to v1.52.2
…-lockfile

chore(deps): update rust crate tower-http to v0.6.9
…anbanda#143)

Adds AnyCache::trim_by to roll back KV layers for speculative decode while leaving hybrid Arrays state untouched.\n\nCI: https://github.com/panbanda/higgs/actions/runs/25312580791
…#148)

* feat(qwen3_next): mixed-bit Qwen3.5 GDN BA loading fallback

Adds a fallback path for loading Qwen3.5 models with mixed-bit GDN
projection weights (some layers q4, some q8 — common in unsloth's
dynamic-quant variants). The default fused-projection loader fuses
`in_proj_a` + `in_proj_b` into a single matmul; mixed-bit weights
have incompatible shapes and the fusion fails.

Behaviour:

  1. Detect via `is_mixed_bit_gdn_ba_fusion_error` — matches a
     `ModelError::ShapeMismatch` whose message contains both
     `in_proj_ba` and `requires separate GDN projections`.

  2. On detection, retry the load with
     `args.use_separate_gdn_projections = true`, taking the
     `load_qwen3_5_moe_weights_direct` path. Forward dispatches go from
     2 to 4 GDN ops per layer — slightly slower but correct.

  3. Forced separate (via `args.use_separate_gdn_projections` config or
     `HIGGS_SEPARATE_GDN_PROJ` env var) skips the fused attempt
     entirely.

Also adds:

  * `qwen3_5_quantization_config` — parses `{group_size, bits}` from
    the per-layer `quantization` map in `config.json`.
  * `qwen3_5_mixed_ba_quantization_layers` — scans for the layers
    where `in_proj_a` and `in_proj_b` differ in bits or group_size.
  * `can_concatenate_axis0` — guard used inside
    `load_qwen3_5_moe_weights_fused` to emit the diagnostic
    `ShapeMismatch` error rather than panicking on the concat.
  * `load_qwen3_5_model_with_gdn_fallback` — private helper called by
    both `load_qwen3_5_model` (dense) and `load_qwen3_5_moe_model`
    (MoE), unifying the fallback path.

Adaptations from feat/magic-canvas → origin/main:

  * The dense `load_qwen3_5_model` previously only honoured the env
    var; now it honours `args.use_separate_gdn_projections` too,
    matching the MoE path. Strict improvement: the config flag is set
    only by the env var or by mixed-bit detection.

  * No `unwrap()`, no `as` casts (use `i32::try_from`); match arms
    enumerate variants. No file-level allows added.

Verification on origin/main (rustc 1.95.0):

  * `cargo check -p higgs-models` — clean
  * `cargo clippy --all-targets --all-features -- -D warnings` — clean
  * `cargo fmt --check` — clean
  * `cargo test -p higgs-models --lib` — 333/333 pass (3 new)

Source: feat/magic-canvas commit `061e500c`. Direct cherry-pick had 5
conflict regions because origin/main has evolved the load functions
independently; this is a manual surgical port that preserves
origin/main's structure while adding the fallback behaviour.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix(qwen3_next): preserve explicit GDN projection config

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: Jonathan Reyes <me@jonathanreyes.com>
…panbanda#141)

* perf(models): fused MoE gate+up — 3→2 expert matmuls per layer

SwitchMlpWeights::forward_gather_fused() lazy-concatenates gate+up
weights on first call, then dispatches a single gather_qmm instead
of two separate calls. FfnBlock::forward() now routes through the
fused path instead of forward_gather_global_sort().

Measured on 35B-A3B-3bit M4 base:
- S=1 decode:   27ms → 17ms  (−37%)
- S=16 verify: 253ms → 112ms (−56%)
- MoE/layer at K=1: 0.47ms (down from ~0.68ms)

Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>

* style: cargo fmt qwen3_next.rs

Reflow `let fw/fs/fb = ops::concatenate_axis(..)` from broken-line
indentation back onto single lines so `cargo fmt --all -- --check`
passes in CI.

* fix(clippy): backtick MoE/gather_qmm doc + safe top_k u32 cast

Two errors flagged by `-D clippy::doc-markdown` and
`-D clippy::cast-sign-loss`/`-D clippy::as-conversions`:

- Backtick `MoE` and `gather_qmm` in the `fused_gate_up` doc comment.
- Replace `top_k as u32` with the same `u32::try_from(top_k).map_err(...)`
  pattern already used by `forward_gather_global_sort`.

* fix(qwen3_next): gate MoE gate-up fusion behind opt-in

---------

Co-authored-by: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Jonathan Reyes <me@jonathanreyes.com>
…anda#142)

* feat(bonsai_q1): add upstream-guarded packed engine

* fix(bonsai-q1): address review feedback

---------

Co-authored-by: Jonathan Reyes <me@jonathanreyes.com>
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
* fix: default qwen3.6 to non-thinking mode

* fix: address qwen thinking review feedback
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
Adds `crates/higgs-models/src/dflash.rs` from feat/magic-canvas — the
0.5B drafter that produces 16 draft tokens per round via a single
non-causal forward pass on hidden states tapped from 5 target-model
layers.

Architecture (8 decoder layers, dual-stream attention) is verbatim from
the magic-canvas baseline `c1f85ade` (final stable state, before WIP
ANE work). Wire-up into `SimpleEngine` lands in the follow-up commit.

Adaptations from feat/magic-canvas → origin/main:

  * `SteppingKeyValueCache::rollback(i32)` was renamed `trim_by(usize)`
    on origin/main (PR panbanda#143). Two call sites converted with
    `unsigned_abs().try_into().unwrap_or(usize::MAX)`.

  * Workspace clippy (nursery: `as_conversions`,
    `cast_possible_truncation`, `doc_markdown`, `assigning_clones`,
    `explicit_iter_loop`, `unnecessary_cast`, `shadow_unrelated`,
    `redundant_pattern_matching`, `missing_const_for_fn`) — all 30
    errors fixed in-place: `i32::try_from` for tensor-shape casts,
    `clone_from` for in-place clones, `filter_map(Option::as_mut)` for
    `iter().filter_map(if-let)` patterns, backticks on doc items.
    No file-level allows.

The original DFlash test suite (~3.8K lines, 30+ end-to-end tests)
depends on tap APIs (`forward_with_taps_tape`, `replay_tape_rollback`,
`forward_all_logits_from_hidden`) and `crate::diffusion::accept_prefix`
that aren't on `origin/main` yet. Tests are deferred to a follow-up PR
alongside the qwen3_next tap-API surface — there's a comment block at
the bottom of `dflash.rs` flagging this.

Verification on origin/main:
  * `cargo check -p higgs-models` — clean
  * `cargo clippy --all-targets --all-features -- -D warnings` — clean
  * `cargo fmt --check` — clean
  * `cargo test -p higgs-models --lib` — 330/330 pass
  * `cargo test -p higgs --lib -- --test-threads=1` — 449/449 pass

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds the model-side surface that the DFlash drafter speculates against —
hidden-state taps during forward, GDN innovation tape for cheap rollback,
and helpers for embedding lookup + lm_head application in isolation.

Methods added on `Qwen3NextCausalLM`:

  * `forward_with_taps` — forward returning logits AND vec of hidden
    states at specified target layers; the drafter conditions on these.

  * `forward_with_taps_stateless` — same, but does NOT mutate the
    recurrent (GDN) state. Used during verify when state advancement is
    handled separately.

  * `forward_with_taps_tape` — forward that records each GDN layer's
    innovation into a `GdnLayerTape`. Enables ~5ms replay vs ~30ms
    rerun for partial-accept rollback.

  * `replay_tape_rollback` — restore GDN state to a tape position
    without re-running the full model.

  * `embed_token_ids` — apply the embedding layer alone (drafter input).

  * `forward_all_logits_from_hidden` — apply lm_head alone (target's
    verification of drafter outputs).

  * `project_logits` (private helper) — lm_head with origin/main's
    available projection paths only (ANE + dense_lm_head fields don't
    exist here yet; ported in PR-8).

Methods added on `GatedDeltaNet`:

  * `forward_stateless` — GDN forward without state mutation.

  * `forward_with_tape` — GDN forward that captures the per-step
    innovation into the tape.

  * `replay_from_tape` — apply a tape to recompute SSM state to a target
    position. Annotated `#[allow(dead_code)]` until the engine glue
    drives it (next commit).

New public type `GdnLayerTape` exposes the per-layer innovation record.

Metal kernel infrastructure ported alongside:

  * `tape_replay_kernel_ffi` + `TAPE_REPLAY_KERNEL` static + Metal source
  * `gated_delta_kernel_ffi_with_tape` + matching kernel
  * `gated_delta_kernel_ffi_stateless` (thin wrapper over existing FFI;
    discards the new state, matches caller semantics in `forward_stateless`)

Adaptations from feat/magic-canvas → origin/main:

  * `SteppingKeyValueCache::rollback(i32)` was renamed `trim_by(usize)`
    on origin/main (PR panbanda#143). Call site in `replay_tape_rollback`
    converted with `unsigned_abs().try_into().unwrap_or(usize::MAX)`.

  * `lm_head_ane`, `dense_lm_head`, `ane_handle`, `ane_kernels` fields
    don't exist on this branch — ANE-feature paths stripped to the
    plain Metal/MLX path. Fields ported in PR-8.

  * Error handler uses `thread_local! RefCell<Option<String>>` instead
    of feat/magic-canvas's `Mutex<Option<String>>` — matches the
    branch's existing FFI error pattern.

Senior-Rust hygiene:

  * No file-level blanket allows added.
  * Function-scoped `#[allow(...)]` on the four genuine numerical
    kernel functions (`forward_stateless`, `forward_with_tape`,
    `forward_with_taps_tape`, `replay_tape_rollback`), each with a
    one-line justification comment.
  * `unwrap_used` never allowed — refactored to `?` propagation or
    `expect("reason")` at the two call sites.
  * Mechanical clippy refactors throughout: `find_map` for
    `filter_map(..).next()`, `clone_from` for `assigning_clones`,
    `if let` for single-pattern `match`, backticks for `doc_markdown`.

Verification on origin/main:
  * `cargo check -p higgs-models` — clean
  * `cargo clippy --all-targets --all-features -- -D warnings` — clean
  * `cargo fmt --check` — clean
  * `cargo test -p higgs-models --lib` — 330/330 pass

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ader

Surfaces the qwen3_next tap APIs through the polymorphic `AnyModel`
enum so engine code can call them without matching variants directly,
adds the greedy speculative-decode acceptance helper, and exposes a
`load_dflash_drafter` entry point on the engine's model_loader.

`AnyModel` (in `higgs-models/src/lib.rs`):

  * `forward_with_taps` — dispatches Qwen3Next + Hybrid; errors otherwise
  * `forward_with_taps_tape` — same, returns `TapsTapeOutput` (logits +
    tap hiddens + per-layer GDN tape) via a public type alias to
    placate `clippy::type_complexity`
  * `embed_token_ids` — Qwen3Next-only
  * `forward_all_logits_from_hidden` — Qwen3Next-only

  All non-Qwen3Next arms enumerate every variant explicitly to satisfy
  `clippy::wildcard_enum_match_arm` (no `_ =>` catch-alls).

`AnyCache`:

  * `as_hybrid` / `as_hybrid_mut` — borrow the inner hybrid layer-cache
    slice/vec for engine glue that needs to inspect GDN state. Returns
    `Result<_, Exception>` rather than panicking when called on a `KV`
    cache, so the verify path in `SimpleEngine::generate_dflash_inner`
    can propagate via `?`.

`dflash::accept_prefix`:

  * Greedy speculative-decode acceptance: longest-matching prefix of
    `draft` against `verify_argmax`, plus one bonus token at the
    diverge point (or after the last accept).
  * 5 unit tests covering full match, first-token reject, partial
    match, empty draft, and the debug-only length assertion.
  * Inlined here rather than ported from `feat/magic-canvas:diffusion.rs`
    to avoid pulling in the 9970-line diffusion module for a 16-line
    helper.

`engine::model_loader::load_dflash_drafter`:

  * Thin `Result` adapter over `higgs_models::dflash::load_dflash_drafter`,
    converting `ModelError` → `EngineError`. The `SimpleEngine::load_with_dflash`
    call site lands in the next commit.

Verification on origin/main:
  * `cargo clippy --all-targets --all-features -- -D warnings` — clean
  * `cargo fmt --check` — clean
  * `cargo test -p higgs-models --lib` — 335/335 pass (5 new accept_prefix tests)
  * `cargo test -p higgs-engine --lib` — 228/228 pass
  * `cargo test -p higgs --lib -- --test-threads=1` — 449/449 pass

The remaining piece — `SimpleEngine::generate_dflash_inner` (the
draft-verify loop wired into `generate_inner`) — lands as a follow-up
commit. It needs end-to-end verification against a real DFlash drafter
checkpoint (Carnice-9B + 0.5B drafter); shipping it without that
runtime test would risk silent correctness regressions in the verify
path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…_match

CI's clippy is one minor version ahead of my local toolchain and flags
the `if rollback > 0 { ... }` body inside the `Some(LayerCache::KV(kv))`
match arm. Two call sites:

  * `dflash.rs` — `GdnStateBackup::restore_and_rollback`
  * `qwen3_next.rs` — `Qwen3NextCausalLM::replay_tape_rollback`

Convert to a match guard and add an explicit no-op arm for the
guard-fails-and-`None` case so the match is exhaustive without a
wildcard. No behaviour change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- GdnStateBackup now saves and restores ArraysCache::conv_pos alongside
  conv_state/ssm_state/offset. Without it, conv buffer state was corrupted
  after a verify-rollback.
- load_dflash_drafter uses ? directly so std::io::Error → ModelError::Io,
  serde_json::Error → ModelError::Json, Exception → ModelError::Mlx via
  the existing From impls (was wrapping all three as ModelError::Io).
- DEFAULT_DECODE_BLOCK_SIZE doc no longer claims a HIGGS_DFLASH_BLOCK_SIZE
  env override that was never implemented.
- DFlashConfig and its public fields now have brief one-line doc comments.

Validation: cargo fmt clean, cargo clippy clean (higgs-models, higgs-engine,
higgs). Full test run skipped (disk space too tight to relink test binary);
no existing tests exercise GdnStateBackup, conv_pos in paged_prefix_cache.rs
is a separate snapshot system unaffected by this change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
After rebasing PR-6a onto origin/main, the new `AnyModel::BonsaiQ1`
variant landed via panbanda#142 (merged) makes the explicit-enumeration arms
in `embed_token_ids` and `forward_all_logits_from_hidden` non-exhaustive
(E0004 x2).

Add `Self::BonsaiQ1(_)` to the existing union of non-Qwen3Next variants
that return the same "only implemented for Qwen3Next" Err — preserving
the DRY pattern (one error message per dispatcher across all 8 non-tap
variants) and the exhaustive-enumeration invariant (no `_ =>` catch-all,
so future variants will trip the same compiler check).

No runtime behaviour change: `BonsaiQ1` already cannot reach a DFlash
spec-decode flow (no tap API plumbing exists for it), so this just
formalises the rejection at the dispatcher boundary with a clear error.

Tests:
- cargo check -p higgs-models: green (was 2x E0004)
- cargo clippy -p {higgs-models, higgs-engine, higgs} --tests -- -Dwarnings: clean
- cargo test: 1030/1030 across three crates (no regression)
@dusterbloom dusterbloom force-pushed the dusterbloom/dflash-baseline branch from 42051b6 to 49c4bc4 Compare May 20, 2026 15:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants