Skip to content

feat(dflash): DFlash drafter foundation — module + tap APIs + dispatch#147

Open
dusterbloom wants to merge 6 commits into
panbanda:mainfrom
dusterbloom:dusterbloom/dflash-baseline
Open

feat(dflash): DFlash drafter foundation — module + tap APIs + dispatch#147
dusterbloom wants to merge 6 commits into
panbanda:mainfrom
dusterbloom:dusterbloom/dflash-baseline

Conversation

@dusterbloom

@dusterbloom dusterbloom commented May 5, 2026

Copy link
Copy Markdown
Contributor

Summary

Lands the DFlash block-diffusion speculative-decoding foundation. PR-6a of the feat/magic-canvas split — the model-side surface needed for DFlash; the engine-level draft-verify loop is deferred to PR-6b because it needs runtime verification against a real drafter checkpoint.

What's in this PR (3 commits, ~2.4K net lines)

Commit Adds
ad7edea1 DFlash drafter module: 0.5B block-diffusion model (config, dual-stream attention, GDN-state save/restore, KV-only rollback, drafter loader)
62a73522 Tap APIs on Qwen3NextCausalLM (forward_with_taps, forward_with_taps_tape, replay_tape_rollback, embed_token_ids, forward_all_logits_from_hidden, project_logits); GatedDeltaNet::{forward_stateless, forward_with_tape, replay_from_tape}; pub struct GdnLayerTape; Metal kernel FFI for tape replay
1c3c0d37 AnyModel/AnyCache polymorphic dispatchers; dflash::accept_prefix (greedy spec-decode acceptance + 5 unit tests); engine::model_loader::load_dflash_drafter

What's deferred to follow-ups

  • SimpleEngine::generate_dflash_inner (draft-verify loop) — needs Carnice-9B + 0.5B drafter checkpoint for runtime verification
  • dflash_cpu.rs (CPU BLAS drafter) — pulls in BLAS helpers from diffusion.rs; lands with engine glue in PR-6b
  • dflash_ane.rs (ANE-accelerated drafter) — feature-gated; PR-8 territory
  • DFlash test suite (~3.8K lines) — depends on engine glue being live

