Skip to content

Commit 7b0bb08

Browse files
committed
Integrate WaterSIC KV-cache in hf_ptq.py
Signed-off-by: Kai Xu <kaix@nvidia.com>
1 parent 9015ac7 commit 7b0bb08

File tree

4 files changed

+47
-54
lines changed

4 files changed

+47
-54
lines changed

examples/llm_ptq/hf_ptq.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None:
127127
"nvfp4": "NVFP4_KV_CFG",
128128
"nvfp4_affine": "NVFP4_AFFINE_KV_CFG",
129129
"nvfp4_rotate": "NVFP4_KV_ROTATE_CFG",
130+
"watersic_kv": "WATERSIC_KV_CFG",
130131
}
131132

132133
# Formats that use use_constant_amax (no calibration needed).
@@ -384,7 +385,7 @@ def forward_step(model, batch):
384385
# We need to explicitly set up KV cache quantization after auto_quantize
385386
enable_quant_kv_cache = args.kv_cache_qformat != "none"
386387
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
387-
if enable_quant_kv_cache:
388+
if enable_quant_kv_cache and args.kv_cache_qformat != "watersic_kv":
388389
kv_cache_quant_cfg = copy.deepcopy(
389390
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"]
390391
)
@@ -403,6 +404,16 @@ def forward_step(model, batch):
403404
[{"quantizer_name": "*", "enable": False}, *kv_cache_quant_cfg],
404405
):
405406
mtq.calibrate(language_model, algorithm="max", forward_loop=calibrate_loop)
407+
408+
# WaterSIC KV-cache needs a separate quantization pass with its own algorithm
409+
if args.kv_cache_qformat == "watersic_kv":
410+
watersic_cfg = copy.deepcopy(getattr(mtq, KV_QUANT_CFG_CHOICES["watersic_kv"]))
411+
if args.watersic_target_rate is not None:
412+
watersic_cfg["algorithm"]["target_rate"] = args.watersic_target_rate
413+
if args.watersic_kl_aware:
414+
watersic_cfg["algorithm"]["kl_aware"] = True
415+
language_model = mtq.quantize(language_model, watersic_cfg, forward_loop=calibrate_loop)
416+
406417
return language_model
407418

408419

