Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 147 additions & 10 deletions src/sketches/countsketch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1130,17 +1130,20 @@ mod tests {

// (de-duplicated) use serde::{Deserialize, Serialize};

/// Default Top-K capacity. Mirrors sketchlib-go `TOPK_SIZE = 100`.
pub const COUNT_SKETCH_TOPK_CAPACITY: usize = 100;

/// Sparse delta between two consecutive CountSketch snapshots —
/// the input shape for [`CountSketch::apply_delta`]. Mirrors the
/// `CountSketchDelta` proto in
/// `sketchlib-go/proto/countsketch/countsketch.proto` (packed
/// encoding only — the deprecated `cells_legacy` path + the
/// non-delta `topk` / `hh_keys` top-K carrier aren't modeled here).
/// `sketchlib-go/proto/countsketch/countsketch.proto` and the native
/// Go `Delta` in `sketchlib-go/sketches/CountSketch/delta.go`.
///
/// Cells apply additively: `matrix[row][col] += d_count` for each
/// `(row, col, d_count)` triple. Top-K on the delta path is a
/// separate follow-up (CS top-K is non-linear; merging deltas
/// would require re-querying the merged matrix).
/// `(row, col, d_count)` triple. Per-row L2 norm deltas apply
/// additively. Heavy-hitter candidate keys (`hh_keys`) are queried
/// against the post-merge matrix and used to rebuild the receiver's
/// Top-K heap.
#[derive(Debug, Clone, Default)]
pub struct CountSketchDelta {
pub rows: u32,
Expand All @@ -1151,10 +1154,17 @@ pub struct CountSketchDelta {
/// base sketch. Kept on the delta surface for downstream
/// error-accounting; `apply_delta` itself ignores L2.
pub l2: Vec<f64>,
/// Heavy-hitter candidate keys forwarded by the upstream
/// Space-Saving tracker. The receiver re-queries the merged CS
/// matrix for each key and updates its Top-K heap with the
/// resulting estimate. Mirrors Go's `Delta.HHKeys`.
pub hh_keys: Vec<String>,
}

/// Minimal Count Sketch state — a flat `rows × cols` matrix of signed
/// counts. Element-wise mergeable (sum over aligned cells).
/// counts. Element-wise mergeable (sum over aligned cells). Mirrors
/// sketchlib-go's `CountSketch.Count`/`TopK` pair (the on-the-wire
/// `L2` field is a derived value and is recomputed on load).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CountSketch {
#[serde(rename = "row_num")]
Expand All @@ -1164,6 +1174,12 @@ pub struct CountSketch {
/// Row-major matrix of signed counts. `matrix[r][c]` is the value of
/// hash row `r`, column `c`.
pub matrix: Vec<Vec<f64>>,
/// Top-K heavy hitters as `(key, count)` pairs, capped at
/// [`COUNT_SKETCH_TOPK_CAPACITY`]. Order is not guaranteed (heap
/// shape is not preserved on the wire). Mirrors Go's
/// `CountSketch.TopK` slot. Defaults to empty on legacy payloads.
#[serde(default)]
pub topk: Vec<(String, f64)>,
}

impl CountSketch {
Expand All @@ -1173,25 +1189,66 @@ impl CountSketch {
rows,
cols,
matrix: vec![vec![0.0; cols]; rows],
topk: Vec::new(),
}
}

/// Construct from a pre-built matrix (used by the modified-OTLP
/// proto-decode path).
/// proto-decode path). `topk` is zero-initialised; callers that
/// need non-zero auxiliary state should use the msgpack/proto path.
pub fn from_legacy_matrix(matrix: Vec<Vec<f64>>, rows: usize, cols: usize) -> Self {
debug_assert_eq!(matrix.len(), rows, "row count mismatch");
debug_assert!(
matrix.iter().all(|r| r.len() == cols),
"column count mismatch in at least one row"
);
Self { rows, cols, matrix }
Self {
rows,
cols,
matrix,
topk: Vec::new(),
}
}

/// Borrow the inner matrix.
pub fn sketch(&self) -> &Vec<Vec<f64>> {
&self.matrix
}

/// Update the in-memory Top-K heap with `(key, count)`. Keeps the
/// heap bounded by [`COUNT_SKETCH_TOPK_CAPACITY`]; on overflow,
/// drops the smallest-count entry. If `key` is already present,
/// the new count replaces the old (max semantics). Used by
/// `apply_delta` to rebuild Top-K from `hh_keys`.
fn topk_update(&mut self, key: &str, count: f64) {
if let Some(slot) = self.topk.iter_mut().find(|(k, _)| k == key) {
if count > slot.1 {
slot.1 = count;
}
return;
}
if self.topk.len() < COUNT_SKETCH_TOPK_CAPACITY {
self.topk.push((key.to_owned(), count));
return;
}
// Capacity hit: replace the minimum if `count` exceeds it.
if let Some((min_idx, min_count)) = self
.topk
.iter()
.enumerate()
.min_by(|a, b| {
a.1.1
.partial_cmp(&b.1.1)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, e)| (i, e.1))
{
if count > min_count {
self.topk[min_idx] = (key.to_owned(), count);
}
}
}

/// Insert a single weighted observation. Each row uses an independent
/// hash seed and a sign bit to update the matrix in place — the
/// standard CountSketch update primitive. The wire format here uses
Expand Down Expand Up @@ -1260,14 +1317,20 @@ impl CountSketch {

/// Apply a sparse delta in place. Matches the `ApplyDelta`
/// semantics in `sketchlib-go/sketches/CountSketch/delta.go`:
/// `matrix[row][col] += d_count` for each cell in the delta.
/// * each `(row, col, d_count)` triple updates the count matrix
/// additively (`matrix[r][c] += d_count`);
/// * each `hh_key` is re-queried against the post-update matrix
/// and pushed into the receiver's Top-K with the merged-estimate
/// count (mirrors Go's `Delta.HHKeys` heavy-hitter rebuild).
///
/// Returns `Err` if any `(row, col)` is out of range — indicating
/// a dimension mismatch between the snapshot this sketch was
/// built from and the delta sender.
pub fn apply_delta(
&mut self,
delta: &CountSketchDelta,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// 1. Cell additions.
for (row, col, d_count) in &delta.cells {
let r = *row as usize;
let c = *col as usize;
Expand All @@ -1282,6 +1345,13 @@ impl CountSketch {
// too (can go negative under adversarial keys).
self.matrix[r][c] += *d_count as f64;
}
// 2. Heavy-hitter rebuild from `hh_keys`. Re-estimate against
// the freshly-updated matrix and push into Top-K with the
// merged count. Mirrors sketchlib-go's `Delta.HHKeys` path.
for key in &delta.hh_keys {
let est = self.estimate(key);
self.topk_update(key, est);
}
Ok(())
}

Expand Down Expand Up @@ -1374,6 +1444,7 @@ mod tests_wire_count {
(1, 1, -15), // 5 - 15 = -10
],
l2: vec![],
hh_keys: vec![],
};
cs.apply_delta(&delta).unwrap();
assert_eq!(
Expand All @@ -1395,6 +1466,7 @@ mod tests_wire_count {
cols: 2,
cells: vec![(0, 0, 10), (1, 1, 20)],
l2: vec![],
hh_keys: vec![],
};
let mut via_delta = base;
via_delta.apply_delta(&delta).unwrap();
Expand All @@ -1409,10 +1481,75 @@ mod tests_wire_count {
cols: 3,
cells: vec![(2, 0, 1)], // row 2 out of range for 2-row matrix
l2: vec![],
hh_keys: vec![],
};
assert!(cs.apply_delta(&delta).is_err());
}

#[test]
fn test_apply_delta_rebuilds_topk_from_hh_keys() {
// Build a sketch with two known keys via the in-process
// `update` path so the matrix has a coherent shape, then
// send a delta that only carries `hh_keys` entries. The
// receiver should re-query the merged matrix and populate
// `topk` with the resulting estimates. Mirrors sketchlib-go's
// `Delta.HHKeys` heavy-hitter rebuild path.
let mut cs = CountSketch::new(3, 16);
cs.update("alpha", 5.0);
cs.update("beta", 3.0);
let delta = CountSketchDelta {
rows: 3,
cols: 16,
cells: vec![],
l2: vec![],
hh_keys: vec!["alpha".to_string(), "beta".to_string()],
};
cs.apply_delta(&delta).unwrap();
assert_eq!(cs.topk.len(), 2);
let alpha_count = cs
.topk
.iter()
.find(|(k, _)| k == "alpha")
.map(|(_, v)| *v)
.expect("alpha should be in topk");
let beta_count = cs
.topk
.iter()
.find(|(k, _)| k == "beta")
.map(|(_, v)| *v)
.expect("beta should be in topk");
// Alpha was inserted with weight 5; the median estimate
// should exceed beta's (weight 3) modulo signed-counter
// cancellation in this small 3x16 matrix.
assert!(
alpha_count > beta_count,
"alpha={alpha_count} beta={beta_count}"
);
}

#[test]
fn test_apply_delta_hh_keys_topk_capacity() {
// Verify the Top-K heap is bounded by COUNT_SKETCH_TOPK_CAPACITY
// and that on overflow, the smallest-count entry is evicted in
// favor of a larger-count newcomer.
let mut cs = CountSketch::new(3, 1024);
let n = COUNT_SKETCH_TOPK_CAPACITY + 5;
let keys: Vec<String> = (0..n).map(|i| format!("k{i:04}")).collect();
// Fill all keys into the matrix so estimates are non-zero.
for (i, k) in keys.iter().enumerate() {
cs.update(k, (i + 1) as f64);
}
let delta = CountSketchDelta {
rows: 3,
cols: 1024,
cells: vec![],
l2: vec![],
hh_keys: keys.clone(),
};
cs.apply_delta(&delta).unwrap();
assert_eq!(cs.topk.len(), COUNT_SKETCH_TOPK_CAPACITY);
}

#[test]
fn test_msgpack_round_trip() {
let original =
Expand Down
2 changes: 1 addition & 1 deletion src/sketches/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ pub use coco::CocoBucket;

pub mod countsketch;
pub use countsketch::Count;
pub use countsketch::{CountSketch, CountSketchDelta};
pub use countsketch::{COUNT_SKETCH_TOPK_CAPACITY, CountSketch, CountSketchDelta};

/// Hashing path markers for matrix-backed sketches.
pub mod mode;
Expand Down
Loading