Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ jobs:
toolchain: stable

- name: Install cargo-llvm-cov
uses: taiki-e/install-action@51cd0b8c0499559d9a4d75c0f5c67bec3a894ec8 # v2
uses: taiki-e/install-action@cca35edeb1d01366c2843b68fc3ca441446d73d3 # v2
with:
tool: cargo-llvm-cov

Expand Down
76 changes: 45 additions & 31 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion crates/higgs-bench/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ serde = { workspace = true }
serde_json = { workspace = true }
sysinfo = "0.32"
tokio = { workspace = true }
toml = "0.8"
toml = "1.0"

[build-dependencies]
built = { version = "0.8", features = ["git2"] }
Expand Down
78 changes: 78 additions & 0 deletions crates/higgs-models/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,29 @@ pub enum AnyCache {
Hybrid(Vec<Option<LayerCache>>),
}

impl AnyCache {
/// Trim every layer cache by `count` tokens, discarding the most recent
/// entries. Used after speculative-decode verify to roll back rejected
/// draft tokens. Hybrid SSM (recurrent) layers are intentionally left
/// untouched — their state cannot be trimmed by offset alone.
pub fn trim_by(&mut self, count: usize) {
match self {
Self::KV(layers) => {
for layer in layers.iter_mut().flatten() {
layer.trim_by(count);
}
}
Self::Hybrid(layers) => {
for layer in layers.iter_mut().flatten() {
if let LayerCache::KV(kv) = layer {
kv.trim_by(count);
}
}
}
}
}
}

/// Unified model wrapper dispatching to the correct architecture.
pub enum AnyModel {
/// Standard transformer architectures: Llama, Mistral, Qwen2/2.5, Qwen3.
Expand Down Expand Up @@ -1189,6 +1212,7 @@ fn remap_quantized_key(key: &str) -> Option<String> {
#[allow(clippy::panic, clippy::unwrap_used, clippy::indexing_slicing)]
mod tests {
use super::*;
use crate::cache::KeyValueCache;

fn params(temp: f32, top_p: f32) -> SamplingParams {
SamplingParams {
Expand Down Expand Up @@ -1695,4 +1719,58 @@ mod tests {
assert!((vals[1] - 2.0).abs() < 1e-5);
assert!((vals[2] - 4.5).abs() < 1e-5);
}

// --- AnyCache::trim_by tests ---

#[test]
fn any_cache_trim_by_kv_dispatches_to_each_layer() {
// Two KV layers, both at offset 0; trim_by saturates to 0.
// Verifies the dispatcher iterates None and Some(_) layers without panic.
let mut cache = AnyCache::KV(vec![
Some(cache::SteppingKeyValueCache::new()),
None,
Some(cache::SteppingKeyValueCache::new()),
]);
cache.trim_by(5);
if let AnyCache::KV(layers) = &cache {
assert_eq!(layers.len(), 3);
for layer in layers.iter().flatten() {
assert_eq!(layer.offset(), 0);
}
} else {
panic!("expected KV variant");
}
}

#[test]
fn any_cache_trim_by_hybrid_skips_arrays_layers() {
// Hybrid mixes LayerCache::KV (trimmable) and LayerCache::Arrays (recurrent,
// intentionally untouched). Verifies the dispatcher reaches into KV layers
// and leaves Arrays alone.
let mut arrays = qwen3_next::ArraysCache::new();
arrays.offset = 7;
let mut cache = AnyCache::Hybrid(vec![
Some(LayerCache::KV(cache::SteppingKeyValueCache::new())),
Some(LayerCache::Arrays(arrays)),
None,
]);
cache.trim_by(3);
if let AnyCache::Hybrid(layers) = &cache {
assert_eq!(layers.len(), 3);
// KV layer trimmed (saturated at 0 since starting offset was 0)
if let Some(LayerCache::KV(kv)) = layers.first().and_then(|l| l.as_ref()) {
assert_eq!(kv.offset(), 0);
} else {
panic!("expected first layer to be KV variant");
}
// Arrays layer offset unchanged (recurrent state, can't trim by offset)
if let Some(LayerCache::Arrays(a)) = layers.get(1).and_then(|l| l.as_ref()) {
assert_eq!(a.offset, 7, "Arrays layer offset must NOT be trimmed");
} else {
panic!("expected second layer to be Arrays variant");
}
} else {
panic!("expected Hybrid variant");
}
}
}
Loading