@@ -423,7 +434,7 @@ def load_model(args: argparse.Namespace):
423434
f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}"
424435
)
425436
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
426-
if args.kv_cache_qformat != "none":
437+
if args.kv_cache_qformat not in {"none", "watersic_kv"}:
427438
quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant(
428439
quant_cfg,
429440
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"],
@@ -652,6 +663,15 @@ def mono_quantize(
652663
else:
653664
language_model = mtq.quantize(language_model, quant_cfg, forward_loop=calibrate_loop)
654665

666+
# WaterSIC KV-cache needs a separate quantization pass with its own algorithm
667+
if args.kv_cache_qformat == "watersic_kv":
668+
watersic_cfg = copy.deepcopy(getattr(mtq, KV_QUANT_CFG_CHOICES["watersic_kv"]))
669+
if args.watersic_target_rate is not None:
670+
watersic_cfg["algorithm"]["target_rate"] = args.watersic_target_rate
671+
if args.watersic_kl_aware:
672+
watersic_cfg["algorithm"]["kl_aware"] = True
673+
language_model = mtq.quantize(language_model, watersic_cfg, forward_loop=calibrate_loop)
674+
655675
# For VL models, update full_model to use the quantized language model
656676
if is_nemotron_vl_model:
657677
language_model_lineage = get_language_model_from_vl(full_model)
@@ -1083,7 +1103,8 @@ def quantize_main(
10831103
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
10841104

10851105
# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
1086-
if enable_quant_kv_cache:
1106+
# WaterSIC KV-cache uses a separate quantization pass, so skip merging here.
1107+
if enable_quant_kv_cache and args.kv_cache_qformat != "watersic_kv":
10871108
quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant(
10881109
quant_cfg,
10891110
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"],
@@ -1242,6 +1263,17 @@ def parse_args() -> argparse.Namespace:
12421263
"Other formats (fp8, nvfp4, etc.) use data-driven calibration."
12431264
),
12441265
)
1266+
parser.add_argument(
1267+
"--watersic_target_rate",
1268+
type=float,
1269+
default=None,
1270+
help="Target bits per element for WaterSIC KV-cache quantization (default: 2.0)",
1271+
)
1272+
parser.add_argument(
1273+
"--watersic_kl_aware",
1274+
action="store_true",
1275+
help="Enable KL-aware importance weighting for WaterSIC KV-cache quantization",
1276+
)
12451277
parser.add_argument(
12461278
"--export_fmt",
12471279
required=False,

modelopt/torch/quantization/algorithms/watersic_kv/kv_quantizer.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,6 @@
3232

3333
from .zsic import _compute_hessian_cholesky, binary_search_c, damp_for_rate, watersic_quantize
3434

35-
# ---------------------------------------------------------------------------
36-
# Data structures
37-
# ---------------------------------------------------------------------------
38-
3935

4036
@dataclass
4137
class WaterSICKVState:
@@ -53,11 +49,6 @@ class WaterSICKVState:
5349
"""Achieved coding rate (bits per element)."""
5450

5551

56-
# ---------------------------------------------------------------------------
57-
# Importance weighting
58-
# ---------------------------------------------------------------------------
59-
60-
6152
def _compute_importance_weights(P: Tensor, importance_clip: float = 50.0) -> Tensor:
6253
"""Derive per-token importance weights from an attention probability matrix.
6354
@@ -90,11 +81,6 @@ def _compute_importance_weights(P: Tensor, importance_clip: float = 50.0) -> Ten
9081
return w.sqrt().unsqueeze(1) # (N, 1)
9182

9283

93-
# ---------------------------------------------------------------------------
94-
# KL divergence in logit space
95-
# ---------------------------------------------------------------------------
96-
97-
9884
def kl_divergence_logits(
9985
Q: Tensor,
10086
K: Tensor,
@@ -142,11 +128,6 @@ def kl_divergence_logits(
142128
return (kl_per_query.mean() / math.log(2)).item()
143129

144130

145-
# ---------------------------------------------------------------------------
146-
# WaterSICKVHelper
147-
# ---------------------------------------------------------------------------
148-
149-
150131
class WaterSICKVHelper:
151132
"""Hook-based helper that captures Q/K activations and runs WaterSIC quantisation.
152133
@@ -178,7 +159,6 @@ def __init__(
178159

179160
self._original_fn = None
180161

181-
# ----- patching --------------------------------------------------
182162

183163
def setup(self):
184164
"""Patch ``_quantized_attention`` on the module instance to capture Q/K."""
@@ -220,7 +200,6 @@ def cleanup(self):
220200
if "_quantized_attention" in vars(self.module):
221201
delattr(self.module, "_quantized_attention")
222202

223-
# ----- quantisation -----------------------------------------------
224203

225204
def quantize(
226205
self,
@@ -342,7 +321,6 @@ def quantize(
342321

343322
return state
344323

345-
# ----- cleanup -----------------------------------------------------
346324

347325
def free(self):
348326
"""Release collected calibration data."""

modelopt/torch/quantization/algorithms/watersic_kv/zsic.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,6 @@
2727
import torch
2828
from torch import Tensor
2929

30-
# ---------------------------------------------------------------------------
31-
# Helpers
32-
# ---------------------------------------------------------------------------
33-
3430

3531
def damp_for_rate(target_rate: float, base: float = 1e-4, knee: float = 5.0) -> float:
3632
"""Return a damping coefficient that decays for rates above *knee*.
@@ -57,11 +53,6 @@ def compute_output_nmse(W: Tensor, W_q: Tensor, A: Tensor) -> float:
5753
return (err.norm() ** 2 / ref.norm() ** 2).item()
5854

5955

60-
# ---------------------------------------------------------------------------
61-
# Hessian / Cholesky
62-
# ---------------------------------------------------------------------------
63-
64-
6556
def _compute_hessian_cholesky(
6657
A: Tensor,
6758
damp_pct: float = 1e-4,
@@ -111,11 +102,6 @@ def _compute_hessian_cholesky(
111102
return Sigma_X, L, perm
112103

113104

114-
# ---------------------------------------------------------------------------
115-
# Rescaler optimisation
116-
# ---------------------------------------------------------------------------
117-
118-
119105
def _optimize_rescalers(
120106
W_hat_0: Tensor,
121107
W: Tensor,
@@ -157,11 +143,6 @@ def _optimize_rescalers(
157143
return t.unsqueeze(1) * W_hat_0 * gamma.unsqueeze(0)
158144

159145

160-
# ---------------------------------------------------------------------------
161-
# Core sequential coding
162-
# ---------------------------------------------------------------------------
163-
164-
165146
def zsic_quantize(
166147
W: Tensor,
167148
A: Tensor,
@@ -234,11 +215,6 @@ def zsic_quantize(
234215
return W_hat, rate, nmse, Z, gamma
235216

236217

237-
# ---------------------------------------------------------------------------
238-
# WaterSIC interface
239-
# ---------------------------------------------------------------------------
240-
241-
242218
def watersic_quantize(
243219
W: Tensor,
244220
A: Tensor,
@@ -302,11 +278,6 @@ def watersic_quantize(
302278
return W_hat, rate, nmse, Z, gamma
303279

304280

305-
# ---------------------------------------------------------------------------
306-
# Binary search for c
307-
# ---------------------------------------------------------------------------
308-
309-
310281
def binary_search_c(
311282
W: Tensor,
312283
A: Tensor,

modelopt/torch/quantization/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,17 @@ def _nvfp4_selective_quant_cfg(
741741
"algorithm": "max",
742742
}
743743

744+
WATERSIC_KV_CFG = {
745+
"quant_cfg": [
746+
{"quantizer_name": "*", "enable": False},
747+
{"quantizer_name": "*[kv]_bmm_quantizer", "enable": True},
748+
],
749+
"algorithm": {
750+
"method": "watersic_kv",
751+
"target_rate": 2.0,
752+
},
753+
}
754+
744755
NVFP4_SVDQUANT_DEFAULT_CFG = _nvfp4_selective_quant_cfg(
745756
["*"], algorithm={"method": "svdquant", "lowrank": 32}
746757
)
@@ -833,6 +844,7 @@ def _nvfp4_selective_quant_cfg(
833844
"MAMBA_MOE_FP8_CONSERVATIVE_CFG",
834845
"MAMBA_MOE_FP8_AGGRESSIVE_CFG",
835846
"NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG",
847+
"WATERSIC_KV_CFG",
836848
}
837849

838850
BiasType = Literal["static", "dynamic"]

0 commit comments

Comments
 (0)