Skip to content

Commit 4b44815

Browse files
committed
Integrate WaterSIC KV-cache in hf_ptq.py
Signed-off-by: Kai Xu <kaix@nvidia.com>
1 parent 1877ee7 commit 4b44815

9 files changed

Lines changed: 121 additions & 165 deletions

File tree

examples/llm_ptq/hf_ptq.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None:
127127
"nvfp4": "NVFP4_KV_CFG",
128128
"nvfp4_affine": "NVFP4_AFFINE_KV_CFG",
129129
"nvfp4_rotate": "NVFP4_KV_ROTATE_CFG",
130+
"watersic_kv": "WATERSIC_KV_CFG",
130131
}
131132

132133
# Formats that use use_constant_amax (no calibration needed).
@@ -384,7 +385,7 @@ def forward_step(model, batch):
384385
# We need to explicitly set up KV cache quantization after auto_quantize
385386
enable_quant_kv_cache = args.kv_cache_qformat != "none"
386387
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
387-
if enable_quant_kv_cache:
388+
if enable_quant_kv_cache and args.kv_cache_qformat != "watersic_kv":
388389
kv_cache_quant_cfg = copy.deepcopy(
389390
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"]
390391
)
@@ -403,6 +404,16 @@ def forward_step(model, batch):
403404
[{"quantizer_name": "*", "enable": False}, *kv_cache_quant_cfg],
404405
):
405406
mtq.calibrate(language_model, algorithm="max", forward_loop=calibrate_loop)
407+
408+
# WaterSIC KV-cache needs a separate quantization pass with its own algorithm
409+
if args.kv_cache_qformat == "watersic_kv":
410+
watersic_cfg = copy.deepcopy(getattr(mtq, KV_QUANT_CFG_CHOICES["watersic_kv"]))
411+
if args.watersic_target_rate is not None:
412+
watersic_cfg["algorithm"]["target_rate"] = args.watersic_target_rate
413+
if args.watersic_kl_aware:
414+
watersic_cfg["algorithm"]["kl_aware"] = True
415+
language_model = mtq.quantize(language_model, watersic_cfg, forward_loop=calibrate_loop)
416+
406417
return language_model
407418

408419