Adaptations to origin/main

  • SteppingKeyValueCache::rollback(i32)trim_by(usize) (renamed in feat(cache): AnyCache::trim_by dispatcher for spec-decode rollback #143)
  • ANE-feature paths in project_logits / forward_with_tape stripped — lm_head_ane/dense_lm_head/ane_handle/ane_kernels fields don't exist yet on main; ported in PR-8
  • FFI error handler uses thread_local! matching the existing local pattern (vs feat/magic-canvas's Mutex<Option<String>>)

Hygiene

  • No file-level blanket allows added (existing #![allow(clippy::items_after_test_module)] preserved)
  • Function-scoped #[allow(...)] only on the four genuinely-numerical Metal-kernel functions, each with a one-line justification
  • unwrap_used refactored to ? propagation throughout
  • All non-Qwen3Next match arms enumerate variants explicitly
  • pub type TapsTapeOutput resolves clippy::type_complexity

Test plan

  • cargo check -p higgs-models — clean
  • cargo clippy --all-targets --all-features -- -D warnings — clean
  • cargo fmt --check — clean
  • cargo test -p higgs-models --lib — 335/335 pass (5 new accept_prefix tests)
  • cargo test -p higgs-engine --lib — 228/228 pass
  • cargo test -p higgs --lib -- --test-threads=1 — 449/449 pass
  • Engine-level draft-verify loop: deferred to PR-6b (needs real DFlash drafter checkpoint)

Context

Part of the feat/magic-canvas PR split. Prior PRs in the chain: #141 (fused MoE), #142 (Bonsai-Q1 fp16 +2.83×), #143 (AnyCache::trim_by), #144 (DraftModel trait), #145 (PLD drafter, 1.84× headline).

🤖 Generated with Claude Code

Summary by CodeRabbit

  • New Features
    • Added speculative-decoding DFlash drafter to accelerate token generation.
    • Exposed tap-layer outputs and new model APIs to support hybrid inference and verification flows.
    • Added cache-management utilities for efficient verification, rollback, trimming, and hybrid cache handling.
  • Chores
    • Engine now provides a public loader to instantiate the DFlash drafter from model files.

Review Change Stack

@coderabbitai

coderabbitai Bot commented May 5, 2026

Copy link
Copy Markdown
Contributor
📝 Walkthrough

Walkthrough

Adds a DFlash speculative-decoding drafter: configuration, MLP/dual‑stream attention, decoder layers, drafter forward with per‑layer KV caches, GDN/KV backup and trimming utilities, model loader, accept_prefix helper with tests, and public AnyModel/AnyCache APIs plus engine loader wiring.

Changes

DFlash Speculative Decoding Implementation

Layer / File(s) Summary
Data Shape & Configuration
crates/higgs-models/src/dflash.rs
Defines DFlashConfig/DFlashSubConfig, accessors, defaults, and DEFAULT_DECODE_BLOCK_SIZE.
Core Primitives (MLP & Attention)
crates/higgs-models/src/dflash.rs
Implements DFlashMLP (SwiGLU-like) and DFlashAttention (dual-stream, separate RoPE offsets, cached context K/V).
Decoder Layer
crates/higgs-models/src/dflash.rs
Adds DFlashDecoderLayer combining RMSNorm, dual-stream attention, residuals, and MLP.
Drafter Forward & Cache
crates/higgs-models/src/dflash.rs
Implements DFlashDrafter::new, make_cache, and forward that projects taps, iterates layers, updates per-layer KV caches, and returns drafter hidden states.
GDN / KV Utilities
crates/higgs-models/src/dflash.rs
Adds GdnStateBackup::save/restore_and_rollback and rollback_kv_only for hybrid cache state management.
Cache Trimming Helpers
crates/higgs-models/src/dflash.rs
Adds crop_drafter_cache and trim_drafter_cache to index/truncate drafter K/V caches.
Model Loading
crates/higgs-models/src/dflash.rs
load_dflash_drafter reads config.json, constructs DFlashDrafter, and loads safetensors weights.
Token Acceptance & Tests
crates/higgs-models/src/dflash.rs
Adds accept_prefix and unit tests for matching, rejection, partial, empty, and debug assertions.
Public API Integration
crates/higgs-models/src/lib.rs
Exports dflash module, adds TapsTapeOutput, extends AnyModel with forward_with_taps, forward_with_taps_tape, embed_token_ids, forward_all_logits_from_hidden, and extends AnyCache with as_hybrid/as_hybrid_mut.
Engine Integration
crates/higgs-engine/src/model_loader.rs
Imports dflash and adds pub fn load_dflash_drafter(...) mapping model errors to EngineError::Model.

Sequence Diagram(s)

sequenceDiagram
    participant App as Application
    participant Loader as Model Loader
    participant Drafter as DFlash Drafter
    participant Cache as KV Cache
    participant Verify as Verification

    App->>Loader: load_dflash_drafter(path)
    Loader->>Loader: read config.json
    Loader->>Drafter: DFlashDrafter::new(config)
    Drafter->>Drafter: build MLP & Attention layers
    Loader->>Drafter: load weights from safetensors
    Loader-->>App: DFlashDrafter

    App->>Drafter: forward(noise, taps, cache)
    Drafter->>Drafter: normalize & project taps
    Drafter->>Drafter: apply decoder layers iteratively
    Drafter->>Cache: accumulate K/V per layer
    Drafter-->>App: draft_logits

    App->>Verify: run target model verification
    Verify-->>App: verify_argmax

    App->>App: accept_prefix(draft_tokens, verify_argmax)
    App->>Cache: crop / trim / rollback cache
    App-->>App: accepted_tokens + bonus_token
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

🐰 I hop through configs, heads, and taps,
building noise and stitching gaps,
Dual streams hum and caches grow,
tokens matched where drafts may go,
A rabbit cheers: DFlash lights the maps!

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The pull request title accurately summarizes the main change: adding the DFlash drafter foundation with module, tap APIs, and dispatch mechanisms. It is concise, clear, and directly reflects the core objective of the PR.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
crates/higgs-models/src/dflash.rs (1)

535-548: ⚡ Quick win

Use the proper ModelError variants instead of wrapping unrelated errors into ModelError::Io.

All three error sources should leverage the automatic From implementations:

  • std::fs::read_to_string error → use ? for ModelError::Io
  • serde_json::from_str error → use ? for ModelError::Json
  • DFlashDrafter::new error (returns Exception) → use ? for ModelError::Mlx

Currently, lines 538–546 incorrectly wrap serde_json::Error and Exception inside std::io::Error::other(), losing type information and conflating unrelated error types.

♻️ Refactor
 pub fn load_dflash_drafter(model_path: &Path) -> Result<DFlashDrafter, ModelError> {
     let config_path = model_path.join("config.json");
-    let config_str = std::fs::read_to_string(&config_path)
-        .map_err(|e| ModelError::Io(std::io::Error::other(format!("reading config.json: {e}"))))?;
-    let config: DFlashConfig = serde_json::from_str(&config_str)
-        .map_err(|e| ModelError::Io(std::io::Error::other(format!("parsing config.json: {e}"))))?;
+    let config_str = std::fs::read_to_string(&config_path)?;
+    let config: DFlashConfig = serde_json::from_str(&config_str)?;
 
-    let mut drafter = DFlashDrafter::new(config)
-        .map_err(|e| ModelError::Io(std::io::Error::other(e.to_string())))?;
+    let mut drafter = DFlashDrafter::new(config)?;
 
     crate::load_safetensors_weights(&mut drafter, model_path)?;
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@crates/higgs-models/src/dflash.rs` around lines 535 - 548, The function
load_dflash_drafter currently wraps all errors as ModelError::Io; instead, use
the existing From implementations and the ? operator so each error maps to the
correct variant: call std::fs::read_to_string(&config_path)? to yield
ModelError::Io, call serde_json::from_str(&config_str)? to yield
ModelError::Json, and propagate the Exception returned by
DFlashDrafter::new(config)? so it becomes ModelError::Mlx; keep
crate::load_safetensors_weights(&mut drafter, model_path)? unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@crates/higgs-models/src/dflash.rs`:
- Around line 81-86: The doc comment on DEFAULT_DECODE_BLOCK_SIZE incorrectly
claims an environment override via HIGGS_DFLASH_BLOCK_SIZE; update the comment
for the constant DEFAULT_DECODE_BLOCK_SIZE to either remove the "Overridable via
`HIGGS_DFLASH_BLOCK_SIZE`" clause or replace it with the correct behavior (e.g.,
state that this is a compile-time constant and not overridden at runtime), and
if an env-var override is intended, implement and reference the actual
runtime/config code that reads HIGGS_DFLASH_BLOCK_SIZE instead of claiming it in
the doc for DEFAULT_DECODE_BLOCK_SIZE.
- Around line 32-49: Add doc comments to the public DFlashConfig struct and each
of its public fields (hidden_size, num_hidden_layers, num_attention_heads,
num_key_value_heads, head_dim, intermediate_size, rms_norm_eps, rope_theta,
block_size, vocab_size) describing expected ranges/types/units and relationships
(e.g., that head_dim defaults via default_head_dim, rms_norm_eps and rope_theta
use serde defaults default_rms_norm_eps/default_rope_theta, block_size relates
to runtime DEFAULT_DECODE_BLOCK_SIZE and should be set accordingly, and
vocab_size must match the target model's vocabulary), and document that
dflash_config is an internal sub-config; include default behavior and any
constraints (e.g., integer sizes, compatibility between
num_attention_heads/num_key_value_heads and head_dim) so consumers understand
valid values and runtime implications for DFlashConfig.
- Around line 441-485: GdnStateBackup currently saves/restores (conv_state,
ssm_state, offset) but omits ArraysCache.conv_pos causing corrupted conv buffer
after rollback; update the struct’s states type from Vec<(Option<Array>,
Option<Array>, i32)> to Vec<(Option<Array>, Option<Array>, i32, i32)>, modify
save() (inside GdnStateBackup::save) to push ac.conv_pos into the tuple together
with conv_state, ssm_state, offset, and update restore_and_rollback() to set
ac.conv_pos = *conv_pos when matching Some(LayerCache::Arrays(ac)), ensuring
conv_state, ssm_state, conv_pos and offset are all restored.

---

Nitpick comments:
In `@crates/higgs-models/src/dflash.rs`:
- Around line 535-548: The function load_dflash_drafter currently wraps all
errors as ModelError::Io; instead, use the existing From implementations and the
? operator so each error maps to the correct variant: call
std::fs::read_to_string(&config_path)? to yield ModelError::Io, call
serde_json::from_str(&config_str)? to yield ModelError::Json, and propagate the
Exception returned by DFlashDrafter::new(config)? so it becomes ModelError::Mlx;
keep crate::load_safetensors_weights(&mut drafter, model_path)? unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: bbf0fb3d-d30e-457c-b678-f91b6ec41e26

📥 Commits

Reviewing files that changed from the base of the PR and between 3d5a136 and 1c3c0d3.

📒 Files selected for processing (4)
  • crates/higgs-engine/src/model_loader.rs
  • crates/higgs-models/src/dflash.rs
  • crates/higgs-models/src/lib.rs
  • crates/higgs-models/src/qwen3_next.rs

Comment thread crates/higgs-models/src/dflash.rs
Comment thread crates/higgs-models/src/dflash.rs Outdated
Comment thread crates/higgs-models/src/dflash.rs
dusterbloom and others added 6 commits May 20, 2026 15:17
Adds `crates/higgs-models/src/dflash.rs` from feat/magic-canvas — the
0.5B drafter that produces 16 draft tokens per round via a single
non-causal forward pass on hidden states tapped from 5 target-model
layers.

Architecture (8 decoder layers, dual-stream attention) is verbatim from
the magic-canvas baseline `c1f85ade` (final stable state, before WIP
ANE work). Wire-up into `SimpleEngine` lands in the follow-up commit.

Adaptations from feat/magic-canvas → origin/main:

  * `SteppingKeyValueCache::rollback(i32)` was renamed `trim_by(usize)`
    on origin/main (PR panbanda#143). Two call sites converted with
    `unsigned_abs().try_into().unwrap_or(usize::MAX)`.

  * Workspace clippy (nursery: `as_conversions`,
    `cast_possible_truncation`, `doc_markdown`, `assigning_clones`,
    `explicit_iter_loop`, `unnecessary_cast`, `shadow_unrelated`,
    `redundant_pattern_matching`, `missing_const_for_fn`) — all 30
    errors fixed in-place: `i32::try_from` for tensor-shape casts,
    `clone_from` for in-place clones, `filter_map(Option::as_mut)` for
    `iter().filter_map(if-let)` patterns, backticks on doc items.
    No file-level allows.

The original DFlash test suite (~3.8K lines, 30+ end-to-end tests)
depends on tap APIs (`forward_with_taps_tape`, `replay_tape_rollback`,
`forward_all_logits_from_hidden`) and `crate::diffusion::accept_prefix`
that aren't on `origin/main` yet. Tests are deferred to a follow-up PR
alongside the qwen3_next tap-API surface — there's a comment block at
the bottom of `dflash.rs` flagging this.

Verification on origin/main:
  * `cargo check -p higgs-models` — clean
  * `cargo clippy --all-targets --all-features -- -D warnings` — clean
  * `cargo fmt --check` — clean
  * `cargo test -p higgs-models --lib` — 330/330 pass
  * `cargo test -p higgs --lib -- --test-threads=1` — 449/449 pass

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds the model-side surface that the DFlash drafter speculates against —
hidden-state taps during forward, GDN innovation tape for cheap rollback,
and helpers for embedding lookup + lm_head application in isolation.

Methods added on `Qwen3NextCausalLM`:

  * `forward_with_taps` — forward returning logits AND vec of hidden
    states at specified target layers; the drafter conditions on these.

  * `forward_with_taps_stateless` — same, but does NOT mutate the
    recurrent (GDN) state. Used during verify when state advancement is
    handled separately.

  * `forward_with_taps_tape` — forward that records each GDN layer's
    innovation into a `GdnLayerTape`. Enables ~5ms replay vs ~30ms
    rerun for partial-accept rollback.

  * `replay_tape_rollback` — restore GDN state to a tape position
    without re-running the full model.

  * `embed_token_ids` — apply the embedding layer alone (drafter input).

  * `forward_all_logits_from_hidden` — apply lm_head alone (target's
    verification of drafter outputs).

  * `project_logits` (private helper) — lm_head with origin/main's
    available projection paths only (ANE + dense_lm_head fields don't
    exist here yet; ported in PR-8).

Methods added on `GatedDeltaNet`:

  * `forward_stateless` — GDN forward without state mutation.

  * `forward_with_tape` — GDN forward that captures the per-step
    innovation into the tape.

  * `replay_from_tape` — apply a tape to recompute SSM state to a target
    position. Annotated `#[allow(dead_code)]` until the engine glue
    drives it (next commit).

New public type `GdnLayerTape` exposes the per-layer innovation record.

Metal kernel infrastructure ported alongside:

  * `tape_replay_kernel_ffi` + `TAPE_REPLAY_KERNEL` static + Metal source
  * `gated_delta_kernel_ffi_with_tape` + matching kernel
  * `gated_delta_kernel_ffi_stateless` (thin wrapper over existing FFI;
    discards the new state, matches caller semantics in `forward_stateless`)

Adaptations from feat/magic-canvas → origin/main:

  * `SteppingKeyValueCache::rollback(i32)` was renamed `trim_by(usize)`
    on origin/main (PR panbanda#143). Call site in `replay_tape_rollback`
    converted with `unsigned_abs().try_into().unwrap_or(usize::MAX)`.

  * `lm_head_ane`, `dense_lm_head`, `ane_handle`, `ane_kernels` fields
    don't exist on this branch — ANE-feature paths stripped to the
    plain Metal/MLX path. Fields ported in PR-8.

  * Error handler uses `thread_local! RefCell<Option<String>>` instead
    of feat/magic-canvas's `Mutex<Option<String>>` — matches the
    branch's existing FFI error pattern.

Senior-Rust hygiene:

  * No file-level blanket allows added.
  * Function-scoped `#[allow(...)]` on the four genuine numerical
    kernel functions (`forward_stateless`, `forward_with_tape`,
    `forward_with_taps_tape`, `replay_tape_rollback`), each with a
    one-line justification comment.
  * `unwrap_used` never allowed — refactored to `?` propagation or
    `expect("reason")` at the two call sites.
  * Mechanical clippy refactors throughout: `find_map` for
    `filter_map(..).next()`, `clone_from` for `assigning_clones`,
    `if let` for single-pattern `match`, backticks for `doc_markdown`.

Verification on origin/main:
  * `cargo check -p higgs-models` — clean
  * `cargo clippy --all-targets --all-features -- -D warnings` — clean
  * `cargo fmt --check` — clean
  * `cargo test -p higgs-models --lib` — 330/330 pass

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ader

Surfaces the qwen3_next tap APIs through the polymorphic `AnyModel`
enum so engine code can call them without matching variants directly,
adds the greedy speculative-decode acceptance helper, and exposes a
`load_dflash_drafter` entry point on the engine's model_loader.

`AnyModel` (in `higgs-models/src/lib.rs`):

  * `forward_with_taps` — dispatches Qwen3Next + Hybrid; errors otherwise
  * `forward_with_taps_tape` — same, returns `TapsTapeOutput` (logits +
    tap hiddens + per-layer GDN tape) via a public type alias to
    placate `clippy::type_complexity`
  * `embed_token_ids` — Qwen3Next-only
  * `forward_all_logits_from_hidden` — Qwen3Next-only

  All non-Qwen3Next arms enumerate every variant explicitly to satisfy
  `clippy::wildcard_enum_match_arm` (no `_ =>` catch-alls).

`AnyCache`:

  * `as_hybrid` / `as_hybrid_mut` — borrow the inner hybrid layer-cache
    slice/vec for engine glue that needs to inspect GDN state. Returns
    `Result<_, Exception>` rather than panicking when called on a `KV`
    cache, so the verify path in `SimpleEngine::generate_dflash_inner`
    can propagate via `?`.

`dflash::accept_prefix`:

  * Greedy speculative-decode acceptance: longest-matching prefix of
    `draft` against `verify_argmax`, plus one bonus token at the
    diverge point (or after the last accept).
  * 5 unit tests covering full match, first-token reject, partial
    match, empty draft, and the debug-only length assertion.
  * Inlined here rather than ported from `feat/magic-canvas:diffusion.rs`
    to avoid pulling in the 9970-line diffusion module for a 16-line
    helper.

`engine::model_loader::load_dflash_drafter`:

  * Thin `Result` adapter over `higgs_models::dflash::load_dflash_drafter`,
    converting `ModelError` → `EngineError`. The `SimpleEngine::load_with_dflash`
    call site lands in the next commit.

Verification on origin/main:
  * `cargo clippy --all-targets --all-features -- -D warnings` — clean
  * `cargo fmt --check` — clean
  * `cargo test -p higgs-models --lib` — 335/335 pass (5 new accept_prefix tests)
  * `cargo test -p higgs-engine --lib` — 228/228 pass
  * `cargo test -p higgs --lib -- --test-threads=1` — 449/449 pass

The remaining piece — `SimpleEngine::generate_dflash_inner` (the
draft-verify loop wired into `generate_inner`) — lands as a follow-up
commit. It needs end-to-end verification against a real DFlash drafter
checkpoint (Carnice-9B + 0.5B drafter); shipping it without that
runtime test would risk silent correctness regressions in the verify
path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…_match

CI's clippy is one minor version ahead of my local toolchain and flags
the `if rollback > 0 { ... }` body inside the `Some(LayerCache::KV(kv))`
match arm. Two call sites:

  * `dflash.rs` — `GdnStateBackup::restore_and_rollback`
  * `qwen3_next.rs` — `Qwen3NextCausalLM::replay_tape_rollback`

Convert to a match guard and add an explicit no-op arm for the
guard-fails-and-`None` case so the match is exhaustive without a
wildcard. No behaviour change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- GdnStateBackup now saves and restores ArraysCache::conv_pos alongside
  conv_state/ssm_state/offset. Without it, conv buffer state was corrupted
  after a verify-rollback.
- load_dflash_drafter uses ? directly so std::io::Error → ModelError::Io,
  serde_json::Error → ModelError::Json, Exception → ModelError::Mlx via
  the existing From impls (was wrapping all three as ModelError::Io).
- DEFAULT_DECODE_BLOCK_SIZE doc no longer claims a HIGGS_DFLASH_BLOCK_SIZE
  env override that was never implemented.
- DFlashConfig and its public fields now have brief one-line doc comments.

Validation: cargo fmt clean, cargo clippy clean (higgs-models, higgs-engine,
higgs). Full test run skipped (disk space too tight to relink test binary);
no existing tests exercise GdnStateBackup, conv_pos in paged_prefix_cache.rs
is a separate snapshot system unaffected by this change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
After rebasing PR-6a onto origin/main, the new `AnyModel::BonsaiQ1`
variant landed via panbanda#142 (merged) makes the explicit-enumeration arms
in `embed_token_ids` and `forward_all_logits_from_hidden` non-exhaustive
(E0004 x2).

Add `Self::BonsaiQ1(_)` to the existing union of non-Qwen3Next variants
that return the same "only implemented for Qwen3Next" Err — preserving
the DRY pattern (one error message per dispatcher across all 8 non-tap
variants) and the exhaustive-enumeration invariant (no `_ =>` catch-all,
so future variants will trip the same compiler check).

No runtime behaviour change: `BonsaiQ1` already cannot reach a DFlash
spec-decode flow (no tap API plumbing exists for it), so this just
formalises the rejection at the dispatcher boundary with a clear error.

Tests:
- cargo check -p higgs-models: green (was 2x E0004)
- cargo clippy -p {higgs-models, higgs-engine, higgs} --tests -- -Dwarnings: clean
- cargo test: 1030/1030 across three crates (no regression)
@dusterbloom dusterbloom force-pushed the dusterbloom/dflash-baseline branch from 42051b6 to 49c4bc4 Compare May 20, 2026 15:42

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@crates/higgs-models/src/dflash.rs`:
- Around line 356-365: Add a doc comment for the public field `config` on the
`DFlashDrafter` struct: describe what the `DFlashConfig` controls, any important
invariants or defaults consumers should know, and how it affects behavior of the
drafter (e.g., model sizing, tokenization, or runtime options). Update the `pub
config: DFlashConfig` field with a concise /// doc comment that mentions that
this is exported and links to `DFlashConfig` for full details.
- Around line 530-536: The docstring for trim_drafter_cache is incorrect
relative to how DFlashAttention::forward (which only writes context K/V into
cache at the positions in its block, see the forward write logic) populates the
cache; update the trim_drafter_cache documentation and contract so it clearly
states that the function removes the last n cached context K/V entries (i.e.,
trims recent context positions from the end), not "noise K/V" from the front,
and where appropriate adjust any caller expectations (or add a short assertion
in trim_drafter_cache that n is <= number of populated entries) to prevent
accidental deletion of verified context.
- Around line 407-439: In forward, before zipping self.layers with cache,
validate that cache.len() equals self.layers.len() (or at least the expected
length for cached entries) and return an Err(Exception::custom(...)) if they
differ; this prevents silently skipping layers or leaving stale cache entries
when using the zip in the loop over
self.layers.iter_mut().zip(cache.iter_mut()), so add that length check using the
unique symbols forward, self.layers, cache and cache_offset and fail fast with a
clear error message when lengths mismatch.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 96de88e6-5f57-4132-bf13-4dd37aad15d5

📥 Commits

Reviewing files that changed from the base of the PR and between 42051b6 and 49c4bc4.

📒 Files selected for processing (4)
  • crates/higgs-engine/src/model_loader.rs
  • crates/higgs-models/src/dflash.rs
  • crates/higgs-models/src/lib.rs
  • crates/higgs-models/src/qwen3_next.rs

Comment on lines +356 to +365
pub struct DFlashDrafter {
#[param]
fc: nn::Linear,
#[param]
hidden_norm: nn::RmsNorm,
#[param]
layers: Vec<DFlashDecoderLayer>,
#[param]
norm: nn::RmsNorm,
pub config: DFlashConfig,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win

Document the public config field.

DFlashDrafter has a struct-level doc comment, but its public config field is still undocumented even though it is part of the exported API.

As per coding guidelines, "Add doc comments on public structs/fields in Rust when changing user-facing behavior".

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@crates/higgs-models/src/dflash.rs` around lines 356 - 365, Add a doc comment
for the public field `config` on the `DFlashDrafter` struct: describe what the
`DFlashConfig` controls, any important invariants or defaults consumers should
know, and how it affects behavior of the drafter (e.g., model sizing,
tokenization, or runtime options). Update the `pub config: DFlashConfig` field
with a concise /// doc comment that mentions that this is exported and links to
`DFlashConfig` for full details.

Comment on lines +407 to +439
pub fn forward(
&mut self,
noise: &Array,
taps: &[Array],
cache: &mut [Option<(Array, Array)>],
) -> Result<Array, Exception> {
if taps.len() != self.config.num_taps() {
return Err(Exception::custom(format!(
"expected {} taps, got {}",
self.config.num_taps(),
taps.len()
)));
}

// Cache offset = current cached seq length (0 on first round)
let cache_offset = cache
.first()
.and_then(|c| c.as_ref())
.and_then(|(k, _)| k.shape().get(2).copied())
.unwrap_or(0);

// Concatenate tap hidden states: [B, T, num_taps * hidden_size]
let tap_refs: Vec<&Array> = taps.iter().collect();
let target_cat = ops::concatenate_axis(&tap_refs, -1)?;

// Project + norm: [B, T, hidden_size]
let target_projected = self.fc.forward(&target_cat)?;
let target_hidden = self.hidden_norm.forward(&target_projected)?;

let mut h = noise.clone();
for (layer, lc) in self.layers.iter_mut().zip(cache.iter_mut()) {
h = layer.forward(&h, &target_hidden, lc, cache_offset)?;
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Reject mismatched cache lengths before zipping layers.

Line 437 uses zip, so a short cache silently skips trailing decoder layers and a long cache leaves stale entries behind. That returns wrong hidden states without surfacing any error from this public API.

Suggested fix
     if taps.len() != self.config.num_taps() {
         return Err(Exception::custom(format!(
             "expected {} taps, got {}",
             self.config.num_taps(),
             taps.len()
         )));
     }
+    if cache.len() != self.layers.len() {
+        return Err(Exception::custom(format!(
+            "expected {} cache slots, got {}",
+            self.layers.len(),
+            cache.len()
+        )));
+    }
 
     // Cache offset = current cached seq length (0 on first round)
     let cache_offset = cache
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@crates/higgs-models/src/dflash.rs` around lines 407 - 439, In forward, before
zipping self.layers with cache, validate that cache.len() equals
self.layers.len() (or at least the expected length for cached entries) and
return an Err(Exception::custom(...)) if they differ; this prevents silently
skipping layers or leaving stale cache entries when using the zip in the loop
over self.layers.iter_mut().zip(cache.iter_mut()), so add that length check
using the unique symbols forward, self.layers, cache and cache_offset and fail
fast with a clear error message when lengths mismatch.

Comment on lines +530 to +536
/// Trim `n` entries from the END of the drafter KV cache.
///
/// Reference: `trim_draft_cache(draft_cache, block_size)` in `dflash-mlx`.
/// After each draft forward, the cache has `prev + ctx_len + block_size` entries.
/// Trimming `block_size` removes the noise K/V while keeping the accumulated
/// target context K/V that conditions future rounds.
pub fn trim_drafter_cache(cache: &mut [Option<(Array, Array)>], n: i32) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win

trim_drafter_cache's contract no longer matches the cache contents.

DFlashAttention::forward only persists context K/V into cache on Lines 267-276, so this helper trims the last n cached context positions. The current doc says it removes "noise K/V", which is the opposite behavior and is likely to steer the follow-up engine loop into deleting verified context.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@crates/higgs-models/src/dflash.rs` around lines 530 - 536, The docstring for
trim_drafter_cache is incorrect relative to how DFlashAttention::forward (which
only writes context K/V into cache at the positions in its block, see the
forward write logic) populates the cache; update the trim_drafter_cache
documentation and contract so it clearly states that the function removes the
last n cached context K/V entries (i.e., trims recent context positions from the
end), not "noise K/V" from the front, and where appropriate adjust any caller
expectations (or add a short assertion in trim_drafter_cache that n is <= number
of populated entries) to prevent accidental deletion of verified context.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant