Skip to content

Commit b210fdb

Browse files
feat: integrate KVPress for KV cache compression (#366) (#623)
* feat: integrate KVPress for KV cache compression Add NVIDIA KVPress as an optional dependency, enabling 31 KV cache compression strategies for causal language models. Includes algorithm class, test tester, and compatibility updates across existing LLM algorithms. * feat: bump kvpress to >=0.5.2, add FastKVzipPress kvpress 0.5.2 relaxes the datasets<3 constraint and reverts to transformers>=4.56, resolving the dependency conflict. uv sync --extra kvpress now works without workarounds. * feat: add press_kwargs for press-specific parameters Allow passing additional keyword arguments to the press constructor via the press_kwargs hyperparameter, enabling fine-grained control over press-specific settings like window_size, n_sink, etc. * fix: compatibility, press_kwargs, unit tests, remove wrappers - Replace tags.QUANTIZER with explicit LLM algorithm names to avoid false symmetry matches with diffuser algorithms - Fix SmashConfig.add() dict flattening: only flatten when key is a registered algorithm name, not for dict-valued hyperparameters - Remove wrapper/special presses from PRESS_TYPES (CriticalKVPress and others that don't accept compression_ratio directly) - Add unit tests for press type validation and kwargs forwarding - Add SnapKV integration test with press_kwargs * feat: add KV_CACHER tag, replace explicit kvpress references Add a new KV_CACHER algorithm tag for KV cache compression algorithms, separate from CACHER (used by diffuser cachers). Use the tag in all LLM algorithm compatibility lists instead of explicit "kvpress" strings. * refactor: rename KV_CACHER tag to KV_COMPRESSOR, improve docstrings * docs: document excluded wrapper presses in kvpress docstring * refactor: remove KV_COMPRESSOR tag, reference kvpress by name Drop the dedicated KV_COMPRESSOR tag and use tags.PRUNER as kvpress's group tag, matching how other pruners are categorized. Replace all tags.KV_COMPRESSOR references in compatible_before/after lists with the string "kvpress" to align with the repo convention of naming specific algorithms in compatibility lists. * fix: handle transformers pipeline in kvpress _apply Add pipeline guard at the top of _apply to delegate to _apply_to_model_within_transformers_pipeline when the model is a TextGenerationPipeline, matching the pattern used by gptq, torch_compile, and other algorithms. * ci: register requires_kvpress marker for optional extra * test: mark kvpress as require_kvpress --------- Co-authored-by: Gaspar Rochette <gaspar.rochette@pruna.ai>
1 parent 6d1d9b0 commit b210fdb

15 files changed

Lines changed: 249 additions & 9 deletions

File tree

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ conflicts = [
8787
[{ extra = "intel" }, { extra = "stable-fast" }, { extra = "stable-fast-extraindex" }],
8888
[{ extra = "intel" }, { extra = "full" }, { extra = "stable-fast-extraindex" }],
8989
[{ extra = "intel" }, { extra = "vllm" }],
90+
[{ extra = "kvpress" }, { extra = "vbench" }],
9091
]
9192

9293
[tool.uv.sources]
@@ -248,6 +249,9 @@ intel = [
248249
"torch>=2.7.0,<2.9.0",
249250
"torchvision>=0.22.0,<0.24.0",
250251
]
252+
kvpress = [
253+
"kvpress>=0.5.2",
254+
]
251255

252256
[build-system]
253257
requires = ["hatchling"]

src/pruna/algorithms/gptq_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class GPTQ(PrunaAlgorithmBase):
4646
processor_required: bool = False
4747
runs_on: list[str] = ["cuda"]
4848
dataset_required: bool = True
49-
compatible_after: Iterable[str] = ["torch_compile", "sage_attn"]
49+
compatible_after: Iterable[str] = ["torch_compile", "sage_attn", "kvpress"]
5050
required_install: str = (
5151
"You must first install the base package with ``pip install pruna`` "
5252
"before installing the GPTQ extension with ``pip install pruna[gptq] --extra-index-url https://prunaai.pythonanywhere.com/``"

src/pruna/algorithms/half.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class Half(PrunaAlgorithmBase):
5050
"stable_fast",
5151
"torch_compile",
5252
"ifw",
53+
"kvpress",
5354
"whisper_s2t",
5455
"sage_attn",
5556
"hyper",

src/pruna/algorithms/hqq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class HQQ(PrunaAlgorithmBase):
6363
runs_on: list[str] = ["cuda"]
6464
dataset_required: bool = False
6565
compatible_before: Iterable[str] = ["torch_structured", "moe_kernel_tuner"]
66-
compatible_after: Iterable[str] = ["torch_compile", "sage_attn", "moe_kernel_tuner"]
66+
compatible_after: Iterable[str] = ["torch_compile", "sage_attn", "kvpress", "moe_kernel_tuner"]
6767
disjointly_compatible_before: Iterable[str] = []
6868
disjointly_compatible_after: Iterable[str] = ["torchao"]
6969

src/pruna/algorithms/huggingface_llm_int8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class LLMInt8(PrunaAlgorithmBase):
5858
runs_on: list[str] = ["cuda", "accelerate"]
5959
save_fn: None = None
6060
compatible_before: Iterable[str] = ["moe_kernel_tuner"]
61-
compatible_after: Iterable[str] = ["torch_compile", "sage_attn", "moe_kernel_tuner"]
61+
compatible_after: Iterable[str] = ["torch_compile", "sage_attn", "kvpress", "moe_kernel_tuner"]
6262

6363
def get_hyperparameters(self) -> list:
6464
"""

src/pruna/algorithms/kvpress.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import functools
18+
from collections.abc import Iterable
19+
from typing import Any, Dict
20+
21+
from ConfigSpace import CategoricalHyperparameter, UniformFloatHyperparameter
22+
23+
from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase
24+
from pruna.algorithms.base.tags import AlgorithmTag as tags
25+
from pruna.config.hyperparameters import UnconstrainedHyperparameter
26+
from pruna.config.smash_config import SmashConfigPrefixWrapper
27+
from pruna.engine.model_checks import is_causal_lm, is_transformers_pipeline_with_causal_lm
28+
from pruna.engine.save import SAVE_FUNCTIONS
29+
30+
PRESS_TYPES = [
31+
"CompactorPress",
32+
"CURPress",
33+
"ExpectedAttentionPress",
34+
"ExpectedAttentionStatsPress",
35+
"FastKVzipPress",
36+
"FinchPress",
37+
"KnormPress",
38+
"KVzapPress",
39+
"KVzipPress",
40+
"KeyDiffPress",
41+
"LagKVPress",
42+
"LeverageScorePress",
43+
"NonCausalAttnPress",
44+
"ObservedAttentionPress",
45+
"PyramidKVPress",
46+
"QFilterPress",
47+
"RandomPress",
48+
"SnapKVPress",
49+
"StreamingLLMPress",
50+
"TOVAPress",
51+
]
52+
53+
54+
class KVPress(PrunaAlgorithmBase):
55+
"""
56+
Compress the KV cache of causal language models using KVPress.
57+
58+
KVPress is a library by NVIDIA that provides over 20 compression strategies (presses) for
59+
reducing the memory footprint of the key-value cache during long-context inference. Each press
60+
scores and prunes KV pairs after the prefill phase according to a chosen importance criterion.
61+
62+
This integration covers all scorer and standalone presses. Wrapper presses (e.g., ChunkPress,
63+
AdaKVPress, PerLayerCompressionPress) that require a nested scorer press as input are not
64+
included, as well as ThinKPress which compresses along the channel dimension with a different
65+
parameter interface.
66+
"""
67+
68+
algorithm_name: str = "kvpress"
69+
group_tags: list[tags] = [tags.PRUNER]
70+
save_fn: SAVE_FUNCTIONS = SAVE_FUNCTIONS.reapply
71+
references: dict[str, str] = {
72+
"GitHub": "https://github.com/NVIDIA/kvpress",
73+
"Article": "https://huggingface.co/blog/nvidia/kvpress",
74+
}
75+
required_install: str = "pip install pruna[kvpress]"
76+
tokenizer_required: bool = False
77+
processor_required: bool = False
78+
dataset_required: bool = False
79+
runs_on: list[str] = ["cuda"]
80+
compatible_before: Iterable[str] = [
81+
"awq", "gptq", "half", "hqq", "llm_int8",
82+
"quanto", "sage_attn", "torchao", "moe_kernel_tuner",
83+
]
84+
compatible_after: Iterable[str] = ["torch_compile", "moe_kernel_tuner"]
85+
86+
def get_hyperparameters(self) -> list:
87+
"""
88+
Configure all algorithm-specific hyperparameters with ConfigSpace.
89+
90+
Returns
91+
-------
92+
list
93+
The hyperparameters.
94+
"""
95+
return [
96+
CategoricalHyperparameter(
97+
"press_type",
98+
choices=PRESS_TYPES,
99+
default_value="ExpectedAttentionPress",
100+
meta={"desc": "The KV cache compression strategy to use."},
101+
),
102+
UniformFloatHyperparameter(
103+
"compression_ratio",
104+
lower=0.0,
105+
upper=1.0,
106+
default_value=0.5,
107+
meta={"desc": "Fraction of KV pairs to remove. 0.0 means no compression."},
108+
),
109+
UnconstrainedHyperparameter(
110+
"press_kwargs",
111+
default_value=None,
112+
meta={"desc": "Additional keyword arguments passed to the press constructor."},
113+
),
114+
]
115+
116+
def model_check_fn(self, model: Any) -> bool:
117+
"""
118+
Check if the model is a causal language model or a pipeline wrapping one.
119+
120+
Parameters
121+
----------
122+
model : Any
123+
The model to check.
124+
125+
Returns
126+
-------
127+
bool
128+
True if the model is compatible with KV cache compression, False otherwise.
129+
"""
130+
return is_causal_lm(model) or is_transformers_pipeline_with_causal_lm(model)
131+
132+
def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
133+
"""
134+
Wrap the model's generate method to apply KV cache compression via a press context manager.
135+
136+
Parameters
137+
----------
138+
model : Any
139+
The causal language model to compress.
140+
smash_config : SmashConfigPrefixWrapper
141+
The algorithm-prefixed configuration containing press_type, compression_ratio, and press_kwargs.
142+
143+
Returns
144+
-------
145+
Any
146+
The model with its generate method wrapped to compress the KV cache on each call.
147+
"""
148+
if is_transformers_pipeline_with_causal_lm(model):
149+
return self._apply_to_model_within_transformers_pipeline(model, smash_config)
150+
151+
imported_modules = self.import_algorithm_packages()
152+
153+
press_type = smash_config["press_type"]
154+
compression_ratio = smash_config["compression_ratio"]
155+
press_kwargs = smash_config["press_kwargs"] or {}
156+
157+
press_cls = imported_modules[press_type]
158+
press = press_cls(compression_ratio=compression_ratio, **press_kwargs)
159+
160+
original_generate = model.generate
161+
162+
@functools.wraps(original_generate)
163+
def generate_with_press(*args, **kwargs):
164+
with press(model):
165+
return original_generate(*args, **kwargs)
166+
167+
model.generate = generate_with_press
168+
model._kvpress_original_generate = original_generate
169+
model._kvpress_press = press
170+
171+
return model
172+
173+
def import_algorithm_packages(self) -> Dict[str, Any]:
174+
"""
175+
Lazily import kvpress and collect all supported press classes.
176+
177+
Returns
178+
-------
179+
Dict[str, Any]
180+
A dictionary mapping press class names to their classes.
181+
"""
182+
import kvpress
183+
184+
return {name: getattr(kvpress, name) for name in PRESS_TYPES}

src/pruna/algorithms/llm_compressor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class LLMCompressor(PrunaAlgorithmBase):
5252
dataset_required: bool = True
5353
runs_on: list[str] = ["cuda"]
5454
compatible_before: Iterable[str] = ["moe_kernel_tuner"]
55-
compatible_after: Iterable[str] = ["sage_attn", "moe_kernel_tuner"]
55+
compatible_after: Iterable[str] = ["sage_attn", "kvpress", "moe_kernel_tuner"]
5656
required_install = "``uv pip install 'pruna[awq]'``"
5757

5858
def get_hyperparameters(self) -> list:

src/pruna/algorithms/moe_kernel_tuner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,13 @@ class MoeKernelTuner(PrunaAlgorithmBase):
5050
dataset_required: bool = False
5151
compatible_before: Iterable[str] = [
5252
"awq", "deepcache", "diffusers_int8", "fastercache", "flash_attn3",
53-
"fora", "hqq", "hqq_diffusers", "llm_int8", "pab", "padding_pruning",
53+
"fora", "hqq", "hqq_diffusers", "kvpress", "llm_int8", "pab", "padding_pruning",
5454
"qkv_diffusers", "quanto", "reduce_noe", "ring_attn", "sage_attn",
5555
"torch_compile", "torchao",
5656
]
5757
compatible_after: Iterable[str] = [
5858
"awq", "deepcache", "diffusers_int8", "fastercache", "flash_attn3",
59-
"fora", "hqq", "hqq_diffusers", "llm_int8", "pab", "padding_pruning",
59+
"fora", "hqq", "hqq_diffusers", "kvpress", "llm_int8", "pab", "padding_pruning",
6060
"qkv_diffusers", "quanto", "ring_attn", "sage_attn",
6161
"torch_compile", "torchao",
6262
]

src/pruna/algorithms/quanto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class Quanto(PrunaAlgorithmBase):
5252
compatible_before: Iterable[str] = ["qkv_diffusers", "moe_kernel_tuner"]
5353
compatible_after: Iterable[str] = [
5454
"deepcache",
55+
"kvpress",
5556
"sage_attn",
5657
"text_to_image_distillation_inplace_perp",
5758
"text_to_image_distillation_lora",

src/pruna/algorithms/sage_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class SageAttn(PrunaAlgorithmBase):
5050
runs_on: list[str] = ["cuda", "accelerate"]
5151
dataset_required: bool = False
5252
compatible_before: Iterable[str | tags] = [tags.QUANTIZER, "moe_kernel_tuner"]
53-
compatible_after: Iterable[str | tags] = ["torch_compile", tags.CACHER, "moe_kernel_tuner"]
53+
compatible_after: Iterable[str | tags] = ["torch_compile", tags.CACHER, "kvpress", "moe_kernel_tuner"]
5454

5555
def model_check_fn(self, model: Any) -> bool:
5656
"""

0 commit comments

Comments
 (0)