Skip to content

Commit e1d32e2

Browse files
committed
Add WaterSIC KV-cache calibration config and update package exports
Signed-off-by: Kai Xu <kaix@nvidia.com>
1 parent c13e5b9 commit e1d32e2

File tree

3 files changed

+154
-0
lines changed

3 files changed

+154
-0
lines changed

modelopt/torch/quantization/algorithms/watersic_kv/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,8 @@
1616
"""WaterSIC KV-cache quantization algorithm."""
1717

1818
from __future__ import annotations
19+
20+
from .config import WaterSICKVCalibConfig
21+
from .kv_quantizer import WaterSICKVHelper, WaterSICKVState
22+
23+
__all__ = ["WaterSICKVCalibConfig", "WaterSICKVHelper", "WaterSICKVState"]
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Configuration for the WaterSIC KV-cache quantization algorithm."""
17+
18+
from __future__ import annotations
19+
20+
from typing import Literal
21+
22+
from modelopt.torch.opt.config import ModeloptField
23+
from modelopt.torch.quantization.config import QuantizeAlgorithmConfig
24+
25+
26+
class WaterSICKVCalibConfig(QuantizeAlgorithmConfig):
27+
"""Configuration for WaterSIC KV-cache quantization.
28+
29+
WaterSIC (Water-filling Successive Interference Cancellation) is a
30+
rate-adaptive quantization method for KV-cache compression. It
31+
applies the ZSIC algorithm with optional KL-aware importance
32+
weighting and LMMSE shrinkage correction to minimize attention-output
33+
distortion at a target bits-per-element budget.
34+
35+
Reference: "WaterSIC: Water-filling Successive Interference
36+
Cancellation for KV-Cache Quantization" (2024).
37+
"""
38+
39+
method: Literal["watersic_kv"] = ModeloptField(
40+
"watersic_kv",
41+
title="Calibration algorithm identifier.",
42+
description="Fixed identifier for the WaterSIC KV-cache calibration method.",
43+
)
44+
45+
target_rate: float = ModeloptField(
46+
default=2.0,
47+
gt=0.0,
48+
title="Target bits per element.",
49+
description=(
50+
"Average number of bits per quantized KV-cache element. The binary "
51+
"search over the ZSIC damping parameter c is driven to hit this rate."
52+
),
53+
)
54+
55+
kl_aware: bool = ModeloptField(
56+
default=False,
57+
title="Enable KL-aware importance weighting.",
58+
description=(
59+
"When True, per-token importance weights derived from the attention "
60+
"distribution are folded into the Hessian so that tokens with higher "
61+
"attention mass receive tighter quantization."
62+
),
63+
)
64+
65+
importance_clip: float = ModeloptField(
66+
default=50.0,
67+
gt=0.0,
68+
title="Importance weight clipping ratio.",
69+
description=(
70+
"Maximum ratio by which a single token's importance weight may exceed "
71+
"the mean weight. Clips extreme outlier tokens to prevent them from "
72+
"dominating the Hessian estimate."
73+
),
74+
)
75+
76+
use_lmmse: bool = ModeloptField(
77+
default=True,
78+
title="Apply LMMSE shrinkage correction.",
79+
description=(
80+
"When True, the LMMSE (Linear Minimum Mean-Squared Error) shrinkage "
81+
"correction is applied after ZSIC quantization to partially undo "
82+
"quantization bias and reduce reconstruction NMSE."
83+
),
84+
)
85+
86+
n_rescaler_iters: int = ModeloptField(
87+
default=0,
88+
ge=0,
89+
title="Diagonal rescaler optimization iterations.",
90+
description=(
91+
"Number of coordinate-descent iterations for the diagonal rescaler "
92+
"that adjusts per-column scale factors after LMMSE. Set to 0 to "
93+
"disable the rescaler (faster but slightly higher distortion)."
94+
),
95+
)
96+
97+
sample_frac: float | None = ModeloptField(
98+
default=None,
99+
title="Row subsampling fraction for binary search.",
100+
description=(
101+
"If set, only this fraction of rows (KV heads) are used during the "
102+
"binary search for c. Full rows are then quantized with the found c. "
103+
"Speeds up calibration on large KV caches at a small accuracy cost."
104+
),
105+
)
106+
107+
use_sequential: bool = ModeloptField(
108+
default=True,
109+
title="Enable sequential layer-by-layer calibration.",
110+
description=(
111+
"When True, the WaterSIC calibration is applied layer-by-layer in "
112+
"decoder-block order so that each layer's quantized KV representation "
113+
"is propagated to subsequent layers before they are calibrated."
114+
),
115+
)

tests/unit/torch/quantization/test_watersic_kv.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,3 +362,37 @@ def test_state_creation(self):
362362
assert state.gamma is gamma
363363
assert state.perm is None
364364
assert state.rate == 2.5
365+
366+
367+
# ---------------------------------------------------------------------------
368+
# TestWaterSICKVCalibConfig
369+
# ---------------------------------------------------------------------------
370+
371+
372+
class TestWaterSICKVCalibConfig:
373+
def test_defaults(self):
374+
from modelopt.torch.quantization.algorithms.watersic_kv.config import WaterSICKVCalibConfig
375+
376+
cfg = WaterSICKVCalibConfig()
377+
assert cfg.method == "watersic_kv"
378+
assert cfg.target_rate == 2.0
379+
assert cfg.kl_aware is False
380+
assert cfg.use_lmmse is True
381+
assert cfg.use_sequential is True
382+
383+
def test_custom_values(self):
384+
from modelopt.torch.quantization.algorithms.watersic_kv.config import WaterSICKVCalibConfig
385+
386+
cfg = WaterSICKVCalibConfig(target_rate=4.0, kl_aware=True, importance_clip=20.0)
387+
assert cfg.target_rate == 4.0
388+
assert cfg.kl_aware is True
389+
assert cfg.importance_clip == 20.0
390+
391+
def test_serialization_roundtrip(self):
392+
from modelopt.torch.quantization.algorithms.watersic_kv.config import WaterSICKVCalibConfig
393+
394+
cfg = WaterSICKVCalibConfig(target_rate=3.0, kl_aware=True)
395+
data = cfg.model_dump()
396+
cfg2 = WaterSICKVCalibConfig(**data)
397+
assert cfg2.target_rate == 3.0
398+
assert cfg2.kl_aware is True

0 commit comments

Comments
 (0)