Skip to content

Commit 9b41b8e

Browse files
committed
Integrate WaterSIC KV-cache in hf_ptq.py
Signed-off-by: Kai Xu <kaix@nvidia.com>
1 parent 9015ac7 commit 9b41b8e

File tree

9 files changed

+101
-162
lines changed

9 files changed

+101
-162
lines changed

examples/llm_ptq/hf_ptq.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None:
127127
"nvfp4": "NVFP4_KV_CFG",
128128
"nvfp4_affine": "NVFP4_AFFINE_KV_CFG",
129129
"nvfp4_rotate": "NVFP4_KV_ROTATE_CFG",
130+
"watersic_kv": "WATERSIC_KV_CFG",
130131
}
131132

132133
# Formats that use use_constant_amax (no calibration needed).
@@ -384,7 +385,7 @@ def forward_step(model, batch):
384385
# We need to explicitly set up KV cache quantization after auto_quantize
385386
enable_quant_kv_cache = args.kv_cache_qformat != "none"
386387
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
387-
if enable_quant_kv_cache:
388+
if enable_quant_kv_cache and args.kv_cache_qformat != "watersic_kv":
388389
kv_cache_quant_cfg = copy.deepcopy(
389390
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"]
390391
)
@@ -403,6 +404,16 @@ def forward_step(model, batch):
403404
[{"quantizer_name": "*", "enable": False}, *kv_cache_quant_cfg],
404405
):
405406
mtq.calibrate(language_model, algorithm="max", forward_loop=calibrate_loop)
407+
408+
# WaterSIC KV-cache needs a separate quantization pass with its own algorithm
409+
if args.kv_cache_qformat == "watersic_kv":
410+
watersic_cfg = copy.deepcopy(getattr(mtq, KV_QUANT_CFG_CHOICES["watersic_kv"]))
411+
if args.watersic_target_rate is not None:
412+
watersic_cfg["algorithm"]["target_rate"] = args.watersic_target_rate
413+
if args.watersic_kl_aware:
414+
watersic_cfg["algorithm"]["kl_aware"] = True
415+
language_model = mtq.quantize(language_model, watersic_cfg, forward_loop=calibrate_loop)
416+
406417
return language_model
407418

408419

