Skip to content

Commit c13e5b9

Browse files
committed
Add WaterSIC KV quantizer helper with unit tests
Signed-off-by: Kai Xu <kaix@nvidia.com>
1 parent a7f65e3 commit c13e5b9

3 files changed

Lines changed: 479 additions & 15 deletions

File tree

Lines changed: 354 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,354 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""WaterSIC KV-cache quantizer helper.
17+
18+
Wraps the core ZSIC math with attention-module hooking for real model
19+
calibration. The :class:`WaterSICKVHelper` patches
20+
``_QuantAttention._quantized_attention`` to capture query / key activations,
21+
then runs :func:`watersic_quantize` per head.
22+
"""
23+
24+
from __future__ import annotations
25+
26+
from dataclasses import dataclass
27+
28+
import torch
29+
from torch import Tensor
30+
31+
from modelopt.torch.utils import print_rank_0
32+
33+
from .zsic import (
34+
_compute_hessian_cholesky,
35+
binary_search_c,
36+
damp_for_rate,
37+
watersic_quantize,
38+
)
39+
40+
# ---------------------------------------------------------------------------
41+
# Data structures
42+
# ---------------------------------------------------------------------------
43+
44+
45+
@dataclass
46+
class WaterSICKVState:
47+
"""Per-layer quantisation state produced by :meth:`WaterSICKVHelper.quantize`."""
48+
49+
Z: Tensor
50+
"""Integer code-book indices."""
51+
alpha: Tensor
52+
"""Per-column step sizes."""
53+
gamma: Tensor
54+
"""Per-column LMMSE gains."""
55+
perm: Tensor | None
56+
"""Column permutation (or *None*)."""
57+
rate: float
58+
"""Achieved coding rate (bits per element)."""
59+
60+
61+
# ---------------------------------------------------------------------------
62+
# Importance weighting
63+
# ---------------------------------------------------------------------------
64+
65+
66+
def _compute_importance_weights(P: Tensor, importance_clip: float = 50.0) -> Tensor:
67+
"""Derive per-token importance weights from an attention probability matrix.
68+
69+
Parameters
70+
----------
71+
P : Tensor (B, N)
72+
Attention probabilities summed (or averaged) over queries – i.e.
73+
``P[b, n]`` is how much attention is paid to token *n* in sample *b*.
74+
Typically ``P = softmax(Q K^T / sqrt(d)).sum(dim=-2)``.
75+
importance_clip : float
76+
Clamp the normalised weights to ``[1/clip, clip]`` to prevent
77+
extreme outliers.
78+
79+
Returns
80+
-------
81+
sqrt_w : Tensor (N, 1)
82+
Square-root importance weights, suitable for left-multiplying the
83+
activation matrix so that high-attention tokens contribute more to
84+
the Hessian.
85+
"""
86+
# Sum across the batch (rows) to get a per-token importance score.
87+
w = P.sum(dim=0) # (N,)
88+
89+
# Normalise so that the mean weight is 1.
90+
w = w / w.mean().clamp(min=1e-30)
91+
92+
# Clip to avoid extreme values.
93+
w = w.clamp(min=1.0 / importance_clip, max=importance_clip)
94+
95+
return w.sqrt().unsqueeze(1) # (N, 1)
96+
97+
98+
# ---------------------------------------------------------------------------
99+
# KL divergence in logit space
100+
# ---------------------------------------------------------------------------
101+
102+
103+
def kl_divergence_logits(
104+
Q: Tensor,
105+
K: Tensor,
106+
K_q: Tensor,
107+
temperature: float = 1.0,
108+
) -> float:
109+
"""Compute the KL divergence between attention distributions induced by *K* and *K_q*.
110+
111+
Uses the logit identity to avoid materialising the full attention matrix:
112+
113+
KL(P || P_q) = E_x[ P^T (s - s_q) + logsumexp(s_q) - logsumexp(s) ]
114+
115+
where ``s = Q K^T / temperature`` and ``s_q = Q K_q^T / temperature``.
116+
117+
Parameters
118+
----------
119+
Q : Tensor (..., S, D)
120+
K : Tensor (..., N, D)
121+
K_q : Tensor (..., N, D)
122+
temperature : float
123+
124+
Returns
125+
-------
126+
kl : float
127+
Mean KL divergence in **bits** (i.e. divided by ln 2).
128+
"""
129+
Q64 = Q.double()
130+
K64 = K.double()
131+
Kq64 = K_q.double()
132+
133+
s = Q64 @ K64.transpose(-2, -1) / temperature # (..., S, N)
134+
s_q = Q64 @ Kq64.transpose(-2, -1) / temperature # (..., S, N)
135+
136+
log_Z = torch.logsumexp(s, dim=-1) # (..., S)
137+
log_Z_q = torch.logsumexp(s_q, dim=-1) # (..., S)
138+
139+
P = torch.softmax(s, dim=-1) # (..., S, N)
140+
141+
# KL per query position: sum_n P_n (s_n - s_q_n) + log_Z_q - log_Z
142+
kl_per_query = (P * (s - s_q)).sum(dim=-1) + log_Z_q - log_Z # (..., S)
143+
144+
# Convert nats to bits and return mean.
145+
import math
146+
147+
return (kl_per_query.mean() / math.log(2)).item()
148+
149+
150+
# ---------------------------------------------------------------------------
151+
# WaterSICKVHelper
152+
# ---------------------------------------------------------------------------
153+
154+
155+
class WaterSICKVHelper:
156+
"""Hook-based helper that captures Q/K activations and runs WaterSIC quantisation.
157+
158+
Usage::
159+
160+
helper = WaterSICKVHelper(quant_attn_module, "layer.3")
161+
helper.setup()
162+
# ... run calibration forward passes ...
163+
state = helper.quantize(target_rate=4.0)
164+
helper.cleanup()
165+
helper.free()
166+
"""
167+
168+
def __init__(
169+
self,
170+
module,
171+
name: str,
172+
kl_aware: bool = False,
173+
importance_clip: float = 50.0,
174+
):
175+
self.module = module
176+
self.name = name
177+
self.kl_aware = kl_aware
178+
self.importance_clip = importance_clip
179+
180+
self.collected_Q: list[Tensor] = []
181+
self.collected_K: list[Tensor] = []
182+
183+
self._original_fn = None
184+
185+
# ----- patching --------------------------------------------------
186+
187+
def setup(self):
188+
"""Patch ``_quantized_attention`` on the module instance to capture Q/K."""
189+
# The original is a @staticmethod on the class – grab the underlying function.
190+
original_fn = type(self.module)._quantized_attention
191+
self._original_fn = original_fn
192+
193+
helper = self # closure reference
194+
195+
def patched_fn(
196+
original_attention_interface,
197+
self_attn,
198+
query_states,
199+
key_states,
200+
value_states,
201+
*args,
202+
**kwargs,
203+
):
204+
# Capture detached CPU copies before quantizers touch them.
205+
helper.collected_Q.append(query_states.detach().cpu())
206+
helper.collected_K.append(key_states.detach().cpu())
207+
208+
# Call the original static method (not bound, pass all args).
209+
return original_fn(
210+
original_attention_interface,
211+
self_attn,
212+
query_states,
213+
key_states,
214+
value_states,
215+
*args,
216+
**kwargs,
217+
)
218+
219+
# Patch on the *instance* so it shadows the class-level staticmethod.
220+
self.module._quantized_attention = patched_fn
221+
222+
def cleanup(self):
223+
"""Remove the instance-level override, restoring the class staticmethod."""
224+
if "_quantized_attention" in vars(self.module):
225+
delattr(self.module, "_quantized_attention")
226+
227+
# ----- quantisation -----------------------------------------------
228+
229+
def quantize(
230+
self,
231+
target_rate: float = 4.0,
232+
use_lmmse: bool = True,
233+
n_rescaler_iters: int = 0,
234+
sample_frac: float = 0.1,
235+
) -> WaterSICKVState:
236+
"""Run WaterSIC quantisation on the collected key activations.
237+
238+
Parameters
239+
----------
240+
target_rate : float
241+
Target coding rate in bits per element.
242+
use_lmmse : bool
243+
Whether to apply LMMSE gain correction.
244+
n_rescaler_iters : int
245+
Number of alternating rescaler iterations (0 = disable).
246+
sample_frac : float
247+
Fraction of rows used by :func:`binary_search_c`.
248+
249+
Returns
250+
-------
251+
WaterSICKVState
252+
"""
253+
# Concatenate collected activations across calibration batches.
254+
# Each tensor is (batch, n_heads, seq, d_head).
255+
Q_all = torch.cat(self.collected_Q, dim=0) # (B_total, H, S_q, D)
256+
K_all = torch.cat(self.collected_K, dim=0) # (B_total, H, S_k, D)
257+
258+
B, H, S_k, D = K_all.shape
259+
260+
# We'll store per-head results.
261+
Z_heads = []
262+
alpha_heads = []
263+
gamma_heads = []
264+
perm_heads = []
265+
rates = []
266+
267+
damp_pct = damp_for_rate(target_rate)
268+
269+
for h in range(H):
270+
# K_h shape: (B, S_k, D) → treat as weight matrix (a, n) where
271+
# a = B * S_k (token-batch dimension) and n = D (head dimension).
272+
K_h = K_all[:, h, :, :].reshape(-1, D).double() # (B*S_k, D)
273+
274+
# Activation matrix: use Q_h^T so the Hessian reflects query-key
275+
# interaction. A shape: (D, B*S_q).
276+
Q_h = Q_all[:, h, :, :].reshape(-1, D).double() # (B*S_q, D)
277+
A = Q_h.T # (D, B*S_q)
278+
279+
# Optional importance weighting — scale K rows (not A) so that
280+
# high-attention tokens contribute more to the quantisation objective.
281+
sqrt_w = None
282+
if self.kl_aware:
283+
# Compute attention probs: P = softmax(Q_h @ K_h^T / sqrt(D))
284+
scores = Q_h @ K_h.T / (D**0.5)
285+
P = torch.softmax(scores.double(), dim=-1).float()
286+
sqrt_w = _compute_importance_weights(P, self.importance_clip)
287+
K_h = K_h * sqrt_w # Scale K rows by importance
288+
289+
# Precompute Hessian / Cholesky.
290+
precomputed = _compute_hessian_cholesky(A, damp_pct=damp_pct)
291+
_, L, perm = precomputed
292+
293+
# Binary search for the scale factor c.
294+
c = binary_search_c(
295+
K_h,
296+
A,
297+
target_rate=target_rate,
298+
damp_pct=damp_pct,
299+
use_lmmse=use_lmmse,
300+
n_rescaler_iters=n_rescaler_iters,
301+
sample_frac=sample_frac,
302+
_precomputed=precomputed,
303+
)
304+
305+
# Full quantisation.
306+
W_hat, rate, nmse, Z_h, gamma_h = watersic_quantize(
307+
K_h,
308+
A,
309+
c,
310+
damp_pct=damp_pct,
311+
use_lmmse=use_lmmse,
312+
n_rescaler_iters=n_rescaler_iters,
313+
_precomputed=precomputed,
314+
)
315+
316+
# Undo importance scaling after quantisation.
317+
if sqrt_w is not None:
318+
W_hat = W_hat / sqrt_w
319+
320+
print_rank_0(
321+
f" [{self.name}] head {h}: rate={rate:.2f} bpe, nmse={nmse:.4f}"
322+
)
323+
324+
# Recover per-head state.
325+
# alpha = c / L.diag() (same as inside watersic_quantize).
326+
alpha_h = (c / L.diag()).float()
327+
328+
Z_heads.append(Z_h)
329+
alpha_heads.append(alpha_h)
330+
gamma_heads.append(gamma_h.float())
331+
perm_heads.append(perm)
332+
rates.append(rate)
333+
334+
mean_rate = sum(rates) / len(rates) if rates else 0.0
335+
336+
state = WaterSICKVState(
337+
Z=torch.stack(Z_heads), # (H, B*S_k, D)
338+
alpha=torch.stack(alpha_heads), # (H, D)
339+
gamma=torch.stack(gamma_heads), # (H, D)
340+
perm=torch.stack(perm_heads) if perm_heads and perm_heads[0] is not None else None,
341+
rate=mean_rate,
342+
)
343+
344+
# Attach state to the module for downstream consumers.
345+
self.module._watersic_kv_state = state
346+
347+
return state
348+
349+
# ----- cleanup -----------------------------------------------------
350+
351+
def free(self):
352+
"""Release collected calibration data."""
353+
self.collected_Q.clear()
354+
self.collected_K.clear()

0 commit comments

Comments
 (0)