Skip to content

Add Cache Merging (CAMPress)#196

Open
saranshagarwal202 wants to merge 13 commits intoNVIDIA:mainfrom
saranshagarwal202:feature/cam-press
Open

Add Cache Merging (CAMPress)#196
saranshagarwal202 wants to merge 13 commits intoNVIDIA:mainfrom
saranshagarwal202:feature/cam-press

Conversation

@saranshagarwal202
Copy link
Copy Markdown

@saranshagarwal202 saranshagarwal202 commented Mar 12, 2026

PR description

Adds CAMPress, a decoding-time KV cache compression method based on Cache Merging (CaM) from "CaM: Cache Merging for Memory-Efficient LLMs Inference" (ICML 2024).

Note: This is a major rewrite of the original submission, addressing all feedback from the auto-review. Key changes: extends DecodingPress instead of standalone implementation, uses cumulative attention (paper Algorithm 1), uses avg instead of max for reference attention (paper Eq. 14), removed Triton kernel for simplicity, and switched to @dataclass with __post_init__.

What is CaM?

Standard KV cache eviction discards evicted tokens' value vectors entirely, introducing perturbation to the attention output. CaM mitigates this by merging evicted tokens' values into neighboring kept tokens before pruning.

CaM is not a standalone eviction policy — it wraps any ScorerPress (e.g., KnormPress, StreamingLLMPress) via the base_press parameter and acts as a merge-before-prune extension of DecodingPress.

Algorithm

At each compression step (every compression_interval decoding steps):

  1. Score & select — The base press scores every cached token. The n_to_evict lowest-scored tokens are marked for eviction; the top target_size are kept.
  2. Pick merge candidates — Among the evicted set, the k tokens with the highest scores (ties broken by later sequence position) are selected for merging, where k = compression_interval (the number of new tokens since the last compression).
  3. Cascading merge targets — For each merge candidate, the merge_budget kept tokens immediately after it (in sequence order) form its merge window.
  4. Merge probability — The ratio of each merge token's cumulative attention to the mean cumulative attention of its window: clamp(Ā_i / avg(Ā_{j:j+m}), 0, 1), where Ā = Σ(k=1..t) A_k is the cumulative attention across all decode steps.
  5. Bernoulli sampling — A binary merge mask is drawn from the probability above. Tokens that pass the mask have their value vectors divided by the window size and scatter-added into the window targets.
  6. Physical pruning — Evicted key/value entries are removed from the cache, and the cumulative attention buffer is pruned to match.

Setting compression_interval=1 recovers the original per-step CaM behavior from the paper.

Evaluation Results (RULER benchmark, Qwen3-8B, data_ratio=0.5)

Metric CaM (ci=8, ts=2048) Decoding (ci=8, ts=2048) CaM (ci=4, ts=3048) Decoding (ci=4, ts=3048)
cwe 30.40 32.73 43.17 36.47
fwe 87.79 92.88 91.73 94.91
niah_multikey_1 99.61 99.61 99.61 99.61
niah_multikey_2 98.67 98.67 82.22 89.78
niah_multikey_3 4.35 4.35 55.65 52.61
niah_multiquery 53.63 45.98 96.08 75.00
niah_multivalue 92.17 84.76 100.00 99.90
niah_single_1 100.00 100.00 100.00 100.00
niah_single_2 100.00 100.00 99.25 100.00
niah_single_3 96.64 95.80 99.58 99.58
qa_1 82.07 81.27 82.47 81.67
qa_2 65.73 65.73 65.73 66.13
vt 89.39 72.63 97.73 91.74
Average 76.96 74.95 85.63 83.65
it/s 4.29 3.92 3.12 2.73

ci = compression_interval, ts = target_size. Base press: KnormPress for both.

CAMPress shows modest but consistent improvements over plain DecodingPress (+2.0 avg at ts=2048, +2.0 avg at ts=3048), with notable gains on retrieval-heavy tasks (niah_multiquery, niah_multivalue, vt).

Why are improvements modest on RULER?

The first compression cycle evicts the bulk of the cache (e.g., ~2000 tokens from a 4k prefill down to target_size=2048), but only k=compression_interval tokens (e.g., 8) are selected for merging — meaning ~99.6% of tokens evicted in the first cycle are not merged. Subsequent compressions evict exactly compression_interval tokens, all of which get merged.

