|
| 1 | +# Copyright 2025 - Pruna AI GmbH. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from __future__ import annotations |
| 16 | + |
| 17 | +import functools |
| 18 | +from collections.abc import Iterable |
| 19 | +from typing import Any, Dict |
| 20 | + |
| 21 | +from ConfigSpace import CategoricalHyperparameter, UniformFloatHyperparameter |
| 22 | + |
| 23 | +from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase |
| 24 | +from pruna.algorithms.base.tags import AlgorithmTag as tags |
| 25 | +from pruna.config.hyperparameters import UnconstrainedHyperparameter |
| 26 | +from pruna.config.smash_config import SmashConfigPrefixWrapper |
| 27 | +from pruna.engine.model_checks import is_causal_lm, is_transformers_pipeline_with_causal_lm |
| 28 | +from pruna.engine.save import SAVE_FUNCTIONS |
| 29 | + |
| 30 | +PRESS_TYPES = [ |
| 31 | + "CompactorPress", |
| 32 | + "CURPress", |
| 33 | + "ExpectedAttentionPress", |
| 34 | + "ExpectedAttentionStatsPress", |
| 35 | + "FastKVzipPress", |
| 36 | + "FinchPress", |
| 37 | + "KnormPress", |
| 38 | + "KVzapPress", |
| 39 | + "KVzipPress", |
| 40 | + "KeyDiffPress", |
| 41 | + "LagKVPress", |
| 42 | + "LeverageScorePress", |
| 43 | + "NonCausalAttnPress", |
| 44 | + "ObservedAttentionPress", |
| 45 | + "PyramidKVPress", |
| 46 | + "QFilterPress", |
| 47 | + "RandomPress", |
| 48 | + "SnapKVPress", |
| 49 | + "StreamingLLMPress", |
| 50 | + "TOVAPress", |
| 51 | +] |
| 52 | + |
| 53 | + |
| 54 | +class KVPress(PrunaAlgorithmBase): |
| 55 | + """ |
| 56 | + Compress the KV cache of causal language models using KVPress. |
| 57 | +
|
| 58 | + KVPress is a library by NVIDIA that provides over 20 compression strategies (presses) for |
| 59 | + reducing the memory footprint of the key-value cache during long-context inference. Each press |
| 60 | + scores and prunes KV pairs after the prefill phase according to a chosen importance criterion. |
| 61 | +
|
| 62 | + This integration covers all scorer and standalone presses. Wrapper presses (e.g., ChunkPress, |
| 63 | + AdaKVPress, PerLayerCompressionPress) that require a nested scorer press as input are not |
| 64 | + included, as well as ThinKPress which compresses along the channel dimension with a different |
| 65 | + parameter interface. |
| 66 | + """ |
| 67 | + |
| 68 | + algorithm_name: str = "kvpress" |
| 69 | + group_tags: list[tags] = [tags.PRUNER] |
| 70 | + save_fn: SAVE_FUNCTIONS = SAVE_FUNCTIONS.reapply |
| 71 | + references: dict[str, str] = { |
| 72 | + "GitHub": "https://github.com/NVIDIA/kvpress", |
| 73 | + "Article": "https://huggingface.co/blog/nvidia/kvpress", |
| 74 | + } |
| 75 | + required_install: str = "pip install pruna[kvpress]" |
| 76 | + tokenizer_required: bool = False |
| 77 | + processor_required: bool = False |
| 78 | + dataset_required: bool = False |
| 79 | + runs_on: list[str] = ["cuda"] |
| 80 | + compatible_before: Iterable[str] = [ |
| 81 | + "awq", "gptq", "half", "hqq", "llm_int8", |
| 82 | + "quanto", "sage_attn", "torchao", "moe_kernel_tuner", |
| 83 | + ] |
| 84 | + compatible_after: Iterable[str] = ["torch_compile", "moe_kernel_tuner"] |
| 85 | + |
| 86 | + def get_hyperparameters(self) -> list: |
| 87 | + """ |
| 88 | + Configure all algorithm-specific hyperparameters with ConfigSpace. |
| 89 | +
|
| 90 | + Returns |
| 91 | + ------- |
| 92 | + list |
| 93 | + The hyperparameters. |
| 94 | + """ |
| 95 | + return [ |
| 96 | + CategoricalHyperparameter( |
| 97 | + "press_type", |
| 98 | + choices=PRESS_TYPES, |
| 99 | + default_value="ExpectedAttentionPress", |
| 100 | + meta={"desc": "The KV cache compression strategy to use."}, |
| 101 | + ), |
| 102 | + UniformFloatHyperparameter( |
| 103 | + "compression_ratio", |
| 104 | + lower=0.0, |
| 105 | + upper=1.0, |
| 106 | + default_value=0.5, |
| 107 | + meta={"desc": "Fraction of KV pairs to remove. 0.0 means no compression."}, |
| 108 | + ), |
| 109 | + UnconstrainedHyperparameter( |
| 110 | + "press_kwargs", |
| 111 | + default_value=None, |
| 112 | + meta={"desc": "Additional keyword arguments passed to the press constructor."}, |
| 113 | + ), |
| 114 | + ] |
| 115 | + |
| 116 | + def model_check_fn(self, model: Any) -> bool: |
| 117 | + """ |
| 118 | + Check if the model is a causal language model or a pipeline wrapping one. |
| 119 | +
|
| 120 | + Parameters |
| 121 | + ---------- |
| 122 | + model : Any |
| 123 | + The model to check. |
| 124 | +
|
| 125 | + Returns |
| 126 | + ------- |
| 127 | + bool |
| 128 | + True if the model is compatible with KV cache compression, False otherwise. |
| 129 | + """ |
| 130 | + return is_causal_lm(model) or is_transformers_pipeline_with_causal_lm(model) |
| 131 | + |
| 132 | + def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: |
| 133 | + """ |
| 134 | + Wrap the model's generate method to apply KV cache compression via a press context manager. |
| 135 | +
|
| 136 | + Parameters |
| 137 | + ---------- |
| 138 | + model : Any |
| 139 | + The causal language model to compress. |
| 140 | + smash_config : SmashConfigPrefixWrapper |
| 141 | + The algorithm-prefixed configuration containing press_type, compression_ratio, and press_kwargs. |
| 142 | +
|
| 143 | + Returns |
| 144 | + ------- |
| 145 | + Any |
| 146 | + The model with its generate method wrapped to compress the KV cache on each call. |
| 147 | + """ |
| 148 | + if is_transformers_pipeline_with_causal_lm(model): |
| 149 | + return self._apply_to_model_within_transformers_pipeline(model, smash_config) |
| 150 | + |
| 151 | + imported_modules = self.import_algorithm_packages() |
| 152 | + |
| 153 | + press_type = smash_config["press_type"] |
| 154 | + compression_ratio = smash_config["compression_ratio"] |
| 155 | + press_kwargs = smash_config["press_kwargs"] or {} |
| 156 | + |
| 157 | + press_cls = imported_modules[press_type] |
| 158 | + press = press_cls(compression_ratio=compression_ratio, **press_kwargs) |
| 159 | + |
| 160 | + original_generate = model.generate |
| 161 | + |
| 162 | + @functools.wraps(original_generate) |
| 163 | + def generate_with_press(*args, **kwargs): |
| 164 | + with press(model): |
| 165 | + return original_generate(*args, **kwargs) |
| 166 | + |
| 167 | + model.generate = generate_with_press |
| 168 | + model._kvpress_original_generate = original_generate |
| 169 | + model._kvpress_press = press |
| 170 | + |
| 171 | + return model |
| 172 | + |
| 173 | + def import_algorithm_packages(self) -> Dict[str, Any]: |
| 174 | + """ |
| 175 | + Lazily import kvpress and collect all supported press classes. |
| 176 | +
|
| 177 | + Returns |
| 178 | + ------- |
| 179 | + Dict[str, Any] |
| 180 | + A dictionary mapping press class names to their classes. |
| 181 | + """ |
| 182 | + import kvpress |
| 183 | + |
| 184 | + return {name: getattr(kvpress, name) for name in PRESS_TYPES} |
0 commit comments