Conversation
|
/ok to test f1c09e8 |
|
Hi @saranshagarwal202 ! Before the manual review of the PR, we’re currently testing an auto-review pipeline. Here’s the generated report. We haven’t fully validated it yet, but it may still be a useful reference. Feel free to let us know if you find it helpful! PR Review: #196 — Add Cache Merging (CAMPress)Author: saranshagarwal202 Checklist
Naming & ConventionsMajor Issues
Minor Issues
Code QualityMajor Issues
Minor Issues
Paper AlignmentPaper checked: https://openreview.net/forum?id=LCTmppB165 Divergences from paper
Minor discrepancies
Repo ConsistencyIssues
SummaryMust fix:
Should fix:
Looks good:
|
|
Thank you for your inputs, I am working on the necessary modifications. I am shifting this PR to draft untill I finish. |
Signed-off-by: Saransh Agrawal <saranshagarwal2020@gmail.com>
- replaced `batch` with `bsz` - replaced `kv_heads` with `num_key_value_heads` - replaced `__init__` with `__post__init__` with assertions - replaced `dev = scores.device` with `device = scores.device` - Added comment under `_torch_merge` to explain fallback Signed-off-by: Saransh Agrawal <saranshagarwal2020@gmail.com>
Algorithm:
1. ACCUMULATE: Each decoding step, accumulate per-KV-head
attention weights into a running sum across generation steps.
2. TRIGGER: Every `compression_interval` steps (when cache
exceeds `target_size`), trigger bulk eviction:
a. SCORE: Use base_press (e.g. KnormPress) to score all
cached tokens. Select the bottom-k as eviction candidates.
b. SPLIT: From eviction candidates, pick the top-k by score
(preferring later sequence positions on ties) as merge
tokens. The rest are pure evictions.
c. MERGE: For each merge token m:
- Find its target window: the next `merge_budget` tokens
in the kept set after m's position.
- Compute merge_prob = attn(m) / mean(attn[window_start:])
using cumulative attention sums.
- Sample Bernoulli(merge_prob). If accepted, scatter-add
value(m) / num_targets equally into each target's value.
d. PRUNE: Physically remove all evicted tokens (both merged
and pure-evicted) from keys, values, and attention sums.
Only `target_size` tokens remain.
3. RESET step counter to 0, continue decoding.
Merge implementations:
- _torch_merge: Fully vectorized with cumsum + scatter_add_
- _triton_merge: Triton kernel
Signed-off-by: Saransh Agrawal <saranshagarwal2020@gmail.com>
Signed-off-by: Saransh Agrawal <saranshagarwal2020@gmail.com>
Signed-off-by: Saransh Agrawal <saranshagarwal2020@gmail.com>
Signed-off-by: Saransh Agrawal <saranshagarwal2020@gmail.com>
Signed-off-by: Saransh Agrawal <saranshagarwal2020@gmail.com>
…s as factory methods Signed-off-by: Saransh Agrawal <saranshagarwal2020@gmail.com>
Signed-off-by: Saransh Agrawal <saranshagarwal2020@gmail.com>
Signed-off-by: Saransh Agrawal <saranshagarwal2020@gmail.com>
Signed-off-by: Saransh Agrawal <saranshagarwal2020@gmail.com>
Signed-off-by: Saransh Agrawal <saranshagarwal2020@gmail.com>
Signed-off-by: Saransh Agrawal <saranshagarwal2020@gmail.com>
39b6eb8 to
c103ed4
Compare
|
Hi @alessiodevoto, thank you for the feedback. I’ve finished a major refactor of the implementation; I've detailed the specific changes in the PR description above. Looking forward to your thoughts on the new structure! |
PR description
Adds
CAMPress, a decoding-time KV cache compression method based on Cache Merging (CaM) from "CaM: Cache Merging for Memory-Efficient LLMs Inference" (ICML 2024).What is CaM?
Standard KV cache eviction discards evicted tokens' value vectors entirely, introducing perturbation to the attention output. CaM mitigates this by merging evicted tokens' values into neighboring kept tokens before pruning.
CaM is not a standalone eviction policy — it wraps any
ScorerPress(e.g.,KnormPress,StreamingLLMPress) via thebase_pressparameter and acts as a merge-before-prune extension ofDecodingPress.Algorithm
At each compression step (every
compression_intervaldecoding steps):n_to_evictlowest-scored tokens are marked for eviction; the toptarget_sizeare kept.ktokens with the highest scores (ties broken by later sequence position) are selected for merging, wherek = compression_interval(the number of new tokens since the last compression).merge_budgetkept tokens immediately after it (in sequence order) form its merge window.clamp(Ā_i / avg(Ā_{j:j+m}), 0, 1), whereĀ = Σ(k=1..t) A_kis the cumulative attention across all decode steps.Setting
compression_interval=1recovers the original per-step CaM behavior from the paper.Evaluation Results (RULER benchmark, Qwen3-8B, data_ratio=0.5)
ci = compression_interval, ts = target_size. Base press: KnormPress for both.
CAMPress shows modest but consistent improvements over plain
DecodingPress(+2.0 avg at ts=2048, +2.0 avg at ts=3048), with notable gains on retrieval-heavy tasks (niah_multiquery, niah_multivalue, vt).Why are improvements modest on RULER?
The first compression cycle evicts the bulk of the cache (e.g., ~2000 tokens from a 4k prefill down to
target_size=2048), but onlyk=compression_intervaltokens (e.g., 8) are selected for merging — meaning ~99.6% of tokens evicted in the first cycle are not merged. Subsequent compressions evict exactlycompression_intervaltokens, all of which get merged.Since RULER generates short outputs (few compression cycles after the first), most information loss comes from the unmerged first compression. We expect CAMPress to show larger gains on tasks with longer generation (more compression cycles where 100% of evicted tokens are merged), such as reasoning/think-mode models or long-form text generation.
Testing
Tests are added in
test_decoding_compression.pyas factory methods alongsideDecodingPress, sinceCAMPressextendsDecodingPress(not a standaloneScorerPress). All existingDecodingPresstests now run parametrically for bothDecodingPressandCAMPress.Checklist
Before submitting a PR, please make sure:
Tests are working (
make test)Code is formatted correctly (
make style, on errors try fix withmake format)Copyright header is included
All commits are signed-off using
git commit -s(new press)
mypress_press.pyis in thepressesdirectory(new press)
MyPressis in__init__.py(new press)
README.mdis updated with a 1 liner about the new press in the Available presses section(new press) New press is in the
default_presseslist intests/default_presses.py(new press) A docstring is provided that follows the same structure as the existing ones