Since RULER generates short outputs (few compression cycles after the first), most information loss comes from the unmerged first compression. We expect CAMPress to show larger gains on tasks with longer generation (more compression cycles where 100% of evicted tokens are merged), such as reasoning/think-mode models or long-form text generation.

Testing

Tests are added in test_decoding_compression.py as factory methods alongside DecodingPress, since CAMPress extends DecodingPress (not a standalone ScorerPress). All existing DecodingPress tests now run parametrically for both DecodingPress and CAMPress.

Checklist

Before submitting a PR, please make sure:

  • Tests are working (make test)

  • Code is formatted correctly (make style, on errors try fix with make format)

  • Copyright header is included

  • All commits are signed-off using git commit -s

  • (new press) mypress_press.py is in the presses directory

  • (new press) MyPress is in __init__.py

  • (new press) README.md is updated with a 1 liner about the new press in the Available presses section

  • (new press) New press is in the default_presses list in tests/default_presses.py

  • (new press) A docstring is provided that follows the same structure as the existing ones

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Mar 12, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@maxjeblick
Copy link
Copy Markdown
Collaborator

/ok to test f1c09e8

@maxjeblick maxjeblick self-requested a review March 15, 2026 13:27
@alessiodevoto
Copy link
Copy Markdown
Collaborator

Hi @saranshagarwal202 ! Before the manual review of the PR, we’re currently testing an auto-review pipeline. Here’s the generated report. We haven’t fully validated it yet, but it may still be a useful reference. Feel free to let us know if you find it helpful!

PR Review: #196 — Add Cache Merging (CAMPress)

Author: saranshagarwal202
Paper checked: https://openreview.net/forum?id=LCTmppB165 (CaM: Cache Merging for Memory-Efficient LLMs Inference, ICML 2024)


Checklist

Item Status Notes
Exported in __init__.py Import line + __all__ entry both present
README.md updated 1-liner added in wrapper presses section
Added to default_presses.py Added with StreamingLLMPress base at compression_ratio 0.2 and 0.8

Naming & Conventions

Major Issues

  • cam_press.py:163batch used instead of the canonical bsz (batch, kv_heads, _ = scores.shape). All other presses use bsz.
  • cam_press.py:163kv_heads used instead of the canonical num_key_value_heads. Should be consistent with expected_attention_press.py:145.
  • cam_press.py:68-73 — Custom __init__ defined in a @dataclass class. DecodingPress (the parent) uses __post_init__ for initialization side-effects. By defining a custom __init__, the dataclass machinery is bypassed and the parent's __post_init__ is never called, skipping its validation assertions (isinstance(base_press, ScorerPress), etc.). Should use __post_init__ instead.

Minor Issues

  • cam_press.py:161dev = scores.devicedev is cryptic. Other presses don't use this abbreviation; inline scores.device or use device.
  • cam_press.py:246_torch_merge has four nested Python loops (n_evicted × batch × heads × budget). A comment noting this is the slow fallback path (and why) would help readers understand the intent vs. the Triton path.

Code Quality

Major Issues

  • cam_press.py:195-214 — Bernoulli merge probability uses single-step attention (attn_squeezed from the current decode step). The paper's Algorithm 1 specifies using cumulative attention Ā = Σ(k=1..t) A_k across all decode steps. Using only the current step's attention is an algorithmic divergence that likely degrades performance at longer sequences. This is the most significant correctness concern.

  • cam_press.py:210ref_attn = per_token_target_attn.max(dim=-1).values uses max of neighbor attention as the reference. The paper uses average (avg(Ā_{j:j+m})). The PR body explicitly describes this as "maximum," which is a documented divergence — but it should be clearly explained in the docstring why max was chosen over avg and whether this affects convergence guarantees from Theorem 3.3.

  • cam_press.py:230-237 — The special case if n_to_evict == 1 and merge_mask is not None and merge_mask.sum() == 0: pass does nothing — pass simply falls through to the gather at the bottom, which would happen anyway. This dead branch is confusing and suggests the merge is correctly skipped, but the values of the non-merged token were already zeroed out earlier at line 218. This zero-value entry will then be gathered into the retained cache. Is this intentional? If the single evicted token has no merge, should it be kept as-is or zeroed?

  • cam_press.py:68-73CAMPress.__init__ does not call super().__post_init__() (or any parent init). DecodingPress.__post_init__ validates base_press type. This validation is silently skipped — passing a non-ScorerPress base_press will fail later at runtime with an obscure error rather than a clear assertion.

