From 68399e56560201782344fbfb9cb7879573cc22e8 Mon Sep 17 00:00:00 2001 From: Peter Clemente III Date: Fri, 29 May 2026 13:14:48 -0400 Subject: [PATCH] Wave2 polish: clean clippy/fmt, add CI, finish dead diagnostic as a test - Resolve all clippy warnings; lib + binaries now pass `cargo clippy --all-targets -- -D warnings`: - `div_ceil` / `is_multiple_of` for manual ceil-div and modulo checks (lib.rs, realdata-bench.rs) - regroup two unreadable RNG-seed literals in tests - module-level `#![allow(clippy::needless_range_loop)]` on the four numeric kernels (attention/outlier/grouped/twobit) where the loop counter indexes packed sign-word slices and magnitude/score arrays in lockstep; documented why explicit indexing is kept - Replace the dead `_ensure_v2_view_compiles` helper (was sitting after the test module under `#[allow(dead_code)]`) with a real unit test for `GroupedKv::flatten_to_v2`, and drop the now-unused `CompressedKv` import - Normalize formatting repo-wide with `cargo fmt` (repo was not fmt-clean) - Add GitHub Actions CI (.github/workflows/ci.yml): fmt check, clippy -D warnings, release build, and tests on stable No behavior changes to the compression/attention logic. Co-Authored-By: Claude Opus 4.8 (1M context) --- .github/workflows/ci.yml | 44 +++++++++++++ src/attention.rs | 44 ++++++++++--- src/bin/kv-bench.rs | 29 +++++++-- src/bin/realdata-bench.rs | 117 +++++++++++++++++++++++++--------- src/grouped.rs | 130 ++++++++++++++++++++++++++++---------- src/lib.rs | 16 +++-- src/outlier.rs | 58 +++++++++++++---- src/quantize.rs | 44 +++++++++---- src/twobit.rs | 110 ++++++++++++++++++++++---------- 9 files changed, 452 insertions(+), 140 deletions(-) create mode 100644 .github/workflows/ci.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..ec06061 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,44 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + +env: + CARGO_TERM_COLOR: always + RUSTFLAGS: -D warnings + +jobs: + test: + name: build / test / lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install stable toolchain + uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt, clippy + + - name: Cache cargo registry and build + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.toml') }} + restore-keys: ${{ runner.os }}-cargo- + + - name: Format check + run: cargo fmt --all -- --check + + - name: Clippy (deny warnings) + run: cargo clippy --all-targets -- -D warnings + + - name: Build (release) + run: cargo build --release --all-targets + + - name: Test + run: cargo test --all-targets diff --git a/src/attention.rs b/src/attention.rs index 35b0278..54b6e57 100644 --- a/src/attention.rs +++ b/src/attention.rs @@ -23,19 +23,29 @@ //! No K or V f32 row is ever assembled. Per-row work for scores is O(words_per_row) //! reads (not D reads), per-row work for output is O(words_per_row) too. +// These attention kernels are deliberate index-math: the loop counter indexes +// a per-row slice of packed sign-words together with the magnitude/score/weight +// arrays in lockstep. Rewriting as zipped iterators obscures the bit-packing +// layout and the i-outer/d-inner cache strategy, so we keep explicit indexing. +#![allow(clippy::needless_range_loop)] + use crate::quantize::CompressedKv; use crate::{sign_words_for, BITS_PER_WORD}; /// Reference (slow) binary attention: decompresses K/V to f32, runs standard /// attention. Used as a correctness oracle in tests; do not use in production. pub fn binary_attention_naive( - q: &[f32], // length d_k - k: &CompressedKv, // n_keys × d_k - v: &CompressedKv, // n_keys × d_v - inv_sqrt_dk: f32, // 1/sqrt(d_k), precomputed by caller + q: &[f32], // length d_k + k: &CompressedKv, // n_keys × d_k + v: &CompressedKv, // n_keys × d_v + inv_sqrt_dk: f32, // 1/sqrt(d_k), precomputed by caller ) -> Vec { assert_eq!(q.len(), k.d, "q.len() != k.d ({} vs {})", q.len(), k.d); - assert_eq!(k.n_rows, v.n_rows, "k.n_rows != v.n_rows ({} vs {})", k.n_rows, v.n_rows); + assert_eq!( + k.n_rows, v.n_rows, + "k.n_rows != v.n_rows ({} vs {})", + k.n_rows, v.n_rows + ); let n = k.n_rows; let dv = v.d; @@ -79,7 +89,11 @@ pub fn binary_attention_fast( inv_sqrt_dk: f32, ) -> Vec { assert_eq!(q.len(), k.d, "q.len() != k.d ({} vs {})", q.len(), k.d); - assert_eq!(k.n_rows, v.n_rows, "k.n_rows != v.n_rows ({} vs {})", k.n_rows, v.n_rows); + assert_eq!( + k.n_rows, v.n_rows, + "k.n_rows != v.n_rows ({} vs {})", + k.n_rows, v.n_rows + ); let n = k.n_rows; let dk = k.d; let dv = v.d; @@ -200,7 +214,12 @@ fn sum_q_where_bit_set(q: &[f32], row_words: &[u64], dk: usize) -> f32 { let base = full_words * BITS_PER_WORD; while word != 0 { let bit = word.trailing_zeros() as usize; - debug_assert!(base + bit < dk, "tail word leaked bit at idx {} ≥ dk={}", base + bit, dk); + debug_assert!( + base + bit < dk, + "tail word leaked bit at idx {} ≥ dk={}", + base + bit, + dk + ); sum += q[base + bit]; word &= word - 1; } @@ -216,7 +235,14 @@ mod tests { use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha8Rng; - fn full_attention(q: &[f32], k_full: &[f32], v_full: &[f32], dk: usize, dv: usize, inv_sqrt_dk: f32) -> Vec { + fn full_attention( + q: &[f32], + k_full: &[f32], + v_full: &[f32], + dk: usize, + dv: usize, + inv_sqrt_dk: f32, + ) -> Vec { let n = k_full.len() / dk; assert_eq!(v_full.len(), n * dv); let mut scores = vec![0f32; n]; @@ -241,7 +267,7 @@ mod tests { #[test] fn naive_and_fast_agree() { - let mut rng = ChaCha8Rng::seed_from_u64(0xdeadbeef_42); + let mut rng = ChaCha8Rng::seed_from_u64(0x0000_00de_adbe_ef42); let n = 32; let dk = 64; let dv = 64; diff --git a/src/bin/kv-bench.rs b/src/bin/kv-bench.rs index c7c7f62..412cbce 100644 --- a/src/bin/kv-bench.rs +++ b/src/bin/kv-bench.rs @@ -1,8 +1,8 @@ //! Quick CLI: generate synthetic Gaussian KV rows, compress, report //! compression ratio + reconstruction MSE + per-row signal-to-quantization-noise. use kv_compressor::{compress_rows, decompress_rows}; -use rand::SeedableRng; use rand::Rng; +use rand::SeedableRng; use rand_chacha::ChaCha8Rng; fn main() { @@ -13,7 +13,10 @@ fn main() { let n_rows = n_tokens * n_kv_heads; let d = head_dim; - println!("synthetic Gaussian KV: n_rows={n_rows} d={d} total elements={}", n_rows * d); + println!( + "synthetic Gaussian KV: n_rows={n_rows} d={d} total elements={}", + n_rows * d + ); let mut rng = ChaCha8Rng::seed_from_u64(0x5157f1d7c0deba5e); let input: Vec = (0..n_rows * d) @@ -47,11 +50,25 @@ fn main() { let throughput_c = (input.len() * 4) as f64 / dt_c.as_secs_f64() / 1e9; let throughput_d = (input.len() * 4) as f64 / dt_d.as_secs_f64() / 1e9; - println!("compressed bytes: {} ({:.1} MB)", c.bytes(), c.bytes() as f64 / 1024.0 / 1024.0); - println!("baseline q8 bytes: {} ({:.1} MB)", c.baseline_bytes(), c.baseline_bytes() as f64 / 1024.0 / 1024.0); + println!( + "compressed bytes: {} ({:.1} MB)", + c.bytes(), + c.bytes() as f64 / 1024.0 / 1024.0 + ); + println!( + "baseline q8 bytes: {} ({:.1} MB)", + c.baseline_bytes(), + c.baseline_bytes() as f64 / 1024.0 / 1024.0 + ); println!("compression ratio: {:.2}×", c.ratio()); println!("reconstruction MSE: {:.6}", mse); println!("SNR (vs Gaussian unit signal): {:.2} dB", snr_db); - println!("compress throughput: {:.2} GB/s ({:?})", throughput_c, dt_c); - println!("decompress throughput: {:.2} GB/s ({:?})", throughput_d, dt_d); + println!( + "compress throughput: {:.2} GB/s ({:?})", + throughput_c, dt_c + ); + println!( + "decompress throughput: {:.2} GB/s ({:?})", + throughput_d, dt_d + ); } diff --git a/src/bin/realdata-bench.rs b/src/bin/realdata-bench.rs index 8f8ccce..bd3411e 100644 --- a/src/bin/realdata-bench.rs +++ b/src/bin/realdata-bench.rs @@ -6,9 +6,8 @@ //! cargo run --release --bin realdata-bench -- ./dumps/qwen3_1p7b_n8.npz use anyhow::{anyhow, Context, Result}; use kv_compressor::{ - binary_attention_fast, compress_grouped, compress_outlier, compress_rows, - compress_twobit, grouped_attention_fast, outlier_attention_fast, - twobit_attention_fast, + binary_attention_fast, compress_grouped, compress_outlier, compress_rows, compress_twobit, + grouped_attention_fast, outlier_attention_fast, twobit_attention_fast, }; use ndarray::{s, Array1, Array4}; use ndarray_npy::NpzReader; @@ -23,11 +22,22 @@ fn cosine(a: &[f32], b: &[f32]) -> f32 { } fn l2_dist(a: &[f32], b: &[f32]) -> f32 { - a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum::().sqrt() + a.iter() + .zip(b) + .map(|(x, y)| (x - y) * (x - y)) + .sum::() + .sqrt() } /// Standard scaled-dot-product attention over real K/V (no compression). -fn full_attention_per_head(q: &[f32], k_full: &[f32], v_full: &[f32], dk: usize, dv: usize, inv_sqrt_dk: f32) -> Vec { +fn full_attention_per_head( + q: &[f32], + k_full: &[f32], + v_full: &[f32], + dk: usize, + dv: usize, + inv_sqrt_dk: f32, +) -> Vec { let n = k_full.len() / dk; let mut scores = vec![0f32; n]; for i in 0..n { @@ -67,7 +77,9 @@ fn full_attention_per_head(q: &[f32], k_full: &[f32], v_full: &[f32], dk: usize, } fn main() -> Result<()> { - let path = env::args().nth(1).ok_or_else(|| anyhow!("usage: realdata-bench "))?; + let path = env::args() + .nth(1) + .ok_or_else(|| anyhow!("usage: realdata-bench "))?; println!("loading {}", path); let mut npz = NpzReader::new(File::open(&path).context(path.clone())?)?; @@ -92,7 +104,7 @@ fn main() -> Result<()> { // Infer per-head dim. Common Qwen3 convention: head_dim = 128, GQA → d_k = n_kv_heads * 128, d_q = n_q_heads * 128. let head_dim: usize = 128; - if d_q % head_dim != 0 || d_k % head_dim != 0 || d_v % head_dim != 0 { + if d_q % head_dim != 0 || !d_k.is_multiple_of(head_dim) || !d_v.is_multiple_of(head_dim) { return Err(anyhow!( "dims not divisible by assumed head_dim=128 (d_q={}, d_k={}, d_v={})", d_q, @@ -128,7 +140,14 @@ fn main() -> Result<()> { ("v4 k=4 ", 4, 1, true), ("v4 k=8 ", 8, 1, true), ]; - println!("variants: {}", variants.iter().map(|v| v.0.trim()).collect::>().join(" | ")); + println!( + "variants: {}", + variants + .iter() + .map(|v| v.0.trim()) + .collect::>() + .join(" | ") + ); let mut layer_cos_sum: Vec> = variants.iter().map(|_| vec![0f64; nl]).collect(); let mut layer_l2_sum: Vec> = variants.iter().map(|_| vec![0f64; nl]).collect(); @@ -155,13 +174,16 @@ fn main() -> Result<()> { v_head_f32.push(v_layer[[t, kvh * head_dim + d]]); } } - let words_per_row = (head_dim + 63) / 64; + let words_per_row = head_dim.div_ceil(64); // Pre-compress every variant once per (sample, layer, kv-head). let v1_ck = compress_rows(&k_head_f32, n_tok, head_dim); let v1_cv = compress_rows(&v_head_f32, n_tok, head_dim); - let mut v2_buf: Vec<(kv_compressor::OutlierKv, kv_compressor::OutlierKv)> = Vec::new(); - let mut v3_buf: Vec<(kv_compressor::GroupedKv, kv_compressor::GroupedKv)> = Vec::new(); - let mut v4_buf: Vec<(kv_compressor::TwoBitKv, kv_compressor::TwoBitKv)> = Vec::new(); + let mut v2_buf: Vec<(kv_compressor::OutlierKv, kv_compressor::OutlierKv)> = + Vec::new(); + let mut v3_buf: Vec<(kv_compressor::GroupedKv, kv_compressor::GroupedKv)> = + Vec::new(); + let mut v4_buf: Vec<(kv_compressor::TwoBitKv, kv_compressor::TwoBitKv)> = + Vec::new(); for &(_, k_out, ng, twobit) in &variants { if twobit { v4_buf.push(( @@ -169,7 +191,9 @@ fn main() -> Result<()> { compress_twobit(&v_head_f32, n_tok, head_dim, k_out), )); } else if ng == 1 { - if k_out == 0 { continue; } + if k_out == 0 { + continue; + } v2_buf.push(( compress_outlier(&k_head_f32, n_tok, head_dim, k_out), compress_outlier(&v_head_f32, n_tok, head_dim, k_out), @@ -217,7 +241,9 @@ fn main() -> Result<()> { mag: ck_full_t.mag[..n_keys * 2].to_vec(), outlier_idx: ck_full_t.outlier_idx[..n_keys * k_out].to_vec(), outlier_val: ck_full_t.outlier_val[..n_keys * k_out].to_vec(), - n_rows: n_keys, d: head_dim, k_outliers: k_out, + n_rows: n_keys, + d: head_dim, + k_outliers: k_out, }; let cv_pref = kv_compressor::TwoBitKv { signs: cv_full_t.signs[..n_keys * words_per_row].to_vec(), @@ -225,19 +251,23 @@ fn main() -> Result<()> { mag: cv_full_t.mag[..n_keys * 2].to_vec(), outlier_idx: cv_full_t.outlier_idx[..n_keys * k_out].to_vec(), outlier_val: cv_full_t.outlier_val[..n_keys * k_out].to_vec(), - n_rows: n_keys, d: head_dim, k_outliers: k_out, + n_rows: n_keys, + d: head_dim, + k_outliers: k_out, }; twobit_attention_fast(&q_vec, &ck_pref, &cv_pref, inv_sqrt_dk) } else if ng == 1 && k_out == 0 { let ck_pref = kv_compressor::CompressedKv { signs: v1_ck.signs[..n_keys * words_per_row].to_vec(), mag: v1_ck.mag[..n_keys].to_vec(), - n_rows: n_keys, d: head_dim, + n_rows: n_keys, + d: head_dim, }; let cv_pref = kv_compressor::CompressedKv { signs: v1_cv.signs[..n_keys * words_per_row].to_vec(), mag: v1_cv.mag[..n_keys].to_vec(), - n_rows: n_keys, d: head_dim, + n_rows: n_keys, + d: head_dim, }; binary_attention_fast(&q_vec, &ck_pref, &cv_pref, inv_sqrt_dk) } else if ng == 1 { @@ -248,14 +278,18 @@ fn main() -> Result<()> { mag: ck_full_o.mag[..n_keys].to_vec(), outlier_idx: ck_full_o.outlier_idx[..n_keys * k_out].to_vec(), outlier_val: ck_full_o.outlier_val[..n_keys * k_out].to_vec(), - n_rows: n_keys, d: head_dim, k_outliers: k_out, + n_rows: n_keys, + d: head_dim, + k_outliers: k_out, }; let cv_pref = kv_compressor::OutlierKv { signs: cv_full_o.signs[..n_keys * words_per_row].to_vec(), mag: cv_full_o.mag[..n_keys].to_vec(), outlier_idx: cv_full_o.outlier_idx[..n_keys * k_out].to_vec(), outlier_val: cv_full_o.outlier_val[..n_keys * k_out].to_vec(), - n_rows: n_keys, d: head_dim, k_outliers: k_out, + n_rows: n_keys, + d: head_dim, + k_outliers: k_out, }; outlier_attention_fast(&q_vec, &ck_pref, &cv_pref, inv_sqrt_dk) } else { @@ -266,14 +300,20 @@ fn main() -> Result<()> { mag: ck_full_g.mag[..n_keys * ng].to_vec(), outlier_idx: ck_full_g.outlier_idx[..n_keys * k_out].to_vec(), outlier_val: ck_full_g.outlier_val[..n_keys * k_out].to_vec(), - n_rows: n_keys, d: head_dim, n_groups: ng, k_outliers: k_out, + n_rows: n_keys, + d: head_dim, + n_groups: ng, + k_outliers: k_out, }; let cv_pref = kv_compressor::GroupedKv { signs: cv_full_g.signs[..n_keys * words_per_row].to_vec(), mag: cv_full_g.mag[..n_keys * ng].to_vec(), outlier_idx: cv_full_g.outlier_idx[..n_keys * k_out].to_vec(), outlier_val: cv_full_g.outlier_val[..n_keys * k_out].to_vec(), - n_rows: n_keys, d: head_dim, n_groups: ng, k_outliers: k_out, + n_rows: n_keys, + d: head_dim, + n_groups: ng, + k_outliers: k_out, }; grouped_attention_fast(&q_vec, &ck_pref, &cv_pref, inv_sqrt_dk) }; @@ -288,31 +328,50 @@ fn main() -> Result<()> { } println!("\n=== per-variant summary ==="); - println!("{:<14} | {:>8} | {:>8} | {:>10}", "variant", "ratio×", "avg_cos", "avg_l2"); + println!( + "{:<14} | {:>8} | {:>8} | {:>10}", + "variant", "ratio×", "avg_cos", "avg_l2" + ); println!("{:-<14}-+-{:->8}-+-{:->8}-+-{:->10}", "", "", "", ""); let baseline_bytes = head_dim + 2; for (vi, &(label, k_out, ng, twobit)) in variants.iter().enumerate() { let comp_bytes = if twobit { // signs + levels (both ceil(d/64) u64) + 2 f32 mags + outliers - 2 * ((head_dim + 63) / 64) * 8 + 2 * 4 + k_out * (2 + 4) + 2 * head_dim.div_ceil(64) * 8 + 2 * 4 + k_out * (2 + 4) } else { - ((head_dim + 63) / 64) * 8 + ng * 4 + k_out * (2 + 4) + head_dim.div_ceil(64) * 8 + ng * 4 + k_out * (2 + 4) }; let ratio = baseline_bytes as f32 / comp_bytes as f32; let total_cos: f64 = layer_cos_sum[vi].iter().sum(); let total_l2: f64 = layer_l2_sum[vi].iter().sum(); let total_n: usize = layer_count[vi].iter().sum(); - let cos = if total_n > 0 { total_cos / total_n as f64 } else { 0.0 }; - let l2 = if total_n > 0 { total_l2 / total_n as f64 } else { 0.0 }; - println!("{:<14} | {:>8.2} | {:>8.4} | {:>10.4}", label, ratio, cos, l2); + let cos = if total_n > 0 { + total_cos / total_n as f64 + } else { + 0.0 + }; + let l2 = if total_n > 0 { + total_l2 / total_n as f64 + } else { + 0.0 + }; + println!( + "{:<14} | {:>8.2} | {:>8.4} | {:>10.4}", + label, ratio, cos, l2 + ); } let last = variants.len() - 1; - println!("\n=== per-layer cos sim, variant: {} ===", variants[last].0.trim()); + println!( + "\n=== per-layer cos sim, variant: {} ===", + variants[last].0.trim() + ); println!("{:>5} | {:>8} | {:>8}", "layer", "cos", "l2"); for li in 0..nl { let n = layer_count[last][li]; - if n == 0 { continue; } + if n == 0 { + continue; + } let cos = layer_cos_sum[last][li] / n as f64; let l2 = layer_l2_sum[last][li] / n as f64; println!("{:>5} | {:>8.4} | {:>8.4}", li, cos, l2); diff --git a/src/grouped.rs b/src/grouped.rs index ec9445b..80debf3 100644 --- a/src/grouped.rs +++ b/src/grouped.rs @@ -34,7 +34,11 @@ //! Outlier correction unchanged from v2 — outlier overrides the inlier //! contribution for its specific channel. -use crate::quantize::CompressedKv; +// Deliberate index-math kernels: the loop counter indexes packed sign-word +// slices alongside per-group magnitude/score/weight arrays in lockstep. See +// attention.rs for the rationale; iterator rewrites would obscure the layout. +#![allow(clippy::needless_range_loop)] + use crate::{sign_words_for, BITS_PER_WORD}; #[derive(Debug, Clone)] @@ -106,9 +110,24 @@ pub fn compress_grouped( n_groups: usize, k_outliers: usize, ) -> GroupedKv { - assert_eq!(input.len(), n_rows * d, "compress_grouped: input length mismatch"); - assert!(n_groups > 0 && n_groups <= d, "n_groups ({}) must be in 1..=d ({})", n_groups, d); - assert_eq!(d % n_groups, 0, "d ({}) must be divisible by n_groups ({})", d, n_groups); + assert_eq!( + input.len(), + n_rows * d, + "compress_grouped: input length mismatch" + ); + assert!( + n_groups > 0 && n_groups <= d, + "n_groups ({}) must be in 1..=d ({})", + n_groups, + d + ); + assert_eq!( + d % n_groups, + 0, + "d ({}) must be divisible by n_groups ({})", + d, + n_groups + ); assert!(k_outliers <= d, "k_outliers must be ≤ d"); let words_per_row = sign_words_for(d); @@ -139,10 +158,9 @@ pub fn compress_grouped( for (i, &x) in row.iter().enumerate() { abs_pairs.push((i as u16, x.abs())); } - abs_pairs.select_nth_unstable_by( - k_eff - 1, - |a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal), - ); + abs_pairs.select_nth_unstable_by(k_eff - 1, |a, b| { + b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal) + }); abs_pairs[..k_eff].sort_by_key(|p| p.0); for (slot, &(idx, _)) in abs_pairs[..k_eff].iter().enumerate() { outlier_idx[r * k_outliers + slot] = idx; @@ -171,13 +189,26 @@ pub fn compress_grouped( n_inl += 1; } } - let m = if n_inl > 0 { (sum_abs / n_inl as f64) as f32 } else { 0.0 }; + let m = if n_inl > 0 { + (sum_abs / n_inl as f64) as f32 + } else { + 0.0 + }; debug_assert!(m.is_finite() && m >= 0.0); mag[r * n_groups + g] = m; } } - GroupedKv { signs, mag, outlier_idx, outlier_val, n_rows, d, n_groups, k_outliers } + GroupedKv { + signs, + mag, + outlier_idx, + outlier_val, + n_rows, + d, + n_groups, + k_outliers, + } } pub fn decompress_grouped(c: &GroupedKv) -> Vec { @@ -198,7 +229,9 @@ pub fn decompress_grouped(c: &GroupedKv) -> Vec { // Outlier overrides for slot in 0..c.k_outliers { let idx = c.outlier_idx[r * c.k_outliers + slot]; - if idx == u16::MAX { continue; } + if idx == u16::MAX { + continue; + } row_out[idx as usize] = c.outlier_val[r * c.k_outliers + slot]; } } @@ -288,7 +321,9 @@ pub fn grouped_attention_fast( let bit = word.trailing_zeros() as usize; let c = base + bit; debug_assert!(c < dk, "K signs leaked tail bit at c={} dk={}", c, dk); - if c >= dk { break; } // release-mode defensive (compressor guarantees this) + if c >= dk { + break; + } // release-mode defensive (compressor guarantees this) let g = c / kgs; pos_sum_buf[g] += q[c]; word &= word - 1; @@ -303,7 +338,9 @@ pub fn grouped_attention_fast( // Outlier corrections for slot in 0..kk { let idx = k.outlier_idx[i * kk + slot]; - if idx == u16::MAX { continue; } + if idx == u16::MAX { + continue; + } let val = k.outlier_val[i * kk + slot]; let g = idx as usize / kgs; let mag_g = row_mag[g]; @@ -334,7 +371,9 @@ pub fn grouped_attention_fast( let bit = word.trailing_zeros() as usize; let d = base + bit; debug_assert!(d < dv, "V signs leaked tail bit at d={} dv={}", d, dv); - if d >= dv { break; } // release-mode defensive + if d >= dv { + break; + } // release-mode defensive let g = d / vgs; output[d] += 2.0 * scaled_w_g[g]; word &= word - 1; @@ -356,7 +395,9 @@ pub fn grouped_attention_fast( let row_mag = &v.mag[i * vg..(i + 1) * vg]; for slot in 0..vk { let idx = v.outlier_idx[i * vk + slot]; - if idx == u16::MAX { continue; } + if idx == u16::MAX { + continue; + } let val = v.outlier_val[i * vk + slot]; let g = idx as usize / vgs; let mag_g = row_mag[g]; @@ -383,15 +424,19 @@ mod tests { 0.1, -0.1, 0.2, -0.2, // group 0: small 10.0, -20.0, 15.0, -8.0, // group 1: large ]; - let c = compress_grouped(&input, 1, 8, /*n_groups=*/2, /*k=*/0); + let c = compress_grouped(&input, 1, 8, /*n_groups=*/ 2, /*k=*/ 0); let r = decompress_grouped(&c); // Group 0 mag = mean(|0.1|+|0.1|+|0.2|+|0.2|)/4 = 0.15 // Group 1 mag = mean(|10|+|20|+|15|+|8|)/4 = 13.25 assert_relative_eq!(c.mag[0], 0.15, max_relative = 1e-4); assert_relative_eq!(c.mag[1], 13.25, max_relative = 1e-4); // Check that group-0 reconstructs as ±0.15, group-1 as ±13.25 - for i in 0..4 { assert_relative_eq!(r[i].abs(), 0.15, max_relative = 1e-4); } - for i in 4..8 { assert_relative_eq!(r[i].abs(), 13.25, max_relative = 1e-4); } + for i in 0..4 { + assert_relative_eq!(r[i].abs(), 0.15, max_relative = 1e-4); + } + for i in 4..8 { + assert_relative_eq!(r[i].abs(), 13.25, max_relative = 1e-4); + } } #[test] @@ -406,7 +451,7 @@ mod tests { let mut v = vec![0f32; n * dv]; for i in 0..n { for d in 0..dk { - let g = d / 16; // 4 groups of 16 + let g = d / 16; // 4 groups of 16 let scale = (g + 1) as f32 * 0.5; k[i * dk + d] = (rng.r#gen::() - 0.5) * scale; v[i * dv + d] = (rng.r#gen::() - 0.5) * scale; @@ -432,13 +477,13 @@ mod tests { let mut data = vec![0f32; n * d]; for r in 0..n { for c in 0..d { - let g = c / 32; // 4 groups of 32 + let g = c / 32; // 4 groups of 32 let scale = (g + 1) as f32 * 0.5; data[r * d + c] = (rng.r#gen::() - 0.5) * scale; } } - let v2 = crate::compress_outlier(&data, n, d, /*k=*/2); - let v3 = compress_grouped(&data, n, d, /*g=*/4, /*k=*/2); + let v2 = crate::compress_outlier(&data, n, d, /*k=*/ 2); + let v3 = compress_grouped(&data, n, d, /*g=*/ 4, /*k=*/ 2); let r2 = crate::decompress_outlier(&v2); let r3 = decompress_grouped(&v3); let cos = |a: &[f32], b: &[f32]| -> f32 { @@ -449,14 +494,42 @@ mod tests { }; let cos_v2 = cos(&data, &r2); let cos_v3 = cos(&data, &r3); - assert!(cos_v3 > cos_v2, "v3 should beat v2 on multi-scale; got v2={cos_v2} v3={cos_v3}"); + assert!( + cos_v3 > cos_v2, + "v3 should beat v2 on multi-scale; got v2={cos_v2} v3={cos_v3}" + ); + } + + #[test] + fn flatten_to_v2_averages_group_mags_and_preserves_outliers() { + // Two groups with distinct scales; flatten_to_v2 should collapse the + // per-group magnitudes into one scalar per row = mean of the groups, + // while carrying the sign/outlier data across unchanged. + let input = vec![ + 0.1, -0.1, 0.2, -0.2, // group 0: small (mag 0.15) + 10.0, -20.0, 15.0, -8.0, // group 1: large (mag 13.25) + ]; + // k=0 so no channels are pulled out as outliers; each group's mag is the + // clean mean(|group|) and flatten_to_v2 averages them directly. + let g = compress_grouped(&input, 1, 8, /*n_groups=*/ 2, /*k=*/ 0); + let v2 = g.flatten_to_v2(); + // Scalar mag = mean(0.15, 13.25) = 6.7 + assert_relative_eq!(v2.mag[0], (0.15 + 13.25) / 2.0, max_relative = 1e-4); + // Structural fields are passed through unchanged. + assert_eq!(v2.n_rows, g.n_rows); + assert_eq!(v2.d, g.d); + assert_eq!(v2.k_outliers, g.k_outliers); + assert_eq!(v2.signs, g.signs); + assert_eq!(v2.outlier_idx, g.outlier_idx); + assert_eq!(v2.outlier_val, g.outlier_val); } #[test] fn grouped_handles_g1_equals_v2() { // n_groups=1 should give cosine-equivalent to v2 (same scalar mag per row). - let input: Vec = (0..1024).map(|i| ((i as f32 * 0.137).sin())).collect(); - let n = 16; let d = 64; + let input: Vec = (0..1024).map(|i| (i as f32 * 0.137).sin()).collect(); + let n = 16; + let d = 64; let v3 = compress_grouped(&input, n, d, 1, 2); // Per-row mag matches v2 mean(|inlier|) when n_groups=1. let v2 = crate::compress_outlier(&input, n, d, 2); @@ -465,10 +538,3 @@ mod tests { } } } - -// suppress dead-code warning for the diagnostic helper while bench wires up -#[allow(dead_code)] -fn _ensure_v2_view_compiles(g: &GroupedKv) -> CompressedKv { - let _v = g.flatten_to_v2(); - CompressedKv { signs: vec![], mag: vec![], n_rows: 0, d: 0 } -} diff --git a/src/lib.rs b/src/lib.rs index 31d8915..da9c0b3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -40,15 +40,23 @@ pub mod quantize; pub mod twobit; pub use attention::{binary_attention_fast, binary_attention_naive}; -pub use grouped::{compress_grouped, decompress_grouped, grouped_attention_fast, grouped_attention_naive, GroupedKv}; -pub use outlier::{compress_outlier, decompress_outlier, outlier_attention_fast, outlier_attention_naive, OutlierKv}; +pub use grouped::{ + compress_grouped, decompress_grouped, grouped_attention_fast, grouped_attention_naive, + GroupedKv, +}; +pub use outlier::{ + compress_outlier, decompress_outlier, outlier_attention_fast, outlier_attention_naive, + OutlierKv, +}; pub use quantize::{compress_rows, decompress_rows, CompressedKv}; -pub use twobit::{compress_twobit, decompress_twobit, twobit_attention_fast, twobit_attention_naive, TwoBitKv}; +pub use twobit::{ + compress_twobit, decompress_twobit, twobit_attention_fast, twobit_attention_naive, TwoBitKv, +}; /// Number of bits in one sign word. pub const BITS_PER_WORD: usize = 64; /// How many u64 sign-words are needed for a row of `d` elements. pub const fn sign_words_for(d: usize) -> usize { - (d + BITS_PER_WORD - 1) / BITS_PER_WORD + d.div_ceil(BITS_PER_WORD) } diff --git a/src/outlier.rs b/src/outlier.rs index 8f2441e..ef8ed20 100644 --- a/src/outlier.rs +++ b/src/outlier.rs @@ -29,6 +29,11 @@ //! where ŝ[i,j] = +1 if outlier value was ≥0 else −1 (matches the bit we //! also packed into `signs`). Same idea for the V-side output. +// Deliberate index-math kernels: the loop counter indexes packed sign-word +// slices alongside magnitude/score/weight arrays in lockstep. See attention.rs +// for the rationale; iterator rewrites would obscure the bit-packing layout. +#![allow(clippy::needless_range_loop)] + use crate::quantize::CompressedKv; use crate::{sign_words_for, BITS_PER_WORD}; @@ -90,8 +95,17 @@ impl OutlierKv { /// (caller should use `quantize::compress_rows` directly in that case; /// we still handle it for unit-test parity). pub fn compress_outlier(input: &[f32], n_rows: usize, d: usize, k_outliers: usize) -> OutlierKv { - assert_eq!(input.len(), n_rows * d, "compress_outlier: input length mismatch"); - assert!(k_outliers <= d, "k_outliers ({}) must be ≤ d ({})", k_outliers, d); + assert_eq!( + input.len(), + n_rows * d, + "compress_outlier: input length mismatch" + ); + assert!( + k_outliers <= d, + "k_outliers ({}) must be ≤ d ({})", + k_outliers, + d + ); assert!(d > 0 && n_rows > 0); let words_per_row = sign_words_for(d); @@ -126,10 +140,9 @@ pub fn compress_outlier(input: &[f32], n_rows: usize, d: usize, k_outliers: usiz } // select_nth_unstable_by puts the k-th largest at position k_eff-1 (0-indexed), // with strictly-larger items in positions 0..k_eff-1. - abs_pairs.select_nth_unstable_by( - k_eff - 1, - |a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal), - ); + abs_pairs.select_nth_unstable_by(k_eff - 1, |a, b| { + b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal) + }); // The top-K are now in abs_pairs[..k_eff], unsorted. Sort that slice // by index for cache-friendly correction in attention (sequential access). abs_pairs[..k_eff].sort_by_key(|p| p.0); @@ -156,8 +169,15 @@ pub fn compress_outlier(input: &[f32], n_rows: usize, d: usize, k_outliers: usiz n_inliers += 1; } } - let m = if n_inliers > 0 { (sum_abs / n_inliers as f64) as f32 } else { 0.0 }; - debug_assert!(m.is_finite() && m >= 0.0, "row {r} mag must be finite ≥ 0, got {m}"); + let m = if n_inliers > 0 { + (sum_abs / n_inliers as f64) as f32 + } else { + 0.0 + }; + debug_assert!( + m.is_finite() && m >= 0.0, + "row {r} mag must be finite ≥ 0, got {m}" + ); mag[r] = m; } @@ -388,8 +408,12 @@ mod tests { let dk = 64; let dv = 64; let q: Vec = (0..dk).map(|_| rng.r#gen::() - 0.5).collect(); - let mut k: Vec = (0..n * dk).map(|_| (rng.r#gen::() - 0.5) * 0.5).collect(); - let mut v: Vec = (0..n * dv).map(|_| (rng.r#gen::() - 0.5) * 0.5).collect(); + let mut k: Vec = (0..n * dk) + .map(|_| (rng.r#gen::() - 0.5) * 0.5) + .collect(); + let mut v: Vec = (0..n * dv) + .map(|_| (rng.r#gen::() - 0.5) * 0.5) + .collect(); // Inject outliers per row at random positions. for i in 0..n { let pos1 = i % dk; @@ -416,7 +440,9 @@ mod tests { let mut rng = ChaCha8Rng::seed_from_u64(0xdab1ed); let n = 8; let d = 128; - let mut data: Vec = (0..n * d).map(|_| (rng.r#gen::() - 0.5) * 0.1).collect(); + let mut data: Vec = (0..n * d) + .map(|_| (rng.r#gen::() - 0.5) * 0.1) + .collect(); for i in 0..n { data[i * d + i % d] = 50.0; // huge outlier per row } @@ -434,7 +460,13 @@ mod tests { }; let cos_v1 = cos(&data, &r1); let cos_v2 = cos(&data, &r2); - assert!(cos_v2 > cos_v1, "v2 should beat v1; got v1={cos_v1} v2={cos_v2}"); - assert!(cos_v2 > 0.9, "v2 cos sim should be > 0.9 with outlier extraction; got {cos_v2}"); + assert!( + cos_v2 > cos_v1, + "v2 should beat v1; got v1={cos_v1} v2={cos_v2}" + ); + assert!( + cos_v2 > 0.9, + "v2 cos sim should be > 0.9 with outlier extraction; got {cos_v2}" + ); } } diff --git a/src/quantize.rs b/src/quantize.rs index 3fe2526..e28bccf 100644 --- a/src/quantize.rs +++ b/src/quantize.rs @@ -22,8 +22,7 @@ pub struct CompressedKv { impl CompressedKv { /// Bytes used by this compressed representation (excludes Vec overhead). pub fn bytes(&self) -> usize { - self.signs.len() * std::mem::size_of::() - + self.mag.len() * std::mem::size_of::() + self.signs.len() * std::mem::size_of::() + self.mag.len() * std::mem::size_of::() } /// Bytes one row would consume in the q8_0 baseline (D bytes payload + 2 bytes scale). @@ -68,7 +67,10 @@ pub fn compress_rows(input: &[f32], n_rows: usize, d: usize) -> CompressedKv { sum += x.abs() as f64; } let m = (sum / d as f64) as f32; - debug_assert!(m.is_finite() && m >= 0.0, "row {r} magnitude must be finite and non-negative, got {m}"); + debug_assert!( + m.is_finite() && m >= 0.0, + "row {r} magnitude must be finite and non-negative, got {m}" + ); mag[r] = m; // Pack signs LSB-first into words @@ -121,11 +123,15 @@ mod tests { for (orig, recon) in input.iter().zip(r.iter()) { // Skip the orig==0 case since the convention is "zero treated as positive"; // covered separately in `zero_treated_as_positive`. - if *orig == 0.0 { continue; } + if *orig == 0.0 { + continue; + } let want_positive = *orig > 0.0; let got_positive = *recon > 0.0; - assert_eq!(want_positive, got_positive, - "sign mismatch at orig={orig} recon={recon}"); + assert_eq!( + want_positive, got_positive, + "sign mismatch at orig={orig} recon={recon}" + ); } } @@ -140,7 +146,12 @@ mod tests { let words_per_row = sign_words_for(4); let row_words = &c.signs[..words_per_row]; // Lower 4 bits should all be set - assert_eq!(row_words[0] & 0b1111, 0b1111, "expected all 4 sign bits set, got {:b}", row_words[0] & 0b1111); + assert_eq!( + row_words[0] & 0b1111, + 0b1111, + "expected all 4 sign bits set, got {:b}", + row_words[0] & 0b1111 + ); } #[test] @@ -150,14 +161,19 @@ mod tests { // read by decompress_rows. let n_rows = 3; let d = 100; - let input: Vec = (0..n_rows * d).map(|i| if i % 2 == 0 { 1.0 } else { -1.0 }).collect(); + let input: Vec = (0..n_rows * d) + .map(|i| if i % 2 == 0 { 1.0 } else { -1.0 }) + .collect(); let c = compress_rows(&input, n_rows, d); let recon = decompress_rows(&c); assert_eq!(recon.len(), n_rows * d); for (i, (a, b)) in input.iter().zip(recon.iter()).enumerate() { let want_positive = *a >= 0.0; let got_positive = *b >= 0.0; - assert_eq!(want_positive, got_positive, "sign mismatch at i={i}: orig={a} recon={b}"); + assert_eq!( + want_positive, got_positive, + "sign mismatch at i={i}: orig={a} recon={b}" + ); } // Trailing bits in last word of each row must be zero (not garbage from accidental writes) let words_per_row = sign_words_for(d); // = 2 @@ -165,8 +181,12 @@ mod tests { let last_word = c.signs[r * words_per_row + words_per_row - 1]; // Used bits: 100 - 64 = 36 → bits [0..36] are valid, bits [36..64] must be zero. let trailing_mask = !((1u64 << 36) - 1); - assert_eq!(last_word & trailing_mask, 0, - "row {r} trailing bits in last word are non-zero: {:b}", last_word); + assert_eq!( + last_word & trailing_mask, + 0, + "row {r} trailing bits in last word are non-zero: {:b}", + last_word + ); } } @@ -180,7 +200,7 @@ mod tests { #[test] fn multi_row_independence() { let input = vec![ - 1.0, -1.0, 1.0, -1.0, // row 0: mean(|.|) = 1.0 + 1.0, -1.0, 1.0, -1.0, // row 0: mean(|.|) = 1.0 10.0, -10.0, 10.0, -10.0, // row 1: mean(|.|) = 10.0 ]; let c = compress_rows(&input, 2, 4); diff --git a/src/twobit.rs b/src/twobit.rs index 97fa5fa..65cf12b 100644 --- a/src/twobit.rs +++ b/src/twobit.rs @@ -40,6 +40,11 @@ //! Three popcount-walks per key: walk(level_words), walk(sign & level), walk(sign & ¬level). //! Same outlier correction as v2/v3. +// Deliberate index-math kernels: the loop counter indexes packed sign-word +// slices alongside magnitude/score/weight arrays in lockstep. See attention.rs +// for the rationale; iterator rewrites would obscure the bit-packing layout. +#![allow(clippy::needless_range_loop)] + use crate::quantize::CompressedKv; use crate::{sign_words_for, BITS_PER_WORD}; @@ -79,7 +84,11 @@ impl TwoBitKv { /// Threshold per row = median(|x| over inliers). /// `small_mag = mean(|x|)` over inliers below threshold; `large_mag` over those at/above. pub fn compress_twobit(input: &[f32], n_rows: usize, d: usize, k_outliers: usize) -> TwoBitKv { - assert_eq!(input.len(), n_rows * d, "compress_twobit: input length mismatch"); + assert_eq!( + input.len(), + n_rows * d, + "compress_twobit: input length mismatch" + ); assert!(k_outliers <= d); let words_per_row = sign_words_for(d); @@ -112,10 +121,9 @@ pub fn compress_twobit(input: &[f32], n_rows: usize, d: usize, k_outliers: usize for (i, &x) in row.iter().enumerate() { abs_pairs.push((i as u16, x.abs())); } - abs_pairs.select_nth_unstable_by( - k_eff - 1, - |a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal), - ); + abs_pairs.select_nth_unstable_by(k_eff - 1, |a, b| { + b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal) + }); abs_pairs[..k_eff].sort_by_key(|p| p.0); for (slot, &(idx, _)) in abs_pairs[..k_eff].iter().enumerate() { outlier_idx[r * k_outliers + slot] = idx; @@ -172,16 +180,37 @@ pub fn compress_twobit(input: &[f32], n_rows: usize, d: usize, k_outliers: usize n_small += 1; } } - let small_mag = if n_small > 0 { (sum_small / n_small as f64) as f32 } else { 0.0 }; - let large_mag = if n_large > 0 { (sum_large / n_large as f64) as f32 } else { 0.0 }; + let small_mag = if n_small > 0 { + (sum_small / n_small as f64) as f32 + } else { + 0.0 + }; + let large_mag = if n_large > 0 { + (sum_large / n_large as f64) as f32 + } else { + 0.0 + }; debug_assert!(small_mag.is_finite() && small_mag >= 0.0); debug_assert!(large_mag.is_finite() && large_mag >= 0.0); - debug_assert!(large_mag >= small_mag, "large_mag must be ≥ small_mag (row {})", r); + debug_assert!( + large_mag >= small_mag, + "large_mag must be ≥ small_mag (row {})", + r + ); mag[r * 2] = small_mag; mag[r * 2 + 1] = large_mag; } - TwoBitKv { signs, levels, mag, outlier_idx, outlier_val, n_rows, d, k_outliers } + TwoBitKv { + signs, + levels, + mag, + outlier_idx, + outlier_val, + n_rows, + d, + k_outliers, + } } pub fn decompress_twobit(c: &TwoBitKv) -> Vec { @@ -202,7 +231,9 @@ pub fn decompress_twobit(c: &TwoBitKv) -> Vec { // Outlier overrides for slot in 0..c.k_outliers { let idx = c.outlier_idx[r * c.k_outliers + slot]; - if idx == u16::MAX { continue; } + if idx == u16::MAX { + continue; + } row_out[idx as usize] = c.outlier_val[r * c.k_outliers + slot]; } } @@ -210,12 +241,7 @@ pub fn decompress_twobit(c: &TwoBitKv) -> Vec { } /// Reference (slow) 2-bit attention. -pub fn twobit_attention_naive( - q: &[f32], - k: &TwoBitKv, - v: &TwoBitKv, - inv_sqrt_dk: f32, -) -> Vec { +pub fn twobit_attention_naive(q: &[f32], k: &TwoBitKv, v: &TwoBitKv, inv_sqrt_dk: f32) -> Vec { assert_eq!(q.len(), k.d); assert_eq!(k.n_rows, v.n_rows); let n = k.n_rows; @@ -256,12 +282,7 @@ pub fn twobit_attention_naive( /// Same trick as v3: walk set bits in (signs AND levels) → +large add, walk /// signs without levels → +small add, then subtract per-key (small + large) /// contributions for the negative-bit corrections. -pub fn twobit_attention_fast( - q: &[f32], - k: &TwoBitKv, - v: &TwoBitKv, - inv_sqrt_dk: f32, -) -> Vec { +pub fn twobit_attention_fast(q: &[f32], k: &TwoBitKv, v: &TwoBitKv, inv_sqrt_dk: f32) -> Vec { assert_eq!(q.len(), k.d); assert_eq!(k.n_rows, v.n_rows); let n = k.n_rows; @@ -297,7 +318,9 @@ pub fn twobit_attention_fast( while bits != 0 { let b = bits.trailing_zeros() as usize; let c = base + b; - if c < dk { pos_large += q[c]; } + if c < dk { + pos_large += q[c]; + } bits &= bits - 1; } // pos_small = walk bits of (s & ¬l) @@ -305,7 +328,9 @@ pub fn twobit_attention_fast( while bits != 0 { let b = bits.trailing_zeros() as usize; let c = base + b; - if c < dk { pos_small += q[c]; } + if c < dk { + pos_small += q[c]; + } bits &= bits - 1; } // q_total_large = walk bits of l @@ -313,23 +338,28 @@ pub fn twobit_attention_fast( while bits != 0 { let b = bits.trailing_zeros() as usize; let c = base + b; - if c < dk { q_total_large += q[c]; } + if c < dk { + q_total_large += q[c]; + } bits &= bits - 1; } } let q_total_small = q_total - q_total_large; - let mut s = small * (2.0 * pos_small - q_total_small) - + large * (2.0 * pos_large - q_total_large); + let mut s = + small * (2.0 * pos_small - q_total_small) + large * (2.0 * pos_large - q_total_large); // Outlier corrections. Binary baseline contributed ±mag (small or large // depending on level bit). The outlier value overrides. So: // delta = q[idx] * (val − reconstructed_mag) for slot in 0..kk { let idx = k.outlier_idx[i * kk + slot]; - if idx == u16::MAX { continue; } + if idx == u16::MAX { + continue; + } let val = k.outlier_val[i * kk + slot]; - let level_bit = (row_levels[idx as usize / BITS_PER_WORD] >> (idx as usize % BITS_PER_WORD)) & 1; + let level_bit = + (row_levels[idx as usize / BITS_PER_WORD] >> (idx as usize % BITS_PER_WORD)) & 1; let mag = if level_bit == 1 { large } else { small }; let sign_mag = if val >= 0.0 { mag } else { -mag }; s += q[idx as usize] * (val - sign_mag); @@ -404,7 +434,11 @@ pub fn twobit_attention_fast( // Mask off bits beyond dv in the last word if base + BITS_PER_WORD > dv { let valid = dv - base; - let mask = if valid >= BITS_PER_WORD { !0u64 } else { (1u64 << valid) - 1 }; + let mask = if valid >= BITS_PER_WORD { + !0u64 + } else { + (1u64 << valid) - 1 + }; bits &= mask; } while bits != 0 { @@ -428,9 +462,12 @@ pub fn twobit_attention_fast( let w_i = weights[i]; for slot in 0..vk { let idx = v.outlier_idx[i * vk + slot]; - if idx == u16::MAX { continue; } + if idx == u16::MAX { + continue; + } let val = v.outlier_val[i * vk + slot]; - let level_bit = (row_levels[idx as usize / BITS_PER_WORD] >> (idx as usize % BITS_PER_WORD)) & 1; + let level_bit = + (row_levels[idx as usize / BITS_PER_WORD] >> (idx as usize % BITS_PER_WORD)) & 1; let mag = if level_bit == 1 { large } else { small }; let sign_mag = if val >= 0.0 { mag } else { -mag }; output[idx as usize] += w_i * (val - sign_mag); @@ -455,7 +492,7 @@ mod tests { // Construct a row with bimodal distribution. let input: Vec = vec![ 0.1, -0.2, 0.15, -0.18, // small cluster - 5.0, -4.5, 6.0, -5.5, // large cluster + 5.0, -4.5, 6.0, -5.5, // large cluster ]; let c = compress_twobit(&input, 1, 8, 0); // small_mag ≈ mean(0.1, 0.2, 0.15, 0.18) = 0.1575 @@ -500,7 +537,7 @@ mod tests { #[test] fn twobit_beats_v1_on_bimodal() { // Bimodal data: should benefit substantially from 2-bit representation. - let mut rng = ChaCha8Rng::seed_from_u64(0xb1_0d_a1_07_d3_a4_b1_eu64); + let mut rng = ChaCha8Rng::seed_from_u64(0x0b10_da10_7d3a_4b1e_u64); let n = 8; let d = 128; let mut data = vec![0f32; n * d]; @@ -523,6 +560,9 @@ mod tests { }; let cos_v1 = cos(&data, &r1); let cos_v4 = cos(&data, &r4); - assert!(cos_v4 > cos_v1, "v4 should beat v1 on bimodal data; got v1={cos_v1} v4={cos_v4}"); + assert!( + cos_v4 > cos_v1, + "v4 should beat v1 on bimodal data; got v1={cos_v1} v4={cos_v4}" + ); } }