feat(dflash): DFlash drafter foundation — module + tap APIs + dispatch#147
feat(dflash): DFlash drafter foundation — module + tap APIs + dispatch#147dusterbloom wants to merge 6 commits into
Conversation
📝 WalkthroughWalkthroughAdds a DFlash speculative-decoding drafter: configuration, MLP/dual‑stream attention, decoder layers, drafter forward with per‑layer KV caches, GDN/KV backup and trimming utilities, model loader, accept_prefix helper with tests, and public AnyModel/AnyCache APIs plus engine loader wiring. ChangesDFlash Speculative Decoding Implementation
Sequence Diagram(s)sequenceDiagram
participant App as Application
participant Loader as Model Loader
participant Drafter as DFlash Drafter
participant Cache as KV Cache
participant Verify as Verification
App->>Loader: load_dflash_drafter(path)
Loader->>Loader: read config.json
Loader->>Drafter: DFlashDrafter::new(config)
Drafter->>Drafter: build MLP & Attention layers
Loader->>Drafter: load weights from safetensors
Loader-->>App: DFlashDrafter
App->>Drafter: forward(noise, taps, cache)
Drafter->>Drafter: normalize & project taps
Drafter->>Drafter: apply decoder layers iteratively
Drafter->>Cache: accumulate K/V per layer
Drafter-->>App: draft_logits
App->>Verify: run target model verification
Verify-->>App: verify_argmax
App->>App: accept_prefix(draft_tokens, verify_argmax)
App->>Cache: crop / trim / rollback cache
App-->>App: accepted_tokens + bonus_token
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
crates/higgs-models/src/dflash.rs (1)
535-548: ⚡ Quick winUse the proper
ModelErrorvariants instead of wrapping unrelated errors intoModelError::Io.All three error sources should leverage the automatic
Fromimplementations:
std::fs::read_to_stringerror → use?forModelError::Ioserde_json::from_strerror → use?forModelError::JsonDFlashDrafter::newerror (returnsException) → use?forModelError::MlxCurrently, lines 538–546 incorrectly wrap
serde_json::ErrorandExceptioninsidestd::io::Error::other(), losing type information and conflating unrelated error types.♻️ Refactor
pub fn load_dflash_drafter(model_path: &Path) -> Result<DFlashDrafter, ModelError> { let config_path = model_path.join("config.json"); - let config_str = std::fs::read_to_string(&config_path) - .map_err(|e| ModelError::Io(std::io::Error::other(format!("reading config.json: {e}"))))?; - let config: DFlashConfig = serde_json::from_str(&config_str) - .map_err(|e| ModelError::Io(std::io::Error::other(format!("parsing config.json: {e}"))))?; + let config_str = std::fs::read_to_string(&config_path)?; + let config: DFlashConfig = serde_json::from_str(&config_str)?; - let mut drafter = DFlashDrafter::new(config) - .map_err(|e| ModelError::Io(std::io::Error::other(e.to_string())))?; + let mut drafter = DFlashDrafter::new(config)?; crate::load_safetensors_weights(&mut drafter, model_path)?;🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@crates/higgs-models/src/dflash.rs` around lines 535 - 548, The function load_dflash_drafter currently wraps all errors as ModelError::Io; instead, use the existing From implementations and the ? operator so each error maps to the correct variant: call std::fs::read_to_string(&config_path)? to yield ModelError::Io, call serde_json::from_str(&config_str)? to yield ModelError::Json, and propagate the Exception returned by DFlashDrafter::new(config)? so it becomes ModelError::Mlx; keep crate::load_safetensors_weights(&mut drafter, model_path)? unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@crates/higgs-models/src/dflash.rs`:
- Around line 81-86: The doc comment on DEFAULT_DECODE_BLOCK_SIZE incorrectly
claims an environment override via HIGGS_DFLASH_BLOCK_SIZE; update the comment
for the constant DEFAULT_DECODE_BLOCK_SIZE to either remove the "Overridable via
`HIGGS_DFLASH_BLOCK_SIZE`" clause or replace it with the correct behavior (e.g.,
state that this is a compile-time constant and not overridden at runtime), and
if an env-var override is intended, implement and reference the actual
runtime/config code that reads HIGGS_DFLASH_BLOCK_SIZE instead of claiming it in
the doc for DEFAULT_DECODE_BLOCK_SIZE.
- Around line 32-49: Add doc comments to the public DFlashConfig struct and each
of its public fields (hidden_size, num_hidden_layers, num_attention_heads,
num_key_value_heads, head_dim, intermediate_size, rms_norm_eps, rope_theta,
block_size, vocab_size) describing expected ranges/types/units and relationships
(e.g., that head_dim defaults via default_head_dim, rms_norm_eps and rope_theta
use serde defaults default_rms_norm_eps/default_rope_theta, block_size relates
to runtime DEFAULT_DECODE_BLOCK_SIZE and should be set accordingly, and
vocab_size must match the target model's vocabulary), and document that
dflash_config is an internal sub-config; include default behavior and any
constraints (e.g., integer sizes, compatibility between
num_attention_heads/num_key_value_heads and head_dim) so consumers understand
valid values and runtime implications for DFlashConfig.
- Around line 441-485: GdnStateBackup currently saves/restores (conv_state,
ssm_state, offset) but omits ArraysCache.conv_pos causing corrupted conv buffer
after rollback; update the struct’s states type from Vec<(Option<Array>,
Option<Array>, i32)> to Vec<(Option<Array>, Option<Array>, i32, i32)>, modify
save() (inside GdnStateBackup::save) to push ac.conv_pos into the tuple together
with conv_state, ssm_state, offset, and update restore_and_rollback() to set
ac.conv_pos = *conv_pos when matching Some(LayerCache::Arrays(ac)), ensuring
conv_state, ssm_state, conv_pos and offset are all restored.
---
Nitpick comments:
In `@crates/higgs-models/src/dflash.rs`:
- Around line 535-548: The function load_dflash_drafter currently wraps all
errors as ModelError::Io; instead, use the existing From implementations and the
? operator so each error maps to the correct variant: call
std::fs::read_to_string(&config_path)? to yield ModelError::Io, call
serde_json::from_str(&config_str)? to yield ModelError::Json, and propagate the
Exception returned by DFlashDrafter::new(config)? so it becomes ModelError::Mlx;
keep crate::load_safetensors_weights(&mut drafter, model_path)? unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: bbf0fb3d-d30e-457c-b678-f91b6ec41e26
📒 Files selected for processing (4)
crates/higgs-engine/src/model_loader.rscrates/higgs-models/src/dflash.rscrates/higgs-models/src/lib.rscrates/higgs-models/src/qwen3_next.rs
Adds `crates/higgs-models/src/dflash.rs` from feat/magic-canvas — the
0.5B drafter that produces 16 draft tokens per round via a single
non-causal forward pass on hidden states tapped from 5 target-model
layers.
Architecture (8 decoder layers, dual-stream attention) is verbatim from
the magic-canvas baseline `c1f85ade` (final stable state, before WIP
ANE work). Wire-up into `SimpleEngine` lands in the follow-up commit.
Adaptations from feat/magic-canvas → origin/main:
* `SteppingKeyValueCache::rollback(i32)` was renamed `trim_by(usize)`
on origin/main (PR panbanda#143). Two call sites converted with
`unsigned_abs().try_into().unwrap_or(usize::MAX)`.
* Workspace clippy (nursery: `as_conversions`,
`cast_possible_truncation`, `doc_markdown`, `assigning_clones`,
`explicit_iter_loop`, `unnecessary_cast`, `shadow_unrelated`,
`redundant_pattern_matching`, `missing_const_for_fn`) — all 30
errors fixed in-place: `i32::try_from` for tensor-shape casts,
`clone_from` for in-place clones, `filter_map(Option::as_mut)` for
`iter().filter_map(if-let)` patterns, backticks on doc items.
No file-level allows.
The original DFlash test suite (~3.8K lines, 30+ end-to-end tests)
depends on tap APIs (`forward_with_taps_tape`, `replay_tape_rollback`,
`forward_all_logits_from_hidden`) and `crate::diffusion::accept_prefix`
that aren't on `origin/main` yet. Tests are deferred to a follow-up PR
alongside the qwen3_next tap-API surface — there's a comment block at
the bottom of `dflash.rs` flagging this.
Verification on origin/main:
* `cargo check -p higgs-models` — clean
* `cargo clippy --all-targets --all-features -- -D warnings` — clean
* `cargo fmt --check` — clean
* `cargo test -p higgs-models --lib` — 330/330 pass
* `cargo test -p higgs --lib -- --test-threads=1` — 449/449 pass
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds the model-side surface that the DFlash drafter speculates against —
hidden-state taps during forward, GDN innovation tape for cheap rollback,
and helpers for embedding lookup + lm_head application in isolation.
Methods added on `Qwen3NextCausalLM`:
* `forward_with_taps` — forward returning logits AND vec of hidden
states at specified target layers; the drafter conditions on these.
* `forward_with_taps_stateless` — same, but does NOT mutate the
recurrent (GDN) state. Used during verify when state advancement is
handled separately.
* `forward_with_taps_tape` — forward that records each GDN layer's
innovation into a `GdnLayerTape`. Enables ~5ms replay vs ~30ms
rerun for partial-accept rollback.
* `replay_tape_rollback` — restore GDN state to a tape position
without re-running the full model.
* `embed_token_ids` — apply the embedding layer alone (drafter input).
* `forward_all_logits_from_hidden` — apply lm_head alone (target's
verification of drafter outputs).
* `project_logits` (private helper) — lm_head with origin/main's
available projection paths only (ANE + dense_lm_head fields don't
exist here yet; ported in PR-8).
Methods added on `GatedDeltaNet`:
* `forward_stateless` — GDN forward without state mutation.
* `forward_with_tape` — GDN forward that captures the per-step
innovation into the tape.
* `replay_from_tape` — apply a tape to recompute SSM state to a target
position. Annotated `#[allow(dead_code)]` until the engine glue
drives it (next commit).
New public type `GdnLayerTape` exposes the per-layer innovation record.
Metal kernel infrastructure ported alongside:
* `tape_replay_kernel_ffi` + `TAPE_REPLAY_KERNEL` static + Metal source
* `gated_delta_kernel_ffi_with_tape` + matching kernel
* `gated_delta_kernel_ffi_stateless` (thin wrapper over existing FFI;
discards the new state, matches caller semantics in `forward_stateless`)
Adaptations from feat/magic-canvas → origin/main:
* `SteppingKeyValueCache::rollback(i32)` was renamed `trim_by(usize)`
on origin/main (PR panbanda#143). Call site in `replay_tape_rollback`
converted with `unsigned_abs().try_into().unwrap_or(usize::MAX)`.
* `lm_head_ane`, `dense_lm_head`, `ane_handle`, `ane_kernels` fields
don't exist on this branch — ANE-feature paths stripped to the
plain Metal/MLX path. Fields ported in PR-8.
* Error handler uses `thread_local! RefCell<Option<String>>` instead
of feat/magic-canvas's `Mutex<Option<String>>` — matches the
branch's existing FFI error pattern.
Senior-Rust hygiene:
* No file-level blanket allows added.
* Function-scoped `#[allow(...)]` on the four genuine numerical
kernel functions (`forward_stateless`, `forward_with_tape`,
`forward_with_taps_tape`, `replay_tape_rollback`), each with a
one-line justification comment.
* `unwrap_used` never allowed — refactored to `?` propagation or
`expect("reason")` at the two call sites.
* Mechanical clippy refactors throughout: `find_map` for
`filter_map(..).next()`, `clone_from` for `assigning_clones`,
`if let` for single-pattern `match`, backticks for `doc_markdown`.
Verification on origin/main:
* `cargo check -p higgs-models` — clean
* `cargo clippy --all-targets --all-features -- -D warnings` — clean
* `cargo fmt --check` — clean
* `cargo test -p higgs-models --lib` — 330/330 pass
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ader
Surfaces the qwen3_next tap APIs through the polymorphic `AnyModel`
enum so engine code can call them without matching variants directly,
adds the greedy speculative-decode acceptance helper, and exposes a
`load_dflash_drafter` entry point on the engine's model_loader.
`AnyModel` (in `higgs-models/src/lib.rs`):
* `forward_with_taps` — dispatches Qwen3Next + Hybrid; errors otherwise
* `forward_with_taps_tape` — same, returns `TapsTapeOutput` (logits +
tap hiddens + per-layer GDN tape) via a public type alias to
placate `clippy::type_complexity`
* `embed_token_ids` — Qwen3Next-only
* `forward_all_logits_from_hidden` — Qwen3Next-only
All non-Qwen3Next arms enumerate every variant explicitly to satisfy
`clippy::wildcard_enum_match_arm` (no `_ =>` catch-alls).
`AnyCache`:
* `as_hybrid` / `as_hybrid_mut` — borrow the inner hybrid layer-cache
slice/vec for engine glue that needs to inspect GDN state. Returns
`Result<_, Exception>` rather than panicking when called on a `KV`
cache, so the verify path in `SimpleEngine::generate_dflash_inner`
can propagate via `?`.
`dflash::accept_prefix`:
* Greedy speculative-decode acceptance: longest-matching prefix of
`draft` against `verify_argmax`, plus one bonus token at the
diverge point (or after the last accept).
* 5 unit tests covering full match, first-token reject, partial
match, empty draft, and the debug-only length assertion.
* Inlined here rather than ported from `feat/magic-canvas:diffusion.rs`
to avoid pulling in the 9970-line diffusion module for a 16-line
helper.
`engine::model_loader::load_dflash_drafter`:
* Thin `Result` adapter over `higgs_models::dflash::load_dflash_drafter`,
converting `ModelError` → `EngineError`. The `SimpleEngine::load_with_dflash`
call site lands in the next commit.
Verification on origin/main:
* `cargo clippy --all-targets --all-features -- -D warnings` — clean
* `cargo fmt --check` — clean
* `cargo test -p higgs-models --lib` — 335/335 pass (5 new accept_prefix tests)
* `cargo test -p higgs-engine --lib` — 228/228 pass
* `cargo test -p higgs --lib -- --test-threads=1` — 449/449 pass
The remaining piece — `SimpleEngine::generate_dflash_inner` (the
draft-verify loop wired into `generate_inner`) — lands as a follow-up
commit. It needs end-to-end verification against a real DFlash drafter
checkpoint (Carnice-9B + 0.5B drafter); shipping it without that
runtime test would risk silent correctness regressions in the verify
path.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…_match
CI's clippy is one minor version ahead of my local toolchain and flags
the `if rollback > 0 { ... }` body inside the `Some(LayerCache::KV(kv))`
match arm. Two call sites:
* `dflash.rs` — `GdnStateBackup::restore_and_rollback`
* `qwen3_next.rs` — `Qwen3NextCausalLM::replay_tape_rollback`
Convert to a match guard and add an explicit no-op arm for the
guard-fails-and-`None` case so the match is exhaustive without a
wildcard. No behaviour change.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- GdnStateBackup now saves and restores ArraysCache::conv_pos alongside conv_state/ssm_state/offset. Without it, conv buffer state was corrupted after a verify-rollback. - load_dflash_drafter uses ? directly so std::io::Error → ModelError::Io, serde_json::Error → ModelError::Json, Exception → ModelError::Mlx via the existing From impls (was wrapping all three as ModelError::Io). - DEFAULT_DECODE_BLOCK_SIZE doc no longer claims a HIGGS_DFLASH_BLOCK_SIZE env override that was never implemented. - DFlashConfig and its public fields now have brief one-line doc comments. Validation: cargo fmt clean, cargo clippy clean (higgs-models, higgs-engine, higgs). Full test run skipped (disk space too tight to relink test binary); no existing tests exercise GdnStateBackup, conv_pos in paged_prefix_cache.rs is a separate snapshot system unaffected by this change. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
After rebasing PR-6a onto origin/main, the new `AnyModel::BonsaiQ1` variant landed via panbanda#142 (merged) makes the explicit-enumeration arms in `embed_token_ids` and `forward_all_logits_from_hidden` non-exhaustive (E0004 x2). Add `Self::BonsaiQ1(_)` to the existing union of non-Qwen3Next variants that return the same "only implemented for Qwen3Next" Err — preserving the DRY pattern (one error message per dispatcher across all 8 non-tap variants) and the exhaustive-enumeration invariant (no `_ =>` catch-all, so future variants will trip the same compiler check). No runtime behaviour change: `BonsaiQ1` already cannot reach a DFlash spec-decode flow (no tap API plumbing exists for it), so this just formalises the rejection at the dispatcher boundary with a clear error. Tests: - cargo check -p higgs-models: green (was 2x E0004) - cargo clippy -p {higgs-models, higgs-engine, higgs} --tests -- -Dwarnings: clean - cargo test: 1030/1030 across three crates (no regression)
42051b6 to
49c4bc4
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@crates/higgs-models/src/dflash.rs`:
- Around line 356-365: Add a doc comment for the public field `config` on the
`DFlashDrafter` struct: describe what the `DFlashConfig` controls, any important
invariants or defaults consumers should know, and how it affects behavior of the
drafter (e.g., model sizing, tokenization, or runtime options). Update the `pub
config: DFlashConfig` field with a concise /// doc comment that mentions that
this is exported and links to `DFlashConfig` for full details.
- Around line 530-536: The docstring for trim_drafter_cache is incorrect
relative to how DFlashAttention::forward (which only writes context K/V into
cache at the positions in its block, see the forward write logic) populates the
cache; update the trim_drafter_cache documentation and contract so it clearly
states that the function removes the last n cached context K/V entries (i.e.,
trims recent context positions from the end), not "noise K/V" from the front,
and where appropriate adjust any caller expectations (or add a short assertion
in trim_drafter_cache that n is <= number of populated entries) to prevent
accidental deletion of verified context.
- Around line 407-439: In forward, before zipping self.layers with cache,
validate that cache.len() equals self.layers.len() (or at least the expected
length for cached entries) and return an Err(Exception::custom(...)) if they
differ; this prevents silently skipping layers or leaving stale cache entries
when using the zip in the loop over
self.layers.iter_mut().zip(cache.iter_mut()), so add that length check using the
unique symbols forward, self.layers, cache and cache_offset and fail fast with a
clear error message when lengths mismatch.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 96de88e6-5f57-4132-bf13-4dd37aad15d5
📒 Files selected for processing (4)
crates/higgs-engine/src/model_loader.rscrates/higgs-models/src/dflash.rscrates/higgs-models/src/lib.rscrates/higgs-models/src/qwen3_next.rs
| pub struct DFlashDrafter { | ||
| #[param] | ||
| fc: nn::Linear, | ||
| #[param] | ||
| hidden_norm: nn::RmsNorm, | ||
| #[param] | ||
| layers: Vec<DFlashDecoderLayer>, | ||
| #[param] | ||
| norm: nn::RmsNorm, | ||
| pub config: DFlashConfig, |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win
Document the public config field.
DFlashDrafter has a struct-level doc comment, but its public config field is still undocumented even though it is part of the exported API.
As per coding guidelines, "Add doc comments on public structs/fields in Rust when changing user-facing behavior".
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@crates/higgs-models/src/dflash.rs` around lines 356 - 365, Add a doc comment
for the public field `config` on the `DFlashDrafter` struct: describe what the
`DFlashConfig` controls, any important invariants or defaults consumers should
know, and how it affects behavior of the drafter (e.g., model sizing,
tokenization, or runtime options). Update the `pub config: DFlashConfig` field
with a concise /// doc comment that mentions that this is exported and links to
`DFlashConfig` for full details.
| pub fn forward( | ||
| &mut self, | ||
| noise: &Array, | ||
| taps: &[Array], | ||
| cache: &mut [Option<(Array, Array)>], | ||
| ) -> Result<Array, Exception> { | ||
| if taps.len() != self.config.num_taps() { | ||
| return Err(Exception::custom(format!( | ||
| "expected {} taps, got {}", | ||
| self.config.num_taps(), | ||
| taps.len() | ||
| ))); | ||
| } | ||
|
|
||
| // Cache offset = current cached seq length (0 on first round) | ||
| let cache_offset = cache | ||
| .first() | ||
| .and_then(|c| c.as_ref()) | ||
| .and_then(|(k, _)| k.shape().get(2).copied()) | ||
| .unwrap_or(0); | ||
|
|
||
| // Concatenate tap hidden states: [B, T, num_taps * hidden_size] | ||
| let tap_refs: Vec<&Array> = taps.iter().collect(); | ||
| let target_cat = ops::concatenate_axis(&tap_refs, -1)?; | ||
|
|
||
| // Project + norm: [B, T, hidden_size] | ||
| let target_projected = self.fc.forward(&target_cat)?; | ||
| let target_hidden = self.hidden_norm.forward(&target_projected)?; | ||
|
|
||
| let mut h = noise.clone(); | ||
| for (layer, lc) in self.layers.iter_mut().zip(cache.iter_mut()) { | ||
| h = layer.forward(&h, &target_hidden, lc, cache_offset)?; | ||
| } |
There was a problem hiding this comment.
Reject mismatched cache lengths before zipping layers.
Line 437 uses zip, so a short cache silently skips trailing decoder layers and a long cache leaves stale entries behind. That returns wrong hidden states without surfacing any error from this public API.
Suggested fix
if taps.len() != self.config.num_taps() {
return Err(Exception::custom(format!(
"expected {} taps, got {}",
self.config.num_taps(),
taps.len()
)));
}
+ if cache.len() != self.layers.len() {
+ return Err(Exception::custom(format!(
+ "expected {} cache slots, got {}",
+ self.layers.len(),
+ cache.len()
+ )));
+ }
// Cache offset = current cached seq length (0 on first round)
let cache_offset = cache🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@crates/higgs-models/src/dflash.rs` around lines 407 - 439, In forward, before
zipping self.layers with cache, validate that cache.len() equals
self.layers.len() (or at least the expected length for cached entries) and
return an Err(Exception::custom(...)) if they differ; this prevents silently
skipping layers or leaving stale cache entries when using the zip in the loop
over self.layers.iter_mut().zip(cache.iter_mut()), so add that length check
using the unique symbols forward, self.layers, cache and cache_offset and fail
fast with a clear error message when lengths mismatch.
| /// Trim `n` entries from the END of the drafter KV cache. | ||
| /// | ||
| /// Reference: `trim_draft_cache(draft_cache, block_size)` in `dflash-mlx`. | ||
| /// After each draft forward, the cache has `prev + ctx_len + block_size` entries. | ||
| /// Trimming `block_size` removes the noise K/V while keeping the accumulated | ||
| /// target context K/V that conditions future rounds. | ||
| pub fn trim_drafter_cache(cache: &mut [Option<(Array, Array)>], n: i32) { |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win
trim_drafter_cache's contract no longer matches the cache contents.
DFlashAttention::forward only persists context K/V into cache on Lines 267-276, so this helper trims the last n cached context positions. The current doc says it removes "noise K/V", which is the opposite behavior and is likely to steer the follow-up engine loop into deleting verified context.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@crates/higgs-models/src/dflash.rs` around lines 530 - 536, The docstring for
trim_drafter_cache is incorrect relative to how DFlashAttention::forward (which
only writes context K/V into cache at the positions in its block, see the
forward write logic) populates the cache; update the trim_drafter_cache
documentation and contract so it clearly states that the function removes the
last n cached context K/V entries (i.e., trims recent context positions from the
end), not "noise K/V" from the front, and where appropriate adjust any caller
expectations (or add a short assertion in trim_drafter_cache that n is <= number
of populated entries) to prevent accidental deletion of verified context.
Summary
Lands the DFlash block-diffusion speculative-decoding foundation. PR-6a of the
feat/magic-canvassplit — the model-side surface needed for DFlash; the engine-level draft-verify loop is deferred to PR-6b because it needs runtime verification against a real drafter checkpoint.What's in this PR (3 commits, ~2.4K net lines)
ad7edea162a73522Qwen3NextCausalLM(forward_with_taps,forward_with_taps_tape,replay_tape_rollback,embed_token_ids,forward_all_logits_from_hidden,project_logits);GatedDeltaNet::{forward_stateless, forward_with_tape, replay_from_tape};pub struct GdnLayerTape; Metal kernel FFI for tape replay1c3c0d37AnyModel/AnyCachepolymorphic dispatchers;dflash::accept_prefix(greedy spec-decode acceptance + 5 unit tests);engine::model_loader::load_dflash_drafterWhat's deferred to follow-ups
SimpleEngine::generate_dflash_inner(draft-verify loop) — needs Carnice-9B + 0.5B drafter checkpoint for runtime verificationdflash_cpu.rs(CPU BLAS drafter) — pulls in BLAS helpers fromdiffusion.rs; lands with engine glue in PR-6bdflash_ane.rs(ANE-accelerated drafter) — feature-gated; PR-8 territoryAdaptations to origin/main
SteppingKeyValueCache::rollback(i32)→trim_by(usize)(renamed in feat(cache): AnyCache::trim_by dispatcher for spec-decode rollback #143)project_logits/forward_with_tapestripped —lm_head_ane/dense_lm_head/ane_handle/ane_kernelsfields don't exist yet on main; ported in PR-8thread_local!matching the existing local pattern (vsfeat/magic-canvas'sMutex<Option<String>>)Hygiene
#![allow(clippy::items_after_test_module)]preserved)#[allow(...)]only on the four genuinely-numerical Metal-kernel functions, each with a one-line justificationunwrap_usedrefactored to?propagation throughoutpub type TapsTapeOutputresolvesclippy::type_complexityTest plan
cargo check -p higgs-models— cleancargo clippy --all-targets --all-features -- -D warnings— cleancargo fmt --check— cleancargo test -p higgs-models --lib— 335/335 pass (5 newaccept_prefixtests)cargo test -p higgs-engine --lib— 228/228 passcargo test -p higgs --lib -- --test-threads=1— 449/449 passContext
Part of the
feat/magic-canvasPR split. Prior PRs in the chain: #141 (fused MoE), #142 (Bonsai-Q1 fp16 +2.83×), #143 (AnyCache::trim_by), #144 (DraftModel trait), #145 (PLD drafter, 1.84× headline).🤖 Generated with Claude Code
Summary by CodeRabbit