Minor Issues

  • cam_press.py:144-145_first_eviction_done is tracked and logged but never actually changes the compression logic. If it's only for debug logging, it doesn't need to be a dict — it could be a single bool since it's set True for all layers at the same conceptual time.
  • cam_press.py:250 — Magic threshold 1e-12 in _torch_merge for detecting a zeroed-out value should be a named constant (e.g., _ZERO_VALUE_THRESHOLD).
  • cam_press.py:86-87 — Docstring is missing a step-by-step algorithm description. Given the non-trivial merging logic (especially the Bernoulli mechanism), a numbered algorithm description similar to ExpectedAttentionPress (lines 26-31) would aid reviewers and users.
  • cam_press.py:70base_press: ScorerPress = None — default None but type annotation says ScorerPress. Should be Optional[ScorerPress] = None or better, make it required with no default (since a None base_press will crash at score time).

Paper Alignment

Paper checked: https://openreview.net/forum?id=LCTmppB165

Divergences from paper

  • cam_press.py:195-214 — Paper Algorithm 1 uses cumulative attention Ā = Σ(k=1..t) A_k across all decode steps. Implementation uses only the current decode step's attention weights. This is a significant divergence — cumulative scoring is central to the paper's Theorem 3.2 guarantee (that merging can achieve zero output perturbation).

  • cam_press.py:210 — Paper Eq. (14) uses avg(Ā_{j:j+m}) as the reference. Implementation uses max(neighbor_attn). The PR body acknowledges this but the docstring does not. Theorem 3.3's condition avg(A_{j:j+m}) < 2A_i (merging beats eviction) is not preserved under the max formulation.

Minor discrepancies

  • README.md:142 — README cites https://arxiv.org/abs/2309.17453 for the paper, while the class docstring and PR body cite https://openreview.net/forum?id=LCTmppB165 (the peer-reviewed ICML version). These are the same paper but the arXiv preprint may differ from the published version. Should use the OpenReview/ICML URL consistently.
  • cam_press.py:95-100per_token_targets uses .clamp(max=cache_len - 1) for out-of-range neighbors, meaning the last valid position accumulates extra contributions from clamped targets. The paper's algorithm only distributes to m valid successors. Dividing by actual_budget (total budget) rather than the count of valid targets means the total energy added is V_i * valid_count / budget < V_i, attenuating the merge at sequence boundaries. This is a minor boundary effect.

Repo Consistency

Issues

  • cam_press.py:50-73CAMPress is the only press in the repo that combines @dataclass with a fully overriding custom __init__. Every other press (including DecodingPress) uses @dataclass with __post_init__ for side-effect initialization. This inconsistency makes CAMPress behave differently under copy/pickle and breaks expected dataclass semantics.

  • cam_press.py:122-138CAMPress defines a score() method on a DecodingPress subclass. DecodingPress is not a ScorerPress and has no score() method; only ScorerPress subclasses are expected to define score(). The pattern of temporarily mutating base_press.compression_ratio (lines 131-136) is borrowed from DecodingPress.compress() but adds a parallel code path. This could cause unexpected interactions if base_press is shared.

  • evaluation/evaluate.py:284-286 — The isinstance(press, CAMPress) check is placed before the generic elif isinstance(press, DecodingPress) block. This ordering is correct (since CAMPress is a subclass), but it means CAMPress also sets compression_ratio via press.compression_ratio = compression_ratio — the same attribute that DecodingPress does NOT set directly. This is correct behavior but should be noted: CAMPress.compression_ratio is used differently from DecodingPress.target_size (ratio vs. absolute size), which is a design inconsistency.


Summary

Must fix:

  1. Cumulative attention is not tracked — the implementation uses single-step attention, diverging from the paper's Algorithm 1. This is the most important algorithmic issue.
  2. Replace custom __init__ with __post_init__ to follow the dataclass pattern and restore DecodingPress validation.
  3. Rename batchbsz and kv_headsnum_key_value_heads to match repo naming conventions.
  4. Clarify (in docstring) the deliberate choice of max vs. avg for ref_attn and its effect on the paper's theoretical guarantees.