@@ -423,7 +434,7 @@ def load_model(args: argparse.Namespace):
423434
f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}"
424435
)
425436
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
426-
if args.kv_cache_qformat != "none":
437+
if args.kv_cache_qformat not in {"none", "watersic_kv"}:
427438
quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant(
428439
quant_cfg,
429440
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"],
@@ -652,6 +663,15 @@ def mono_quantize(
652663
else:
653664
language_model = mtq.quantize(language_model, quant_cfg, forward_loop=calibrate_loop)
654665

666+
# WaterSIC KV-cache needs a separate quantization pass with its own algorithm
667+
if args.kv_cache_qformat == "watersic_kv":
668+
watersic_cfg = copy.deepcopy(getattr(mtq, KV_QUANT_CFG_CHOICES["watersic_kv"]))
669+
if args.watersic_target_rate is not None:
670+
watersic_cfg["algorithm"]["target_rate"] = args.watersic_target_rate
671+
if args.watersic_kl_aware:
672+
watersic_cfg["algorithm"]["kl_aware"] = True
673+
language_model = mtq.quantize(language_model, watersic_cfg, forward_loop=calibrate_loop)
674+
655675
# For VL models, update full_model to use the quantized language model
656676
if is_nemotron_vl_model:
657677
language_model_lineage = get_language_model_from_vl(full_model)
@@ -1083,7 +1103,8 @@ def quantize_main(
10831103
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
10841104

10851105
# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
1086-
if enable_quant_kv_cache:
1106+
# WaterSIC KV-cache uses a separate quantization pass, so skip merging here.
1107+
if enable_quant_kv_cache and args.kv_cache_qformat != "watersic_kv":
10871108
quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant(
10881109
quant_cfg,
10891110
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"],
@@ -1242,6 +1263,17 @@ def parse_args() -> argparse.Namespace:
12421263
"Other formats (fp8, nvfp4, etc.) use data-driven calibration."
12431264
),
12441265
)
1266+
parser.add_argument(
1267+
"--watersic_target_rate",
1268+
type=float,
1269+
default=None,
1270+
help="Target bits per element for WaterSIC KV-cache quantization (default: 2.0)",
1271+
)
1272+
parser.add_argument(
1273+
"--watersic_kl_aware",
1274+
action="store_true",
1275+
help="Enable KL-aware importance weighting for WaterSIC KV-cache quantization",
1276+
)
12451277
parser.add_argument(
12461278
"--export_fmt",
12471279
required=False,

modelopt/torch/quantization/algorithms/watersic_kv/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,6 @@
1818
from __future__ import annotations
1919

2020
from .config import WaterSICKVCalibConfig
21-
from .kv_quantizer import WaterSICKVHelper, WaterSICKVState
21+
from .helper import WaterSICKVHelper, WaterSICKVState
2222

2323
__all__ = ["WaterSICKVCalibConfig", "WaterSICKVHelper", "WaterSICKVState"]

modelopt/torch/quantization/algorithms/watersic_kv/config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,11 @@ class WaterSICKVCalibConfig(QuantizeAlgorithmConfig):
105105
)
106106

107107
use_sequential: bool = ModeloptField(
108-
default=True,
108+
default=False,
109109
title="Enable sequential layer-by-layer calibration.",
110110
description=(
111-
"When True, the WaterSIC calibration is applied layer-by-layer in "
112-
"decoder-block order so that each layer's quantized KV representation "
113-
"is propagated to subsequent layers before they are calibrated."
111+
"Must be False for WaterSIC. Unlike weight quantization, KV-cache "
112+
"quantization does not have progressive error accumulation between "
113+
"layers, so sequential calibration is not needed."
114114
),
115115
)

modelopt/torch/quantization/algorithms/watersic_kv/kv_quantizer.py renamed to modelopt/torch/quantization/algorithms/watersic_kv/helper.py

Lines changed: 27 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,6 @@
3232

3333
from .zsic import _compute_hessian_cholesky, binary_search_c, damp_for_rate, watersic_quantize
3434

35-
# ---------------------------------------------------------------------------
36-
# Data structures
37-
# ---------------------------------------------------------------------------
38-
3935

4036
@dataclass
4137
class WaterSICKVState:
@@ -53,11 +49,6 @@ class WaterSICKVState:
5349
"""Achieved coding rate (bits per element)."""
5450

5551

56-
# ---------------------------------------------------------------------------
57-
# Importance weighting
58-
# ---------------------------------------------------------------------------
59-
60-
6152
def _compute_importance_weights(P: Tensor, importance_clip: float = 50.0) -> Tensor:
6253
"""Derive per-token importance weights from an attention probability matrix.
6354
@@ -90,63 +81,6 @@ def _compute_importance_weights(P: Tensor, importance_clip: float = 50.0) -> Ten
9081
return w.sqrt().unsqueeze(1) # (N, 1)
9182

9283

93-
# ---------------------------------------------------------------------------
94-
# KL divergence in logit space
95-
# ---------------------------------------------------------------------------
96-
97-
98-
def kl_divergence_logits(
99-
Q: Tensor,
100-
K: Tensor,
101-
K_q: Tensor,
102-
temperature: float = 1.0,
103-
) -> float:
104-
"""Compute the KL divergence between attention distributions induced by *K* and *K_q*.
105-
106-
Uses the logit identity to avoid materialising the full attention matrix:
107-
108-
KL(P || P_q) = E_x[ P^T (s - s_q) + logsumexp(s_q) - logsumexp(s) ]
109-
110-
where ``s = Q K^T / temperature`` and ``s_q = Q K_q^T / temperature``.
111-
112-
Parameters
113-
----------
114-
Q : Tensor (..., S, D)
115-
K : Tensor (..., N, D)
116-
K_q : Tensor (..., N, D)
117-
temperature : float
118-
119-
Returns:
120-
-------
121-
kl : float
122-
Mean KL divergence in **bits** (i.e. divided by ln 2).
123-
"""
124-
Q64 = Q.double()
125-
K64 = K.double()
126-
Kq64 = K_q.double()
127-
128-
s = Q64 @ K64.transpose(-2, -1) / temperature # (..., S, N)
129-
s_q = Q64 @ Kq64.transpose(-2, -1) / temperature # (..., S, N)
130-
131-
log_Z = torch.logsumexp(s, dim=-1) # (..., S)
132-
log_Z_q = torch.logsumexp(s_q, dim=-1) # (..., S)
133-
134-
P = torch.softmax(s, dim=-1) # (..., S, N)
135-
136-
# KL per query position: sum_n P_n (s_n - s_q_n) + log_Z_q - log_Z
137-
kl_per_query = (P * (s - s_q)).sum(dim=-1) + log_Z_q - log_Z # (..., S)
138-
139-
# Convert nats to bits and return mean.
140-
import math
141-
142-
return (kl_per_query.mean() / math.log(2)).item()
143-
144-
145-
# ---------------------------------------------------------------------------
146-
# WaterSICKVHelper
147-
# ---------------------------------------------------------------------------
148-
149-
15084
class WaterSICKVHelper:
15185
"""Hook-based helper that captures Q/K activations and runs WaterSIC quantisation.
15286
@@ -178,8 +112,6 @@ def __init__(
178112

179113
self._original_fn = None
180114

181-
# ----- patching --------------------------------------------------
182-
183115
def setup(self):
184116
"""Patch ``_quantized_attention`` on the module instance to capture Q/K."""
185117
# The original is a @staticmethod on the class - grab the underlying function.
@@ -220,8 +152,6 @@ def cleanup(self):
220152
if "_quantized_attention" in vars(self.module):
221153
delattr(self.module, "_quantized_attention")
222154

223-
# ----- quantisation -----------------------------------------------
224-
225155
def quantize(
226156
self,
227157
target_rate: float = 4.0,
@@ -246,6 +176,13 @@ def quantize(
246176
-------
247177
WaterSICKVState
248178
"""
179+
if not self.collected_Q or not self.collected_K:
180+
raise RuntimeError(
181+
f"[{self.name}] No Q/K activations were collected during the calibration "
182+
f"forward pass. Ensure setup() was called before the forward loop and that "
183+
f"the forward loop passes data through this attention layer."
184+
)
185+
249186
# Concatenate collected activations across calibration batches.
250187
# Each tensor is (batch, n_heads, seq, d_head).
251188
Q_all = torch.cat(self.collected_Q, dim=0) # (B_total, H, S_q, D)
@@ -262,14 +199,17 @@ def quantize(
262199

263200
damp_pct = damp_for_rate(target_rate)
264201

202+
# Run quantization on GPU if available (much faster for real models).
203+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
204+
265205
for h in range(H):
266206
# K_h shape: (B, S_k, D) → treat as weight matrix (a, n) where
267207
# a = B * S_k (token-batch dimension) and n = D (head dimension).
268-
K_h = K_all[:, h, :, :].reshape(-1, D).double() # (B*S_k, D)
208+
K_h = K_all[:, h, :, :].reshape(-1, D).to(device=device, dtype=torch.float64)
269209

270210
# Activation matrix: use Q_h^T so the Hessian reflects query-key
271211
# interaction. A shape: (D, B*S_q).
272-
Q_h = Q_all[:, h, :, :].reshape(-1, D).double() # (B*S_q, D)
212+
Q_h = Q_all[:, h, :, :].reshape(-1, D).to(device=device, dtype=torch.float64)
273213
A = Q_h.T # (D, B*S_q)
274214

275215
# Optional importance weighting — scale K rows (not A) so that
@@ -320,19 +260,26 @@ def quantize(
320260
# Recover per-head state.
321261
# alpha = c / L.diag() (same as inside watersic_quantize).
322262
alpha_h = (c / L.diag()).float()
323-
324-
Z_heads.append(Z_h)
325-
alpha_heads.append(alpha_h)
326-
gamma_heads.append(gamma_h.float())
327-
perm_heads.append(perm)
263+
if perm is not None:
264+
inv_perm = torch.argsort(perm)
265+
alpha_h = alpha_h[inv_perm]
266+
267+
# Move results to CPU to free GPU memory for next head.
268+
Z_heads.append(Z_h.cpu())
269+
alpha_heads.append(alpha_h.cpu())
270+
gamma_heads.append(gamma_h.float().cpu())
271+
perm_heads.append(perm.cpu() if perm is not None else None)
328272
rates.append(rate)
329273

274+
if torch.cuda.is_available():
275+
torch.cuda.empty_cache()
276+
330277
mean_rate = sum(rates) / len(rates) if rates else 0.0
331278

332279
state = WaterSICKVState(
333-
Z=torch.stack(Z_heads), # (H, B*S_k, D)
334-
alpha=torch.stack(alpha_heads), # (H, D)
335-
gamma=torch.stack(gamma_heads), # (H, D)
280+
Z=torch.stack(Z_heads),
281+
alpha=torch.stack(alpha_heads),
282+
gamma=torch.stack(gamma_heads),
336283
perm=torch.stack(perm_heads) if perm_heads and perm_heads[0] is not None else None,
337284
rate=mean_rate,
338285
)
@@ -342,8 +289,6 @@ def quantize(
342289

343290
return state
344291

345-
# ----- cleanup -----------------------------------------------------
346-
347292
def free(self):
348293
"""Release collected calibration data."""
349294
self.collected_Q.clear()

0 commit comments

Comments
 (0)