|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""WaterSIC KV-cache quantizer helper. |
| 17 | +
|
| 18 | +Wraps the core ZSIC math with attention-module hooking for real model |
| 19 | +calibration. The :class:`WaterSICKVHelper` patches |
| 20 | +``_QuantAttention._quantized_attention`` to capture query / key activations, |
| 21 | +then runs :func:`watersic_quantize` per head. |
| 22 | +""" |
| 23 | + |
| 24 | +from __future__ import annotations |
| 25 | + |
| 26 | +from dataclasses import dataclass |
| 27 | + |
| 28 | +import torch |
| 29 | +from torch import Tensor |
| 30 | + |
| 31 | +from modelopt.torch.utils import print_rank_0 |
| 32 | + |
| 33 | +from .zsic import ( |
| 34 | + _compute_hessian_cholesky, |
| 35 | + binary_search_c, |
| 36 | + damp_for_rate, |
| 37 | + watersic_quantize, |
| 38 | +) |
| 39 | + |
| 40 | +# --------------------------------------------------------------------------- |
| 41 | +# Data structures |
| 42 | +# --------------------------------------------------------------------------- |
| 43 | + |
| 44 | + |
| 45 | +@dataclass |
| 46 | +class WaterSICKVState: |
| 47 | + """Per-layer quantisation state produced by :meth:`WaterSICKVHelper.quantize`.""" |
| 48 | + |
| 49 | + Z: Tensor |
| 50 | + """Integer code-book indices.""" |
| 51 | + alpha: Tensor |
| 52 | + """Per-column step sizes.""" |
| 53 | + gamma: Tensor |
| 54 | + """Per-column LMMSE gains.""" |
| 55 | + perm: Tensor | None |
| 56 | + """Column permutation (or *None*).""" |
| 57 | + rate: float |
| 58 | + """Achieved coding rate (bits per element).""" |
| 59 | + |
| 60 | + |
| 61 | +# --------------------------------------------------------------------------- |
| 62 | +# Importance weighting |
| 63 | +# --------------------------------------------------------------------------- |
| 64 | + |
| 65 | + |
| 66 | +def _compute_importance_weights(P: Tensor, importance_clip: float = 50.0) -> Tensor: |
| 67 | + """Derive per-token importance weights from an attention probability matrix. |
| 68 | +
|
| 69 | + Parameters |
| 70 | + ---------- |
| 71 | + P : Tensor (B, N) |
| 72 | + Attention probabilities summed (or averaged) over queries – i.e. |
| 73 | + ``P[b, n]`` is how much attention is paid to token *n* in sample *b*. |
| 74 | + Typically ``P = softmax(Q K^T / sqrt(d)).sum(dim=-2)``. |
| 75 | + importance_clip : float |
| 76 | + Clamp the normalised weights to ``[1/clip, clip]`` to prevent |
| 77 | + extreme outliers. |
| 78 | +
|
| 79 | + Returns |
| 80 | + ------- |
| 81 | + sqrt_w : Tensor (N, 1) |
| 82 | + Square-root importance weights, suitable for left-multiplying the |
| 83 | + activation matrix so that high-attention tokens contribute more to |
| 84 | + the Hessian. |
| 85 | + """ |
| 86 | + # Sum across the batch (rows) to get a per-token importance score. |
| 87 | + w = P.sum(dim=0) # (N,) |
| 88 | + |
| 89 | + # Normalise so that the mean weight is 1. |
| 90 | + w = w / w.mean().clamp(min=1e-30) |
| 91 | + |
| 92 | + # Clip to avoid extreme values. |
| 93 | + w = w.clamp(min=1.0 / importance_clip, max=importance_clip) |
| 94 | + |
| 95 | + return w.sqrt().unsqueeze(1) # (N, 1) |
| 96 | + |
| 97 | + |
| 98 | +# --------------------------------------------------------------------------- |
| 99 | +# KL divergence in logit space |
| 100 | +# --------------------------------------------------------------------------- |
| 101 | + |
| 102 | + |
| 103 | +def kl_divergence_logits( |
| 104 | + Q: Tensor, |
| 105 | + K: Tensor, |
| 106 | + K_q: Tensor, |
| 107 | + temperature: float = 1.0, |
| 108 | +) -> float: |
| 109 | + """Compute the KL divergence between attention distributions induced by *K* and *K_q*. |
| 110 | +
|
| 111 | + Uses the logit identity to avoid materialising the full attention matrix: |
| 112 | +
|
| 113 | + KL(P || P_q) = E_x[ P^T (s - s_q) + logsumexp(s_q) - logsumexp(s) ] |
| 114 | +
|
| 115 | + where ``s = Q K^T / temperature`` and ``s_q = Q K_q^T / temperature``. |
| 116 | +
|
| 117 | + Parameters |
| 118 | + ---------- |
| 119 | + Q : Tensor (..., S, D) |
| 120 | + K : Tensor (..., N, D) |
| 121 | + K_q : Tensor (..., N, D) |
| 122 | + temperature : float |
| 123 | +
|
| 124 | + Returns |
| 125 | + ------- |
| 126 | + kl : float |
| 127 | + Mean KL divergence in **bits** (i.e. divided by ln 2). |
| 128 | + """ |
| 129 | + Q64 = Q.double() |
| 130 | + K64 = K.double() |
| 131 | + Kq64 = K_q.double() |
| 132 | + |
| 133 | + s = Q64 @ K64.transpose(-2, -1) / temperature # (..., S, N) |
| 134 | + s_q = Q64 @ Kq64.transpose(-2, -1) / temperature # (..., S, N) |
| 135 | + |
| 136 | + log_Z = torch.logsumexp(s, dim=-1) # (..., S) |
| 137 | + log_Z_q = torch.logsumexp(s_q, dim=-1) # (..., S) |
| 138 | + |
| 139 | + P = torch.softmax(s, dim=-1) # (..., S, N) |
| 140 | + |
| 141 | + # KL per query position: sum_n P_n (s_n - s_q_n) + log_Z_q - log_Z |
| 142 | + kl_per_query = (P * (s - s_q)).sum(dim=-1) + log_Z_q - log_Z # (..., S) |
| 143 | + |
| 144 | + # Convert nats to bits and return mean. |
| 145 | + import math |
| 146 | + |
| 147 | + return (kl_per_query.mean() / math.log(2)).item() |
| 148 | + |
| 149 | + |
| 150 | +# --------------------------------------------------------------------------- |
| 151 | +# WaterSICKVHelper |
| 152 | +# --------------------------------------------------------------------------- |
| 153 | + |
| 154 | + |
| 155 | +class WaterSICKVHelper: |
| 156 | + """Hook-based helper that captures Q/K activations and runs WaterSIC quantisation. |
| 157 | +
|
| 158 | + Usage:: |
| 159 | +
|
| 160 | + helper = WaterSICKVHelper(quant_attn_module, "layer.3") |
| 161 | + helper.setup() |
| 162 | + # ... run calibration forward passes ... |
| 163 | + state = helper.quantize(target_rate=4.0) |
| 164 | + helper.cleanup() |
| 165 | + helper.free() |
| 166 | + """ |
| 167 | + |
| 168 | + def __init__( |
| 169 | + self, |
| 170 | + module, |
| 171 | + name: str, |
| 172 | + kl_aware: bool = False, |
| 173 | + importance_clip: float = 50.0, |
| 174 | + ): |
| 175 | + self.module = module |
| 176 | + self.name = name |
| 177 | + self.kl_aware = kl_aware |
| 178 | + self.importance_clip = importance_clip |
| 179 | + |
| 180 | + self.collected_Q: list[Tensor] = [] |
| 181 | + self.collected_K: list[Tensor] = [] |
| 182 | + |
| 183 | + self._original_fn = None |
| 184 | + |
| 185 | + # ----- patching -------------------------------------------------- |
| 186 | + |
| 187 | + def setup(self): |
| 188 | + """Patch ``_quantized_attention`` on the module instance to capture Q/K.""" |
| 189 | + # The original is a @staticmethod on the class – grab the underlying function. |
| 190 | + original_fn = type(self.module)._quantized_attention |
| 191 | + self._original_fn = original_fn |
| 192 | + |
| 193 | + helper = self # closure reference |
| 194 | + |
| 195 | + def patched_fn( |
| 196 | + original_attention_interface, |
| 197 | + self_attn, |
| 198 | + query_states, |
| 199 | + key_states, |
| 200 | + value_states, |
| 201 | + *args, |
| 202 | + **kwargs, |
| 203 | + ): |
| 204 | + # Capture detached CPU copies before quantizers touch them. |
| 205 | + helper.collected_Q.append(query_states.detach().cpu()) |
| 206 | + helper.collected_K.append(key_states.detach().cpu()) |
| 207 | + |
| 208 | + # Call the original static method (not bound, pass all args). |
| 209 | + return original_fn( |
| 210 | + original_attention_interface, |
| 211 | + self_attn, |
| 212 | + query_states, |
| 213 | + key_states, |
| 214 | + value_states, |
| 215 | + *args, |
| 216 | + **kwargs, |
| 217 | + ) |
| 218 | + |
| 219 | + # Patch on the *instance* so it shadows the class-level staticmethod. |
| 220 | + self.module._quantized_attention = patched_fn |
| 221 | + |
| 222 | + def cleanup(self): |
| 223 | + """Remove the instance-level override, restoring the class staticmethod.""" |
| 224 | + if "_quantized_attention" in vars(self.module): |
| 225 | + delattr(self.module, "_quantized_attention") |
| 226 | + |
| 227 | + # ----- quantisation ----------------------------------------------- |
| 228 | + |
| 229 | + def quantize( |
| 230 | + self, |
| 231 | + target_rate: float = 4.0, |
| 232 | + use_lmmse: bool = True, |
| 233 | + n_rescaler_iters: int = 0, |
| 234 | + sample_frac: float = 0.1, |
| 235 | + ) -> WaterSICKVState: |
| 236 | + """Run WaterSIC quantisation on the collected key activations. |
| 237 | +
|
| 238 | + Parameters |
| 239 | + ---------- |
| 240 | + target_rate : float |
| 241 | + Target coding rate in bits per element. |
| 242 | + use_lmmse : bool |
| 243 | + Whether to apply LMMSE gain correction. |
| 244 | + n_rescaler_iters : int |
| 245 | + Number of alternating rescaler iterations (0 = disable). |
| 246 | + sample_frac : float |
| 247 | + Fraction of rows used by :func:`binary_search_c`. |
| 248 | +
|
| 249 | + Returns |
| 250 | + ------- |
| 251 | + WaterSICKVState |
| 252 | + """ |
| 253 | + # Concatenate collected activations across calibration batches. |
| 254 | + # Each tensor is (batch, n_heads, seq, d_head). |
| 255 | + Q_all = torch.cat(self.collected_Q, dim=0) # (B_total, H, S_q, D) |
| 256 | + K_all = torch.cat(self.collected_K, dim=0) # (B_total, H, S_k, D) |
| 257 | + |
| 258 | + B, H, S_k, D = K_all.shape |
| 259 | + |
| 260 | + # We'll store per-head results. |
| 261 | + Z_heads = [] |
| 262 | + alpha_heads = [] |
| 263 | + gamma_heads = [] |
| 264 | + perm_heads = [] |
| 265 | + rates = [] |
| 266 | + |
| 267 | + damp_pct = damp_for_rate(target_rate) |
| 268 | + |
| 269 | + for h in range(H): |
| 270 | + # K_h shape: (B, S_k, D) → treat as weight matrix (a, n) where |
| 271 | + # a = B * S_k (token-batch dimension) and n = D (head dimension). |
| 272 | + K_h = K_all[:, h, :, :].reshape(-1, D).double() # (B*S_k, D) |
| 273 | + |
| 274 | + # Activation matrix: use Q_h^T so the Hessian reflects query-key |
| 275 | + # interaction. A shape: (D, B*S_q). |
| 276 | + Q_h = Q_all[:, h, :, :].reshape(-1, D).double() # (B*S_q, D) |
| 277 | + A = Q_h.T # (D, B*S_q) |
| 278 | + |
| 279 | + # Optional importance weighting — scale K rows (not A) so that |
| 280 | + # high-attention tokens contribute more to the quantisation objective. |
| 281 | + sqrt_w = None |
| 282 | + if self.kl_aware: |
| 283 | + # Compute attention probs: P = softmax(Q_h @ K_h^T / sqrt(D)) |
| 284 | + scores = Q_h @ K_h.T / (D**0.5) |
| 285 | + P = torch.softmax(scores.double(), dim=-1).float() |
| 286 | + sqrt_w = _compute_importance_weights(P, self.importance_clip) |
| 287 | + K_h = K_h * sqrt_w # Scale K rows by importance |
| 288 | + |
| 289 | + # Precompute Hessian / Cholesky. |
| 290 | + precomputed = _compute_hessian_cholesky(A, damp_pct=damp_pct) |
| 291 | + _, L, perm = precomputed |
| 292 | + |
| 293 | + # Binary search for the scale factor c. |
| 294 | + c = binary_search_c( |
| 295 | + K_h, |
| 296 | + A, |
| 297 | + target_rate=target_rate, |
| 298 | + damp_pct=damp_pct, |
| 299 | + use_lmmse=use_lmmse, |
| 300 | + n_rescaler_iters=n_rescaler_iters, |
| 301 | + sample_frac=sample_frac, |
| 302 | + _precomputed=precomputed, |
| 303 | + ) |
| 304 | + |
| 305 | + # Full quantisation. |
| 306 | + W_hat, rate, nmse, Z_h, gamma_h = watersic_quantize( |
| 307 | + K_h, |
| 308 | + A, |
| 309 | + c, |
| 310 | + damp_pct=damp_pct, |
| 311 | + use_lmmse=use_lmmse, |
| 312 | + n_rescaler_iters=n_rescaler_iters, |
| 313 | + _precomputed=precomputed, |
| 314 | + ) |
| 315 | + |
| 316 | + # Undo importance scaling after quantisation. |
| 317 | + if sqrt_w is not None: |
| 318 | + W_hat = W_hat / sqrt_w |
| 319 | + |
| 320 | + print_rank_0( |
| 321 | + f" [{self.name}] head {h}: rate={rate:.2f} bpe, nmse={nmse:.4f}" |
| 322 | + ) |
| 323 | + |
| 324 | + # Recover per-head state. |
| 325 | + # alpha = c / L.diag() (same as inside watersic_quantize). |
| 326 | + alpha_h = (c / L.diag()).float() |
| 327 | + |
| 328 | + Z_heads.append(Z_h) |
| 329 | + alpha_heads.append(alpha_h) |
| 330 | + gamma_heads.append(gamma_h.float()) |
| 331 | + perm_heads.append(perm) |
| 332 | + rates.append(rate) |
| 333 | + |
| 334 | + mean_rate = sum(rates) / len(rates) if rates else 0.0 |
| 335 | + |
| 336 | + state = WaterSICKVState( |
| 337 | + Z=torch.stack(Z_heads), # (H, B*S_k, D) |
| 338 | + alpha=torch.stack(alpha_heads), # (H, D) |
| 339 | + gamma=torch.stack(gamma_heads), # (H, D) |
| 340 | + perm=torch.stack(perm_heads) if perm_heads and perm_heads[0] is not None else None, |
| 341 | + rate=mean_rate, |
| 342 | + ) |
| 343 | + |
| 344 | + # Attach state to the module for downstream consumers. |
| 345 | + self.module._watersic_kv_state = state |
| 346 | + |
| 347 | + return state |
| 348 | + |
| 349 | + # ----- cleanup ----------------------------------------------------- |
| 350 | + |
| 351 | + def free(self): |
| 352 | + """Release collected calibration data.""" |
| 353 | + self.collected_Q.clear() |
| 354 | + self.collected_K.clear() |
0 commit comments