From 07de591637040c9fdf75bf3138412181eb7e7e6a Mon Sep 17 00:00:00 2001 From: Saransh Agrawal Date: Thu, 12 Mar 2026 13:18:36 -0700 Subject: [PATCH 01/13] Add cam-press a decoding type press. Signed-off-by: Saransh Agrawal --- README.md | 1 + evaluation/evaluate.py | 4 + evaluation/evaluate_registry.py | 2 + kvpress/__init__.py | 2 + kvpress/presses/cam_press.py | 418 ++++++++++++++++++++++++++++++++ tests/default_presses.py | 8 + 6 files changed, 435 insertions(+) create mode 100644 kvpress/presses/cam_press.py diff --git a/README.md b/README.md index d23ee89a..a459dd65 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,7 @@ Finally we provide wrapper presses that can be combined with other presses: - `DecodingPress` ([source](kvpress/presses/decoding_press.py)): allow for compression during decoding, see decoding section in this README. - `PrefillDecodingPress` ([source](kvpress/presses/prefill_decoding_press.py)): allow to compress both during prefilling and during decoding. - `DMSPress` ([source](kvpress/presses/dms_press.py), [paper](https://arxiv.org/abs/2506.05345)): evict keys and values with scores below a given threshold of any `ScorerPress` instead of relying on top-k scores. Support both prefilling and decoding (if decoding=True), but only supports dense-prefill and not sparse-prefill. +- `CAMPress` ([source](kvpress/presses/cam_press.py), [paper](https://arxiv.org/abs/2309.17453)): A decoding press that merges the kv cache of evicted tokens into keep tokens to preserve information. For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression) diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py index b7b59bac..b4d8f60a 100644 --- a/evaluation/evaluate.py +++ b/evaluation/evaluate.py @@ -21,6 +21,7 @@ from transformers import FineGrainedFP8Config, Pipeline, pipeline from kvpress import ( + CAMPress, ComposedPress, DecodingPress, DMSPress, @@ -281,6 +282,9 @@ def _setup_press(self): assert key_channel_compression_ratio is not None, "key_channel_compression_ratio must be set for ThinKPress" press.key_channel_compression_ratio = key_channel_compression_ratio logger.info(f"Set ThinKPress key_channel_compression_ratio to {key_channel_compression_ratio}") + elif isinstance(press, CAMPress): + press.compression_ratio = compression_ratio + logger.info(f"Set CAMPress compression_ratio to {compression_ratio}") elif isinstance(press, DecodingPress): press.compression_interval = self.config.compression_interval or press.compression_interval press.target_size = self.config.target_size or press.target_size diff --git a/evaluation/evaluate_registry.py b/evaluation/evaluate_registry.py index e81e7fda..1c31cc03 100644 --- a/evaluation/evaluate_registry.py +++ b/evaluation/evaluate_registry.py @@ -15,6 +15,7 @@ from kvpress import ( AdaKVPress, BlockPress, + CAMPress, ChunkKVPress, CompactorPress, ComposedPress, @@ -106,6 +107,7 @@ "compactor": CompactorPress(), "adakv_compactor": AdaKVPress(CompactorPress()), "no_press": None, + "cam_streaming_llm": CAMPress(base_press=StreamingLLMPress()), "decoding_knorm": DecodingPress(base_press=KnormPress()), "decoding_streaming_llm": DecodingPress(base_press=StreamingLLMPress()), "decoding_tova": DecodingPress(base_press=TOVAPress()), diff --git a/kvpress/__init__.py b/kvpress/__init__.py index 0f08519e..1cc8c590 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -7,6 +7,7 @@ from kvpress.presses.adakv_press import AdaKVPress from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress from kvpress.presses.block_press import BlockPress +from kvpress.presses.cam_press import CAMPress from kvpress.presses.chunk_press import ChunkPress from kvpress.presses.chunkkv_press import ChunkKVPress from kvpress.presses.compactor_press import CompactorPress @@ -75,6 +76,7 @@ "KeyDiffPress", "KVzipPress", "ExpectedAttentionStatsPress", + "CAMPress", "DecodingPress", "PrefillDecodingPress", "CompactorPress", diff --git a/kvpress/presses/cam_press.py b/kvpress/presses/cam_press.py new file mode 100644 index 00000000..ea5ed714 --- /dev/null +++ b/kvpress/presses/cam_press.py @@ -0,0 +1,418 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +from __future__ import annotations + +import logging +import math +from collections import defaultdict +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from transformers import PreTrainedModel +from transformers.models.llama.modeling_llama import repeat_kv, rotate_half + +from kvpress.presses.decoding_press import DecodingPress +from kvpress.presses.scorer_press import ScorerPress +from kvpress.utils import extract_keys_and_values, get_prerope_query_states + +logger = logging.getLogger(__name__) + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + + +def _aggregate_attention_per_kv_head( + attentions: torch.Tensor, + num_kv_heads: int, +) -> torch.Tensor: + """Average attention scores across query heads that share a KV head.""" + num_query_heads = attentions.shape[1] + if num_query_heads == num_kv_heads: + return attentions + group_size = num_query_heads // num_kv_heads + batch, _, seq_q, seq_k = attentions.shape + return attentions.reshape(batch, num_kv_heads, group_size, seq_q, seq_k).mean(dim=2) + + +@dataclass +class CAMPress(DecodingPress): + """ + Cache Merging (CaM) KV cache compression during decoding. + + Evicted tokens' values are merged into their sequential neighbors using a + Bernoulli merge probability derived from relative attention scores. Keys are + pruned after merging. + + Based on CaM (https://openreview.net/forum?id=LCTmppB165). + + Parameters + ---------- + base_press : ScorerPress + Scorer used to select which tokens to evict (e.g., StreamingLLMPress). + compression_ratio : float, default=0.0 + Fraction of prefill tokens to evict during decoding. + merge_budget : int or None, default=64 + Number of sequential neighbors to merge each evicted token into. + None merges into all remaining tokens after the evicted position. + use_triton : bool, default=True + Use the Triton kernel for merging when available (CUDA only). + """ + + base_press: ScorerPress = None + compression_ratio: float = 0.0 + merge_budget: Optional[int] = 64 + use_triton: bool = True + + def __init__( + self, + base_press: ScorerPress, + compression_ratio: float = 0.0, + merge_budget: Optional[int] = 64, + use_triton: bool = True, + ): + self.base_press = base_press + self.compression_ratio = compression_ratio + self.merge_budget = merge_budget + self.use_triton = use_triton + self._target_cache_size: dict[int, int] = {} + self._first_eviction_done: dict[int, bool] = defaultdict(lambda: False) + + def post_init_from_model(self, model: PreTrainedModel): + if hasattr(self.base_press, "post_init_from_model"): + self.base_press.post_init_from_model(model) + + def reset(self): + """Reset per-sequence state.""" + self._target_cache_size = {} + self._first_eviction_done = defaultdict(lambda: False) + + def score( + self, + module: nn.Module, + hidden_states: torch.Tensor, + keys: torch.Tensor, + values: torch.Tensor, + attentions: Optional[torch.Tensor], + kwargs: dict, + ) -> torch.Tensor: + """Delegate scoring to base_press with the compression ratio adjusted for the current cache size.""" + cache_len = keys.shape[2] + n_to_evict = cache_len - self._target_cache_size[int(module.layer_idx)] + cr = n_to_evict / cache_len if cache_len > 0 else 0.0 + + old_cr = self.base_press.compression_ratio + self.base_press.compression_ratio = cr + try: + scores = self.base_press.score(module, hidden_states, keys, values, attentions, kwargs) + finally: + self.base_press.compression_ratio = old_cr + + return scores + + def compress( + self, + module: nn.Module, + hidden_states: torch.Tensor, + keys: torch.Tensor, + values: torch.Tensor, + attentions: Optional[torch.Tensor], + kwargs: dict, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Merge evicted token values into sequential neighbors, then prune.""" + layer_idx = int(module.layer_idx) + head_dim = module.head_dim + cache_len = keys.shape[2] + + target_size = self._target_cache_size[layer_idx] + n_to_evict = cache_len - target_size + + scores = self.score(module, hidden_states, keys, values, attentions, kwargs) + + batch, kv_heads, _ = scores.shape + dev = scores.device + n_kept = cache_len - n_to_evict + + kept_indices = scores.topk(n_kept, dim=-1).indices + kept_indices = torch.sort(kept_indices, dim=-1).values + + all_idx = torch.arange(cache_len, device=dev) + kept_mask = torch.zeros(batch, kv_heads, cache_len, dtype=torch.bool, device=dev) + kept_mask.scatter_(2, kept_indices, True) + evicted_positions = all_idx.expand(batch, kv_heads, -1)[~kept_mask].reshape(batch, kv_heads, n_to_evict) + + effective_budget = self.merge_budget if self.merge_budget is not None else (cache_len - 1) + actual_budget = min(effective_budget, cache_len - 1) + + offsets = torch.arange(actual_budget, device=dev) + per_token_targets = (evicted_positions.unsqueeze(-1) + 1 + offsets).clamp(max=cache_len - 1) + valid_targets = (evicted_positions.unsqueeze(-1) + 1 + offsets) < cache_len + + merge_mask = None + + if attentions is None and actual_budget > 0 and n_to_evict > 0: + attentions = self._compute_current_token_attention(module, hidden_states, keys, kwargs) + + if attentions is not None and actual_budget > 0 and n_to_evict > 0: + attn_per_kv = _aggregate_attention_per_kv_head(attentions, kv_heads) + if attn_per_kv.shape[2] > 1: + attn_per_kv = attn_per_kv[:, :, -1:, :] + attn_squeezed = attn_per_kv.squeeze(2) + + evicted_attn = attn_squeezed.gather(2, evicted_positions) + per_token_target_attn = ( + attn_squeezed.unsqueeze(2).expand(-1, -1, n_to_evict, -1).gather(3, per_token_targets) + ) + per_token_target_attn = per_token_target_attn.masked_fill( + ~valid_targets.expand_as(per_token_target_attn), float("-inf") + ) + ref_attn = per_token_target_attn.max(dim=-1).values + + merge_prob = torch.where( + ref_attn > 0, + (evicted_attn.float() / ref_attn.float().clamp(min=1e-9)).clamp(0.0, 1.0), + torch.zeros_like(evicted_attn, dtype=torch.float32), + ).to(evicted_attn.dtype) + merge_mask = torch.bernoulli(merge_prob) + + non_merge = merge_mask < 0.5 + if non_merge.any(): + b_idx = torch.arange(batch, device=dev)[:, None, None].expand_as(evicted_positions) + h_idx = torch.arange(kv_heads, device=dev)[None, :, None].expand_as(evicted_positions) + pos_to_zero = evicted_positions[non_merge] + if pos_to_zero.numel() > 0: + values[b_idx[non_merge], h_idx[non_merge], pos_to_zero, :] = 0.0 + + is_first = not self._first_eviction_done[layer_idx] + n_merged = int(merge_mask.sum().item()) + logger.debug( + f"CaM L{layer_idx}: {'BULK' if is_first else 'step'} evict={n_to_evict}, " + f"merged={n_merged}/{n_to_evict}, mean_prob={merge_prob.mean():.3f}, " + f"cache={cache_len}->{n_kept}" + ) + else: + logger.debug(f"CaM L{layer_idx}: no attention, always-merge, evict={n_to_evict}") + + if actual_budget > 0 and n_to_evict > 0: + if n_to_evict == 1 and merge_mask is not None and merge_mask.sum() == 0: + pass + else: + if not per_token_targets.is_contiguous(): + per_token_targets = per_token_targets.contiguous() + + if not evicted_positions.is_contiguous(): + evicted_positions = evicted_positions.contiguous() + + valid_targets_c = valid_targets if valid_targets.is_contiguous() else valid_targets.contiguous() + + if self.use_triton and HAS_TRITON and values.is_cuda: + values = self._triton_merge( + values, evicted_positions, per_token_targets, actual_budget, valid_targets_c + ) + else: + values = self._torch_merge( + values, evicted_positions, per_token_targets, actual_budget, valid_targets_c + ) + + gather_idx = kept_indices.unsqueeze(-1).expand(-1, -1, -1, head_dim) + keys = keys.gather(2, gather_idx).contiguous() + values = values.gather(2, gather_idx).contiguous() + + return keys, values + + @staticmethod + def _compute_current_token_attention( + module: nn.Module, + hidden_states: torch.Tensor, + keys: torch.Tensor, + kwargs: dict, + ) -> torch.Tensor: + """Compute softmax attention from the last query token to all cached keys.""" + _, num_kv_heads, cache_len, head_dim = keys.shape + num_query_heads = module.config.num_attention_heads + num_key_value_groups = num_query_heads // num_kv_heads + + query_states = get_prerope_query_states(module, hidden_states) + query_states = query_states[:, :, -1:, :] + + cos, sin = kwargs["position_embeddings"] + cos = cos[:, -1:, :].unsqueeze(1) + sin = sin[:, -1:, :].unsqueeze(1) + query_states = (query_states * cos) + (rotate_half(query_states) * sin) + + keys_repeated = repeat_kv(keys, num_key_value_groups) + scores = torch.matmul(query_states, keys_repeated.transpose(-2, -1)) / math.sqrt(head_dim) + return torch.nn.functional.softmax(scores, dim=-1, dtype=torch.float32).to(query_states.dtype) + + def forward_hook( + self, + module: nn.Module, + input: list[torch.Tensor], + kwargs: dict, + output: list, + ): + hidden_states = kwargs["hidden_states"] + cache = kwargs["past_key_values"] + q_len = hidden_states.shape[1] + layer_idx = int(module.layer_idx) + + # Only operate during decoding + if kwargs["cache_position"][-1] <= q_len: + return output + + if self.compression_ratio <= 0: + return output + + keys, values = extract_keys_and_values(cache, layer_idx) + cache_len = keys.shape[2] + + if layer_idx not in self._target_cache_size: + prefill_len = cache_len - 1 + self._target_cache_size[layer_idx] = max(int(prefill_len * (1 - self.compression_ratio)), 1) + + if cache_len <= self._target_cache_size[layer_idx]: + return output + + attentions = output[1] if len(output) > 1 and output[1] is not None else None + + keys, values = self.compress(module, hidden_states, keys, values, attentions, kwargs) + + cache.layers[layer_idx].keys = keys + cache.layers[layer_idx].values = values + self._first_eviction_done[layer_idx] = True + + return output + + def _torch_merge(self, values, evicted_positions, per_token_targets, actual_budget, valid_targets): + """Merge each evicted token's value into its sequential neighbors (pure PyTorch fallback).""" + n_evicted = evicted_positions.shape[2] + + for i in range(n_evicted): + for b in range(values.shape[0]): + for h in range(values.shape[1]): + p = evicted_positions[b, h, i].item() + v_evicted = values[b, h, p, :].clone() + if v_evicted.abs().sum() < 1e-12: + continue + contribution = v_evicted / actual_budget + for t in range(actual_budget): + if valid_targets[b, h, i, t].item(): + target = per_token_targets[b, h, i, t].item() + values[b, h, target, :] += contribution + + return values + + def _triton_merge(self, values, evicted_positions, per_token_targets, actual_budget, valid_targets): + """Merge each evicted token's value into its sequential neighbors (Triton kernel).""" + if not HAS_TRITON: + return self._torch_merge(values, evicted_positions, per_token_targets, actual_budget, valid_targets) + + batch_size, num_kv_heads, seq_len, head_dim = values.shape + n_evicted = evicted_positions.shape[2] + + BLOCK_D = triton.next_power_of_2(head_dim) + TILE_R = min(64, triton.next_power_of_2(actual_budget)) + grid = (batch_size, num_kv_heads) + + _cam_decoding_merge_kernel[grid]( + values_ptr=values, + evicted_pos_ptr=evicted_positions, + merge_targets_ptr=per_token_targets, + seq_len=seq_len, + n_evicted=n_evicted, + actual_budget=actual_budget, + v_stride_b=values.stride(0), + v_stride_h=values.stride(1), + v_stride_s=values.stride(2), + v_stride_d=values.stride(3), + ep_stride_b=evicted_positions.stride(0), + ep_stride_h=evicted_positions.stride(1), + ep_stride_e=evicted_positions.stride(2), + mt_stride_b=per_token_targets.stride(0), + mt_stride_h=per_token_targets.stride(1), + mt_stride_e=per_token_targets.stride(2), + mt_stride_t=per_token_targets.stride(3), + head_dim=head_dim, + BLOCK_D=BLOCK_D, + TILE_R=TILE_R, + ) + return values + + +if HAS_TRITON: + + @triton.jit + def _cam_decoding_merge_kernel( + values_ptr, + evicted_pos_ptr, + merge_targets_ptr, + seq_len, + n_evicted, + actual_budget, + v_stride_b, + v_stride_h, + v_stride_s, + v_stride_d, + ep_stride_b, + ep_stride_h, + ep_stride_e, + mt_stride_b, + mt_stride_h, + mt_stride_e, + mt_stride_t, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, + TILE_R: tl.constexpr, + ): + """ + Tiled scatter-add merge kernel. Grid: (batch_size, num_kv_heads). + Each evicted token scatters its contribution into its own neighbor list. + """ + batch_idx = tl.program_id(0) + head_idx = tl.program_id(1) + + v_base = values_ptr + batch_idx * v_stride_b + head_idx * v_stride_h + ep_base = evicted_pos_ptr + batch_idx * ep_stride_b + head_idx * ep_stride_h + mt_base = merge_targets_ptr + batch_idx * mt_stride_b + head_idx * mt_stride_h + + d_offsets = tl.arange(0, BLOCK_D) + d_mask = d_offsets < head_dim + + for evict_idx in tl.range(0, n_evicted): + token_pos = tl.load(ep_base + evict_idx * ep_stride_e).to(tl.int64) + v_evicted = tl.load( + v_base + token_pos * v_stride_s + d_offsets * v_stride_d, + mask=d_mask, + other=0.0, + ) + contribution = v_evicted / actual_budget + + mt_evict_base = mt_base + evict_idx * mt_stride_e + n_tiles = (actual_budget + TILE_R - 1) // TILE_R + + for tile_idx in tl.range(0, n_tiles): + t_offsets = tl.arange(0, TILE_R) + t_indices = tile_idx * TILE_R + t_offsets + t_mask = t_indices < actual_budget + + target_positions = tl.load( + mt_evict_base + t_indices * mt_stride_t, + mask=t_mask, + other=0, + ).to(tl.int64) + + valid = (target_positions < seq_len) & t_mask + ptrs = v_base + target_positions[:, None] * v_stride_s + d_offsets[None, :] * v_stride_d + mask_2d = valid[:, None] & d_mask[None, :] + + v_block = tl.load(ptrs, mask=mask_2d, other=0.0) + v_block = v_block + contribution[None, :] + tl.store(ptrs, v_block, mask=mask_2d) diff --git a/tests/default_presses.py b/tests/default_presses.py index 413f1ea6..8f7c9994 100644 --- a/tests/default_presses.py +++ b/tests/default_presses.py @@ -4,6 +4,7 @@ import numpy as np from kvpress import ( + CAMPress, CompactorPress, CURPress, DuoAttentionPress, @@ -141,4 +142,11 @@ def post_init_from_model(self, model): }, ], }, + { + "cls": CAMPress, + "kwargs": [ + {"base_press": StreamingLLMPress(), "compression_ratio": 0.2}, + {"base_press": StreamingLLMPress(), "compression_ratio": 0.8}, + ], + }, ] From 537b61f9a35459b54e0f1abfde9d01a97184c3bf Mon Sep 17 00:00:00 2001 From: Saransh Agrawal Date: Tue, 17 Mar 2026 13:48:14 -0700 Subject: [PATCH 02/13] Fixed Naming & Conventions - replaced `batch` with `bsz` - replaced `kv_heads` with `num_key_value_heads` - replaced `__init__` with `__post__init__` with assertions - replaced `dev = scores.device` with `device = scores.device` - Added comment under `_torch_merge` to explain fallback Signed-off-by: Saransh Agrawal --- kvpress/presses/cam_press.py | 73 +++++++++++++++++++----------------- 1 file changed, 38 insertions(+), 35 deletions(-) diff --git a/kvpress/presses/cam_press.py b/kvpress/presses/cam_press.py index ea5ed714..8d8cdce7 100644 --- a/kvpress/presses/cam_press.py +++ b/kvpress/presses/cam_press.py @@ -32,15 +32,15 @@ def _aggregate_attention_per_kv_head( attentions: torch.Tensor, - num_kv_heads: int, + num_key_value_heads: int, ) -> torch.Tensor: """Average attention scores across query heads that share a KV head.""" num_query_heads = attentions.shape[1] - if num_query_heads == num_kv_heads: + if num_query_heads == num_key_value_heads: return attentions - group_size = num_query_heads // num_kv_heads - batch, _, seq_q, seq_k = attentions.shape - return attentions.reshape(batch, num_kv_heads, group_size, seq_q, seq_k).mean(dim=2) + group_size = num_query_heads // num_key_value_heads + bsz, _, seq_q, seq_k = attentions.shape + return attentions.reshape(bsz, num_key_value_heads, group_size, seq_q, seq_k).mean(dim=2) @dataclass @@ -72,20 +72,20 @@ class CAMPress(DecodingPress): merge_budget: Optional[int] = 64 use_triton: bool = True - def __init__( - self, - base_press: ScorerPress, - compression_ratio: float = 0.0, - merge_budget: Optional[int] = 64, - use_triton: bool = True, - ): - self.base_press = base_press - self.compression_ratio = compression_ratio - self.merge_budget = merge_budget - self.use_triton = use_triton + def __post_init__(self): + assert isinstance(self.base_press, ScorerPress), "CAMPress requires a ScorerPress as base_press" + assert self.compression_ratio >= 0.0, "compression_ratio must be non-negative" + assert self.merge_budget is None or self.merge_budget > 0, "merge_budget must be positive or None" + assert isinstance(self.merge_budget, (int, type(None))), "merge_budget must be an int or None" + assert isinstance(self.use_triton, bool), "use_triton must be a boolean" + self._target_cache_size: dict[int, int] = {} self._first_eviction_done: dict[int, bool] = defaultdict(lambda: False) + if self.use_triton and not HAS_TRITON: + logger.warning(f"Triton is not available. Falling back to PyTorch merge implementation for {self.__class__.__name__}.") + + def post_init_from_model(self, model: PreTrainedModel): if hasattr(self.base_press, "post_init_from_model"): self.base_press.post_init_from_model(model) @@ -137,22 +137,22 @@ def compress( scores = self.score(module, hidden_states, keys, values, attentions, kwargs) - batch, kv_heads, _ = scores.shape - dev = scores.device + bsz, num_key_value_heads, _ = scores.shape + device = scores.device n_kept = cache_len - n_to_evict kept_indices = scores.topk(n_kept, dim=-1).indices kept_indices = torch.sort(kept_indices, dim=-1).values - all_idx = torch.arange(cache_len, device=dev) - kept_mask = torch.zeros(batch, kv_heads, cache_len, dtype=torch.bool, device=dev) + all_idx = torch.arange(cache_len, device=device) + kept_mask = torch.zeros(bsz, num_key_value_heads, cache_len, dtype=torch.bool, device=device) kept_mask.scatter_(2, kept_indices, True) - evicted_positions = all_idx.expand(batch, kv_heads, -1)[~kept_mask].reshape(batch, kv_heads, n_to_evict) + evicted_positions = all_idx.expand(bsz, num_key_value_heads, -1)[~kept_mask].reshape(bsz, num_key_value_heads, n_to_evict) effective_budget = self.merge_budget if self.merge_budget is not None else (cache_len - 1) actual_budget = min(effective_budget, cache_len - 1) - offsets = torch.arange(actual_budget, device=dev) + offsets = torch.arange(actual_budget, device=device) per_token_targets = (evicted_positions.unsqueeze(-1) + 1 + offsets).clamp(max=cache_len - 1) valid_targets = (evicted_positions.unsqueeze(-1) + 1 + offsets) < cache_len @@ -162,7 +162,7 @@ def compress( attentions = self._compute_current_token_attention(module, hidden_states, keys, kwargs) if attentions is not None and actual_budget > 0 and n_to_evict > 0: - attn_per_kv = _aggregate_attention_per_kv_head(attentions, kv_heads) + attn_per_kv = _aggregate_attention_per_kv_head(attentions, num_key_value_heads) if attn_per_kv.shape[2] > 1: attn_per_kv = attn_per_kv[:, :, -1:, :] attn_squeezed = attn_per_kv.squeeze(2) @@ -185,8 +185,8 @@ def compress( non_merge = merge_mask < 0.5 if non_merge.any(): - b_idx = torch.arange(batch, device=dev)[:, None, None].expand_as(evicted_positions) - h_idx = torch.arange(kv_heads, device=dev)[None, :, None].expand_as(evicted_positions) + b_idx = torch.arange(bsz, device=device)[:, None, None].expand_as(evicted_positions) + h_idx = torch.arange(num_key_value_heads, device=device)[None, :, None].expand_as(evicted_positions) pos_to_zero = evicted_positions[non_merge] if pos_to_zero.numel() > 0: values[b_idx[non_merge], h_idx[non_merge], pos_to_zero, :] = 0.0 @@ -236,9 +236,9 @@ def _compute_current_token_attention( kwargs: dict, ) -> torch.Tensor: """Compute softmax attention from the last query token to all cached keys.""" - _, num_kv_heads, cache_len, head_dim = keys.shape + _, num_key_value_heads, cache_len, head_dim = keys.shape num_query_heads = module.config.num_attention_heads - num_key_value_groups = num_query_heads // num_kv_heads + num_key_value_groups = num_query_heads // num_key_value_heads query_states = get_prerope_query_states(module, hidden_states) query_states = query_states[:, :, -1:, :] @@ -292,7 +292,10 @@ def forward_hook( return output def _torch_merge(self, values, evicted_positions, per_token_targets, actual_budget, valid_targets): - """Merge each evicted token's value into its sequential neighbors (pure PyTorch fallback).""" + """Merge each evicted token's value into its sequential neighbors (pure PyTorch fallback). + + Used when Triton is unavailable (not installed) or when values are not on a CUDA device. + """ n_evicted = evicted_positions.shape[2] for i in range(n_evicted): @@ -315,12 +318,12 @@ def _triton_merge(self, values, evicted_positions, per_token_targets, actual_bud if not HAS_TRITON: return self._torch_merge(values, evicted_positions, per_token_targets, actual_budget, valid_targets) - batch_size, num_kv_heads, seq_len, head_dim = values.shape + bsz, num_key_value_heads, seq_len, head_dim = values.shape n_evicted = evicted_positions.shape[2] BLOCK_D = triton.next_power_of_2(head_dim) TILE_R = min(64, triton.next_power_of_2(actual_budget)) - grid = (batch_size, num_kv_heads) + grid = (bsz, num_key_value_heads) _cam_decoding_merge_kernel[grid]( values_ptr=values, @@ -373,15 +376,15 @@ def _cam_decoding_merge_kernel( TILE_R: tl.constexpr, ): """ - Tiled scatter-add merge kernel. Grid: (batch_size, num_kv_heads). + Tiled scatter-add merge kernel. Grid: (batch_size, num_key_value_heads). Each evicted token scatters its contribution into its own neighbor list. """ - batch_idx = tl.program_id(0) + bsz_idx = tl.program_id(0) head_idx = tl.program_id(1) - v_base = values_ptr + batch_idx * v_stride_b + head_idx * v_stride_h - ep_base = evicted_pos_ptr + batch_idx * ep_stride_b + head_idx * ep_stride_h - mt_base = merge_targets_ptr + batch_idx * mt_stride_b + head_idx * mt_stride_h + v_base = values_ptr + bsz_idx * v_stride_b + head_idx * v_stride_h + ep_base = evicted_pos_ptr + bsz_idx * ep_stride_b + head_idx * ep_stride_h + mt_base = merge_targets_ptr + bsz_idx * mt_stride_b + head_idx * mt_stride_h d_offsets = tl.arange(0, BLOCK_D) d_mask = d_offsets < head_dim From 1a1fab4fea3672c85acd42dd18fffded8017bb66 Mon Sep 17 00:00:00 2001 From: Saransh Agrawal Date: Tue, 24 Mar 2026 16:40:45 -0700 Subject: [PATCH 03/13] Major fixes to CaM algorithm keeping it as close to paper as possible. Algorithm: 1. ACCUMULATE: Each decoding step, accumulate per-KV-head attention weights into a running sum across generation steps. 2. TRIGGER: Every `compression_interval` steps (when cache exceeds `target_size`), trigger bulk eviction: a. SCORE: Use base_press (e.g. KnormPress) to score all cached tokens. Select the bottom-k as eviction candidates. b. SPLIT: From eviction candidates, pick the top-k by score (preferring later sequence positions on ties) as merge tokens. The rest are pure evictions. c. MERGE: For each merge token m: - Find its target window: the next `merge_budget` tokens in the kept set after m's position. - Compute merge_prob = attn(m) / mean(attn[window_start:]) using cumulative attention sums. - Sample Bernoulli(merge_prob). If accepted, scatter-add value(m) / num_targets equally into each target's value. d. PRUNE: Physically remove all evicted tokens (both merged and pure-evicted) from keys, values, and attention sums. Only `target_size` tokens remain. 3. RESET step counter to 0, continue decoding. Merge implementations: - _torch_merge: Fully vectorized with cumsum + scatter_add_ - _triton_merge: Triton kernel Signed-off-by: Saransh Agrawal --- kvpress/presses/cam_press.py | 567 ++++++++++++++++++----------------- 1 file changed, 290 insertions(+), 277 deletions(-) diff --git a/kvpress/presses/cam_press.py b/kvpress/presses/cam_press.py index 8d8cdce7..94b3dfd7 100644 --- a/kvpress/presses/cam_press.py +++ b/kvpress/presses/cam_press.py @@ -1,22 +1,21 @@ # SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 - from __future__ import annotations import logging import math from collections import defaultdict from dataclasses import dataclass -from typing import Optional import torch import torch.nn as nn -from transformers import PreTrainedModel +from transformers import QuantizedCache from transformers.models.llama.modeling_llama import repeat_kv, rotate_half from kvpress.presses.decoding_press import DecodingPress from kvpress.presses.scorer_press import ScorerPress +from kvpress.presses.adakv_press import AdaKVPress from kvpress.utils import extract_keys_and_values, get_prerope_query_states logger = logging.getLogger(__name__) @@ -30,19 +29,6 @@ HAS_TRITON = False -def _aggregate_attention_per_kv_head( - attentions: torch.Tensor, - num_key_value_heads: int, -) -> torch.Tensor: - """Average attention scores across query heads that share a KV head.""" - num_query_heads = attentions.shape[1] - if num_query_heads == num_key_value_heads: - return attentions - group_size = num_query_heads // num_key_value_heads - bsz, _, seq_q, seq_k = attentions.shape - return attentions.reshape(bsz, num_key_value_heads, group_size, seq_q, seq_k).mean(dim=2) - - @dataclass class CAMPress(DecodingPress): """ @@ -67,166 +53,161 @@ class CAMPress(DecodingPress): Use the Triton kernel for merging when available (CUDA only). """ - base_press: ScorerPress = None - compression_ratio: float = 0.0 - merge_budget: Optional[int] = 64 - use_triton: bool = True + base_press: ScorerPress | AdaKVPress = None + compression_interval: int = 2 + target_size: int = 3048 + hidden_states_buffer_size: int = 256 + merge_budget: int = 32 + use_triton: bool = False def __post_init__(self): - assert isinstance(self.base_press, ScorerPress), "CAMPress requires a ScorerPress as base_press" - assert self.compression_ratio >= 0.0, "compression_ratio must be non-negative" - assert self.merge_budget is None or self.merge_budget > 0, "merge_budget must be positive or None" - assert isinstance(self.merge_budget, (int, type(None))), "merge_budget must be an int or None" + assert isinstance(self.base_press, (ScorerPress, AdaKVPress)), "CAMPress requires a ScorerPress as base_press" + assert self.compression_interval > 0, "compression_interval must be greater than 0" + assert self.target_size > 0, "target_size must be greater than 0" + assert self.merge_budget > 0, "merge_budget must be positive " assert isinstance(self.use_triton, bool), "use_triton must be a boolean" - self._target_cache_size: dict[int, int] = {} - self._first_eviction_done: dict[int, bool] = defaultdict(lambda: False) + # State Variables + self.layer_step_counts = defaultdict(int) + self._running_attn_sum: dict[int, torch.Tensor] = {} if self.use_triton and not HAS_TRITON: - logger.warning(f"Triton is not available. Falling back to PyTorch merge implementation for {self.__class__.__name__}.") - - - def post_init_from_model(self, model: PreTrainedModel): - if hasattr(self.base_press, "post_init_from_model"): - self.base_press.post_init_from_model(model) - - def reset(self): - """Reset per-sequence state.""" - self._target_cache_size = {} - self._first_eviction_done = defaultdict(lambda: False) + logger.warning( + f"Triton is not available. Falling back to PyTorch merge implementation for {self.__class__.__name__}." + ) - def score( + def compress( self, module: nn.Module, hidden_states: torch.Tensor, keys: torch.Tensor, values: torch.Tensor, - attentions: Optional[torch.Tensor], + attentions: torch.Tensor, kwargs: dict, - ) -> torch.Tensor: - """Delegate scoring to base_press with the compression ratio adjusted for the current cache size.""" + ) -> tuple[torch.Tensor, torch.Tensor]: + layer_idx = int(module.layer_idx) cache_len = keys.shape[2] - n_to_evict = cache_len - self._target_cache_size[int(module.layer_idx)] - cr = n_to_evict / cache_len if cache_len > 0 else 0.0 + n_to_evict = cache_len-self.target_size + + target_compression_ratio = self._find_target_compression_ratio(cache_len, self.target_size) + + + if n_to_evict <= 0: + return keys, values + + # Temporary override base press ratio to get correct topK scores old_cr = self.base_press.compression_ratio - self.base_press.compression_ratio = cr - try: - scores = self.base_press.score(module, hidden_states, keys, values, attentions, kwargs) - finally: - self.base_press.compression_ratio = old_cr + self.base_press.compression_ratio = target_compression_ratio + scores = self.base_press.score(module, hidden_states, keys, values, None, kwargs) + self.base_press.compression_ratio = old_cr - return scores + bsz, num_key_value_heads, _, head_dim = keys.shape - def compress( + evict_indices = scores[:, 0, :].topk(n_to_evict, dim=-1, largest=False).indices + evict_indices = torch.sort(evict_indices, dim=-1).values + + evict_scores = scores[:, 0, :].gather(-1, evict_indices) + # Flip so later sequence positions come first; stable sort preserves this order for ties + k = self.layer_step_counts[layer_idx] + order = evict_scores.flip(-1).argsort(dim=-1, descending=True, stable=True)[:, :k] + merge_indices = evict_indices.gather(-1, n_to_evict - 1 - order) + merge_indices = torch.sort(merge_indices, dim=-1).values + + kept_indices = scores[:, 0, :].topk(self.target_size, dim=-1).indices + kept_indices = torch.sort(kept_indices, dim=-1).values + + if n_to_evict > 0: + if self.use_triton and HAS_TRITON and values.is_cuda: + values = self._triton_merge(values, merge_indices, kept_indices, attentions, self.merge_budget) + else: + values = self._torch_merge(values, merge_indices, kept_indices, attentions, self.merge_budget) + + # Physical Pruning + kept_indices_expand = kept_indices.view(bsz, 1, self.target_size, 1).expand(bsz, num_key_value_heads, self.target_size, head_dim) + keys = keys.gather(2, kept_indices_expand).contiguous() + values = values.gather(2, kept_indices_expand).contiguous() + + # prune cumulative attentions + expanded_indices = kept_indices.unsqueeze(1).expand(bsz, num_key_value_heads, -1) + self._running_attn_sum[layer_idx] = self._running_attn_sum[layer_idx].gather(2, expanded_indices).contiguous() + + return keys, values + + def forward_hook( self, module: nn.Module, - hidden_states: torch.Tensor, - keys: torch.Tensor, - values: torch.Tensor, - attentions: Optional[torch.Tensor], + input: list[torch.Tensor], kwargs: dict, - ) -> tuple[torch.Tensor, torch.Tensor]: - """Merge evicted token values into sequential neighbors, then prune.""" + output: list, + ): + hidden_states = kwargs["hidden_states"] + cache = kwargs["past_key_values"] + q_len = hidden_states.shape[1] layer_idx = int(module.layer_idx) - head_dim = module.head_dim - cache_len = keys.shape[2] - target_size = self._target_cache_size[layer_idx] - n_to_evict = cache_len - target_size + # Only operate during decoding + if kwargs["cache_position"][-1] <= q_len: + return output - scores = self.score(module, hidden_states, keys, values, attentions, kwargs) + cache_layer = cache.layers[module.layer_idx] + keys, values = extract_keys_and_values(cache, layer_idx) + bsz, num_key_value_heads, seq_len, _ = keys.shape - bsz, num_key_value_heads, _ = scores.shape - device = scores.device - n_kept = cache_len - n_to_evict + # Accumulate Cumulative Attention over generation steps + attentions = output[1] if len(output) > 1 and output[1] is not None else None + if attentions is None: + attentions = self._compute_current_token_attention(module, hidden_states, keys, kwargs) + else: + attentions = attentions[:,:,-1:,:] - kept_indices = scores.topk(n_kept, dim=-1).indices - kept_indices = torch.sort(kept_indices, dim=-1).values + attentions = self._aggregate_attention_per_kv_head(attentions, num_key_value_heads) - all_idx = torch.arange(cache_len, device=device) - kept_mask = torch.zeros(bsz, num_key_value_heads, cache_len, dtype=torch.bool, device=device) - kept_mask.scatter_(2, kept_indices, True) - evicted_positions = all_idx.expand(bsz, num_key_value_heads, -1)[~kept_mask].reshape(bsz, num_key_value_heads, n_to_evict) + if attentions is not None: + attn_squeezed = attentions.squeeze(2) - effective_budget = self.merge_budget if self.merge_budget is not None else (cache_len - 1) - actual_budget = min(effective_budget, cache_len - 1) + if layer_idx not in self._running_attn_sum: + self._running_attn_sum[layer_idx] = attn_squeezed.clone() + else: + # Pad running sum for the new token growth + prev_len = self._running_attn_sum[layer_idx].shape[-1] + pad_len = seq_len - prev_len - offsets = torch.arange(actual_budget, device=device) - per_token_targets = (evicted_positions.unsqueeze(-1) + 1 + offsets).clamp(max=cache_len - 1) - valid_targets = (evicted_positions.unsqueeze(-1) + 1 + offsets) < cache_len + if pad_len > 0: + pad = torch.zeros( + (bsz, num_key_value_heads, pad_len), device=attn_squeezed.device, dtype=attn_squeezed.dtype + ) + self._running_attn_sum[layer_idx] = torch.cat([self._running_attn_sum[layer_idx], pad], dim=-1) - merge_mask = None + self._running_attn_sum[layer_idx] += attn_squeezed - if attentions is None and actual_budget > 0 and n_to_evict > 0: - attentions = self._compute_current_token_attention(module, hidden_states, keys, kwargs) + self.layer_step_counts[layer_idx] += 1 - if attentions is not None and actual_budget > 0 and n_to_evict > 0: - attn_per_kv = _aggregate_attention_per_kv_head(attentions, num_key_value_heads) - if attn_per_kv.shape[2] > 1: - attn_per_kv = attn_per_kv[:, :, -1:, :] - attn_squeezed = attn_per_kv.squeeze(2) + # Trigger interval-based bulk eviction + if (self.layer_step_counts[layer_idx] >= self.compression_interval and seq_len>self.target_size) or (q_len >= self.target_size): - evicted_attn = attn_squeezed.gather(2, evicted_positions) - per_token_target_attn = ( - attn_squeezed.unsqueeze(2).expand(-1, -1, n_to_evict, -1).gather(3, per_token_targets) - ) - per_token_target_attn = per_token_target_attn.masked_fill( - ~valid_targets.expand_as(per_token_target_attn), float("-inf") - ) - ref_attn = per_token_target_attn.max(dim=-1).values - - merge_prob = torch.where( - ref_attn > 0, - (evicted_attn.float() / ref_attn.float().clamp(min=1e-9)).clamp(0.0, 1.0), - torch.zeros_like(evicted_attn, dtype=torch.float32), - ).to(evicted_attn.dtype) - merge_mask = torch.bernoulli(merge_prob) - - non_merge = merge_mask < 0.5 - if non_merge.any(): - b_idx = torch.arange(bsz, device=device)[:, None, None].expand_as(evicted_positions) - h_idx = torch.arange(num_key_value_heads, device=device)[None, :, None].expand_as(evicted_positions) - pos_to_zero = evicted_positions[non_merge] - if pos_to_zero.numel() > 0: - values[b_idx[non_merge], h_idx[non_merge], pos_to_zero, :] = 0.0 - - is_first = not self._first_eviction_done[layer_idx] - n_merged = int(merge_mask.sum().item()) - logger.debug( - f"CaM L{layer_idx}: {'BULK' if is_first else 'step'} evict={n_to_evict}, " - f"merged={n_merged}/{n_to_evict}, mean_prob={merge_prob.mean():.3f}, " - f"cache={cache_len}->{n_kept}" - ) - else: - logger.debug(f"CaM L{layer_idx}: no attention, always-merge, evict={n_to_evict}") + attn_squeezed = self._running_attn_sum[layer_idx] + keys, values = self.compress(module, hidden_states, keys, values, attn_squeezed, kwargs) - if actual_budget > 0 and n_to_evict > 0: - if n_to_evict == 1 and merge_mask is not None and merge_mask.sum() == 0: - pass + # Update cache with compressed keys and values + if isinstance(cache, QuantizedCache): + cache_layer._quantized_keys = cache_layer._quantize(keys, axis=cache_layer.axis_key) + cache_layer._quantized_values = cache_layer._quantize(values, axis=cache_layer.axis_value) + cache_layer.keys = torch.zeros(0, dtype=keys.dtype, device=keys.device) # type: ignore[index] + cache_layer.values = torch.zeros(0, dtype=keys.dtype, device=keys.device) # type: ignore[index] + cache_layer.cumulative_length = keys.shape[2] else: - if not per_token_targets.is_contiguous(): - per_token_targets = per_token_targets.contiguous() - - if not evicted_positions.is_contiguous(): - evicted_positions = evicted_positions.contiguous() + cache_layer.keys = keys + cache_layer.values = values - valid_targets_c = valid_targets if valid_targets.is_contiguous() else valid_targets.contiguous() - - if self.use_triton and HAS_TRITON and values.is_cuda: - values = self._triton_merge( - values, evicted_positions, per_token_targets, actual_budget, valid_targets_c - ) - else: - values = self._torch_merge( - values, evicted_positions, per_token_targets, actual_budget, valid_targets_c - ) + self.layer_step_counts[layer_idx] = 0 - gather_idx = kept_indices.unsqueeze(-1).expand(-1, -1, -1, head_dim) - keys = keys.gather(2, gather_idx).contiguous() - values = values.gather(2, gather_idx).contiguous() + return output - return keys, values + def reset(self): + """Reset per-sequence state.""" + self.layer_step_counts = defaultdict(int) + self._running_attn_sum: dict[int, torch.Tensor] = {} @staticmethod def _compute_current_token_attention( @@ -252,100 +233,115 @@ def _compute_current_token_attention( scores = torch.matmul(query_states, keys_repeated.transpose(-2, -1)) / math.sqrt(head_dim) return torch.nn.functional.softmax(scores, dim=-1, dtype=torch.float32).to(query_states.dtype) - def forward_hook( - self, - module: nn.Module, - input: list[torch.Tensor], - kwargs: dict, - output: list, - ): - hidden_states = kwargs["hidden_states"] - cache = kwargs["past_key_values"] - q_len = hidden_states.shape[1] - layer_idx = int(module.layer_idx) - - # Only operate during decoding - if kwargs["cache_position"][-1] <= q_len: - return output - - if self.compression_ratio <= 0: - return output - - keys, values = extract_keys_and_values(cache, layer_idx) - cache_len = keys.shape[2] - - if layer_idx not in self._target_cache_size: - prefill_len = cache_len - 1 - self._target_cache_size[layer_idx] = max(int(prefill_len * (1 - self.compression_ratio)), 1) - - if cache_len <= self._target_cache_size[layer_idx]: - return output - - attentions = output[1] if len(output) > 1 and output[1] is not None else None - - keys, values = self.compress(module, hidden_states, keys, values, attentions, kwargs) - - cache.layers[layer_idx].keys = keys - cache.layers[layer_idx].values = values - self._first_eviction_done[layer_idx] = True - - return output - - def _torch_merge(self, values, evicted_positions, per_token_targets, actual_budget, valid_targets): - """Merge each evicted token's value into its sequential neighbors (pure PyTorch fallback). + @staticmethod + def _aggregate_attention_per_kv_head( + attentions: torch.Tensor, + num_key_value_heads: int, + ) -> torch.Tensor: + """Average attention scores across query heads that share a KV head.""" + num_query_heads = attentions.shape[1] + if num_query_heads == num_key_value_heads: + return attentions + group_size = num_query_heads // num_key_value_heads + bsz, _, seq_q, seq_k = attentions.shape + return attentions.reshape(bsz, num_key_value_heads, group_size, seq_q, seq_k).mean(dim=2) + + def _torch_merge(self, values, merge_indices, kept_indices, attentions_per_kv, merge_budget): + bsz, num_kv_heads, seq_len, head_dim = values.shape + n_merge = merge_indices.shape[1] + n_kept = kept_indices.shape[1] + + # 1. Cascading target starts + base_idx_first = torch.searchsorted(kept_indices, merge_indices[:, 0:1], right=True) + target_starts = torch.arange(n_merge, device=kept_indices.device).unsqueeze(0) + base_idx_first + + # 2. Build target window indices: [bsz, n_merge, merge_budget] + offsets = torch.arange(merge_budget, device=kept_indices.device) + window_idx = target_starts.unsqueeze(-1) + offsets.view(1, 1, -1) + valid_mask = window_idx < n_kept + window_idx = window_idx.clamp(max=n_kept - 1) + target_positions = kept_indices.gather(1, window_idx.view(bsz, -1)).view(bsz, n_merge, merge_budget) + + # 3. Actual budget per merge token + actual_budget = valid_mask.sum(dim=-1) + + # 4. Suffix mean via cumsum + attn_cumsum = torch.nn.functional.pad(attentions_per_kv.cumsum(dim=-1), (1, 0)) + total_sum = attn_cumsum[:, :, -1:] + start_sum = attn_cumsum.gather(2, target_starts.unsqueeze(1).expand(-1, num_kv_heads, -1)) + suffix_sum = total_sum - start_sum + suffix_len = (seq_len - target_starts).unsqueeze(1) + mean_attn = suffix_sum / suffix_len + + # 5. Merge probability + merge_token_attn = attentions_per_kv.gather(2, merge_indices.unsqueeze(1).expand(-1, num_kv_heads, -1)) + merge_prob = merge_token_attn / mean_attn + merge_prob = torch.where(torch.isnan(merge_prob), torch.zeros_like(merge_prob), merge_prob) + merge_prob = torch.where(torch.isinf(merge_prob), torch.ones_like(merge_prob), merge_prob) + merge_prob = merge_prob.clamp(0, 1) + + # 6. Bernoulli sampling + merge_mask = torch.bernoulli(merge_prob) + + # 7. Build contributions and scatter-add + merge_values = values.gather(2, merge_indices.view(bsz, 1, n_merge, 1).expand(-1, num_kv_heads, -1, head_dim)) + scale = (merge_mask / actual_budget.unsqueeze(1)).unsqueeze(-1) + scale = torch.where(actual_budget.unsqueeze(1).unsqueeze(-1) == 0, torch.zeros_like(scale), scale) + contributions = merge_values * scale + contributions = contributions.unsqueeze(3).expand(-1, -1, -1, merge_budget, -1) + contributions = contributions * valid_mask.view(bsz, 1, n_merge, merge_budget, 1) + contributions = contributions.reshape(bsz, num_kv_heads, n_merge * merge_budget, head_dim) + scatter_idx = target_positions.view(bsz, 1, n_merge * merge_budget, 1).expand(-1, num_kv_heads, -1, head_dim) + + values.scatter_add_(2, scatter_idx, contributions) + return values - Used when Triton is unavailable (not installed) or when values are not on a CUDA device. - """ - n_evicted = evicted_positions.shape[2] - - for i in range(n_evicted): - for b in range(values.shape[0]): - for h in range(values.shape[1]): - p = evicted_positions[b, h, i].item() - v_evicted = values[b, h, p, :].clone() - if v_evicted.abs().sum() < 1e-12: - continue - contribution = v_evicted / actual_budget - for t in range(actual_budget): - if valid_targets[b, h, i, t].item(): - target = per_token_targets[b, h, i, t].item() - values[b, h, target, :] += contribution + def _triton_merge(self, values, merge_indices, kept_indices, attentions_per_kv, merge_budget): + """Pre-computes cascading start targets and prefix sums, then merges in a single Triton kernel.""" + bsz, num_kv_heads, _, head_dim = values.shape + n_merge = merge_indices.shape[1] + n_kept = kept_indices.shape[1] + attn_len = attentions_per_kv.shape[2] - return values + # 1. Pre-compute the cascading target + base_idx_first = torch.searchsorted(kept_indices, merge_indices[:, 0:1], right=True) + target_starts = torch.arange(n_merge, device=kept_indices.device).unsqueeze(0) + target_starts += base_idx_first - def _triton_merge(self, values, evicted_positions, per_token_targets, actual_budget, valid_targets): - """Merge each evicted token's value into its sequential neighbors (Triton kernel).""" - if not HAS_TRITON: - return self._torch_merge(values, evicted_positions, per_token_targets, actual_budget, valid_targets) + # 2. Prefix sum for O(1) suffix-mean in kernel: prefix_sum[i] = sum(attn[0:i]) + attn_prefix_sum = torch.nn.functional.pad(attentions_per_kv.cumsum(dim=-1), (1, 0)) - bsz, num_key_value_heads, seq_len, head_dim = values.shape - n_evicted = evicted_positions.shape[2] + # 3. Pre-sampled random values for deterministic Bernoulli inside kernel + rand_thresholds = torch.rand((bsz, num_kv_heads, n_merge), device=values.device) BLOCK_D = triton.next_power_of_2(head_dim) - TILE_R = min(64, triton.next_power_of_2(actual_budget)) - grid = (bsz, num_key_value_heads) - - _cam_decoding_merge_kernel[grid]( - values_ptr=values, - evicted_pos_ptr=evicted_positions, - merge_targets_ptr=per_token_targets, - seq_len=seq_len, - n_evicted=n_evicted, - actual_budget=actual_budget, - v_stride_b=values.stride(0), - v_stride_h=values.stride(1), - v_stride_s=values.stride(2), - v_stride_d=values.stride(3), - ep_stride_b=evicted_positions.stride(0), - ep_stride_h=evicted_positions.stride(1), - ep_stride_e=evicted_positions.stride(2), - mt_stride_b=per_token_targets.stride(0), - mt_stride_h=per_token_targets.stride(1), - mt_stride_e=per_token_targets.stride(2), - mt_stride_t=per_token_targets.stride(3), + grid = (bsz, num_kv_heads, n_merge) + + _cam_merge_kernel[grid]( + value_states_ptr=values, + merge_token_ids_ptr=merge_indices, + kept_token_ids_ptr=kept_indices, + merge_target_starts_ptr=target_starts, + attn_cumsum_ptr=attn_prefix_sum, + attn_weights_ptr=attentions_per_kv, + rand_thresholds_ptr=rand_thresholds, + num_kept_tokens=n_kept, + merge_budget=merge_budget, + seq_len=attn_len, + v_stride_batch=values.stride(0), + v_stride_head=values.stride(1), + v_stride_seq=values.stride(2), + v_stride_dim=values.stride(3), + idx_stride_batch=merge_indices.stride(0), + idx_stride_seq=merge_indices.stride(1), + cs_stride_batch=attn_prefix_sum.stride(0), + cs_stride_head=attn_prefix_sum.stride(1), + cs_stride_seq=attn_prefix_sum.stride(2), + attn_stride_batch=attentions_per_kv.stride(0), + attn_stride_head=attentions_per_kv.stride(1), + attn_stride_seq=attentions_per_kv.stride(2), head_dim=head_dim, BLOCK_D=BLOCK_D, - TILE_R=TILE_R, ) return values @@ -353,69 +349,86 @@ def _triton_merge(self, values, evicted_positions, per_token_targets, actual_bud if HAS_TRITON: @triton.jit - def _cam_decoding_merge_kernel( - values_ptr, - evicted_pos_ptr, - merge_targets_ptr, + def _cam_merge_kernel( + value_states_ptr, + merge_token_ids_ptr, + kept_token_ids_ptr, + merge_target_starts_ptr, + attn_cumsum_ptr, + attn_weights_ptr, + rand_thresholds_ptr, + num_kept_tokens, + merge_budget, seq_len, - n_evicted, - actual_budget, - v_stride_b, - v_stride_h, - v_stride_s, - v_stride_d, - ep_stride_b, - ep_stride_h, - ep_stride_e, - mt_stride_b, - mt_stride_h, - mt_stride_e, - mt_stride_t, + v_stride_batch, + v_stride_head, + v_stride_seq, + v_stride_dim, + idx_stride_batch, + idx_stride_seq, + cs_stride_batch, + cs_stride_head, + cs_stride_seq, + attn_stride_batch, + attn_stride_head, + attn_stride_seq, head_dim: tl.constexpr, BLOCK_D: tl.constexpr, - TILE_R: tl.constexpr, ): """ - Tiled scatter-add merge kernel. Grid: (batch_size, num_key_value_heads). - Each evicted token scatters its contribution into its own neighbor list. + CaM scatter-add merge using cumulative attention with O(1) suffix-mean. + Grid: (batch_size, num_kv_heads, n_merge) """ - bsz_idx = tl.program_id(0) + batch_idx = tl.program_id(0) head_idx = tl.program_id(1) + merge_idx = tl.program_id(2) + + # Load the merge token position and target window start + merge_token_pos = tl.load(merge_token_ids_ptr + batch_idx * idx_stride_batch + merge_idx * idx_stride_seq) + target_start = tl.load(merge_target_starts_ptr + batch_idx * idx_stride_batch + merge_idx * idx_stride_seq) + + # Calculate number of target tokens (handling edge near end of sequence) + num_targets = tl.minimum(merge_budget, num_kept_tokens - target_start) + if num_targets <= 0: + return + + # Compute mean attention of suffix [target_start:] via cumulative sums (O(1)) + cumsum_base = attn_cumsum_ptr + batch_idx * cs_stride_batch + head_idx * cs_stride_head + attn_suffix_sum = tl.load(cumsum_base + seq_len * cs_stride_seq) - tl.load(cumsum_base + target_start * cs_stride_seq) + attn_suffix_len = seq_len - target_start + mean_attn = attn_suffix_sum / attn_suffix_len + + # Load attention weight for the token being merged + attn_weights_base = attn_weights_ptr + batch_idx * attn_stride_batch + head_idx * attn_stride_head + merge_token_attn = tl.load(attn_weights_base + merge_token_pos * attn_stride_seq) + + # Calculate merge probability (with nan/inf safe-guards) + if mean_attn == 0.0: + merge_prob = 1.0 if merge_token_attn > 0 else 0.0 + else: + merge_prob = merge_token_attn / mean_attn + merge_prob = tl.minimum(tl.maximum(merge_prob, 0.0), 1.0) - v_base = values_ptr + bsz_idx * v_stride_b + head_idx * v_stride_h - ep_base = evicted_pos_ptr + bsz_idx * ep_stride_b + head_idx * ep_stride_h - mt_base = merge_targets_ptr + bsz_idx * mt_stride_b + head_idx * mt_stride_h - - d_offsets = tl.arange(0, BLOCK_D) - d_mask = d_offsets < head_dim - - for evict_idx in tl.range(0, n_evicted): - token_pos = tl.load(ep_base + evict_idx * ep_stride_e).to(tl.int64) - v_evicted = tl.load( - v_base + token_pos * v_stride_s + d_offsets * v_stride_d, - mask=d_mask, - other=0.0, - ) - contribution = v_evicted / actual_budget + # Bernoulli draw using pre-computed random thresholds + rand_val = tl.load( + rand_thresholds_ptr + batch_idx * (tl.num_programs(1) * tl.num_programs(2)) + head_idx * tl.num_programs(2) + merge_idx + ) - mt_evict_base = mt_base + evict_idx * mt_stride_e - n_tiles = (actual_budget + TILE_R - 1) // TILE_R + if merge_prob > rand_val: + dim_offsets = tl.arange(0, BLOCK_D) + dim_mask = dim_offsets < head_dim - for tile_idx in tl.range(0, n_tiles): - t_offsets = tl.arange(0, TILE_R) - t_indices = tile_idx * TILE_R + t_offsets - t_mask = t_indices < actual_budget + value_base = value_states_ptr + batch_idx * v_stride_batch + head_idx * v_stride_head - target_positions = tl.load( - mt_evict_base + t_indices * mt_stride_t, - mask=t_mask, - other=0, - ).to(tl.int64) + # Load value vector of the token being merged + merge_token_value = tl.load(value_base + merge_token_pos * v_stride_seq + dim_offsets * v_stride_dim, mask=dim_mask, other=0.0) + contribution = merge_token_value / num_targets - valid = (target_positions < seq_len) & t_mask - ptrs = v_base + target_positions[:, None] * v_stride_s + d_offsets[None, :] * v_stride_d - mask_2d = valid[:, None] & d_mask[None, :] + # Scatter-add contribution equally across target tokens + for i in range(merge_budget): + if i < num_targets: + target_offset = target_start + i + target_token_pos = tl.load(kept_token_ids_ptr + batch_idx * idx_stride_batch + target_offset * idx_stride_seq) - v_block = tl.load(ptrs, mask=mask_2d, other=0.0) - v_block = v_block + contribution[None, :] - tl.store(ptrs, v_block, mask=mask_2d) + target_value_ptrs = value_base + target_token_pos * v_stride_seq + dim_offsets * v_stride_dim + tl.atomic_add(target_value_ptrs, contribution, mask=dim_mask) \ No newline at end of file From e381c51f91099c170bedfc084b3c4c229c64b5fd Mon Sep 17 00:00:00 2001 From: Saransh Agrawal Date: Tue, 24 Mar 2026 19:05:15 -0700 Subject: [PATCH 04/13] removed CAMPress params, using DecodingPress params now Signed-off-by: Saransh Agrawal --- evaluation/evaluate.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py index b4d8f60a..b7b59bac 100644 --- a/evaluation/evaluate.py +++ b/evaluation/evaluate.py @@ -21,7 +21,6 @@ from transformers import FineGrainedFP8Config, Pipeline, pipeline from kvpress import ( - CAMPress, ComposedPress, DecodingPress, DMSPress, @@ -282,9 +281,6 @@ def _setup_press(self): assert key_channel_compression_ratio is not None, "key_channel_compression_ratio must be set for ThinKPress" press.key_channel_compression_ratio = key_channel_compression_ratio logger.info(f"Set ThinKPress key_channel_compression_ratio to {key_channel_compression_ratio}") - elif isinstance(press, CAMPress): - press.compression_ratio = compression_ratio - logger.info(f"Set CAMPress compression_ratio to {compression_ratio}") elif isinstance(press, DecodingPress): press.compression_interval = self.config.compression_interval or press.compression_interval press.target_size = self.config.target_size or press.target_size From 77668751e640d65c50ea6d20e950de56626ebe19 Mon Sep 17 00:00:00 2001 From: Saransh Agrawal Date: Tue, 24 Mar 2026 19:05:47 -0700 Subject: [PATCH 05/13] Added CAMPress Signed-off-by: Saransh Agrawal --- evaluation/evaluate_registry.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/evaluation/evaluate_registry.py b/evaluation/evaluate_registry.py index 1c31cc03..37265233 100644 --- a/evaluation/evaluate_registry.py +++ b/evaluation/evaluate_registry.py @@ -108,6 +108,9 @@ "adakv_compactor": AdaKVPress(CompactorPress()), "no_press": None, "cam_streaming_llm": CAMPress(base_press=StreamingLLMPress()), + "cam_knorm": CAMPress(base_press=KnormPress()), + "cam_adakv_snapkv": CAMPress(base_press=AdaKVPress(SnapKVPress())), + "cam_tova": CAMPress(base_press=TOVAPress()), "decoding_knorm": DecodingPress(base_press=KnormPress()), "decoding_streaming_llm": DecodingPress(base_press=StreamingLLMPress()), "decoding_tova": DecodingPress(base_press=TOVAPress()), From 8aa326d3ef4ffd060923f0a063e0843e7ecf84f4 Mon Sep 17 00:00:00 2001 From: Saransh Agrawal Date: Tue, 24 Mar 2026 19:07:47 -0700 Subject: [PATCH 06/13] Added hidden_states_buffer Signed-off-by: Saransh Agrawal --- kvpress/presses/cam_press.py | 68 ++++++++++++++++++++++++------------ 1 file changed, 45 insertions(+), 23 deletions(-) diff --git a/kvpress/presses/cam_press.py b/kvpress/presses/cam_press.py index 94b3dfd7..ea0c559c 100644 --- a/kvpress/presses/cam_press.py +++ b/kvpress/presses/cam_press.py @@ -5,7 +5,6 @@ import logging import math -from collections import defaultdict from dataclasses import dataclass import torch @@ -13,9 +12,9 @@ from transformers import QuantizedCache from transformers.models.llama.modeling_llama import repeat_kv, rotate_half +from kvpress.presses.adakv_press import AdaKVPress from kvpress.presses.decoding_press import DecodingPress from kvpress.presses.scorer_press import ScorerPress -from kvpress.presses.adakv_press import AdaKVPress from kvpress.utils import extract_keys_and_values, get_prerope_query_states logger = logging.getLogger(__name__) @@ -54,21 +53,18 @@ class CAMPress(DecodingPress): """ base_press: ScorerPress | AdaKVPress = None - compression_interval: int = 2 - target_size: int = 3048 + compression_interval: int = 512 + target_size: int = 2048 hidden_states_buffer_size: int = 256 merge_budget: int = 32 - use_triton: bool = False + use_triton: bool = True def __post_init__(self): - assert isinstance(self.base_press, (ScorerPress, AdaKVPress)), "CAMPress requires a ScorerPress as base_press" - assert self.compression_interval > 0, "compression_interval must be greater than 0" - assert self.target_size > 0, "target_size must be greater than 0" + super().__post_init__() assert self.merge_budget > 0, "merge_budget must be positive " assert isinstance(self.use_triton, bool), "use_triton must be a boolean" - # State Variables - self.layer_step_counts = defaultdict(int) + # To maintain cumulative attention sum across generation steps self._running_attn_sum: dict[int, torch.Tensor] = {} if self.use_triton and not HAS_TRITON: @@ -88,10 +84,9 @@ def compress( layer_idx = int(module.layer_idx) cache_len = keys.shape[2] - n_to_evict = cache_len-self.target_size + n_to_evict = cache_len - self.target_size target_compression_ratio = self._find_target_compression_ratio(cache_len, self.target_size) - if n_to_evict <= 0: return keys, values @@ -124,7 +119,9 @@ def compress( values = self._torch_merge(values, merge_indices, kept_indices, attentions, self.merge_budget) # Physical Pruning - kept_indices_expand = kept_indices.view(bsz, 1, self.target_size, 1).expand(bsz, num_key_value_heads, self.target_size, head_dim) + kept_indices_expand = kept_indices.view(bsz, 1, self.target_size, 1).expand( + bsz, num_key_value_heads, self.target_size, head_dim + ) keys = keys.gather(2, kept_indices_expand).contiguous() values = values.gather(2, kept_indices_expand).contiguous() @@ -150,6 +147,9 @@ def forward_hook( if kwargs["cache_position"][-1] <= q_len: return output + # All hidden_states_buffer code is borrowed from DecodingPress + self.hidden_states_buffer[layer_idx].append(hidden_states.detach().clone()) + cache_layer = cache.layers[module.layer_idx] keys, values = extract_keys_and_values(cache, layer_idx) bsz, num_key_value_heads, seq_len, _ = keys.shape @@ -159,7 +159,7 @@ def forward_hook( if attentions is None: attentions = self._compute_current_token_attention(module, hidden_states, keys, kwargs) else: - attentions = attentions[:,:,-1:,:] + attentions = attentions[:, :, -1:, :] attentions = self._aggregate_attention_per_kv_head(attentions, num_key_value_heads) @@ -184,10 +184,14 @@ def forward_hook( self.layer_step_counts[layer_idx] += 1 # Trigger interval-based bulk eviction - if (self.layer_step_counts[layer_idx] >= self.compression_interval and seq_len>self.target_size) or (q_len >= self.target_size): + if (self.layer_step_counts[layer_idx] >= self.compression_interval and seq_len > self.target_size) or ( + q_len >= self.target_size + ): + # Apply compression using cumulative attention scores and buffered hidden states attn_squeezed = self._running_attn_sum[layer_idx] - keys, values = self.compress(module, hidden_states, keys, values, attn_squeezed, kwargs) + buffered_hidden_states = torch.cat(self.hidden_states_buffer[layer_idx], dim=1) + keys, values = self.compress(module, buffered_hidden_states, keys, values, attn_squeezed, kwargs) # Update cache with compressed keys and values if isinstance(cache, QuantizedCache): @@ -201,13 +205,22 @@ def forward_hook( cache_layer.values = values self.layer_step_counts[layer_idx] = 0 + # Always clear the buffer after compression - otherwise there's a mismatch between + # hidden states buffer and kv cache + self.hidden_states_buffer[layer_idx] = [] + + self.hidden_states_buffer[layer_idx] = ( + self.hidden_states_buffer[layer_idx][-self.hidden_states_buffer_size :] + if self.hidden_states_buffer_size > 0 + else [] + ) return output def reset(self): """Reset per-sequence state.""" - self.layer_step_counts = defaultdict(int) - self._running_attn_sum: dict[int, torch.Tensor] = {} + super().reset() + self._running_attn_sum = {} @staticmethod def _compute_current_token_attention( @@ -394,7 +407,9 @@ def _cam_merge_kernel( # Compute mean attention of suffix [target_start:] via cumulative sums (O(1)) cumsum_base = attn_cumsum_ptr + batch_idx * cs_stride_batch + head_idx * cs_stride_head - attn_suffix_sum = tl.load(cumsum_base + seq_len * cs_stride_seq) - tl.load(cumsum_base + target_start * cs_stride_seq) + attn_suffix_sum = tl.load(cumsum_base + seq_len * cs_stride_seq) - tl.load( + cumsum_base + target_start * cs_stride_seq + ) attn_suffix_len = seq_len - target_start mean_attn = attn_suffix_sum / attn_suffix_len @@ -411,7 +426,10 @@ def _cam_merge_kernel( # Bernoulli draw using pre-computed random thresholds rand_val = tl.load( - rand_thresholds_ptr + batch_idx * (tl.num_programs(1) * tl.num_programs(2)) + head_idx * tl.num_programs(2) + merge_idx + rand_thresholds_ptr + + batch_idx * (tl.num_programs(1) * tl.num_programs(2)) + + head_idx * tl.num_programs(2) + + merge_idx ) if merge_prob > rand_val: @@ -421,14 +439,18 @@ def _cam_merge_kernel( value_base = value_states_ptr + batch_idx * v_stride_batch + head_idx * v_stride_head # Load value vector of the token being merged - merge_token_value = tl.load(value_base + merge_token_pos * v_stride_seq + dim_offsets * v_stride_dim, mask=dim_mask, other=0.0) + merge_token_value = tl.load( + value_base + merge_token_pos * v_stride_seq + dim_offsets * v_stride_dim, mask=dim_mask, other=0.0 + ) contribution = merge_token_value / num_targets # Scatter-add contribution equally across target tokens for i in range(merge_budget): if i < num_targets: target_offset = target_start + i - target_token_pos = tl.load(kept_token_ids_ptr + batch_idx * idx_stride_batch + target_offset * idx_stride_seq) + target_token_pos = tl.load( + kept_token_ids_ptr + batch_idx * idx_stride_batch + target_offset * idx_stride_seq + ) target_value_ptrs = value_base + target_token_pos * v_stride_seq + dim_offsets * v_stride_dim - tl.atomic_add(target_value_ptrs, contribution, mask=dim_mask) \ No newline at end of file + tl.atomic_add(target_value_ptrs, contribution, mask=dim_mask) From 1373541dbf570166969b88783fab8f8d1cd7d604 Mon Sep 17 00:00:00 2001 From: Saransh Agrawal Date: Tue, 24 Mar 2026 19:08:34 -0700 Subject: [PATCH 07/13] Removed CAM from default presses Signed-off-by: Saransh Agrawal --- tests/default_presses.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/default_presses.py b/tests/default_presses.py index 8f7c9994..413f1ea6 100644 --- a/tests/default_presses.py +++ b/tests/default_presses.py @@ -4,7 +4,6 @@ import numpy as np from kvpress import ( - CAMPress, CompactorPress, CURPress, DuoAttentionPress, @@ -142,11 +141,4 @@ def post_init_from_model(self, model): }, ], }, - { - "cls": CAMPress, - "kwargs": [ - {"base_press": StreamingLLMPress(), "compression_ratio": 0.2}, - {"base_press": StreamingLLMPress(), "compression_ratio": 0.8}, - ], - }, ] From dbc67123581e58c6038bbf8335195fc875644d17 Mon Sep 17 00:00:00 2001 From: Saransh Agrawal Date: Tue, 24 Mar 2026 19:09:56 -0700 Subject: [PATCH 08/13] Modified decoding compression tests to use both decoding and cam press as factory methods Signed-off-by: Saransh Agrawal --- tests/test_decoding_compression.py | 83 ++++++++++++++++++------------ 1 file changed, 51 insertions(+), 32 deletions(-) diff --git a/tests/test_decoding_compression.py b/tests/test_decoding_compression.py index dd30c7f4..4f2a87b2 100644 --- a/tests/test_decoding_compression.py +++ b/tests/test_decoding_compression.py @@ -1,8 +1,8 @@ -# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 """ -Test script to verify that DecodingPress actually compresses during decoding. +Test script to verify that DecodingPress and CAMPress compress during decoding. """ import logging @@ -11,6 +11,7 @@ from transformers import DynamicCache, pipeline from kvpress import ( + CAMPress, CompactorPress, DecodingPress, KnormPress, @@ -26,19 +27,35 @@ logger = logging.getLogger(__name__) +def make_decoding_press(target_size: int, compression_interval: int, base_press: ScorerPress): + return DecodingPress( + base_press=base_press, + compression_interval=compression_interval, + target_size=target_size, + ) + + +def make_cam_press(target_size: int, compression_interval: int, base_press: ScorerPress): + return CAMPress( + base_press=base_press, + compression_interval=compression_interval, + target_size=target_size, + ) + + +@pytest.mark.parametrize("press_factory", [make_decoding_press, make_cam_press], ids=["DecodingPress", "CAMPress"]) @pytest.mark.parametrize("token_buffer_size", [32, 64, 128]) -def test_decoding_compression(token_buffer_size): - """Test that DecodingPress compresses the cache during decoding.""" +def test_decoding_compression(press_factory, token_buffer_size): + """Test that decoding presses compress the cache during decoding.""" # Initialize pipeline with a small model pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") # Create a DecodingPress with KnormPress - press = DecodingPress( - base_press=KnormPress(compression_ratio=0.5), # Remove 50% of tokens - compression_interval=4, # Compress every 4 tokens - target_size=token_buffer_size, - ) + + press = press_factory( + target_size=token_buffer_size, compression_interval=4, base_press=KnormPress(compression_ratio=0.5) + ) # Remove 50% of tokens # Create cache cache = DynamicCache() @@ -61,7 +78,8 @@ def test_decoding_compression(token_buffer_size): ) -def test_prefill_decoding_press_calls_both_phases(): +@pytest.mark.parametrize("press_factory", [make_decoding_press, make_cam_press], ids=["DecodingPress", "CAMPress"]) +def test_prefill_decoding_press_calls_both_phases(press_factory): """Test that PrefillDecodingPress calls both prefilling and decoding presses.""" # Initialize pipeline @@ -70,7 +88,7 @@ def test_prefill_decoding_press_calls_both_phases(): # Create PrefillDecodingPress with both presses combined_press = PrefillDecodingPress( prefilling_press=KnormPress(compression_ratio=0.6), # Compress to 60% during prefill - decoding_press=DecodingPress(base_press=KnormPress(), compression_interval=3, target_size=48), + decoding_press=press_factory(target_size=48, compression_interval=3, base_press=KnormPress()), ) # Test context and question @@ -95,14 +113,14 @@ def test_prefill_decoding_press_calls_both_phases(): ) -def test_decoding_press_without_prefill(): - """Test that DecodingPress works correctly when used standalone (no prefill compression).""" +@pytest.mark.parametrize("press_factory", [make_decoding_press, make_cam_press], ids=["DecodingPress", "CAMPress"]) +def test_decoding_press_without_prefill(press_factory): + """Test that decoding presses work correctly when used standalone (no prefill compression).""" # Initialize pipeline pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") - # Create DecodingPress only - decoding_press = DecodingPress(base_press=KnormPress(compression_ratio=0.4), compression_interval=5, target_size=64) + press = press_factory(target_size=64, compression_interval=5, base_press=KnormPress(compression_ratio=0.4)) # Test context and question context = "The quick brown fox jumps over the lazy dog. " * 8 @@ -110,7 +128,7 @@ def test_decoding_press_without_prefill(): # Run pipeline cache = DynamicCache() - pipe(context, question=question, press=decoding_press, cache=cache, max_new_tokens=25) + pipe(context, question=question, press=press, cache=cache, max_new_tokens=25) # Check that cache was compressed during decoding for layer_idx, cache_layer in enumerate(cache.layers): @@ -125,7 +143,8 @@ def test_decoding_press_without_prefill(): ) -def test_prefill_decoding_press_decoding_only(): +@pytest.mark.parametrize("press_factory", [make_decoding_press, make_cam_press], ids=["DecodingPress", "CAMPress"]) +def test_prefill_decoding_press_decoding_only(press_factory): """Test PrefillDecodingPress with only decoding press (no prefill compression).""" # Initialize pipeline @@ -134,8 +153,8 @@ def test_prefill_decoding_press_decoding_only(): # Create PrefillDecodingPress with only decoding press combined_press = PrefillDecodingPress( prefilling_press=None, - decoding_press=DecodingPress( - base_press=KnormPress(compression_ratio=0.6), compression_interval=4, target_size=56 + decoding_press=press_factory( + target_size=56, compression_interval=4, base_press=KnormPress(compression_ratio=0.6) ), ) @@ -160,7 +179,8 @@ def test_prefill_decoding_press_decoding_only(): ) -def test_decoding_press_equivalence(): +@pytest.mark.parametrize("press_factory", [make_decoding_press, make_cam_press], ids=["DecodingPress", "CAMPress"]) +def test_decoding_press_equivalence(press_factory): """Test that DecodingPress standalone yields same result as PrefillDecodingPress with decoding only.""" # Set random seed for reproducibility @@ -170,13 +190,15 @@ def test_decoding_press_equivalence(): pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") # Create standalone decoding press - decoding_press = DecodingPress(base_press=KnormPress(compression_ratio=0.5), compression_interval=3, target_size=52) + decoding_press = press_factory( + target_size=52, compression_interval=3, base_press=KnormPress(compression_ratio=0.5) + ) # Create PrefillDecodingPress with only decoding press combined_press = PrefillDecodingPress( prefilling_press=None, - decoding_press=DecodingPress( - base_press=KnormPress(compression_ratio=0.5), compression_interval=3, target_size=52 + decoding_press=press_factory( + target_size=52, compression_interval=3, base_press=KnormPress(compression_ratio=0.5) ), ) @@ -217,8 +239,9 @@ def test_decoding_press_equivalence(): """ +@pytest.mark.parametrize("press_factory", [make_decoding_press, make_cam_press], ids=["DecodingPress", "CAMPress"]) @pytest.mark.parametrize("press_config", default_presses) -def test_all_presses_work_with_decoding_press(press_config): +def test_all_presses_work_with_decoding_press(press_factory, press_config): """Test that all default presses work as base presses for DecodingPress.""" # Initialize pipeline @@ -246,7 +269,7 @@ def test_all_presses_work_with_decoding_press(press_config): return # Create DecodingPress with this base press - decoding_press = DecodingPress(base_press=base_press, compression_interval=3, target_size=48) + decoding_press = press_factory(target_size=48, compression_interval=3, base_press=base_press) # Test context and question context = "The quick brown fox jumps over the lazy dog. " * 8 @@ -271,7 +294,8 @@ def test_all_presses_work_with_decoding_press(press_config): ), f"{press_cls.__name__}: Layer {layer_idx} cache size {layer_seq_len} not in expected range [{target_size}-{max_expected_size}]" # noqa: E501 -def test_compression_actually_reduces_memory(): +@pytest.mark.parametrize("press_factory", [make_decoding_press, make_cam_press], ids=["DecodingPress", "CAMPress"]) +def test_compression_actually_reduces_memory(press_factory): """Test that compression actually reduces memory usage compared to no compression.""" pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") @@ -283,12 +307,7 @@ def test_compression_actually_reduces_memory(): cache_uncompressed = DynamicCache() result_uncompressed = pipe(context, question=question, cache=cache_uncompressed, max_new_tokens=25) - # Run with compression - press = DecodingPress( - base_press=KnormPress(compression_ratio=0.3), # Aggressive compression - compression_interval=3, - target_size=40, - ) + press = press_factory(target_size=40, compression_interval=3, base_press=KnormPress(compression_ratio=0.3)) cache_compressed = DynamicCache() result_compressed = pipe(context, question=question, press=press, cache=cache_compressed, max_new_tokens=25) From b91c182a83b0ad503e69684732345486b0320662 Mon Sep 17 00:00:00 2001 From: Saransh Agrawal Date: Wed, 25 Mar 2026 00:26:52 -0700 Subject: [PATCH 09/13] Modified mean_attn to window Signed-off-by: Saransh Agrawal --- kvpress/presses/cam_press.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/kvpress/presses/cam_press.py b/kvpress/presses/cam_press.py index ea0c559c..bb03be46 100644 --- a/kvpress/presses/cam_press.py +++ b/kvpress/presses/cam_press.py @@ -57,7 +57,7 @@ class CAMPress(DecodingPress): target_size: int = 2048 hidden_states_buffer_size: int = 256 merge_budget: int = 32 - use_triton: bool = True + use_triton: bool = False def __post_init__(self): super().__post_init__() @@ -278,13 +278,14 @@ def _torch_merge(self, values, merge_indices, kept_indices, attentions_per_kv, m # 3. Actual budget per merge token actual_budget = valid_mask.sum(dim=-1) - # 4. Suffix mean via cumsum + # 4. Window mean via cumsum (from target_start to min(target_start + merge_budget, seq_len)) attn_cumsum = torch.nn.functional.pad(attentions_per_kv.cumsum(dim=-1), (1, 0)) - total_sum = attn_cumsum[:, :, -1:] start_sum = attn_cumsum.gather(2, target_starts.unsqueeze(1).expand(-1, num_kv_heads, -1)) - suffix_sum = total_sum - start_sum - suffix_len = (seq_len - target_starts).unsqueeze(1) - mean_attn = suffix_sum / suffix_len + window_end = (target_starts + merge_budget).clamp(max=seq_len) + end_sum = attn_cumsum.gather(2, window_end.unsqueeze(1).expand(-1, num_kv_heads, -1)) + window_sum = end_sum - start_sum + window_len = (window_end - target_starts).unsqueeze(1) + mean_attn = window_sum / window_len # 5. Merge probability merge_token_attn = attentions_per_kv.gather(2, merge_indices.unsqueeze(1).expand(-1, num_kv_heads, -1)) @@ -405,13 +406,14 @@ def _cam_merge_kernel( if num_targets <= 0: return - # Compute mean attention of suffix [target_start:] via cumulative sums (O(1)) + # Compute mean attention of window [target_start:target_start+merge_budget] via cumulative sums (O(1)) cumsum_base = attn_cumsum_ptr + batch_idx * cs_stride_batch + head_idx * cs_stride_head - attn_suffix_sum = tl.load(cumsum_base + seq_len * cs_stride_seq) - tl.load( + window_end = tl.minimum(target_start + merge_budget, seq_len) + attn_window_sum = tl.load(cumsum_base + window_end * cs_stride_seq) - tl.load( cumsum_base + target_start * cs_stride_seq ) - attn_suffix_len = seq_len - target_start - mean_attn = attn_suffix_sum / attn_suffix_len + attn_window_len = window_end - target_start + mean_attn = attn_window_sum / attn_window_len # Load attention weight for the token being merged attn_weights_base = attn_weights_ptr + batch_idx * attn_stride_batch + head_idx * attn_stride_head From 016cc87682aa961ec60bd9710bf7cd417ccef903 Mon Sep 17 00:00:00 2001 From: Saransh Agrawal Date: Wed, 25 Mar 2026 14:31:55 -0700 Subject: [PATCH 10/13] Added CAMPress paper link Signed-off-by: Saransh Agrawal --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a459dd65..00eb78e8 100644 --- a/README.md +++ b/README.md @@ -139,7 +139,7 @@ Finally we provide wrapper presses that can be combined with other presses: - `DecodingPress` ([source](kvpress/presses/decoding_press.py)): allow for compression during decoding, see decoding section in this README. - `PrefillDecodingPress` ([source](kvpress/presses/prefill_decoding_press.py)): allow to compress both during prefilling and during decoding. - `DMSPress` ([source](kvpress/presses/dms_press.py), [paper](https://arxiv.org/abs/2506.05345)): evict keys and values with scores below a given threshold of any `ScorerPress` instead of relying on top-k scores. Support both prefilling and decoding (if decoding=True), but only supports dense-prefill and not sparse-prefill. -- `CAMPress` ([source](kvpress/presses/cam_press.py), [paper](https://arxiv.org/abs/2309.17453)): A decoding press that merges the kv cache of evicted tokens into keep tokens to preserve information. +- `CAMPress` ([source](kvpress/presses/cam_press.py), [paper](https://openreview.net/forum?id=LCTmppB165)): A decoding press that merges the kv cache of evicted tokens into keep tokens to preserve information. For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression) From 841b7187e59c6d7206a6e00703f6fc091bd9d806 Mon Sep 17 00:00:00 2001 From: Saransh Agrawal Date: Wed, 25 Mar 2026 14:33:43 -0700 Subject: [PATCH 11/13] Removed triton kernel, added docstrings Signed-off-by: Saransh Agrawal --- kvpress/presses/cam_press.py | 354 ++++++++++++----------------------- 1 file changed, 124 insertions(+), 230 deletions(-) diff --git a/kvpress/presses/cam_press.py b/kvpress/presses/cam_press.py index bb03be46..073f700e 100644 --- a/kvpress/presses/cam_press.py +++ b/kvpress/presses/cam_press.py @@ -19,37 +19,41 @@ logger = logging.getLogger(__name__) -try: - import triton - import triton.language as tl - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - - @dataclass class CAMPress(DecodingPress): """ Cache Merging (CaM) KV cache compression during decoding. - Evicted tokens' values are merged into their sequential neighbors using a - Bernoulli merge probability derived from relative attention scores. Keys are - pruned after merging. + Instead of simply evicting low-importance tokens, CaM merges their value vectors + into sequential neighbors before pruning. A Bernoulli merge mask, derived from the + ratio of the evicted token's cumulative attention to the mean attention of its merge + window, decides whether each merge occurs. This reduces the output + perturbation caused by cache eviction. + + This implementation extends the original per-step algorithm to support batched + eviction: tokens accumulate over ``compression_interval`` steps, then a bulk + merge-and-prune pass is applied. Setting ``compression_interval=1`` creates + the original per-step CaM behavior. Based on CaM (https://openreview.net/forum?id=LCTmppB165). Parameters ---------- base_press : ScorerPress - Scorer used to select which tokens to evict (e.g., StreamingLLMPress). - compression_ratio : float, default=0.0 - Fraction of prefill tokens to evict during decoding. - merge_budget : int or None, default=64 - Number of sequential neighbors to merge each evicted token into. - None merges into all remaining tokens after the evicted position. - use_triton : bool, default=True - Use the Triton kernel for merging when available (CUDA only). + The scorer press used to compute importance scores for tokens. + compression_interval : int, default=512 + Number of decoding steps between compression, i.e. compression will be applied + every compression_interval steps. + target_size : int, default=2048 + Target number of tokens to keep after compression. + hidden_states_buffer_size : int, default=256 + Maximum number of hidden states to keep before compression. Larger values use + more GPU memory. Note: Some presses don't need buffered hidden states and can + set this to 0 to use only the current hidden state for compression scoring. + merge_budget : int, default=32 + Number of sequential kept-token neighbors each evicted token's value is merged + into. Smaller values concentrate the merged information; larger values spread it + more evenly. """ base_press: ScorerPress | AdaKVPress = None @@ -57,21 +61,14 @@ class CAMPress(DecodingPress): target_size: int = 2048 hidden_states_buffer_size: int = 256 merge_budget: int = 32 - use_triton: bool = False def __post_init__(self): super().__post_init__() assert self.merge_budget > 0, "merge_budget must be positive " - assert isinstance(self.use_triton, bool), "use_triton must be a boolean" # To maintain cumulative attention sum across generation steps self._running_attn_sum: dict[int, torch.Tensor] = {} - if self.use_triton and not HAS_TRITON: - logger.warning( - f"Triton is not available. Falling back to PyTorch merge implementation for {self.__class__.__name__}." - ) - def compress( self, module: nn.Module, @@ -81,6 +78,47 @@ def compress( attentions: torch.Tensor, kwargs: dict, ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Merge evicted tokens' values into kept neighbors, then prune. + + Overrides `DecodingPress.compress` to implement the CaM merge-before-prune + strategy instead of plain eviction. + + Args: + module: The transformer attention module being compressed. + hidden_states: Buffered hidden states from recent decoding steps + (shape: [batch, buffer_len, hidden_dim]). + keys: Key cache (shape: [batch, n_kv_heads, seq_len, head_dim]). + values: Value cache (shape: [batch, n_kv_heads, seq_len, head_dim]). + attentions: Cumulative attention scores summed over generation steps + (shape: [batch, n_kv_heads, seq_len]). + kwargs: Additional keyword arguments forwarded to the base press scorer. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Compressed (keys, values) with seq_len + reduced to ``target_size``. + + Algorithm: + 1. **Score & select** — The base press scores every cached token. The + ``n_to_evict`` lowest-scored tokens are marked for eviction; the top + ``target_size`` are kept. + 2. **Pick merge candidates** — Among the evicted set, the ``k`` tokens with + the highest scores (with ties broken by later sequence position) are + selected for merging, where ``k = layer_step_counts[layer_idx]`` + (the number of new tokens since the last compression). + 3. **Cascading merge targets** — For each merge candidate, the + ``merge_budget`` kept tokens immediately after it (in sequence order) + form its merge window. + 4. **Merge probability** — The ratio of each merge token's cumulative + attention to the mean cumulative attention of its window is computed + ``clamp(A_i / avg(A_{j:j+m}), 0, 1)``. + 5. **Bernoulli sampling** — A binary merge mask is drawn from the + probability above. Tokens that pass the mask have their value vectors + divided by the window size and scatter-added into the window targets. + 6. **Physical pruning** — Evicted key/value entries are removed from the + cache, and the cumulative attention buffer is pruned to match. + """ + layer_idx = int(module.layer_idx) cache_len = keys.shape[2] @@ -97,7 +135,7 @@ def compress( scores = self.base_press.score(module, hidden_states, keys, values, None, kwargs) self.base_press.compression_ratio = old_cr - bsz, num_key_value_heads, _, head_dim = keys.shape + bsz, num_key_value_heads, seq_len, head_dim = keys.shape evict_indices = scores[:, 0, :].topk(n_to_evict, dim=-1, largest=False).indices evict_indices = torch.sort(evict_indices, dim=-1).values @@ -112,11 +150,51 @@ def compress( kept_indices = scores[:, 0, :].topk(self.target_size, dim=-1).indices kept_indices = torch.sort(kept_indices, dim=-1).values - if n_to_evict > 0: - if self.use_triton and HAS_TRITON and values.is_cuda: - values = self._triton_merge(values, merge_indices, kept_indices, attentions, self.merge_budget) - else: - values = self._torch_merge(values, merge_indices, kept_indices, attentions, self.merge_budget) + n_to_merge = merge_indices.shape[1] + + base_idx_first = torch.searchsorted(kept_indices, merge_indices[:, 0:1], right=True) + target_starts = torch.arange(n_to_merge, device=kept_indices.device).unsqueeze(0) + base_idx_first + + # 2. Build target window indices: [bsz, n_to_merge, merge_budget] + offsets = torch.arange(self.merge_budget, device=kept_indices.device) + window_idx = target_starts.unsqueeze(-1) + offsets.view(1, 1, -1) + valid_mask = window_idx < self.target_size + window_idx = window_idx.clamp(max=self.target_size - 1) + target_positions = kept_indices.gather(1, window_idx.view(bsz, -1)).view(bsz, n_to_merge, self.merge_budget) + + # 3. Actual budget per merge token + actual_budget = valid_mask.sum(dim=-1) + + # 4. Window mean via cumsum (from target_start to min(target_start + merge_budget, seq_len)) + attn_cumsum = torch.nn.functional.pad(attentions.cumsum(dim=-1), (1, 0)) + start_sum = attn_cumsum.gather(2, target_starts.unsqueeze(1).expand(-1, num_key_value_heads, -1)) + window_end = (target_starts + self.merge_budget).clamp(max=seq_len) + end_sum = attn_cumsum.gather(2, window_end.unsqueeze(1).expand(-1, num_key_value_heads, -1)) + window_sum = end_sum - start_sum + window_len = (window_end - target_starts).unsqueeze(1) + mean_attn = window_sum / window_len + + # 5. Merge probability + merge_token_attn = attentions.gather(2, merge_indices.unsqueeze(1).expand(-1, num_key_value_heads, -1)) + merge_prob = merge_token_attn / mean_attn + merge_prob = torch.where(torch.isnan(merge_prob), torch.zeros_like(merge_prob), merge_prob) + merge_prob = torch.where(torch.isinf(merge_prob), torch.ones_like(merge_prob), merge_prob) + merge_prob = merge_prob.clamp(0, 1) + + # 6. Bernoulli sampling + merge_mask = torch.bernoulli(merge_prob) + + # 7. Build contributions and scatter-add + merge_values = values.gather(2, merge_indices.view(bsz, 1, n_to_merge, 1).expand(-1, num_key_value_heads, -1, head_dim)) + scale = (merge_mask / actual_budget.unsqueeze(1)).unsqueeze(-1) + scale = torch.where(actual_budget.unsqueeze(1).unsqueeze(-1) == 0, torch.zeros_like(scale), scale) + contributions = merge_values * scale + contributions = contributions.unsqueeze(3).expand(-1, -1, -1, self.merge_budget, -1) + contributions = contributions * valid_mask.view(bsz, 1, n_to_merge, self.merge_budget, 1) + contributions = contributions.reshape(bsz, num_key_value_heads, n_to_merge * self.merge_budget, head_dim) + scatter_idx = target_positions.view(bsz, 1, n_to_merge * self.merge_budget, 1).expand(-1, num_key_value_heads, -1, head_dim) + + values.scatter_add_(2, scatter_idx, contributions) # Physical Pruning kept_indices_expand = kept_indices.view(bsz, 1, self.target_size, 1).expand( @@ -138,6 +216,19 @@ def forward_hook( kwargs: dict, output: list, ): + """ + Forward hook that manages cumulative attention tracking and interval-based compression. + + Extends `DecodingPress.forward_hook` with per-step attention accumulation. + + This hook: + This hook: + 1. Detects when we're in decoding phase (not prefilling) + 2. Accumulates hidden states in a buffer + 3. Accumulates cumulative attention A_bar = sum(A^k) in a buffer + 4. Applies compression every N steps + 5. Clears the buffer after compression + """ hidden_states = kwargs["hidden_states"] cache = kwargs["past_key_values"] q_len = hidden_states.shape[1] @@ -259,200 +350,3 @@ def _aggregate_attention_per_kv_head( bsz, _, seq_q, seq_k = attentions.shape return attentions.reshape(bsz, num_key_value_heads, group_size, seq_q, seq_k).mean(dim=2) - def _torch_merge(self, values, merge_indices, kept_indices, attentions_per_kv, merge_budget): - bsz, num_kv_heads, seq_len, head_dim = values.shape - n_merge = merge_indices.shape[1] - n_kept = kept_indices.shape[1] - - # 1. Cascading target starts - base_idx_first = torch.searchsorted(kept_indices, merge_indices[:, 0:1], right=True) - target_starts = torch.arange(n_merge, device=kept_indices.device).unsqueeze(0) + base_idx_first - - # 2. Build target window indices: [bsz, n_merge, merge_budget] - offsets = torch.arange(merge_budget, device=kept_indices.device) - window_idx = target_starts.unsqueeze(-1) + offsets.view(1, 1, -1) - valid_mask = window_idx < n_kept - window_idx = window_idx.clamp(max=n_kept - 1) - target_positions = kept_indices.gather(1, window_idx.view(bsz, -1)).view(bsz, n_merge, merge_budget) - - # 3. Actual budget per merge token - actual_budget = valid_mask.sum(dim=-1) - - # 4. Window mean via cumsum (from target_start to min(target_start + merge_budget, seq_len)) - attn_cumsum = torch.nn.functional.pad(attentions_per_kv.cumsum(dim=-1), (1, 0)) - start_sum = attn_cumsum.gather(2, target_starts.unsqueeze(1).expand(-1, num_kv_heads, -1)) - window_end = (target_starts + merge_budget).clamp(max=seq_len) - end_sum = attn_cumsum.gather(2, window_end.unsqueeze(1).expand(-1, num_kv_heads, -1)) - window_sum = end_sum - start_sum - window_len = (window_end - target_starts).unsqueeze(1) - mean_attn = window_sum / window_len - - # 5. Merge probability - merge_token_attn = attentions_per_kv.gather(2, merge_indices.unsqueeze(1).expand(-1, num_kv_heads, -1)) - merge_prob = merge_token_attn / mean_attn - merge_prob = torch.where(torch.isnan(merge_prob), torch.zeros_like(merge_prob), merge_prob) - merge_prob = torch.where(torch.isinf(merge_prob), torch.ones_like(merge_prob), merge_prob) - merge_prob = merge_prob.clamp(0, 1) - - # 6. Bernoulli sampling - merge_mask = torch.bernoulli(merge_prob) - - # 7. Build contributions and scatter-add - merge_values = values.gather(2, merge_indices.view(bsz, 1, n_merge, 1).expand(-1, num_kv_heads, -1, head_dim)) - scale = (merge_mask / actual_budget.unsqueeze(1)).unsqueeze(-1) - scale = torch.where(actual_budget.unsqueeze(1).unsqueeze(-1) == 0, torch.zeros_like(scale), scale) - contributions = merge_values * scale - contributions = contributions.unsqueeze(3).expand(-1, -1, -1, merge_budget, -1) - contributions = contributions * valid_mask.view(bsz, 1, n_merge, merge_budget, 1) - contributions = contributions.reshape(bsz, num_kv_heads, n_merge * merge_budget, head_dim) - scatter_idx = target_positions.view(bsz, 1, n_merge * merge_budget, 1).expand(-1, num_kv_heads, -1, head_dim) - - values.scatter_add_(2, scatter_idx, contributions) - return values - - def _triton_merge(self, values, merge_indices, kept_indices, attentions_per_kv, merge_budget): - """Pre-computes cascading start targets and prefix sums, then merges in a single Triton kernel.""" - bsz, num_kv_heads, _, head_dim = values.shape - n_merge = merge_indices.shape[1] - n_kept = kept_indices.shape[1] - attn_len = attentions_per_kv.shape[2] - - # 1. Pre-compute the cascading target - base_idx_first = torch.searchsorted(kept_indices, merge_indices[:, 0:1], right=True) - target_starts = torch.arange(n_merge, device=kept_indices.device).unsqueeze(0) - target_starts += base_idx_first - - # 2. Prefix sum for O(1) suffix-mean in kernel: prefix_sum[i] = sum(attn[0:i]) - attn_prefix_sum = torch.nn.functional.pad(attentions_per_kv.cumsum(dim=-1), (1, 0)) - - # 3. Pre-sampled random values for deterministic Bernoulli inside kernel - rand_thresholds = torch.rand((bsz, num_kv_heads, n_merge), device=values.device) - - BLOCK_D = triton.next_power_of_2(head_dim) - grid = (bsz, num_kv_heads, n_merge) - - _cam_merge_kernel[grid]( - value_states_ptr=values, - merge_token_ids_ptr=merge_indices, - kept_token_ids_ptr=kept_indices, - merge_target_starts_ptr=target_starts, - attn_cumsum_ptr=attn_prefix_sum, - attn_weights_ptr=attentions_per_kv, - rand_thresholds_ptr=rand_thresholds, - num_kept_tokens=n_kept, - merge_budget=merge_budget, - seq_len=attn_len, - v_stride_batch=values.stride(0), - v_stride_head=values.stride(1), - v_stride_seq=values.stride(2), - v_stride_dim=values.stride(3), - idx_stride_batch=merge_indices.stride(0), - idx_stride_seq=merge_indices.stride(1), - cs_stride_batch=attn_prefix_sum.stride(0), - cs_stride_head=attn_prefix_sum.stride(1), - cs_stride_seq=attn_prefix_sum.stride(2), - attn_stride_batch=attentions_per_kv.stride(0), - attn_stride_head=attentions_per_kv.stride(1), - attn_stride_seq=attentions_per_kv.stride(2), - head_dim=head_dim, - BLOCK_D=BLOCK_D, - ) - return values - - -if HAS_TRITON: - - @triton.jit - def _cam_merge_kernel( - value_states_ptr, - merge_token_ids_ptr, - kept_token_ids_ptr, - merge_target_starts_ptr, - attn_cumsum_ptr, - attn_weights_ptr, - rand_thresholds_ptr, - num_kept_tokens, - merge_budget, - seq_len, - v_stride_batch, - v_stride_head, - v_stride_seq, - v_stride_dim, - idx_stride_batch, - idx_stride_seq, - cs_stride_batch, - cs_stride_head, - cs_stride_seq, - attn_stride_batch, - attn_stride_head, - attn_stride_seq, - head_dim: tl.constexpr, - BLOCK_D: tl.constexpr, - ): - """ - CaM scatter-add merge using cumulative attention with O(1) suffix-mean. - Grid: (batch_size, num_kv_heads, n_merge) - """ - batch_idx = tl.program_id(0) - head_idx = tl.program_id(1) - merge_idx = tl.program_id(2) - - # Load the merge token position and target window start - merge_token_pos = tl.load(merge_token_ids_ptr + batch_idx * idx_stride_batch + merge_idx * idx_stride_seq) - target_start = tl.load(merge_target_starts_ptr + batch_idx * idx_stride_batch + merge_idx * idx_stride_seq) - - # Calculate number of target tokens (handling edge near end of sequence) - num_targets = tl.minimum(merge_budget, num_kept_tokens - target_start) - if num_targets <= 0: - return - - # Compute mean attention of window [target_start:target_start+merge_budget] via cumulative sums (O(1)) - cumsum_base = attn_cumsum_ptr + batch_idx * cs_stride_batch + head_idx * cs_stride_head - window_end = tl.minimum(target_start + merge_budget, seq_len) - attn_window_sum = tl.load(cumsum_base + window_end * cs_stride_seq) - tl.load( - cumsum_base + target_start * cs_stride_seq - ) - attn_window_len = window_end - target_start - mean_attn = attn_window_sum / attn_window_len - - # Load attention weight for the token being merged - attn_weights_base = attn_weights_ptr + batch_idx * attn_stride_batch + head_idx * attn_stride_head - merge_token_attn = tl.load(attn_weights_base + merge_token_pos * attn_stride_seq) - - # Calculate merge probability (with nan/inf safe-guards) - if mean_attn == 0.0: - merge_prob = 1.0 if merge_token_attn > 0 else 0.0 - else: - merge_prob = merge_token_attn / mean_attn - merge_prob = tl.minimum(tl.maximum(merge_prob, 0.0), 1.0) - - # Bernoulli draw using pre-computed random thresholds - rand_val = tl.load( - rand_thresholds_ptr - + batch_idx * (tl.num_programs(1) * tl.num_programs(2)) - + head_idx * tl.num_programs(2) - + merge_idx - ) - - if merge_prob > rand_val: - dim_offsets = tl.arange(0, BLOCK_D) - dim_mask = dim_offsets < head_dim - - value_base = value_states_ptr + batch_idx * v_stride_batch + head_idx * v_stride_head - - # Load value vector of the token being merged - merge_token_value = tl.load( - value_base + merge_token_pos * v_stride_seq + dim_offsets * v_stride_dim, mask=dim_mask, other=0.0 - ) - contribution = merge_token_value / num_targets - - # Scatter-add contribution equally across target tokens - for i in range(merge_budget): - if i < num_targets: - target_offset = target_start + i - target_token_pos = tl.load( - kept_token_ids_ptr + batch_idx * idx_stride_batch + target_offset * idx_stride_seq - ) - - target_value_ptrs = value_base + target_token_pos * v_stride_seq + dim_offsets * v_stride_dim - tl.atomic_add(target_value_ptrs, contribution, mask=dim_mask) From d78febd54daaf948c529c373035a1d4e7f638508 Mon Sep 17 00:00:00 2001 From: Saransh Agrawal Date: Wed, 25 Mar 2026 17:28:35 -0700 Subject: [PATCH 12/13] Fixed style Signed-off-by: Saransh Agrawal --- kvpress/presses/cam_press.py | 13 ++++++++----- tests/test_decoding_compression.py | 4 +--- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/kvpress/presses/cam_press.py b/kvpress/presses/cam_press.py index 073f700e..4e2f7312 100644 --- a/kvpress/presses/cam_press.py +++ b/kvpress/presses/cam_press.py @@ -19,6 +19,7 @@ logger = logging.getLogger(__name__) + @dataclass class CAMPress(DecodingPress): """ @@ -151,7 +152,7 @@ def compress( kept_indices = torch.sort(kept_indices, dim=-1).values n_to_merge = merge_indices.shape[1] - + base_idx_first = torch.searchsorted(kept_indices, merge_indices[:, 0:1], right=True) target_starts = torch.arange(n_to_merge, device=kept_indices.device).unsqueeze(0) + base_idx_first @@ -185,14 +186,18 @@ def compress( merge_mask = torch.bernoulli(merge_prob) # 7. Build contributions and scatter-add - merge_values = values.gather(2, merge_indices.view(bsz, 1, n_to_merge, 1).expand(-1, num_key_value_heads, -1, head_dim)) + merge_values = values.gather( + 2, merge_indices.view(bsz, 1, n_to_merge, 1).expand(-1, num_key_value_heads, -1, head_dim) + ) scale = (merge_mask / actual_budget.unsqueeze(1)).unsqueeze(-1) scale = torch.where(actual_budget.unsqueeze(1).unsqueeze(-1) == 0, torch.zeros_like(scale), scale) contributions = merge_values * scale contributions = contributions.unsqueeze(3).expand(-1, -1, -1, self.merge_budget, -1) contributions = contributions * valid_mask.view(bsz, 1, n_to_merge, self.merge_budget, 1) contributions = contributions.reshape(bsz, num_key_value_heads, n_to_merge * self.merge_budget, head_dim) - scatter_idx = target_positions.view(bsz, 1, n_to_merge * self.merge_budget, 1).expand(-1, num_key_value_heads, -1, head_dim) + scatter_idx = target_positions.view(bsz, 1, n_to_merge * self.merge_budget, 1).expand( + -1, num_key_value_heads, -1, head_dim + ) values.scatter_add_(2, scatter_idx, contributions) @@ -221,7 +226,6 @@ def forward_hook( Extends `DecodingPress.forward_hook` with per-step attention accumulation. - This hook: This hook: 1. Detects when we're in decoding phase (not prefilling) 2. Accumulates hidden states in a buffer @@ -349,4 +353,3 @@ def _aggregate_attention_per_kv_head( group_size = num_query_heads // num_key_value_heads bsz, _, seq_q, seq_k = attentions.shape return attentions.reshape(bsz, num_key_value_heads, group_size, seq_q, seq_k).mean(dim=2) - diff --git a/tests/test_decoding_compression.py b/tests/test_decoding_compression.py index 4f2a87b2..ad998d3c 100644 --- a/tests/test_decoding_compression.py +++ b/tests/test_decoding_compression.py @@ -190,9 +190,7 @@ def test_decoding_press_equivalence(press_factory): pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") # Create standalone decoding press - decoding_press = press_factory( - target_size=52, compression_interval=3, base_press=KnormPress(compression_ratio=0.5) - ) + decoding_press = press_factory(target_size=52, compression_interval=3, base_press=KnormPress(compression_ratio=0.5)) # Create PrefillDecodingPress with only decoding press combined_press = PrefillDecodingPress( From c103ed49fa2a25d612cfeb4064f02fc5bca2227e Mon Sep 17 00:00:00 2001 From: Saransh Agrawal Date: Thu, 26 Mar 2026 10:06:59 -0700 Subject: [PATCH 13/13] Removed initialization from base_press Signed-off-by: Saransh Agrawal --- kvpress/presses/cam_press.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kvpress/presses/cam_press.py b/kvpress/presses/cam_press.py index 4e2f7312..b8aa606a 100644 --- a/kvpress/presses/cam_press.py +++ b/kvpress/presses/cam_press.py @@ -57,7 +57,7 @@ class CAMPress(DecodingPress): more evenly. """ - base_press: ScorerPress | AdaKVPress = None + base_press: ScorerPress | AdaKVPress compression_interval: int = 512 target_size: int = 2048 hidden_states_buffer_size: int = 256