@@ -423,7 +434,7 @@ def load_model(args: argparse.Namespace):
423434
f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}"
424435
)
425436
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
426-
if args.kv_cache_qformat != "none":
437+
if args.kv_cache_qformat not in {"none", "watersic_kv"}:
427438
quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant(
428439
quant_cfg,
429440
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"],
@@ -652,6 +663,15 @@ def mono_quantize(
652663
else:
653664
language_model = mtq.quantize(language_model, quant_cfg, forward_loop=calibrate_loop)
654665

666+
# WaterSIC KV-cache needs a separate quantization pass with its own algorithm
667+
if args.kv_cache_qformat == "watersic_kv":
668+
watersic_cfg = copy.deepcopy(getattr(mtq, KV_QUANT_CFG_CHOICES["watersic_kv"]))
669+
if args.watersic_target_rate is not None:
670+
watersic_cfg["algorithm"]["target_rate"] = args.watersic_target_rate
671+
if args.watersic_kl_aware:
672+
watersic_cfg["algorithm"]["kl_aware"] = True
673+
language_model = mtq.quantize(language_model, watersic_cfg, forward_loop=calibrate_loop)
674+
655675
# For VL models, update full_model to use the quantized language model
656676
if is_nemotron_vl_model:
657677
language_model_lineage = get_language_model_from_vl(full_model)
@@ -1083,7 +1103,8 @@ def quantize_main(
10831103
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
10841104

10851105
# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
1086-
if enable_quant_kv_cache:
1106+
# WaterSIC KV-cache uses a separate quantization pass, so skip merging here.
1107+
if enable_quant_kv_cache and args.kv_cache_qformat != "watersic_kv":
10871108
quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant(
10881109
quant_cfg,
10891110
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"],
@@ -1242,6 +1263,17 @@ def parse_args() -> argparse.Namespace:
12421263
"Other formats (fp8, nvfp4, etc.) use data-driven calibration."
12431264
),
12441265
)
1266+
parser.add_argument(
1267+
"--watersic_target_rate",
1268+
type=float,
1269+
default=None,
1270+
help="Target bits per element for WaterSIC KV-cache quantization (default: 2.0)",
1271+
)
1272+
parser.add_argument(
1273+
"--watersic_kl_aware",
1274+
action="store_true",
1275+
help="Enable KL-aware importance weighting for WaterSIC KV-cache quantization",
1276+
)
12451277
parser.add_argument(
12461278
"--export_fmt",
12471279
required=False,

modelopt/torch/quantization/algorithms/watersic_kv/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,6 @@
1818
from __future__ import annotations
1919

2020
from .config import WaterSICKVCalibConfig
21-
from .kv_quantizer import WaterSICKVHelper, WaterSICKVState
21+
from .helper import WaterSICKVHelper, WaterSICKVState
2222

2323
__all__ = ["WaterSICKVCalibConfig", "WaterSICKVHelper", "WaterSICKVState"]

modelopt/torch/quantization/algorithms/watersic_kv/config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,11 @@ class WaterSICKVCalibConfig(QuantizeAlgorithmConfig):
105105
)
106106

107107
use_sequential: bool = ModeloptField(
108-
default=True,
108+
default=False,
109109
title="Enable sequential layer-by-layer calibration.",
110110
description=(
111-
"When True, the WaterSIC calibration is applied layer-by-layer in "
112-
"decoder-block order so that each layer's quantized KV representation "
113-
"is propagated to subsequent layers before they are calibrated."
111+
"Must be False for WaterSIC. Unlike weight quantization, KV-cache "
112+
"quantization does not have progressive error accumulation between "
113+
"layers, so sequential calibration is not needed."
114114
),
115115
)

modelopt/torch/quantization/algorithms/watersic_kv/kv_quantizer.py renamed to modelopt/torch/quantization/algorithms/watersic_kv/helper.py

Lines changed: 20 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,6 @@
3232

3333
from .zsic import _compute_hessian_cholesky, binary_search_c, damp_for_rate, watersic_quantize
3434

35-
# ---------------------------------------------------------------------------
36-
# Data structures
37-
# ---------------------------------------------------------------------------
38-
3935

4036
@dataclass
4137
class WaterSICKVState:
@@ -53,11 +49,6 @@ class WaterSICKVState:
5349
"""Achieved coding rate (bits per element)."""
5450

5551

56-
# ---------------------------------------------------------------------------
57-
# Importance weighting
58-
# ---------------------------------------------------------------------------
59-
60-
6152
def _compute_importance_weights(P: Tensor, importance_clip: float = 50.0) -> Tensor:
6253
"""Derive per-token importance weights from an attention probability matrix.
6354
@@ -90,63 +81,6 @@ def _compute_importance_weights(P: Tensor, importance_clip: float = 50.0) -> Ten
9081
return w.sqrt().unsqueeze(1) # (N, 1)
9182

9283

93-
# ---------------------------------------------------------------------------
94-
# KL divergence in logit space
95-
# ---------------------------------------------------------------------------
96-
97-
98-
def kl_divergence_logits(
99-
Q: Tensor,
100-
K: Tensor,
101-
K_q: Tensor,
102-
temperature: float = 1.0,
103-
) -> float:
104-
"""Compute the KL divergence between attention distributions induced by *K* and *K_q*.
105-
106-
Uses the logit identity to avoid materialising the full attention matrix:
107-
108-
KL(P || P_q) = E_x[ P^T (s - s_q) + logsumexp(s_q) - logsumexp(s) ]
109-
110-
where ``s = Q K^T / temperature`` and ``s_q = Q K_q^T / temperature``.
111-
112-
Parameters
113-
----------
114-
Q : Tensor (..., S, D)
115-
K : Tensor (..., N, D)
116-
K_q : Tensor (..., N, D)
117-
temperature : float
118-
119-
Returns:
120-
-------
121-
kl : float
122-
Mean KL divergence in **bits** (i.e. divided by ln 2).
123-
"""
124-
Q64 = Q.double()
125-
K64 = K.double()
126-
Kq64 = K_q.double()
127-
128-
s = Q64 @ K64.transpose(-2, -1) / temperature # (..., S, N)
129-
s_q = Q64 @ Kq64.transpose(-2, -1) / temperature # (..., S, N)
130-
131-
log_Z = torch.logsumexp(s, dim=-1) # (..., S)
132-
log_Z_q = torch.logsumexp(s_q, dim=-1) # (..., S)
133-
134-
P = torch.softmax(s, dim=-1) # (..., S, N)
135-
136-
# KL per query position: sum_n P_n (s_n - s_q_n) + log_Z_q - log_Z
137-
kl_per_query = (P * (s - s_q)).sum(dim=-1) + log_Z_q - log_Z # (..., S)
138-
139-
# Convert nats to bits and return mean.
140-
import math
141-
142-
return (kl_per_query.mean() / math.log(2)).item()
143-
144-
145-
# ---------------------------------------------------------------------------
146-
# WaterSICKVHelper
147-
# ---------------------------------------------------------------------------
148-
149-
15084
class WaterSICKVHelper:
15185
"""Hook-based helper that captures Q/K activations and runs WaterSIC quantisation.
15286
@@ -178,8 +112,6 @@ def __init__(
178112

179113
self._original_fn = None
180114

181-
# ----- patching --------------------------------------------------
182-
183115
def setup(self):
184116
"""Patch ``_quantized_attention`` on the module instance to capture Q/K."""
185117
# The original is a @staticmethod on the class - grab the underlying function.
@@ -220,8 +152,6 @@ def cleanup(self):
220152
if "_quantized_attention" in vars(self.module):
221153
delattr(self.module, "_quantized_attention")
222154

223-
# ----- quantisation -----------------------------------------------
224-
225155
def quantize(
226156
self,
227157
target_rate: float = 4.0,
@@ -262,14 +192,17 @@ def quantize(
262192

263193
damp_pct = damp_for_rate(target_rate)
264194

195+
# Run quantization on GPU if available (much faster for real models).
196+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
197+
265198
for h in range(H):
266199
# K_h shape: (B, S_k, D) → treat as weight matrix (a, n) where
267200
# a = B * S_k (token-batch dimension) and n = D (head dimension).
268-
K_h = K_all[:, h, :, :].reshape(-1, D).double() # (B*S_k, D)
201+
K_h = K_all[:, h, :, :].reshape(-1, D).to(device=device, dtype=torch.float64)
269202

270203
# Activation matrix: use Q_h^T so the Hessian reflects query-key
271204
# interaction. A shape: (D, B*S_q).
272-
Q_h = Q_all[:, h, :, :].reshape(-1, D).double() # (B*S_q, D)
205+
Q_h = Q_all[:, h, :, :].reshape(-1, D).to(device=device, dtype=torch.float64)
273206
A = Q_h.T # (D, B*S_q)
274207

275208
# Optional importance weighting — scale K rows (not A) so that
@@ -320,19 +253,26 @@ def quantize(
320253
# Recover per-head state.
321254
# alpha = c / L.diag() (same as inside watersic_quantize).
322255
alpha_h = (c / L.diag()).float()
323-
324-
Z_heads.append(Z_h)
325-
alpha_heads.append(alpha_h)
326-
gamma_heads.append(gamma_h.float())
327-
perm_heads.append(perm)
256+
if perm is not None:
257+
inv_perm = torch.argsort(perm)
258+
alpha_h = alpha_h[inv_perm]
259+
260+
# Move results to CPU to free GPU memory for next head.
261+
Z_heads.append(Z_h.cpu())
262+
alpha_heads.append(alpha_h.cpu())
263+
gamma_heads.append(gamma_h.float().cpu())
264+
perm_heads.append(perm.cpu() if perm is not None else None)
328265
rates.append(rate)
329266

267+
if torch.cuda.is_available():
268+
torch.cuda.empty_cache()
269+
330270
mean_rate = sum(rates) / len(rates) if rates else 0.0
331271

332272
state = WaterSICKVState(
333-
Z=torch.stack(Z_heads), # (H, B*S_k, D)
334-
alpha=torch.stack(alpha_heads), # (H, D)
335-
gamma=torch.stack(gamma_heads), # (H, D)
273+
Z=torch.stack(Z_heads),
274+
alpha=torch.stack(alpha_heads),
275+
gamma=torch.stack(gamma_heads),
336276
perm=torch.stack(perm_heads) if perm_heads and perm_heads[0] is not None else None,
337277
rate=mean_rate,
338278
)
@@ -342,8 +282,6 @@ def quantize(
342282

343283
return state
344284

345-
# ----- cleanup -----------------------------------------------------
346-
347285
def free(self):
348286
"""Release collected calibration data."""
349287
self.collected_Q.clear()

modelopt/torch/quantization/algorithms/watersic_kv/zsic.py

Lines changed: 14 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
"""Core ZSIC (Zero-Shot Integer Compression) algorithm for WaterSIC KV-cache quantization.
1717
1818
This is a pure math module with no Model-Optimizer dependencies. It implements
19-
the sequential integer coding algorithm described in the WaterSIC paper, ported
20-
from psx-vfp commit 39073b1.
19+
the sequential integer coding algorithm described in the WaterSIC paper.
2120
"""
2221

2322
from __future__ import annotations
@@ -27,10 +26,6 @@
2726
import torch
2827
from torch import Tensor
2928

30-
# ---------------------------------------------------------------------------
31-
# Helpers
32-
# ---------------------------------------------------------------------------
33-
3429

3530
def damp_for_rate(target_rate: float, base: float = 1e-4, knee: float = 5.0) -> float:
3631
"""Return a damping coefficient that decays for rates above *knee*.
@@ -51,15 +46,20 @@ def compute_entropy(Z: Tensor) -> float:
5146

5247

5348
def compute_output_nmse(W: Tensor, W_q: Tensor, A: Tensor) -> float:
54-
"""Normalised MSE measured in the output space: ``||err @ A||^2 / ||W @ A||^2``."""
55-
err = (W - W_q) @ A
56-
ref = W @ A
57-
return (err.norm() ** 2 / ref.norm() ** 2).item()
58-
49+
"""Normalised MSE measured in the output space: ``||err @ A||^2 / ||W @ A||^2``.
5950
60-
# ---------------------------------------------------------------------------
61-
# Hessian / Cholesky
62-
# ---------------------------------------------------------------------------
51+
Uses the trace identity ``||M @ N||_F^2 = tr(M^T M N N^T)`` to avoid
52+
materialising the ``(a, a)`` output matrix, which can be prohibitively large
53+
when the number of tokens *a* is high (e.g. real-model calibration).
54+
Only ``(n, n)`` intermediates are needed, where *n* = ``A.shape[0]``.
55+
"""
56+
Sigma_X = A @ A.T # (n, n)
57+
delta = W - W_q # (a, n)
58+
err_gram = delta.T @ delta # (n, n)
59+
ref_gram = W.T @ W # (n, n)
60+
err_sq = (err_gram * Sigma_X).sum()
61+
ref_sq = (ref_gram * Sigma_X).sum()
62+
return (err_sq / ref_sq).item()
6363

6464

6565
def _compute_hessian_cholesky(
@@ -111,11 +111,6 @@ def _compute_hessian_cholesky(
111111
return Sigma_X, L, perm
112112

113113

114-
# ---------------------------------------------------------------------------
115-
# Rescaler optimisation
116-
# ---------------------------------------------------------------------------
117-
118-
119114
def _optimize_rescalers(
120115
W_hat_0: Tensor,
121116
W: Tensor,
@@ -157,11 +152,6 @@ def _optimize_rescalers(
157152
return t.unsqueeze(1) * W_hat_0 * gamma.unsqueeze(0)
158153

159154

160-
# ---------------------------------------------------------------------------
161-
# Core sequential coding
162-
# ---------------------------------------------------------------------------
163-
164-
165155
def zsic_quantize(
166156
W: Tensor,
167157
A: Tensor,
@@ -234,11 +224,6 @@ def zsic_quantize(
234224
return W_hat, rate, nmse, Z, gamma
235225

236226

237-
# ---------------------------------------------------------------------------
238-
# WaterSIC interface
239-
# ---------------------------------------------------------------------------
240-
241-
242227
def watersic_quantize(
243228
W: Tensor,
244229
A: Tensor,
@@ -302,11 +287,6 @@ def watersic_quantize(
302287
return W_hat, rate, nmse, Z, gamma
303288

304289

305-
# ---------------------------------------------------------------------------
306-
# Binary search for c
307-
# ---------------------------------------------------------------------------
308-
309-
310290
def binary_search_c(
311291
W: Tensor,
312292
A: Tensor,

0 commit comments

Comments
 (0)