Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
name: CI

on:
push:
branches: [main]
pull_request:

env:
CARGO_TERM_COLOR: always
RUSTFLAGS: -D warnings

jobs:
test:
name: build / test / lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Install stable toolchain
uses: dtolnay/rust-toolchain@stable
with:
components: rustfmt, clippy

- name: Cache cargo registry and build
uses: actions/cache@v4
with:
path: |
~/.cargo/registry
~/.cargo/git
target
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.toml') }}
restore-keys: ${{ runner.os }}-cargo-

- name: Format check
run: cargo fmt --all -- --check

- name: Clippy (deny warnings)
run: cargo clippy --all-targets -- -D warnings

- name: Build (release)
run: cargo build --release --all-targets

- name: Test
run: cargo test --all-targets
44 changes: 35 additions & 9 deletions src/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,29 @@
//! No K or V f32 row is ever assembled. Per-row work for scores is O(words_per_row)
//! reads (not D reads), per-row work for output is O(words_per_row) too.

// These attention kernels are deliberate index-math: the loop counter indexes
// a per-row slice of packed sign-words together with the magnitude/score/weight
// arrays in lockstep. Rewriting as zipped iterators obscures the bit-packing
// layout and the i-outer/d-inner cache strategy, so we keep explicit indexing.
#![allow(clippy::needless_range_loop)]

use crate::quantize::CompressedKv;
use crate::{sign_words_for, BITS_PER_WORD};

/// Reference (slow) binary attention: decompresses K/V to f32, runs standard
/// attention. Used as a correctness oracle in tests; do not use in production.
pub fn binary_attention_naive(
q: &[f32], // length d_k
k: &CompressedKv, // n_keys × d_k
v: &CompressedKv, // n_keys × d_v
inv_sqrt_dk: f32, // 1/sqrt(d_k), precomputed by caller
q: &[f32], // length d_k
k: &CompressedKv, // n_keys × d_k
v: &CompressedKv, // n_keys × d_v
inv_sqrt_dk: f32, // 1/sqrt(d_k), precomputed by caller
) -> Vec<f32> {
assert_eq!(q.len(), k.d, "q.len() != k.d ({} vs {})", q.len(), k.d);
assert_eq!(k.n_rows, v.n_rows, "k.n_rows != v.n_rows ({} vs {})", k.n_rows, v.n_rows);
assert_eq!(
k.n_rows, v.n_rows,
"k.n_rows != v.n_rows ({} vs {})",
k.n_rows, v.n_rows
);
let n = k.n_rows;
let dv = v.d;

Expand Down Expand Up @@ -79,7 +89,11 @@ pub fn binary_attention_fast(
inv_sqrt_dk: f32,
) -> Vec<f32> {
assert_eq!(q.len(), k.d, "q.len() != k.d ({} vs {})", q.len(), k.d);
assert_eq!(k.n_rows, v.n_rows, "k.n_rows != v.n_rows ({} vs {})", k.n_rows, v.n_rows);
assert_eq!(
k.n_rows, v.n_rows,
"k.n_rows != v.n_rows ({} vs {})",
k.n_rows, v.n_rows
);
let n = k.n_rows;
let dk = k.d;
let dv = v.d;
Expand Down Expand Up @@ -200,7 +214,12 @@ fn sum_q_where_bit_set(q: &[f32], row_words: &[u64], dk: usize) -> f32 {
let base = full_words * BITS_PER_WORD;
while word != 0 {
let bit = word.trailing_zeros() as usize;
debug_assert!(base + bit < dk, "tail word leaked bit at idx {} ≥ dk={}", base + bit, dk);
debug_assert!(
base + bit < dk,
"tail word leaked bit at idx {} ≥ dk={}",
base + bit,
dk
);
sum += q[base + bit];
word &= word - 1;
}
Expand All @@ -216,7 +235,14 @@ mod tests {
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;

fn full_attention(q: &[f32], k_full: &[f32], v_full: &[f32], dk: usize, dv: usize, inv_sqrt_dk: f32) -> Vec<f32> {
fn full_attention(
q: &[f32],
k_full: &[f32],
v_full: &[f32],
dk: usize,
dv: usize,
inv_sqrt_dk: f32,
) -> Vec<f32> {
let n = k_full.len() / dk;
assert_eq!(v_full.len(), n * dv);
let mut scores = vec![0f32; n];
Expand All @@ -241,7 +267,7 @@ mod tests {

#[test]
fn naive_and_fast_agree() {
let mut rng = ChaCha8Rng::seed_from_u64(0xdeadbeef_42);
let mut rng = ChaCha8Rng::seed_from_u64(0x0000_00de_adbe_ef42);
let n = 32;
let dk = 64;
let dv = 64;
Expand Down
29 changes: 23 additions & 6 deletions src/bin/kv-bench.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
//! Quick CLI: generate synthetic Gaussian KV rows, compress, report
//! compression ratio + reconstruction MSE + per-row signal-to-quantization-noise.
use kv_compressor::{compress_rows, decompress_rows};
use rand::SeedableRng;
use rand::Rng;
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;

fn main() {
Expand All @@ -13,7 +13,10 @@ fn main() {
let n_rows = n_tokens * n_kv_heads;
let d = head_dim;

println!("synthetic Gaussian KV: n_rows={n_rows} d={d} total elements={}", n_rows * d);
println!(
"synthetic Gaussian KV: n_rows={n_rows} d={d} total elements={}",
n_rows * d
);

let mut rng = ChaCha8Rng::seed_from_u64(0x5157f1d7c0deba5e);
let input: Vec<f32> = (0..n_rows * d)
Expand Down Expand Up @@ -47,11 +50,25 @@ fn main() {
let throughput_c = (input.len() * 4) as f64 / dt_c.as_secs_f64() / 1e9;
let throughput_d = (input.len() * 4) as f64 / dt_d.as_secs_f64() / 1e9;

println!("compressed bytes: {} ({:.1} MB)", c.bytes(), c.bytes() as f64 / 1024.0 / 1024.0);
println!("baseline q8 bytes: {} ({:.1} MB)", c.baseline_bytes(), c.baseline_bytes() as f64 / 1024.0 / 1024.0);
println!(
"compressed bytes: {} ({:.1} MB)",
c.bytes(),
c.bytes() as f64 / 1024.0 / 1024.0
);
println!(
"baseline q8 bytes: {} ({:.1} MB)",
c.baseline_bytes(),
c.baseline_bytes() as f64 / 1024.0 / 1024.0
);
println!("compression ratio: {:.2}×", c.ratio());
println!("reconstruction MSE: {:.6}", mse);
println!("SNR (vs Gaussian unit signal): {:.2} dB", snr_db);
println!("compress throughput: {:.2} GB/s ({:?})", throughput_c, dt_c);
println!("decompress throughput: {:.2} GB/s ({:?})", throughput_d, dt_d);
println!(
"compress throughput: {:.2} GB/s ({:?})",
throughput_c, dt_c
);
println!(
"decompress throughput: {:.2} GB/s ({:?})",
throughput_d, dt_d
);
}
Loading
Loading