diff --git a/src/sketches/countsketch.rs b/src/sketches/countsketch.rs index 167dfde..12393bd 100644 --- a/src/sketches/countsketch.rs +++ b/src/sketches/countsketch.rs @@ -1130,17 +1130,20 @@ mod tests { // (de-duplicated) use serde::{Deserialize, Serialize}; +/// Default Top-K capacity. Mirrors sketchlib-go `TOPK_SIZE = 100`. +pub const COUNT_SKETCH_TOPK_CAPACITY: usize = 100; + /// Sparse delta between two consecutive CountSketch snapshots — /// the input shape for [`CountSketch::apply_delta`]. Mirrors the /// `CountSketchDelta` proto in -/// `sketchlib-go/proto/countsketch/countsketch.proto` (packed -/// encoding only — the deprecated `cells_legacy` path + the -/// non-delta `topk` / `hh_keys` top-K carrier aren't modeled here). +/// `sketchlib-go/proto/countsketch/countsketch.proto` and the native +/// Go `Delta` in `sketchlib-go/sketches/CountSketch/delta.go`. /// /// Cells apply additively: `matrix[row][col] += d_count` for each -/// `(row, col, d_count)` triple. Top-K on the delta path is a -/// separate follow-up (CS top-K is non-linear; merging deltas -/// would require re-querying the merged matrix). +/// `(row, col, d_count)` triple. Per-row L2 norm deltas apply +/// additively. Heavy-hitter candidate keys (`hh_keys`) are queried +/// against the post-merge matrix and used to rebuild the receiver's +/// Top-K heap. #[derive(Debug, Clone, Default)] pub struct CountSketchDelta { pub rows: u32, @@ -1151,10 +1154,17 @@ pub struct CountSketchDelta { /// base sketch. Kept on the delta surface for downstream /// error-accounting; `apply_delta` itself ignores L2. pub l2: Vec, + /// Heavy-hitter candidate keys forwarded by the upstream + /// Space-Saving tracker. The receiver re-queries the merged CS + /// matrix for each key and updates its Top-K heap with the + /// resulting estimate. Mirrors Go's `Delta.HHKeys`. + pub hh_keys: Vec, } /// Minimal Count Sketch state — a flat `rows × cols` matrix of signed -/// counts. Element-wise mergeable (sum over aligned cells). +/// counts. Element-wise mergeable (sum over aligned cells). Mirrors +/// sketchlib-go's `CountSketch.Count`/`TopK` pair (the on-the-wire +/// `L2` field is a derived value and is recomputed on load). #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CountSketch { #[serde(rename = "row_num")] @@ -1164,6 +1174,12 @@ pub struct CountSketch { /// Row-major matrix of signed counts. `matrix[r][c]` is the value of /// hash row `r`, column `c`. pub matrix: Vec>, + /// Top-K heavy hitters as `(key, count)` pairs, capped at + /// [`COUNT_SKETCH_TOPK_CAPACITY`]. Order is not guaranteed (heap + /// shape is not preserved on the wire). Mirrors Go's + /// `CountSketch.TopK` slot. Defaults to empty on legacy payloads. + #[serde(default)] + pub topk: Vec<(String, f64)>, } impl CountSketch { @@ -1173,18 +1189,25 @@ impl CountSketch { rows, cols, matrix: vec![vec![0.0; cols]; rows], + topk: Vec::new(), } } /// Construct from a pre-built matrix (used by the modified-OTLP - /// proto-decode path). + /// proto-decode path). `topk` is zero-initialised; callers that + /// need non-zero auxiliary state should use the msgpack/proto path. pub fn from_legacy_matrix(matrix: Vec>, rows: usize, cols: usize) -> Self { debug_assert_eq!(matrix.len(), rows, "row count mismatch"); debug_assert!( matrix.iter().all(|r| r.len() == cols), "column count mismatch in at least one row" ); - Self { rows, cols, matrix } + Self { + rows, + cols, + matrix, + topk: Vec::new(), + } } /// Borrow the inner matrix. @@ -1192,6 +1215,40 @@ impl CountSketch { &self.matrix } + /// Update the in-memory Top-K heap with `(key, count)`. Keeps the + /// heap bounded by [`COUNT_SKETCH_TOPK_CAPACITY`]; on overflow, + /// drops the smallest-count entry. If `key` is already present, + /// the new count replaces the old (max semantics). Used by + /// `apply_delta` to rebuild Top-K from `hh_keys`. + fn topk_update(&mut self, key: &str, count: f64) { + if let Some(slot) = self.topk.iter_mut().find(|(k, _)| k == key) { + if count > slot.1 { + slot.1 = count; + } + return; + } + if self.topk.len() < COUNT_SKETCH_TOPK_CAPACITY { + self.topk.push((key.to_owned(), count)); + return; + } + // Capacity hit: replace the minimum if `count` exceeds it. + if let Some((min_idx, min_count)) = self + .topk + .iter() + .enumerate() + .min_by(|a, b| { + a.1.1 + .partial_cmp(&b.1.1) + .unwrap_or(std::cmp::Ordering::Equal) + }) + .map(|(i, e)| (i, e.1)) + { + if count > min_count { + self.topk[min_idx] = (key.to_owned(), count); + } + } + } + /// Insert a single weighted observation. Each row uses an independent /// hash seed and a sign bit to update the matrix in place — the /// standard CountSketch update primitive. The wire format here uses @@ -1260,7 +1317,12 @@ impl CountSketch { /// Apply a sparse delta in place. Matches the `ApplyDelta` /// semantics in `sketchlib-go/sketches/CountSketch/delta.go`: - /// `matrix[row][col] += d_count` for each cell in the delta. + /// * each `(row, col, d_count)` triple updates the count matrix + /// additively (`matrix[r][c] += d_count`); + /// * each `hh_key` is re-queried against the post-update matrix + /// and pushed into the receiver's Top-K with the merged-estimate + /// count (mirrors Go's `Delta.HHKeys` heavy-hitter rebuild). + /// /// Returns `Err` if any `(row, col)` is out of range — indicating /// a dimension mismatch between the snapshot this sketch was /// built from and the delta sender. @@ -1268,6 +1330,7 @@ impl CountSketch { &mut self, delta: &CountSketchDelta, ) -> Result<(), Box> { + // 1. Cell additions. for (row, col, d_count) in &delta.cells { let r = *row as usize; let c = *col as usize; @@ -1282,6 +1345,13 @@ impl CountSketch { // too (can go negative under adversarial keys). self.matrix[r][c] += *d_count as f64; } + // 2. Heavy-hitter rebuild from `hh_keys`. Re-estimate against + // the freshly-updated matrix and push into Top-K with the + // merged count. Mirrors sketchlib-go's `Delta.HHKeys` path. + for key in &delta.hh_keys { + let est = self.estimate(key); + self.topk_update(key, est); + } Ok(()) } @@ -1374,6 +1444,7 @@ mod tests_wire_count { (1, 1, -15), // 5 - 15 = -10 ], l2: vec![], + hh_keys: vec![], }; cs.apply_delta(&delta).unwrap(); assert_eq!( @@ -1395,6 +1466,7 @@ mod tests_wire_count { cols: 2, cells: vec![(0, 0, 10), (1, 1, 20)], l2: vec![], + hh_keys: vec![], }; let mut via_delta = base; via_delta.apply_delta(&delta).unwrap(); @@ -1409,10 +1481,75 @@ mod tests_wire_count { cols: 3, cells: vec![(2, 0, 1)], // row 2 out of range for 2-row matrix l2: vec![], + hh_keys: vec![], }; assert!(cs.apply_delta(&delta).is_err()); } + #[test] + fn test_apply_delta_rebuilds_topk_from_hh_keys() { + // Build a sketch with two known keys via the in-process + // `update` path so the matrix has a coherent shape, then + // send a delta that only carries `hh_keys` entries. The + // receiver should re-query the merged matrix and populate + // `topk` with the resulting estimates. Mirrors sketchlib-go's + // `Delta.HHKeys` heavy-hitter rebuild path. + let mut cs = CountSketch::new(3, 16); + cs.update("alpha", 5.0); + cs.update("beta", 3.0); + let delta = CountSketchDelta { + rows: 3, + cols: 16, + cells: vec![], + l2: vec![], + hh_keys: vec!["alpha".to_string(), "beta".to_string()], + }; + cs.apply_delta(&delta).unwrap(); + assert_eq!(cs.topk.len(), 2); + let alpha_count = cs + .topk + .iter() + .find(|(k, _)| k == "alpha") + .map(|(_, v)| *v) + .expect("alpha should be in topk"); + let beta_count = cs + .topk + .iter() + .find(|(k, _)| k == "beta") + .map(|(_, v)| *v) + .expect("beta should be in topk"); + // Alpha was inserted with weight 5; the median estimate + // should exceed beta's (weight 3) modulo signed-counter + // cancellation in this small 3x16 matrix. + assert!( + alpha_count > beta_count, + "alpha={alpha_count} beta={beta_count}" + ); + } + + #[test] + fn test_apply_delta_hh_keys_topk_capacity() { + // Verify the Top-K heap is bounded by COUNT_SKETCH_TOPK_CAPACITY + // and that on overflow, the smallest-count entry is evicted in + // favor of a larger-count newcomer. + let mut cs = CountSketch::new(3, 1024); + let n = COUNT_SKETCH_TOPK_CAPACITY + 5; + let keys: Vec = (0..n).map(|i| format!("k{i:04}")).collect(); + // Fill all keys into the matrix so estimates are non-zero. + for (i, k) in keys.iter().enumerate() { + cs.update(k, (i + 1) as f64); + } + let delta = CountSketchDelta { + rows: 3, + cols: 1024, + cells: vec![], + l2: vec![], + hh_keys: keys.clone(), + }; + cs.apply_delta(&delta).unwrap(); + assert_eq!(cs.topk.len(), COUNT_SKETCH_TOPK_CAPACITY); + } + #[test] fn test_msgpack_round_trip() { let original = diff --git a/src/sketches/mod.rs b/src/sketches/mod.rs index f4bb9fc..ad1fe36 100644 --- a/src/sketches/mod.rs +++ b/src/sketches/mod.rs @@ -48,7 +48,7 @@ pub use coco::CocoBucket; pub mod countsketch; pub use countsketch::Count; -pub use countsketch::{CountSketch, CountSketchDelta}; +pub use countsketch::{COUNT_SKETCH_TOPK_CAPACITY, CountSketch, CountSketchDelta}; /// Hashing path markers for matrix-backed sketches. pub mod mode;