Should fix:

  1. The if n_to_evict == 1 and merge_mask.sum() == 0: pass dead branch needs a comment or removal — the zeroing-then-gathering of non-merged tokens looks unintentional.
  2. base_press: ScorerPress = None default should be Optional[ScorerPress] = None or required.
  3. Add a numbered algorithm description to the class docstring.
  4. Use consistent paper URL (OpenReview ICML vs. arXiv) across README and docstring.

Looks good:

  • All CONTRIBUTING.md checklist items are completed correctly.
  • The Triton fallback architecture is clean and the kernel handles valid_targets correctly.
  • The reset() method is properly implemented and called via __call__ context manager.
  • evaluate.py CAMPress branch is correctly ordered before the generic DecodingPress branch.
  • Test cases cover both low (0.2) and high (0.8) compression ratios.

@saranshagarwal202
Copy link
Copy Markdown
Author

Thank you for your inputs, I am working on the necessary modifications. I am shifting this PR to draft untill I finish.

@saranshagarwal202 saranshagarwal202 marked this pull request as draft March 19, 2026 03:36
Signed-off-by: Saransh Agrawal <saranshagarwal2020@gmail.com>
- replaced `batch` with `bsz`
- replaced `kv_heads` with `num_key_value_heads`
- replaced `__init__` with `__post__init__` with assertions
- replaced `dev = scores.device` with `device = scores.device`
- Added comment under `_torch_merge` to explain fallback

Signed-off-by: Saransh Agrawal <saranshagarwal2020@gmail.com>
Algorithm:
  1. ACCUMULATE: Each decoding step, accumulate per-KV-head
     attention weights into a running sum across generation steps.

  2. TRIGGER: Every `compression_interval` steps (when cache
     exceeds `target_size`), trigger bulk eviction:

     a. SCORE: Use base_press (e.g. KnormPress) to score all
        cached tokens. Select the bottom-k as eviction candidates.

     b. SPLIT: From eviction candidates, pick the top-k by score
        (preferring later sequence positions on ties) as merge
        tokens. The rest are pure evictions.

     c. MERGE: For each merge token m:
        - Find its target window: the next `merge_budget` tokens
          in the kept set after m's position.
        - Compute merge_prob = attn(m) / mean(attn[window_start:])
          using cumulative attention sums.
        - Sample Bernoulli(merge_prob). If accepted, scatter-add
          value(m) / num_targets equally into each target's value.

     d. PRUNE: Physically remove all evicted tokens (both merged
        and pure-evicted) from keys, values, and attention sums.
        Only `target_size` tokens remain.

  3. RESET step counter to 0, continue decoding.

Merge implementations:
  - _torch_merge: Fully vectorized with cumsum + scatter_add_
  - _triton_merge: Triton kernel

Signed-off-by: Saransh Agrawal <saranshagarwal2020@gmail.com>
Signed-off-by: Saransh Agrawal <saranshagarwal2020@gmail.com>
Signed-off-by: Saransh Agrawal <saranshagarwal2020@gmail.com>
Signed-off-by: Saransh Agrawal <saranshagarwal2020@gmail.com>
Signed-off-by: Saransh Agrawal <saranshagarwal2020@gmail.com>
…s as factory methods

Signed-off-by: Saransh Agrawal <saranshagarwal2020@gmail.com>
Signed-off-by: Saransh Agrawal <saranshagarwal2020@gmail.com>
Signed-off-by: Saransh Agrawal <saranshagarwal2020@gmail.com>
Signed-off-by: Saransh Agrawal <saranshagarwal2020@gmail.com>
Signed-off-by: Saransh Agrawal <saranshagarwal2020@gmail.com>
Signed-off-by: Saransh Agrawal <saranshagarwal2020@gmail.com>
@saranshagarwal202 saranshagarwal202 marked this pull request as ready for review March 26, 2026 18:07
@saranshagarwal202
Copy link
Copy Markdown
Author

Hi @alessiodevoto, thank you for the feedback. I’ve finished a major refactor of the implementation; I've detailed the specific changes in the PR description above. Looking forward to your thoughts on the new structure!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants