diff --git a/README.md b/README.md index d23ee89a..00eb78e8 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,7 @@ Finally we provide wrapper presses that can be combined with other presses: - `DecodingPress` ([source](kvpress/presses/decoding_press.py)): allow for compression during decoding, see decoding section in this README. - `PrefillDecodingPress` ([source](kvpress/presses/prefill_decoding_press.py)): allow to compress both during prefilling and during decoding. - `DMSPress` ([source](kvpress/presses/dms_press.py), [paper](https://arxiv.org/abs/2506.05345)): evict keys and values with scores below a given threshold of any `ScorerPress` instead of relying on top-k scores. Support both prefilling and decoding (if decoding=True), but only supports dense-prefill and not sparse-prefill. +- `CAMPress` ([source](kvpress/presses/cam_press.py), [paper](https://openreview.net/forum?id=LCTmppB165)): A decoding press that merges the kv cache of evicted tokens into keep tokens to preserve information. For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression) diff --git a/evaluation/evaluate_registry.py b/evaluation/evaluate_registry.py index e81e7fda..37265233 100644 --- a/evaluation/evaluate_registry.py +++ b/evaluation/evaluate_registry.py @@ -15,6 +15,7 @@ from kvpress import ( AdaKVPress, BlockPress, + CAMPress, ChunkKVPress, CompactorPress, ComposedPress, @@ -106,6 +107,10 @@ "compactor": CompactorPress(), "adakv_compactor": AdaKVPress(CompactorPress()), "no_press": None, + "cam_streaming_llm": CAMPress(base_press=StreamingLLMPress()), + "cam_knorm": CAMPress(base_press=KnormPress()), + "cam_adakv_snapkv": CAMPress(base_press=AdaKVPress(SnapKVPress())), + "cam_tova": CAMPress(base_press=TOVAPress()), "decoding_knorm": DecodingPress(base_press=KnormPress()), "decoding_streaming_llm": DecodingPress(base_press=StreamingLLMPress()), "decoding_tova": DecodingPress(base_press=TOVAPress()), diff --git a/kvpress/__init__.py b/kvpress/__init__.py index 0f08519e..1cc8c590 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -7,6 +7,7 @@ from kvpress.presses.adakv_press import AdaKVPress from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress from kvpress.presses.block_press import BlockPress +from kvpress.presses.cam_press import CAMPress from kvpress.presses.chunk_press import ChunkPress from kvpress.presses.chunkkv_press import ChunkKVPress from kvpress.presses.compactor_press import CompactorPress @@ -75,6 +76,7 @@ "KeyDiffPress", "KVzipPress", "ExpectedAttentionStatsPress", + "CAMPress", "DecodingPress", "PrefillDecodingPress", "CompactorPress", diff --git a/kvpress/presses/cam_press.py b/kvpress/presses/cam_press.py new file mode 100644 index 00000000..b8aa606a --- /dev/null +++ b/kvpress/presses/cam_press.py @@ -0,0 +1,355 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import logging +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn +from transformers import QuantizedCache +from transformers.models.llama.modeling_llama import repeat_kv, rotate_half + +from kvpress.presses.adakv_press import AdaKVPress +from kvpress.presses.decoding_press import DecodingPress +from kvpress.presses.scorer_press import ScorerPress +from kvpress.utils import extract_keys_and_values, get_prerope_query_states + +logger = logging.getLogger(__name__) + + +@dataclass +class CAMPress(DecodingPress): + """ + Cache Merging (CaM) KV cache compression during decoding. + + Instead of simply evicting low-importance tokens, CaM merges their value vectors + into sequential neighbors before pruning. A Bernoulli merge mask, derived from the + ratio of the evicted token's cumulative attention to the mean attention of its merge + window, decides whether each merge occurs. This reduces the output + perturbation caused by cache eviction. + + This implementation extends the original per-step algorithm to support batched + eviction: tokens accumulate over ``compression_interval`` steps, then a bulk + merge-and-prune pass is applied. Setting ``compression_interval=1`` creates + the original per-step CaM behavior. + + Based on CaM (https://openreview.net/forum?id=LCTmppB165). + + Parameters + ---------- + base_press : ScorerPress + The scorer press used to compute importance scores for tokens. + compression_interval : int, default=512 + Number of decoding steps between compression, i.e. compression will be applied + every compression_interval steps. + target_size : int, default=2048 + Target number of tokens to keep after compression. + hidden_states_buffer_size : int, default=256 + Maximum number of hidden states to keep before compression. Larger values use + more GPU memory. Note: Some presses don't need buffered hidden states and can + set this to 0 to use only the current hidden state for compression scoring. + merge_budget : int, default=32 + Number of sequential kept-token neighbors each evicted token's value is merged + into. Smaller values concentrate the merged information; larger values spread it + more evenly. + """ + + base_press: ScorerPress | AdaKVPress + compression_interval: int = 512 + target_size: int = 2048 + hidden_states_buffer_size: int = 256 + merge_budget: int = 32 + + def __post_init__(self): + super().__post_init__() + assert self.merge_budget > 0, "merge_budget must be positive " + + # To maintain cumulative attention sum across generation steps + self._running_attn_sum: dict[int, torch.Tensor] = {} + + def compress( + self, + module: nn.Module, + hidden_states: torch.Tensor, + keys: torch.Tensor, + values: torch.Tensor, + attentions: torch.Tensor, + kwargs: dict, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Merge evicted tokens' values into kept neighbors, then prune. + + Overrides `DecodingPress.compress` to implement the CaM merge-before-prune + strategy instead of plain eviction. + + Args: + module: The transformer attention module being compressed. + hidden_states: Buffered hidden states from recent decoding steps + (shape: [batch, buffer_len, hidden_dim]). + keys: Key cache (shape: [batch, n_kv_heads, seq_len, head_dim]). + values: Value cache (shape: [batch, n_kv_heads, seq_len, head_dim]). + attentions: Cumulative attention scores summed over generation steps + (shape: [batch, n_kv_heads, seq_len]). + kwargs: Additional keyword arguments forwarded to the base press scorer. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Compressed (keys, values) with seq_len + reduced to ``target_size``. + + Algorithm: + 1. **Score & select** — The base press scores every cached token. The + ``n_to_evict`` lowest-scored tokens are marked for eviction; the top + ``target_size`` are kept. + 2. **Pick merge candidates** — Among the evicted set, the ``k`` tokens with + the highest scores (with ties broken by later sequence position) are + selected for merging, where ``k = layer_step_counts[layer_idx]`` + (the number of new tokens since the last compression). + 3. **Cascading merge targets** — For each merge candidate, the + ``merge_budget`` kept tokens immediately after it (in sequence order) + form its merge window. + 4. **Merge probability** — The ratio of each merge token's cumulative + attention to the mean cumulative attention of its window is computed + ``clamp(A_i / avg(A_{j:j+m}), 0, 1)``. + 5. **Bernoulli sampling** — A binary merge mask is drawn from the + probability above. Tokens that pass the mask have their value vectors + divided by the window size and scatter-added into the window targets. + 6. **Physical pruning** — Evicted key/value entries are removed from the + cache, and the cumulative attention buffer is pruned to match. + """ + + layer_idx = int(module.layer_idx) + cache_len = keys.shape[2] + + n_to_evict = cache_len - self.target_size + + target_compression_ratio = self._find_target_compression_ratio(cache_len, self.target_size) + + if n_to_evict <= 0: + return keys, values + + # Temporary override base press ratio to get correct topK scores + old_cr = self.base_press.compression_ratio + self.base_press.compression_ratio = target_compression_ratio + scores = self.base_press.score(module, hidden_states, keys, values, None, kwargs) + self.base_press.compression_ratio = old_cr + + bsz, num_key_value_heads, seq_len, head_dim = keys.shape + + evict_indices = scores[:, 0, :].topk(n_to_evict, dim=-1, largest=False).indices + evict_indices = torch.sort(evict_indices, dim=-1).values + + evict_scores = scores[:, 0, :].gather(-1, evict_indices) + # Flip so later sequence positions come first; stable sort preserves this order for ties + k = self.layer_step_counts[layer_idx] + order = evict_scores.flip(-1).argsort(dim=-1, descending=True, stable=True)[:, :k] + merge_indices = evict_indices.gather(-1, n_to_evict - 1 - order) + merge_indices = torch.sort(merge_indices, dim=-1).values + + kept_indices = scores[:, 0, :].topk(self.target_size, dim=-1).indices + kept_indices = torch.sort(kept_indices, dim=-1).values + + n_to_merge = merge_indices.shape[1] + + base_idx_first = torch.searchsorted(kept_indices, merge_indices[:, 0:1], right=True) + target_starts = torch.arange(n_to_merge, device=kept_indices.device).unsqueeze(0) + base_idx_first + + # 2. Build target window indices: [bsz, n_to_merge, merge_budget] + offsets = torch.arange(self.merge_budget, device=kept_indices.device) + window_idx = target_starts.unsqueeze(-1) + offsets.view(1, 1, -1) + valid_mask = window_idx < self.target_size + window_idx = window_idx.clamp(max=self.target_size - 1) + target_positions = kept_indices.gather(1, window_idx.view(bsz, -1)).view(bsz, n_to_merge, self.merge_budget) + + # 3. Actual budget per merge token + actual_budget = valid_mask.sum(dim=-1) + + # 4. Window mean via cumsum (from target_start to min(target_start + merge_budget, seq_len)) + attn_cumsum = torch.nn.functional.pad(attentions.cumsum(dim=-1), (1, 0)) + start_sum = attn_cumsum.gather(2, target_starts.unsqueeze(1).expand(-1, num_key_value_heads, -1)) + window_end = (target_starts + self.merge_budget).clamp(max=seq_len) + end_sum = attn_cumsum.gather(2, window_end.unsqueeze(1).expand(-1, num_key_value_heads, -1)) + window_sum = end_sum - start_sum + window_len = (window_end - target_starts).unsqueeze(1) + mean_attn = window_sum / window_len + + # 5. Merge probability + merge_token_attn = attentions.gather(2, merge_indices.unsqueeze(1).expand(-1, num_key_value_heads, -1)) + merge_prob = merge_token_attn / mean_attn + merge_prob = torch.where(torch.isnan(merge_prob), torch.zeros_like(merge_prob), merge_prob) + merge_prob = torch.where(torch.isinf(merge_prob), torch.ones_like(merge_prob), merge_prob) + merge_prob = merge_prob.clamp(0, 1) + + # 6. Bernoulli sampling + merge_mask = torch.bernoulli(merge_prob) + + # 7. Build contributions and scatter-add + merge_values = values.gather( + 2, merge_indices.view(bsz, 1, n_to_merge, 1).expand(-1, num_key_value_heads, -1, head_dim) + ) + scale = (merge_mask / actual_budget.unsqueeze(1)).unsqueeze(-1) + scale = torch.where(actual_budget.unsqueeze(1).unsqueeze(-1) == 0, torch.zeros_like(scale), scale) + contributions = merge_values * scale + contributions = contributions.unsqueeze(3).expand(-1, -1, -1, self.merge_budget, -1) + contributions = contributions * valid_mask.view(bsz, 1, n_to_merge, self.merge_budget, 1) + contributions = contributions.reshape(bsz, num_key_value_heads, n_to_merge * self.merge_budget, head_dim) + scatter_idx = target_positions.view(bsz, 1, n_to_merge * self.merge_budget, 1).expand( + -1, num_key_value_heads, -1, head_dim + ) + + values.scatter_add_(2, scatter_idx, contributions) + + # Physical Pruning + kept_indices_expand = kept_indices.view(bsz, 1, self.target_size, 1).expand( + bsz, num_key_value_heads, self.target_size, head_dim + ) + keys = keys.gather(2, kept_indices_expand).contiguous() + values = values.gather(2, kept_indices_expand).contiguous() + + # prune cumulative attentions + expanded_indices = kept_indices.unsqueeze(1).expand(bsz, num_key_value_heads, -1) + self._running_attn_sum[layer_idx] = self._running_attn_sum[layer_idx].gather(2, expanded_indices).contiguous() + + return keys, values + + def forward_hook( + self, + module: nn.Module, + input: list[torch.Tensor], + kwargs: dict, + output: list, + ): + """ + Forward hook that manages cumulative attention tracking and interval-based compression. + + Extends `DecodingPress.forward_hook` with per-step attention accumulation. + + This hook: + 1. Detects when we're in decoding phase (not prefilling) + 2. Accumulates hidden states in a buffer + 3. Accumulates cumulative attention A_bar = sum(A^k) in a buffer + 4. Applies compression every N steps + 5. Clears the buffer after compression + """ + hidden_states = kwargs["hidden_states"] + cache = kwargs["past_key_values"] + q_len = hidden_states.shape[1] + layer_idx = int(module.layer_idx) + + # Only operate during decoding + if kwargs["cache_position"][-1] <= q_len: + return output + + # All hidden_states_buffer code is borrowed from DecodingPress + self.hidden_states_buffer[layer_idx].append(hidden_states.detach().clone()) + + cache_layer = cache.layers[module.layer_idx] + keys, values = extract_keys_and_values(cache, layer_idx) + bsz, num_key_value_heads, seq_len, _ = keys.shape + + # Accumulate Cumulative Attention over generation steps + attentions = output[1] if len(output) > 1 and output[1] is not None else None + if attentions is None: + attentions = self._compute_current_token_attention(module, hidden_states, keys, kwargs) + else: + attentions = attentions[:, :, -1:, :] + + attentions = self._aggregate_attention_per_kv_head(attentions, num_key_value_heads) + + if attentions is not None: + attn_squeezed = attentions.squeeze(2) + + if layer_idx not in self._running_attn_sum: + self._running_attn_sum[layer_idx] = attn_squeezed.clone() + else: + # Pad running sum for the new token growth + prev_len = self._running_attn_sum[layer_idx].shape[-1] + pad_len = seq_len - prev_len + + if pad_len > 0: + pad = torch.zeros( + (bsz, num_key_value_heads, pad_len), device=attn_squeezed.device, dtype=attn_squeezed.dtype + ) + self._running_attn_sum[layer_idx] = torch.cat([self._running_attn_sum[layer_idx], pad], dim=-1) + + self._running_attn_sum[layer_idx] += attn_squeezed + + self.layer_step_counts[layer_idx] += 1 + + # Trigger interval-based bulk eviction + if (self.layer_step_counts[layer_idx] >= self.compression_interval and seq_len > self.target_size) or ( + q_len >= self.target_size + ): + + # Apply compression using cumulative attention scores and buffered hidden states + attn_squeezed = self._running_attn_sum[layer_idx] + buffered_hidden_states = torch.cat(self.hidden_states_buffer[layer_idx], dim=1) + keys, values = self.compress(module, buffered_hidden_states, keys, values, attn_squeezed, kwargs) + + # Update cache with compressed keys and values + if isinstance(cache, QuantizedCache): + cache_layer._quantized_keys = cache_layer._quantize(keys, axis=cache_layer.axis_key) + cache_layer._quantized_values = cache_layer._quantize(values, axis=cache_layer.axis_value) + cache_layer.keys = torch.zeros(0, dtype=keys.dtype, device=keys.device) # type: ignore[index] + cache_layer.values = torch.zeros(0, dtype=keys.dtype, device=keys.device) # type: ignore[index] + cache_layer.cumulative_length = keys.shape[2] + else: + cache_layer.keys = keys + cache_layer.values = values + + self.layer_step_counts[layer_idx] = 0 + # Always clear the buffer after compression - otherwise there's a mismatch between + # hidden states buffer and kv cache + self.hidden_states_buffer[layer_idx] = [] + + self.hidden_states_buffer[layer_idx] = ( + self.hidden_states_buffer[layer_idx][-self.hidden_states_buffer_size :] + if self.hidden_states_buffer_size > 0 + else [] + ) + + return output + + def reset(self): + """Reset per-sequence state.""" + super().reset() + self._running_attn_sum = {} + + @staticmethod + def _compute_current_token_attention( + module: nn.Module, + hidden_states: torch.Tensor, + keys: torch.Tensor, + kwargs: dict, + ) -> torch.Tensor: + """Compute softmax attention from the last query token to all cached keys.""" + _, num_key_value_heads, cache_len, head_dim = keys.shape + num_query_heads = module.config.num_attention_heads + num_key_value_groups = num_query_heads // num_key_value_heads + + query_states = get_prerope_query_states(module, hidden_states) + query_states = query_states[:, :, -1:, :] + + cos, sin = kwargs["position_embeddings"] + cos = cos[:, -1:, :].unsqueeze(1) + sin = sin[:, -1:, :].unsqueeze(1) + query_states = (query_states * cos) + (rotate_half(query_states) * sin) + + keys_repeated = repeat_kv(keys, num_key_value_groups) + scores = torch.matmul(query_states, keys_repeated.transpose(-2, -1)) / math.sqrt(head_dim) + return torch.nn.functional.softmax(scores, dim=-1, dtype=torch.float32).to(query_states.dtype) + + @staticmethod + def _aggregate_attention_per_kv_head( + attentions: torch.Tensor, + num_key_value_heads: int, + ) -> torch.Tensor: + """Average attention scores across query heads that share a KV head.""" + num_query_heads = attentions.shape[1] + if num_query_heads == num_key_value_heads: + return attentions + group_size = num_query_heads // num_key_value_heads + bsz, _, seq_q, seq_k = attentions.shape + return attentions.reshape(bsz, num_key_value_heads, group_size, seq_q, seq_k).mean(dim=2) diff --git a/tests/test_decoding_compression.py b/tests/test_decoding_compression.py index dd30c7f4..ad998d3c 100644 --- a/tests/test_decoding_compression.py +++ b/tests/test_decoding_compression.py @@ -1,8 +1,8 @@ -# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 """ -Test script to verify that DecodingPress actually compresses during decoding. +Test script to verify that DecodingPress and CAMPress compress during decoding. """ import logging @@ -11,6 +11,7 @@ from transformers import DynamicCache, pipeline from kvpress import ( + CAMPress, CompactorPress, DecodingPress, KnormPress, @@ -26,19 +27,35 @@ logger = logging.getLogger(__name__) +def make_decoding_press(target_size: int, compression_interval: int, base_press: ScorerPress): + return DecodingPress( + base_press=base_press, + compression_interval=compression_interval, + target_size=target_size, + ) + + +def make_cam_press(target_size: int, compression_interval: int, base_press: ScorerPress): + return CAMPress( + base_press=base_press, + compression_interval=compression_interval, + target_size=target_size, + ) + + +@pytest.mark.parametrize("press_factory", [make_decoding_press, make_cam_press], ids=["DecodingPress", "CAMPress"]) @pytest.mark.parametrize("token_buffer_size", [32, 64, 128]) -def test_decoding_compression(token_buffer_size): - """Test that DecodingPress compresses the cache during decoding.""" +def test_decoding_compression(press_factory, token_buffer_size): + """Test that decoding presses compress the cache during decoding.""" # Initialize pipeline with a small model pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") # Create a DecodingPress with KnormPress - press = DecodingPress( - base_press=KnormPress(compression_ratio=0.5), # Remove 50% of tokens - compression_interval=4, # Compress every 4 tokens - target_size=token_buffer_size, - ) + + press = press_factory( + target_size=token_buffer_size, compression_interval=4, base_press=KnormPress(compression_ratio=0.5) + ) # Remove 50% of tokens # Create cache cache = DynamicCache() @@ -61,7 +78,8 @@ def test_decoding_compression(token_buffer_size): ) -def test_prefill_decoding_press_calls_both_phases(): +@pytest.mark.parametrize("press_factory", [make_decoding_press, make_cam_press], ids=["DecodingPress", "CAMPress"]) +def test_prefill_decoding_press_calls_both_phases(press_factory): """Test that PrefillDecodingPress calls both prefilling and decoding presses.""" # Initialize pipeline @@ -70,7 +88,7 @@ def test_prefill_decoding_press_calls_both_phases(): # Create PrefillDecodingPress with both presses combined_press = PrefillDecodingPress( prefilling_press=KnormPress(compression_ratio=0.6), # Compress to 60% during prefill - decoding_press=DecodingPress(base_press=KnormPress(), compression_interval=3, target_size=48), + decoding_press=press_factory(target_size=48, compression_interval=3, base_press=KnormPress()), ) # Test context and question @@ -95,14 +113,14 @@ def test_prefill_decoding_press_calls_both_phases(): ) -def test_decoding_press_without_prefill(): - """Test that DecodingPress works correctly when used standalone (no prefill compression).""" +@pytest.mark.parametrize("press_factory", [make_decoding_press, make_cam_press], ids=["DecodingPress", "CAMPress"]) +def test_decoding_press_without_prefill(press_factory): + """Test that decoding presses work correctly when used standalone (no prefill compression).""" # Initialize pipeline pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") - # Create DecodingPress only - decoding_press = DecodingPress(base_press=KnormPress(compression_ratio=0.4), compression_interval=5, target_size=64) + press = press_factory(target_size=64, compression_interval=5, base_press=KnormPress(compression_ratio=0.4)) # Test context and question context = "The quick brown fox jumps over the lazy dog. " * 8 @@ -110,7 +128,7 @@ def test_decoding_press_without_prefill(): # Run pipeline cache = DynamicCache() - pipe(context, question=question, press=decoding_press, cache=cache, max_new_tokens=25) + pipe(context, question=question, press=press, cache=cache, max_new_tokens=25) # Check that cache was compressed during decoding for layer_idx, cache_layer in enumerate(cache.layers): @@ -125,7 +143,8 @@ def test_decoding_press_without_prefill(): ) -def test_prefill_decoding_press_decoding_only(): +@pytest.mark.parametrize("press_factory", [make_decoding_press, make_cam_press], ids=["DecodingPress", "CAMPress"]) +def test_prefill_decoding_press_decoding_only(press_factory): """Test PrefillDecodingPress with only decoding press (no prefill compression).""" # Initialize pipeline @@ -134,8 +153,8 @@ def test_prefill_decoding_press_decoding_only(): # Create PrefillDecodingPress with only decoding press combined_press = PrefillDecodingPress( prefilling_press=None, - decoding_press=DecodingPress( - base_press=KnormPress(compression_ratio=0.6), compression_interval=4, target_size=56 + decoding_press=press_factory( + target_size=56, compression_interval=4, base_press=KnormPress(compression_ratio=0.6) ), ) @@ -160,7 +179,8 @@ def test_prefill_decoding_press_decoding_only(): ) -def test_decoding_press_equivalence(): +@pytest.mark.parametrize("press_factory", [make_decoding_press, make_cam_press], ids=["DecodingPress", "CAMPress"]) +def test_decoding_press_equivalence(press_factory): """Test that DecodingPress standalone yields same result as PrefillDecodingPress with decoding only.""" # Set random seed for reproducibility @@ -170,13 +190,13 @@ def test_decoding_press_equivalence(): pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") # Create standalone decoding press - decoding_press = DecodingPress(base_press=KnormPress(compression_ratio=0.5), compression_interval=3, target_size=52) + decoding_press = press_factory(target_size=52, compression_interval=3, base_press=KnormPress(compression_ratio=0.5)) # Create PrefillDecodingPress with only decoding press combined_press = PrefillDecodingPress( prefilling_press=None, - decoding_press=DecodingPress( - base_press=KnormPress(compression_ratio=0.5), compression_interval=3, target_size=52 + decoding_press=press_factory( + target_size=52, compression_interval=3, base_press=KnormPress(compression_ratio=0.5) ), ) @@ -217,8 +237,9 @@ def test_decoding_press_equivalence(): """ +@pytest.mark.parametrize("press_factory", [make_decoding_press, make_cam_press], ids=["DecodingPress", "CAMPress"]) @pytest.mark.parametrize("press_config", default_presses) -def test_all_presses_work_with_decoding_press(press_config): +def test_all_presses_work_with_decoding_press(press_factory, press_config): """Test that all default presses work as base presses for DecodingPress.""" # Initialize pipeline @@ -246,7 +267,7 @@ def test_all_presses_work_with_decoding_press(press_config): return # Create DecodingPress with this base press - decoding_press = DecodingPress(base_press=base_press, compression_interval=3, target_size=48) + decoding_press = press_factory(target_size=48, compression_interval=3, base_press=base_press) # Test context and question context = "The quick brown fox jumps over the lazy dog. " * 8 @@ -271,7 +292,8 @@ def test_all_presses_work_with_decoding_press(press_config): ), f"{press_cls.__name__}: Layer {layer_idx} cache size {layer_seq_len} not in expected range [{target_size}-{max_expected_size}]" # noqa: E501 -def test_compression_actually_reduces_memory(): +@pytest.mark.parametrize("press_factory", [make_decoding_press, make_cam_press], ids=["DecodingPress", "CAMPress"]) +def test_compression_actually_reduces_memory(press_factory): """Test that compression actually reduces memory usage compared to no compression.""" pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto") @@ -283,12 +305,7 @@ def test_compression_actually_reduces_memory(): cache_uncompressed = DynamicCache() result_uncompressed = pipe(context, question=question, cache=cache_uncompressed, max_new_tokens=25) - # Run with compression - press = DecodingPress( - base_press=KnormPress(compression_ratio=0.3), # Aggressive compression - compression_interval=3, - target_size=40, - ) + press = press_factory(target_size=40, compression_interval=3, base_press=KnormPress(compression_ratio=0.3)) cache_compressed = DynamicCache() result_compressed = pipe(context, question=question, press=press, cache=cache_compressed, max_new_tokens=25)