diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index f19e82c5d4..24a27f3cca 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -127,6 +127,7 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None: "nvfp4": "NVFP4_KV_CFG", "nvfp4_affine": "NVFP4_AFFINE_KV_CFG", "nvfp4_rotate": "NVFP4_KV_ROTATE_CFG", + "watersic_kv": "WATERSIC_KV_CFG", } # Formats that use use_constant_amax (no calibration needed). @@ -384,7 +385,7 @@ def forward_step(model, batch): # We need to explicitly set up KV cache quantization after auto_quantize enable_quant_kv_cache = args.kv_cache_qformat != "none" print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") - if enable_quant_kv_cache: + if enable_quant_kv_cache and args.kv_cache_qformat != "watersic_kv": kv_cache_quant_cfg = copy.deepcopy( getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"] ) @@ -403,6 +404,16 @@ def forward_step(model, batch): [{"quantizer_name": "*", "enable": False}, *kv_cache_quant_cfg], ): mtq.calibrate(language_model, algorithm="max", forward_loop=calibrate_loop) + + # WaterSIC KV-cache needs a separate quantization pass with its own algorithm + if args.kv_cache_qformat == "watersic_kv": + watersic_cfg = copy.deepcopy(getattr(mtq, KV_QUANT_CFG_CHOICES["watersic_kv"])) + if args.watersic_target_rate is not None: + watersic_cfg["algorithm"]["target_rate"] = args.watersic_target_rate + if args.watersic_kl_aware: + watersic_cfg["algorithm"]["kl_aware"] = True + language_model = mtq.quantize(language_model, watersic_cfg, forward_loop=calibrate_loop) + return language_model @@ -423,7 +434,7 @@ def load_model(args: argparse.Namespace): f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}" ) quant_cfg = QUANT_CFG_CHOICES[args.qformat] - if args.kv_cache_qformat != "none": + if args.kv_cache_qformat not in {"none", "watersic_kv"}: quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( quant_cfg, getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"], @@ -652,6 +663,15 @@ def mono_quantize( else: language_model = mtq.quantize(language_model, quant_cfg, forward_loop=calibrate_loop) + # WaterSIC KV-cache needs a separate quantization pass with its own algorithm + if args.kv_cache_qformat == "watersic_kv": + watersic_cfg = copy.deepcopy(getattr(mtq, KV_QUANT_CFG_CHOICES["watersic_kv"])) + if args.watersic_target_rate is not None: + watersic_cfg["algorithm"]["target_rate"] = args.watersic_target_rate + if args.watersic_kl_aware: + watersic_cfg["algorithm"]["kl_aware"] = True + language_model = mtq.quantize(language_model, watersic_cfg, forward_loop=calibrate_loop) + # For VL models, update full_model to use the quantized language model if is_nemotron_vl_model: language_model_lineage = get_language_model_from_vl(full_model) @@ -1083,7 +1103,8 @@ def quantize_main( print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") # Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer. - if enable_quant_kv_cache: + # WaterSIC KV-cache uses a separate quantization pass, so skip merging here. + if enable_quant_kv_cache and args.kv_cache_qformat != "watersic_kv": quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant( quant_cfg, getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"], @@ -1242,6 +1263,17 @@ def parse_args() -> argparse.Namespace: "Other formats (fp8, nvfp4, etc.) use data-driven calibration." ), ) + parser.add_argument( + "--watersic_target_rate", + type=float, + default=None, + help="Target bits per element for WaterSIC KV-cache quantization (default: 2.0)", + ) + parser.add_argument( + "--watersic_kl_aware", + action="store_true", + help="Enable KL-aware importance weighting for WaterSIC KV-cache quantization", + ) parser.add_argument( "--export_fmt", required=False, diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms/__init__.py similarity index 99% rename from modelopt/torch/quantization/algorithms.py rename to modelopt/torch/quantization/algorithms/__init__.py index f1db2df9e8..9e5d16f8e5 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms/__init__.py @@ -38,12 +38,12 @@ from modelopt.torch.utils import create_param_grad_clear_hook, print_rank_0, report_memory from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState, is_master -from . import config as mtq_config -from . import model_calib -from .config import QuantizeConfig, QuantizerAttributeConfig -from .conversion import set_quantizer_by_cfg -from .nn import QuantLinearConvBase, QuantModule, SequentialQuantizer, TensorQuantizer -from .utils import is_quantized_linear +from .. import config as mtq_config +from .. import model_calib +from ..config import QuantizeConfig, QuantizerAttributeConfig +from ..conversion import set_quantizer_by_cfg +from ..nn import QuantLinearConvBase, QuantModule, SequentialQuantizer, TensorQuantizer +from ..utils import is_quantized_linear def estimate_quant_compression(quant_cfg: QuantizeConfig) -> float: @@ -615,8 +615,8 @@ def before_search(self): # Import here to avoid circular import from modelopt.torch.quantization.model_quant import calibrate - from .conversion import restore_quantizer_state, update_quantize_metadata - from .utils import get_quantizer_state_dict, set_quantizer_state_dict + from ..conversion import restore_quantizer_state, update_quantize_metadata + from ..utils import get_quantizer_state_dict, set_quantizer_state_dict super().before_search() restored_method = getattr(self, "method", None) diff --git a/modelopt/torch/quantization/algorithms/watersic_kv/__init__.py b/modelopt/torch/quantization/algorithms/watersic_kv/__init__.py new file mode 100644 index 0000000000..2843cf4e55 --- /dev/null +++ b/modelopt/torch/quantization/algorithms/watersic_kv/__init__.py @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""WaterSIC KV-cache quantization algorithm.""" + +from __future__ import annotations + +from .config import WaterSICKVCalibConfig +from .helper import WaterSICKVHelper, WaterSICKVState + +__all__ = ["WaterSICKVCalibConfig", "WaterSICKVHelper", "WaterSICKVState"] diff --git a/modelopt/torch/quantization/algorithms/watersic_kv/config.py b/modelopt/torch/quantization/algorithms/watersic_kv/config.py new file mode 100644 index 0000000000..3232ba4e09 --- /dev/null +++ b/modelopt/torch/quantization/algorithms/watersic_kv/config.py @@ -0,0 +1,115 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration for the WaterSIC KV-cache quantization algorithm.""" + +from __future__ import annotations + +from typing import Literal + +from modelopt.torch.opt.config import ModeloptField +from modelopt.torch.quantization.config import QuantizeAlgorithmConfig + + +class WaterSICKVCalibConfig(QuantizeAlgorithmConfig): + """Configuration for WaterSIC KV-cache quantization. + + WaterSIC (Water-filling Successive Interference Cancellation) is a + rate-adaptive quantization method for KV-cache compression. It + applies the ZSIC algorithm with optional KL-aware importance + weighting and LMMSE shrinkage correction to minimize attention-output + distortion at a target bits-per-element budget. + + Reference: "WaterSIC: Water-filling Successive Interference + Cancellation for KV-Cache Quantization" (2024). + """ + + method: Literal["watersic_kv"] = ModeloptField( + "watersic_kv", + title="Calibration algorithm identifier.", + description="Fixed identifier for the WaterSIC KV-cache calibration method.", + ) + + target_rate: float = ModeloptField( + default=2.0, + gt=0.0, + title="Target bits per element.", + description=( + "Average number of bits per quantized KV-cache element. The binary " + "search over the ZSIC damping parameter c is driven to hit this rate." + ), + ) + + kl_aware: bool = ModeloptField( + default=False, + title="Enable KL-aware importance weighting.", + description=( + "When True, per-token importance weights derived from the attention " + "distribution are folded into the Hessian so that tokens with higher " + "attention mass receive tighter quantization." + ), + ) + + importance_clip: float = ModeloptField( + default=50.0, + gt=0.0, + title="Importance weight clipping ratio.", + description=( + "Maximum ratio by which a single token's importance weight may exceed " + "the mean weight. Clips extreme outlier tokens to prevent them from " + "dominating the Hessian estimate." + ), + ) + + use_lmmse: bool = ModeloptField( + default=True, + title="Apply LMMSE shrinkage correction.", + description=( + "When True, the LMMSE (Linear Minimum Mean-Squared Error) shrinkage " + "correction is applied after ZSIC quantization to partially undo " + "quantization bias and reduce reconstruction NMSE." + ), + ) + + n_rescaler_iters: int = ModeloptField( + default=0, + ge=0, + title="Diagonal rescaler optimization iterations.", + description=( + "Number of coordinate-descent iterations for the diagonal rescaler " + "that adjusts per-column scale factors after LMMSE. Set to 0 to " + "disable the rescaler (faster but slightly higher distortion)." + ), + ) + + sample_frac: float | None = ModeloptField( + default=None, + title="Row subsampling fraction for binary search.", + description=( + "If set, only this fraction of rows (KV heads) are used during the " + "binary search for c. Full rows are then quantized with the found c. " + "Speeds up calibration on large KV caches at a small accuracy cost." + ), + ) + + use_sequential: bool = ModeloptField( + default=False, + title="Enable sequential layer-by-layer calibration.", + description=( + "Must be False for WaterSIC. Unlike weight quantization, KV-cache " + "quantization does not have progressive error accumulation between " + "layers, so sequential calibration is not needed." + ), + ) diff --git a/modelopt/torch/quantization/algorithms/watersic_kv/helper.py b/modelopt/torch/quantization/algorithms/watersic_kv/helper.py new file mode 100644 index 0000000000..5621ad2316 --- /dev/null +++ b/modelopt/torch/quantization/algorithms/watersic_kv/helper.py @@ -0,0 +1,295 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""WaterSIC KV-cache quantizer helper. + +Wraps the core ZSIC math with attention-module hooking for real model +calibration. The :class:`WaterSICKVHelper` patches +``_QuantAttention._quantized_attention`` to capture query / key activations, +then runs :func:`watersic_quantize` per head. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import torch +from torch import Tensor + +from modelopt.torch.utils import print_rank_0 + +from .zsic import _compute_hessian_cholesky, binary_search_c, damp_for_rate, watersic_quantize + + +@dataclass +class WaterSICKVState: + """Per-layer quantisation state produced by :meth:`WaterSICKVHelper.quantize`.""" + + Z: Tensor + """Integer code-book indices.""" + alpha: Tensor + """Per-column step sizes.""" + gamma: Tensor + """Per-column LMMSE gains.""" + perm: Tensor | None + """Column permutation (or *None*).""" + rate: float + """Achieved coding rate (bits per element).""" + + +def _compute_importance_weights(P: Tensor, importance_clip: float = 50.0) -> Tensor: + """Derive per-token importance weights from an attention probability matrix. + + Parameters + ---------- + P : Tensor (B, N) + Attention probabilities summed (or averaged) over queries – i.e. + ``P[b, n]`` is how much attention is paid to token *n* in sample *b*. + Typically ``P = softmax(Q K^T / sqrt(d)).sum(dim=-2)``. + importance_clip : float + Clamp the normalised weights to ``[1/clip, clip]`` to prevent + extreme outliers. + + Returns: + ------- + sqrt_w : Tensor (N, 1) + Square-root importance weights, suitable for left-multiplying the + activation matrix so that high-attention tokens contribute more to + the Hessian. + """ + # Sum across the batch (rows) to get a per-token importance score. + w = P.sum(dim=0) # (N,) + + # Normalise so that the mean weight is 1. + w = w / w.mean().clamp(min=1e-30) + + # Clip to avoid extreme values. + w = w.clamp(min=1.0 / importance_clip, max=importance_clip) + + return w.sqrt().unsqueeze(1) # (N, 1) + + +class WaterSICKVHelper: + """Hook-based helper that captures Q/K activations and runs WaterSIC quantisation. + + Usage:: + + helper = WaterSICKVHelper(quant_attn_module, "layer.3") + helper.setup() + # ... run calibration forward passes ... + state = helper.quantize(target_rate=4.0) + helper.cleanup() + helper.free() + """ + + def __init__( + self, + module, + name: str, + kl_aware: bool = False, + importance_clip: float = 50.0, + ): + """Initialize helper for a single attention module.""" + self.module = module + self.name = name + self.kl_aware = kl_aware + self.importance_clip = importance_clip + + self.collected_Q: list[Tensor] = [] + self.collected_K: list[Tensor] = [] + + self._original_fn = None + + def setup(self): + """Patch ``_quantized_attention`` on the module instance to capture Q/K.""" + # The original is a @staticmethod on the class - grab the underlying function. + original_fn = type(self.module)._quantized_attention + self._original_fn = original_fn + + helper = self # closure reference + + def patched_fn( + original_attention_interface, + self_attn, + query_states, + key_states, + value_states, + *args, + **kwargs, + ): + # Capture detached CPU copies before quantizers touch them. + helper.collected_Q.append(query_states.detach().cpu()) + helper.collected_K.append(key_states.detach().cpu()) + + # Call the original static method (not bound, pass all args). + return original_fn( + original_attention_interface, + self_attn, + query_states, + key_states, + value_states, + *args, + **kwargs, + ) + + # Patch on the *instance* so it shadows the class-level staticmethod. + self.module._quantized_attention = patched_fn + + def cleanup(self): + """Remove the instance-level override, restoring the class staticmethod.""" + if "_quantized_attention" in vars(self.module): + delattr(self.module, "_quantized_attention") + + def quantize( + self, + target_rate: float = 4.0, + use_lmmse: bool = True, + n_rescaler_iters: int = 0, + sample_frac: float | None = None, + ) -> WaterSICKVState: + """Run WaterSIC quantisation on the collected key activations. + + Parameters + ---------- + target_rate : float + Target coding rate in bits per element. + use_lmmse : bool + Whether to apply LMMSE gain correction. + n_rescaler_iters : int + Number of alternating rescaler iterations (0 = disable). + sample_frac : float + Fraction of rows used by :func:`binary_search_c`. + + Returns: + ------- + WaterSICKVState + """ + if not self.collected_Q or not self.collected_K: + raise RuntimeError( + f"[{self.name}] No Q/K activations were collected during the calibration " + f"forward pass. Ensure setup() was called before the forward loop and that " + f"the forward loop passes data through this attention layer." + ) + + # Concatenate collected activations across calibration batches. + # Each tensor is (batch, n_heads, seq, d_head). + Q_all = torch.cat(self.collected_Q, dim=0) # (B_total, H, S_q, D) + K_all = torch.cat(self.collected_K, dim=0) # (B_total, H, S_k, D) + + B, H, S_k, D = K_all.shape + + # We'll store per-head results. + Z_heads = [] + alpha_heads = [] + gamma_heads = [] + perm_heads = [] + rates = [] + + damp_pct = damp_for_rate(target_rate) + + # Run quantization on GPU if available (much faster for real models). + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + for h in range(H): + # K_h shape: (B, S_k, D) → treat as weight matrix (a, n) where + # a = B * S_k (token-batch dimension) and n = D (head dimension). + K_h = K_all[:, h, :, :].reshape(-1, D).to(device=device, dtype=torch.float64) + + # Activation matrix: use Q_h^T so the Hessian reflects query-key + # interaction. A shape: (D, B*S_q). + Q_h = Q_all[:, h, :, :].reshape(-1, D).to(device=device, dtype=torch.float64) + A = Q_h.T # (D, B*S_q) + + # Optional importance weighting — scale K rows (not A) so that + # high-attention tokens contribute more to the quantisation objective. + sqrt_w = None + if self.kl_aware: + # Compute attention probs: P = softmax(Q_h @ K_h^T / sqrt(D)) + scores = Q_h @ K_h.T / (D**0.5) + P = torch.softmax(scores.double(), dim=-1).float() + sqrt_w = _compute_importance_weights(P, self.importance_clip) + K_h = K_h * sqrt_w # Scale K rows by importance + + # Precompute Hessian / Cholesky. + precomputed = _compute_hessian_cholesky(A, damp_pct=damp_pct) + _, L, perm = precomputed + + # Binary search for the scale factor c. + n_tokens = K_h.shape[0] + sf = sample_frac if sample_frac is not None else min(0.1, 1000.0 / max(n_tokens, 1)) + c = binary_search_c( + K_h, + A, + target_rate=target_rate, + damp_pct=damp_pct, + use_lmmse=use_lmmse, + n_rescaler_iters=n_rescaler_iters, + sample_frac=sf, + _precomputed=precomputed, + ) + + # Full quantisation. + W_hat, rate, nmse, Z_h, gamma_h = watersic_quantize( + K_h, + A, + c, + damp_pct=damp_pct, + use_lmmse=use_lmmse, + n_rescaler_iters=n_rescaler_iters, + _precomputed=precomputed, + ) + + # Undo importance scaling after quantisation. + if sqrt_w is not None: + W_hat = W_hat / sqrt_w + + print_rank_0(f" [{self.name}] head {h}: rate={rate:.2f} bpe, nmse={nmse:.4f}") + + # Recover per-head state. + # alpha = c / L.diag() (same as inside watersic_quantize). + alpha_h = (c / L.diag()).float() + if perm is not None: + inv_perm = torch.argsort(perm) + alpha_h = alpha_h[inv_perm] + + # Move results to CPU to free GPU memory for next head. + Z_heads.append(Z_h.cpu()) + alpha_heads.append(alpha_h.cpu()) + gamma_heads.append(gamma_h.float().cpu()) + perm_heads.append(perm.cpu() if perm is not None else None) + rates.append(rate) + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + mean_rate = sum(rates) / len(rates) if rates else 0.0 + + state = WaterSICKVState( + Z=torch.stack(Z_heads), + alpha=torch.stack(alpha_heads), + gamma=torch.stack(gamma_heads), + perm=torch.stack(perm_heads) if perm_heads and perm_heads[0] is not None else None, + rate=mean_rate, + ) + + # Attach state to the module for downstream consumers. + self.module._watersic_kv_state = state + + return state + + def free(self): + """Release collected calibration data.""" + self.collected_Q.clear() + self.collected_K.clear() diff --git a/modelopt/torch/quantization/algorithms/watersic_kv/zsic.py b/modelopt/torch/quantization/algorithms/watersic_kv/zsic.py new file mode 100644 index 0000000000..30cfad50ab --- /dev/null +++ b/modelopt/torch/quantization/algorithms/watersic_kv/zsic.py @@ -0,0 +1,372 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Core ZSIC (Zero-Shot Integer Compression) algorithm for WaterSIC KV-cache quantization. + +This is a pure math module with no Model-Optimizer dependencies. It implements +the sequential integer coding algorithm described in the WaterSIC paper. +""" + +from __future__ import annotations + +import math + +import torch +from torch import Tensor + + +def damp_for_rate(target_rate: float, base: float = 1e-4, knee: float = 5.0) -> float: + """Return a damping coefficient that decays for rates above *knee*. + + ``base * 4 ** (-max(0, target_rate - knee))`` + """ + return base * 4.0 ** (-max(0.0, target_rate - knee)) + + +def compute_entropy(Z: Tensor) -> float: + """Compute Shannon entropy (in bits) of integer-valued tensor *Z*.""" + # Flatten and count occurrences of each unique integer value. + flat = Z.flatten().long() + counts = torch.bincount(flat - flat.min()) + counts = counts[counts > 0] + probs = counts.float() / counts.sum().float() + return -(probs * probs.log2()).sum().item() + + +def compute_output_nmse(W: Tensor, W_q: Tensor, A: Tensor) -> float: + """Normalised MSE measured in the output space: ``||err @ A||^2 / ||W @ A||^2``. + + Uses the trace identity ``||M @ N||_F^2 = tr(M^T M N N^T)`` to avoid + materialising the ``(a, a)`` output matrix, which can be prohibitively large + when the number of tokens *a* is high (e.g. real-model calibration). + Only ``(n, n)`` intermediates are needed, where *n* = ``A.shape[0]``. + """ + Sigma_X = A @ A.T # (n, n) + delta = W - W_q # (a, n) + err_gram = delta.T @ delta # (n, n) + ref_gram = W.T @ W # (n, n) + err_sq = (err_gram * Sigma_X).sum() + ref_sq = (ref_gram * Sigma_X).sum() + if ref_sq < 1e-30: + return float("inf") + return (err_sq / ref_sq).item() + + +def _compute_hessian_cholesky( + A: Tensor, + damp_pct: float = 1e-4, + sort_cols: bool = True, +) -> tuple[Tensor, Tensor, Tensor | None]: + """Build the Hessian ``A A^T`` and return its damped Cholesky factor. + + Parameters + ---------- + A : Tensor + Activation matrix of shape ``(n, T)`` where *n* is the number of + weight columns and *T* is the number of calibration tokens. + damp_pct : float + Fraction of the mean diagonal used as Tikhonov damping. + sort_cols : bool + If *True*, reorder columns by ascending diagonal of ``A A^T`` + (improves numerical stability of the sequential coding). + + Returns: + ------- + Sigma_X : Tensor – ``A A^T`` (possibly column-reordered), shape ``(n, n)`` + L : Tensor – lower-triangular Cholesky factor of the damped + Hessian, shape ``(n, n)`` + perm : Tensor | None – permutation used (LongTensor of length *n*), or + *None* when ``sort_cols`` is *False*. + """ + perm: Tensor | None = None + + if sort_cols: + # Sort by ascending diagonal of A A^T ≡ ascending row-norms of A. + diag = (A * A).sum(dim=1) + perm = torch.argsort(diag) + A = A[perm] + + Sigma_X = A @ A.T + + damp = damp_pct * Sigma_X.diag().mean() + H = Sigma_X + damp * torch.eye(Sigma_X.shape[0], device=A.device, dtype=A.dtype) + + try: + L = torch.linalg.cholesky(H) + except torch.linalg.LinAlgError: + retry_damp = max(10 * damp, 1e-6) + H += retry_damp * torch.eye(H.shape[0], device=A.device, dtype=A.dtype) + try: + L = torch.linalg.cholesky(H) + except torch.linalg.LinAlgError as e: + raise RuntimeError( + f"Cholesky factorization failed even with increased damping " + f"({retry_damp:.2e}). The activation matrix (shape {tuple(A.shape)}) " + f"may be degenerate. Check that calibration data produces non-trivial " + f"activations." + ) from e + + return Sigma_X, L, perm + + +def _optimize_rescalers( + W_hat_0: Tensor, + W: Tensor, + Sigma_X: Tensor, + gamma_init: Tensor, + n_iters: int = 10, +) -> Tensor: + """Alternating row / column rescaler optimisation. + + Starting from ``gamma_init`` (per-column), iterate: + 1. Fix gamma, solve for row rescalers *t*. + 2. Fix *t*, solve for column rescalers *gamma*. + + Returns the rescaled reconstruction ``diag(t) @ W_hat_0 @ diag(gamma)``. + """ + gamma = gamma_init.clone() + t = torch.ones(W.shape[0], device=W.device, dtype=W.dtype) + + for _ in range(n_iters): + # --- Row rescalers (t) --- given gamma, minimise over t_i independently. + # For each row i: t_i = (W_hat_0[i] * gamma) . Sigma_X . W[i] + # / (W_hat_0[i] * gamma) . Sigma_X . (W_hat_0[i] * gamma) + scaled = W_hat_0 * gamma.unsqueeze(0) # (a, n) + num_t = (scaled @ Sigma_X * W).sum(dim=1) # (a,) + den_t = (scaled @ Sigma_X * scaled).sum(dim=1) # (a,) + t = num_t / den_t.clamp(min=1e-20) + + # --- Column rescalers (gamma) --- given t, minimise over gamma_j independently. + # gamma_j = sum_i t_i * W_hat_0[i,j] * (Sigma_X[j,:] @ W[i,:].T) + # / sum_i (t_i * W_hat_0[i,j])^2 * Sigma_X[j,j] + t_col = t.unsqueeze(1) # (a, 1) + tw = t_col * W_hat_0 # (a, n) + # numerator: for each j, sum_i tw[i,j] * (Sigma_X[j,:] @ W[i,:]) + num_g = (tw.T @ W @ Sigma_X.T).diag() # (n,) -- Sigma_X symmetric so .T ok + den_g = (tw * tw).T @ torch.ones(W.shape[0], device=W.device, dtype=W.dtype) + den_g = den_g * Sigma_X.diag() # (n,) + gamma = num_g / den_g.clamp(min=1e-20) + + return t.unsqueeze(1) * W_hat_0 * gamma.unsqueeze(0) + + +def zsic_quantize( + W: Tensor, + A: Tensor, + alpha: Tensor, + Sigma_X: Tensor, + L: Tensor, + use_lmmse: bool = True, + n_rescaler_iters: int = 0, +) -> tuple[Tensor, float, float, Tensor, Tensor]: + """Run the ZSIC sequential integer coding loop. + + Parameters + ---------- + W : Tensor (a, n) – weight matrix (rows = output channels). + A : Tensor (n, T) – activation matrix. + alpha : Tensor (n,) – per-column step sizes. + Sigma_X : Tensor (n, n) – ``A A^T``. + L : Tensor (n, n) – lower-triangular Cholesky factor. + use_lmmse : bool + Apply per-column LMMSE gain correction. + n_rescaler_iters : int + Number of alternating rescaler iterations (0 = disable). + + Returns: + ------- + W_hat : Tensor (a, n) – quantised reconstruction. + rate : float – estimated coding rate (bits per weight element). + nmse : float – output NMSE. + Z : Tensor (a, n) – integer codes. + gamma : Tensor (n,) – per-column LMMSE shrinkage gains. + """ + a, n = W.shape + + # M_T = L^{-1} Sigma_X (solve L M_T = Sigma_X for M_T) + M_T = torch.linalg.solve_triangular(L, Sigma_X, upper=False) + Y = W @ M_T.T # (a, n) + + Z = torch.zeros(a, n, device=W.device, dtype=torch.long) + gamma = torch.ones(n, device=W.device, dtype=W.dtype) + + for i in range(n - 1, -1, -1): + d_i = alpha[i] * L[i, i] + z_i = torch.round(Y[:, i] / d_i).long() + Z[:, i] = z_i + + z_f = z_i.float().to(W.dtype) + z_sq = z_f.dot(z_f) + + if use_lmmse and z_sq > 1e-20: + gamma[i] = z_f.dot(Y[:, i]) / (d_i * z_sq) + + # Efficient rank-1 update: Y -= (gamma[i]*alpha[i]) * z_f outer L[i,:] + Y.addr_(z_f, L[i], alpha=-(gamma[i] * alpha[i]).item()) + + # --- Entropy and rate --- + entropy = compute_entropy(Z) + rate = entropy + 16.0 / a + 16.0 / n + + # --- Reconstruction --- + W_hat_0 = Z.float().to(W.dtype) * alpha.unsqueeze(0) + + if n_rescaler_iters > 0: + W_hat = _optimize_rescalers(W_hat_0, W, Sigma_X, gamma, n_iters=n_rescaler_iters) + elif use_lmmse: + W_hat = W_hat_0 * gamma.unsqueeze(0) + else: + W_hat = W_hat_0 + + nmse = compute_output_nmse(W, W_hat, A) + return W_hat, rate, nmse, Z, gamma + + +def watersic_quantize( + W: Tensor, + A: Tensor, + c: float, + damp_pct: float = 1e-4, + use_lmmse: bool = True, + n_rescaler_iters: int = 0, + _precomputed: tuple[Tensor, Tensor, Tensor | None] | None = None, +) -> tuple[Tensor, float, float, Tensor, Tensor]: + """Quantise *W* using the WaterSIC algorithm for a given scale factor *c*. + + Parameters + ---------- + W : Tensor (a, n) + A : Tensor (n, T) + c : float – global scale factor that controls the rate/distortion trade-off. + damp_pct : float + use_lmmse : bool + n_rescaler_iters : int + _precomputed : tuple, optional + ``(Sigma_X, L, perm)`` from a prior call to + :func:`_compute_hessian_cholesky` to avoid redundant computation. + + Returns: + ------- + W_hat : Tensor (a, n) – quantised reconstruction (in original column order). + rate : float + nmse : float + Z : Tensor (a, n) – integer codes (in original column order). + gamma : Tensor (n,) – per-column LMMSE shrinkage gains (in original column order). + """ + if _precomputed is not None: + Sigma_X, L, perm = _precomputed + else: + Sigma_X, L, perm = _compute_hessian_cholesky(A, damp_pct=damp_pct) + + # Apply permutation to weight columns if used. + if perm is not None: + W = W[:, perm] + A = A[perm] + + alpha = c / L.diag() + + W_hat, rate, nmse, Z, gamma = zsic_quantize( + W, + A, + alpha, + Sigma_X, + L, + use_lmmse=use_lmmse, + n_rescaler_iters=n_rescaler_iters, + ) + + # Undo permutation. + if perm is not None: + inv_perm = torch.argsort(perm) + W_hat = W_hat[:, inv_perm] + Z = Z[:, inv_perm] + gamma = gamma[inv_perm] + + return W_hat, rate, nmse, Z, gamma + + +def binary_search_c( + W: Tensor, + A: Tensor, + target_rate: float, + damp_pct: float | None = None, + use_lmmse: bool = True, + n_rescaler_iters: int = 0, + n_iters: int = 30, + sample_frac: float = 0.1, + _precomputed: tuple[Tensor, Tensor, Tensor | None] | None = None, +) -> float: + """Find the scale factor *c* that achieves *target_rate* via log-space binary search. + + Parameters + ---------- + W : Tensor (a, n) + A : Tensor (n, T) + target_rate : float – desired bits per weight element. + damp_pct : float | None + If *None*, determined automatically via :func:`damp_for_rate`. + use_lmmse : bool + n_rescaler_iters : int + n_iters : int – number of binary-search iterations (default 30). + sample_frac : float – fraction of rows to use (default 10%). + _precomputed : tuple, optional + + Returns: + ------- + c : float – optimal scale factor. + """ + if damp_pct is None: + damp_pct = damp_for_rate(target_rate) + + a = W.shape[0] + n_sample = max(4, int(a * sample_frac)) + + # Subsample rows for speed. + if n_sample < a: + idx = torch.randperm(a, device=W.device)[:n_sample] + W_sub = W[idx] + else: + W_sub = W + + # Precompute Hessian / Cholesky (shared across iterations). + if _precomputed is not None: + precomputed = _precomputed + else: + precomputed = _compute_hessian_cholesky(A, damp_pct=damp_pct) + + log_c_lo = math.log(1e-6) + log_c_hi = math.log(1e3) + + for _ in range(n_iters): + log_c_mid = 0.5 * (log_c_lo + log_c_hi) + c_mid = math.exp(log_c_mid) + + _, rate, _, _, _ = watersic_quantize( + W_sub, + A, + c_mid, + damp_pct=damp_pct, + use_lmmse=use_lmmse, + n_rescaler_iters=n_rescaler_iters, + _precomputed=precomputed, + ) + + if rate > target_rate: + # c is too small (finer quantization = higher rate), increase c. + log_c_lo = log_c_mid + else: + log_c_hi = log_c_mid + + return math.exp(0.5 * (log_c_lo + log_c_hi)) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 99c729efbc..393e8ba7e2 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -741,6 +741,17 @@ def _nvfp4_selective_quant_cfg( "algorithm": "max", } +WATERSIC_KV_CFG = { + "quant_cfg": [ + {"quantizer_name": "*", "enable": False}, + {"quantizer_name": "*[kv]_bmm_quantizer", "enable": True}, + ], + "algorithm": { + "method": "watersic_kv", + "target_rate": 2.0, + }, +} + NVFP4_SVDQUANT_DEFAULT_CFG = _nvfp4_selective_quant_cfg( ["*"], algorithm={"method": "svdquant", "lowrank": 32} ) @@ -833,6 +844,7 @@ def _nvfp4_selective_quant_cfg( "MAMBA_MOE_FP8_CONSERVATIVE_CFG", "MAMBA_MOE_FP8_AGGRESSIVE_CFG", "NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG", + "WATERSIC_KV_CFG", } BiasType = Literal["static", "dynamic"] diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index c81d5c89c7..0363a15748 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -66,6 +66,7 @@ sequential_calibrate, smoothquant, svdquant, + watersic_kv, ) __all__ = ["BaseCalibrateModeDescriptor"] @@ -502,3 +503,17 @@ def config_class(self) -> type[QuantizeAlgorithmConfig]: return GPTQCalibConfig _calib_func = gptq + + +@CalibrateModeRegistry.register_mode +class WaterSICKVModeDescriptor(BaseCalibrateModeDescriptor): + """Mode for WaterSIC KV-cache quantization algorithm.""" + + @property + def config_class(self) -> type[QuantizeAlgorithmConfig]: + """Specifies the config class for the mode.""" + from .algorithms.watersic_kv.config import WaterSICKVCalibConfig + + return WaterSICKVCalibConfig + + _calib_func = watersic_kv diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 35a0e931c9..c7197ed1a3 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -23,8 +23,8 @@ import torch import torch.distributed as dist -import torch.nn as nn import torch.nn.functional as F +from torch import nn from tqdm import tqdm from modelopt.torch.opt.searcher import ForwardLoop @@ -58,6 +58,7 @@ "sequential_calibrate", "smoothquant", "svdquant", + "watersic_kv", ] @@ -100,7 +101,7 @@ def _check_moe_calibration_complete(quantizer, parallel_state): if any(amax_states) and not all(amax_states): raise RuntimeError( "MoE calibration incomplete: some experts received no tokens during calibration. " - "Increase --calib-size to ensure all experts see calibration data." + "Increase --calib-size to ensure all experts see calibration data.", ) @@ -186,7 +187,11 @@ def sync_quantizer_amax_across_tp( for _q in quantizer: # Syncing amax across TP for sequential quantizer sync_quantizer_amax_across_tp( - _q, linear_name, quantizer_type, axes_for_sync, parallel_state + _q, + linear_name, + quantizer_type, + axes_for_sync, + parallel_state, ) return # sync is not needed for block quantization @@ -194,7 +199,7 @@ def sync_quantizer_amax_across_tp( if hasattr(quantizer, "_padding"): warnings.warn( f"Found block-quantized padded {quantizer_type} for {linear_name}, amax will" - " not be synced correctly." + " not be synced correctly.", ) # Skip amax sync for INT4 / W4A8 block quantization # Sync amax for NVFP4 (dynamic per-block, static per-tensor quantized scale) @@ -249,7 +254,7 @@ def sync_quantizer_amax_across_tp( for quantizer in [module.k_bmm_quantizer, module.v_bmm_quantizer]: if isinstance(quantizer, TensorQuantizer) and quantizer.amax is not None: quantizer.sync_amax_across_distributed_group( - module.parallel_state.tensor_parallel_group + module.parallel_state.tensor_parallel_group, ) @@ -373,7 +378,7 @@ def mse_calibrate( # Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation # This prevents massive memory accumulation seen in large models for idx, (parent_module, weight_name, weight_quantizer) in enumerate( - tqdm(weight_quantizers, desc="MSE weight calibration") + tqdm(weight_quantizers, desc="MSE weight calibration"), ): # Enable calibration mode for the weight quantizer weight_quantizer.disable_quant() @@ -486,7 +491,7 @@ def setup(self): if self.cin % self.block_size != 0: warnings.warn( f"Module {self.name}: input features ({self.cin}) not divisible by " - f"block_size ({self.block_size}). Skipping local Hessian for this module." + f"block_size ({self.block_size}). Skipping local Hessian for this module.", ) self.is_enabled = False @@ -705,7 +710,7 @@ def enable_stats_collection(model: nn.Module): # Disable quantization during calibration so it doesn't affect other quantizers. module.disable_quant() continue - elif module._calibrator is not None: + if module._calibrator is not None: module.disable_quant() module.enable_calib() else: @@ -752,7 +757,7 @@ def disable_pre_quant_scale_and_resmooth(linear: nn.Module, delete_pre_quant_sca pre_quant_scale = linear.input_quantizer._pre_quant_scale.to(torch.float32) linear.weight.copy_( - (linear.weight * pre_quant_scale.squeeze()[None, :]).to(linear.weight.dtype) + (linear.weight * pre_quant_scale.squeeze()[None, :]).to(linear.weight.dtype), ) linear.weight_quantizer.reset_amax() max_calibrate(linear, lambda linear: linear.weight_quantizer(linear.weight)) @@ -764,7 +769,8 @@ def disable_pre_quant_scale_and_resmooth(linear: nn.Module, delete_pre_quant_sca assert hasattr(linear.input_quantizer, "_amax_for_smoothing") device, dtype = linear.weight.device, linear.weight.dtype linear.input_quantizer.amax = linear.input_quantizer._amax_for_smoothing.amax().to( - device=device, dtype=dtype + device=device, + dtype=dtype, ) if delete_pre_quant_scale: @@ -781,13 +787,13 @@ def _apply_weight_pre_quant_scale(linear, pre_quant_scale): if _ENABLE_FOLDING_PQS_TO_WEIGHTS: linear.weight.data.copy_( (linear.weight * pre_quant_scale.to(linear.weight.device).squeeze()[None, :]).to( - linear.weight.dtype - ) + linear.weight.dtype, + ), ) else: linear.weight_quantizer._enable_pre_quant_scale = True linear.weight_quantizer.pre_quant_scale = pre_quant_scale.squeeze()[None, :].to( - linear.weight.dtype + linear.weight.dtype, ) linear.weight_quantizer.reset_amax() @@ -796,7 +802,8 @@ def _apply_weight_pre_quant_scale(linear, pre_quant_scale): @torch.no_grad() def apply_pre_quant_scale_and_smooth( - linear: nn.Module, pre_quant_scale: torch.Tensor | None = None + linear: nn.Module, + pre_quant_scale: torch.Tensor | None = None, ): """Apply pre_quant_scale and smooth the quantized linear weights. @@ -829,7 +836,8 @@ def apply_pre_quant_scale_and_smooth( assert hasattr(linear.input_quantizer, "_amax_for_smoothing") device, dtype = linear.weight.device, linear.weight.dtype _amax_for_smoothing = linear.input_quantizer._amax_for_smoothing.to( - device=device, dtype=dtype + device=device, + dtype=dtype, ) linear.input_quantizer.amax = ( (_amax_for_smoothing * pre_quant_scale.to(device)).amax().to(dtype) @@ -837,7 +845,7 @@ def apply_pre_quant_scale_and_smooth( if is_quantized_column_parallel_linear(linear) or is_quantized_row_parallel_linear(linear): linear.input_quantizer.sync_amax_across_distributed_group( - linear.parallel_state.tensor_parallel_group + linear.parallel_state.tensor_parallel_group, ) @@ -1104,7 +1112,8 @@ def forward(self, input, *args, **kwargs): self.awq_lite.num_tokens += input.numel() / input.shape[-1] if self.awq_lite.is_input_quantized: with set_quantizer_by_cfg_context( - self.input_quantizer, [{"quantizer_name": "*", "enable": True}] + self.input_quantizer, + [{"quantizer_name": "*", "enable": True}], ): max_calibrate(self.input_quantizer, lambda quantizer: quantizer(input), False) return out_actual @@ -1155,7 +1164,9 @@ def sync_act_scale_across_dp(module, data_parallel_group): """Sync activation scale across Data Parallel (DP).""" if data_parallel_group.is_initialized(): dist.all_reduce( - module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=data_parallel_group.group + module.awq_lite.act_scale, + op=dist.ReduceOp.AVG, + group=data_parallel_group.group, ) for name, module in model.named_modules(): @@ -1169,10 +1180,12 @@ def sync_act_scale_across_dp(module, data_parallel_group): module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any( - torch.isnan(module.awq_lite.weight_scale) + torch.isnan(module.awq_lite.weight_scale), ) has_nan = DistributedProcessGroup.get_dist_syncd_obj( - has_nan_local, module.parallel_state.data_parallel_group, lambda objs: any(objs) + has_nan_local, + module.parallel_state.data_parallel_group, + lambda objs: any(objs), ) if has_nan: @@ -1250,7 +1263,7 @@ def postprocess(module, name): warnings.warn( "awq_lite: Calling `forward_loop(model)` the second time did not forward" f" data through the {name}. Please provide a valid `forward_loop` function" - " that can be used to forward data through the model many times." + " that can be used to forward data through the model many times.", ) with enable_weight_access_and_writeback(module, model, name_to_module): postprocess(module, name) @@ -1333,7 +1346,9 @@ def update_best_params(self): indices = loss < self.awq_clip.best_loss self.awq_clip.best_loss = torch.where(indices, loss, self.awq_clip.best_loss) self.awq_clip.best_clip_val = torch.where( - indices, self.awq_clip.w_amax * shrink, self.awq_clip.best_clip_val + indices, + self.awq_clip.w_amax * shrink, + self.awq_clip.best_clip_val, ) def _clip_search(self, inputs, co_bsz=256, max_tokens=16): @@ -1424,7 +1439,7 @@ def forward(name, self, input, *args, **kwargs): if "CUDA out of memory" in str(e): raise RuntimeError( f"Clip search on {name} failed due to CUDA out of memory, try reducing" - " max_co_batch_size" + " max_co_batch_size", ) from e raise RuntimeError(e) @@ -1499,7 +1514,7 @@ def svd(weight, rank): warnings.warn( "The low-rank dimensions do not match the layer dimensions. " "Please verify your configuration and model settings. " - f"Rank is {us.shape[1]} and {vt.shape[0]}" + f"Rank is {us.shape[1]} and {vt.shape[0]}", ) us_temp = torch.zeros((us.shape[0], rank), dtype=us.dtype, device=us.device) vt_temp = torch.zeros((rank, vt.shape[1]), dtype=vt.dtype, device=vt.device) @@ -1535,7 +1550,7 @@ def postprocess(module, name): module.weight_quantizer.svdquant_lora_a = vt module.weight_quantizer.svdquant_lora_b = us module.weight.data.sub_( - module.weight_quantizer.svdquant_lora_b @ module.weight_quantizer.svdquant_lora_a + module.weight_quantizer.svdquant_lora_b @ module.weight_quantizer.svdquant_lora_a, ) module.weight_quantizer.reset_amax() module.input_quantizer.reset_amax() @@ -1567,14 +1582,14 @@ def sequential_calibrate( if forward_loop is None: raise ValueError( "forward_loop must not be None for sequential calibration. " - "Please provide a valid forward_loop callable." + "Please provide a valid forward_loop callable.", ) transformer_layers = LayerActivationCollector.get_decoder_layers(model) if transformer_layers is None or len(transformer_layers) == 0: raise ValueError( "Could not find transformer layers in model. " - "Sequential calibration requires a model with identifiable transformer layers." + "Sequential calibration requires a model with identifiable transformer layers.", ) print_rank_0(f"Sequential calibration: Found {len(transformer_layers)} transformer layers") @@ -1655,7 +1670,8 @@ def gptq( print_rank_0(f"Computing Hessians for {len(gptq_handles)} linear layers...") with set_quantizer_by_cfg_context( - model, [{"quantizer_name": "*weight_quantizer", "enable": False}] + model, + [{"quantizer_name": "*weight_quantizer", "enable": False}], ): forward_loop(model) @@ -1671,3 +1687,78 @@ def gptq( if torch.cuda.is_available(): torch.cuda.empty_cache() print_rank_0(f"GPTQ time: {time.time() - total_start:.2f}s") + + +@torch.no_grad() +def watersic_kv( + model: nn.Module, + forward_loop: ForwardLoop, + target_rate: float = 2.0, + kl_aware: bool = False, + importance_clip: float = 50.0, + use_lmmse: bool = True, + n_rescaler_iters: int = 0, + sample_frac: float | None = None, +): + """WaterSIC KV-cache quantization. + + Collects post-RoPE Q and K tensors from attention layers during calibration, + then runs WaterSIC entropy-coded quantization on K cache per attention head. + + Args: + model: Module to quantize (full model or single decoder layer). + forward_loop: Callable that replays calibration inputs. + target_rate: Target bits per element. + kl_aware: Use attention-based importance weighting. + importance_clip: Clamp range for importance weights. + use_lmmse: Apply LMMSE shrinkage correction. + n_rescaler_iters: Diagonal rescaler optimization iterations. + sample_frac: Row subsampling for binary search (None = auto). + """ + from modelopt.torch.quantization.algorithms.watersic_kv.helper import WaterSICKVHelper + from modelopt.torch.quantization.plugins.huggingface import _QuantAttention + + total_start = time.time() + + attn_modules = [(n, m) for n, m in model.named_modules() if isinstance(m, _QuantAttention)] + if not attn_modules: + raise ValueError( + "WaterSIC KV-cache quantization was requested, but no _QuantAttention modules " + "were found. Ensure the model has been quantized with KV-cache quantizers enabled " + "before running WaterSIC KV calibration." + ) + + print_rank_0(f"WaterSIC KV: Found {len(attn_modules)} attention layers") + + # Phase 1: Collect Q, K + helpers = { + name: WaterSICKVHelper(m, name, kl_aware=kl_aware, importance_clip=importance_clip) + for name, m in attn_modules + } + for helper in helpers.values(): + helper.setup() + + print_rank_0("Collecting Q, K activations...") + try: + forward_loop(model) + finally: + for helper in helpers.values(): + helper.cleanup() + + # Phase 2: Run WaterSIC per layer + print_rank_0("Running WaterSIC quantization...") + for helper in helpers.values(): + helper.quantize( + target_rate=target_rate, + use_lmmse=use_lmmse, + n_rescaler_iters=n_rescaler_iters, + sample_frac=sample_frac, + ) + helper.free() + del helpers + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + elapsed = time.time() - total_start + print_rank_0(f"WaterSIC KV completed in {elapsed:.1f}s") diff --git a/pyproject.toml b/pyproject.toml index 6170876308..246eab65ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -211,6 +211,8 @@ extend-ignore = [ "N806", ] # triton kernel style "examples/deepseek/ds_kernel.py" = ["N803", "N806", "E731"] # triton style +"modelopt/torch/quantization/algorithms/watersic_kv/*" = ["N803", "N806"] +"tests/unit/torch/quantization/test_watersic_kv.py" = ["N803", "N806"] [tool.ruff.lint.pycodestyle] max-line-length = 120 # Line length limit for comments and docstrings diff --git a/tests/gpu/torch/quantization/test_watersic_kv.py b/tests/gpu/torch/quantization/test_watersic_kv.py new file mode 100644 index 0000000000..596cdcee6a --- /dev/null +++ b/tests/gpu/torch/quantization/test_watersic_kv.py @@ -0,0 +1,131 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU end-to-end tests for WaterSIC KV-cache quantization.""" + +import pytest +import torch +from _test_utils.torch.transformers_models import get_tiny_llama + +import modelopt.torch.quantization as mtq +from modelopt.torch.quantization.plugins.huggingface import _QuantAttention + + +@pytest.fixture +def tiny_llama(): + return get_tiny_llama() + + +@pytest.fixture +def calib_loop(): + def forward_loop(m): + # Use vocab_size=32 matching the tiny_llama fixture + input_ids = torch.randint(0, 32, (2, 32), device=next(m.parameters()).device) + m(input_ids) + + return forward_loop + + +class TestWaterSICKVEndToEnd: + """End-to-end GPU tests for WaterSIC KV-cache quantization.""" + + def test_standalone_watersic_kv(self, tiny_llama, calib_loop): + """Test WaterSIC KV-cache quantization as a standalone algorithm.""" + model = tiny_llama.to("cuda") + model.eval() + + config = { + "quant_cfg": [ + {"quantizer_name": "*", "enable": False}, + {"quantizer_name": "*[kv]_bmm_quantizer", "enable": True}, + ], + "algorithm": { + "method": "watersic_kv", + "target_rate": 4.0, + "use_sequential": False, + }, + } + + model = mtq.quantize(model, config, forward_loop=calib_loop) + + # Verify _watersic_kv_state exists on _QuantAttention modules + attn_modules = [m for m in model.modules() if isinstance(m, _QuantAttention)] + assert len(attn_modules) > 0, "No _QuantAttention modules found" + + for m in attn_modules: + assert hasattr(m, "_watersic_kv_state"), f"Module {m} missing _watersic_kv_state" + state = m._watersic_kv_state + assert state.rate > 0, f"Rate should be positive, got {state.rate}" + assert state.rate < 10, f"Rate should be < 10, got {state.rate}" + + def test_composable_with_fp8_weights(self, tiny_llama, calib_loop): + """Test composition with FP8 weight quantization.""" + model = tiny_llama.to("cuda") + model.eval() + + # Step 1: FP8 weight quantization + model = mtq.quantize(model, mtq.FP8_DEFAULT_CFG, forward_loop=calib_loop) + + # Step 2: WaterSIC KV-cache quantization + watersic_config = { + "quant_cfg": [ + {"quantizer_name": "*", "enable": False}, + {"quantizer_name": "*[kv]_bmm_quantizer", "enable": True}, + ], + "algorithm": { + "method": "watersic_kv", + "target_rate": 3.0, + "use_sequential": False, + }, + } + + model = mtq.quantize(model, watersic_config, forward_loop=calib_loop) + + # Verify model produces valid output (no NaN) + input_ids = torch.randint(0, 32, (1, 16), device="cuda") + with torch.no_grad(): + output = model(input_ids) + assert not torch.isnan(output.logits).any(), "Output contains NaN values" + + def test_kl_aware_mode(self, tiny_llama, calib_loop): + """Test KL-aware importance weighting.""" + model = tiny_llama.to("cuda") + model.eval() + + config = { + "quant_cfg": [ + {"quantizer_name": "*", "enable": False}, + {"quantizer_name": "*[kv]_bmm_quantizer", "enable": True}, + ], + "algorithm": { + "method": "watersic_kv", + "target_rate": 4.0, + "use_sequential": False, + "kl_aware": True, + "importance_clip": 20.0, + }, + } + + model = mtq.quantize(model, config, forward_loop=calib_loop) + + # Verify _watersic_kv_state exists on _QuantAttention modules + attn_modules = [m for m in model.modules() if isinstance(m, _QuantAttention)] + assert len(attn_modules) > 0, "No _QuantAttention modules found" + + for m in attn_modules: + assert hasattr(m, "_watersic_kv_state"), f"Module {m} missing _watersic_kv_state" + state = m._watersic_kv_state + assert state.rate > 0, f"Rate should be positive, got {state.rate}" + assert state.rate < 10, f"Rate should be < 10, got {state.rate}" diff --git a/tests/unit/torch/quantization/test_watersic_kv.py b/tests/unit/torch/quantization/test_watersic_kv.py new file mode 100644 index 0000000000..bd31315e72 --- /dev/null +++ b/tests/unit/torch/quantization/test_watersic_kv.py @@ -0,0 +1,377 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the core ZSIC algorithm (WaterSIC KV-cache quantization).""" + +from __future__ import annotations + +import pytest +import torch + +from modelopt.torch.quantization.algorithms.watersic_kv.zsic import ( + _compute_hessian_cholesky, + binary_search_c, + compute_entropy, + compute_output_nmse, + damp_for_rate, + watersic_quantize, + zsic_quantize, +) + +# --------------------------------------------------------------------------- +# TestDampForRate +# --------------------------------------------------------------------------- + + +class TestDampForRate: + """Tests for :func:`damp_for_rate`.""" + + def test_below_knee_returns_base(self): + """Rates below the knee should return the base value.""" + assert damp_for_rate(3.0) == pytest.approx(1e-4) + + def test_at_knee_returns_base(self): + """Rate exactly at the knee should return the base value.""" + assert damp_for_rate(5.0) == pytest.approx(1e-4) + + def test_above_knee_decays(self): + """Rate above the knee should decay: rate=6.0 gives base * 4^{-1} = 2.5e-5.""" + assert damp_for_rate(6.0) == pytest.approx(2.5e-5) + + def test_high_rate_very_small(self): + """Very high rates should produce a very small damping value.""" + val = damp_for_rate(10.0) + assert val < 1e-6 + + +# --------------------------------------------------------------------------- +# TestComputeEntropy +# --------------------------------------------------------------------------- + + +class TestComputeEntropy: + """Tests for :func:`compute_entropy`.""" + + def test_single_value_zero_entropy(self): + """A constant tensor has zero entropy.""" + Z = torch.full((10, 5), 3, dtype=torch.long) + assert compute_entropy(Z) == pytest.approx(0.0, abs=1e-7) + + def test_uniform_distribution(self): + """Four equally-likely values should give log2(4) = 2.0 bits.""" + Z = torch.tensor([0, 1, 2, 3] * 25, dtype=torch.long) + assert compute_entropy(Z) == pytest.approx(2.0, abs=1e-5) + + def test_binary(self): + """Half 0s, half 1s should give 1.0 bit.""" + Z = torch.tensor([0] * 50 + [1] * 50, dtype=torch.long) + assert compute_entropy(Z) == pytest.approx(1.0, abs=1e-5) + + +# --------------------------------------------------------------------------- +# TestComputeHessianCholesky +# --------------------------------------------------------------------------- + + +class TestComputeHessianCholesky: + """Tests for :func:`_compute_hessian_cholesky`.""" + + def test_identity_activations(self): + """With identity-like activations the Hessian should be PSD and L lower-triangular.""" + n = 8 + A = torch.eye(n, dtype=torch.float64) + Sigma_X, L, perm = _compute_hessian_cholesky(A, sort_cols=False) + + # PSD: eigenvalues >= 0 + eigvals = torch.linalg.eigvalsh(Sigma_X) + assert (eigvals >= -1e-8).all() + + # L is lower triangular + assert torch.allclose(L, L.tril()) + + # L @ L^T should approximate H = Sigma_X + damp*I + damp = 1e-4 * Sigma_X.diag().mean() + H = Sigma_X + damp * torch.eye(n, dtype=torch.float64) + assert torch.allclose(L @ L.T, H, atol=1e-6) + + def test_with_column_sorting(self): + """Column sorting should return a valid permutation with ascending diagonal.""" + torch.manual_seed(42) + A = torch.randn(6, 20, dtype=torch.float64) + Sigma_X, L, perm = _compute_hessian_cholesky(A, sort_cols=True) + + assert perm is not None + # perm is a valid permutation of 0..n-1 + assert set(perm.tolist()) == set(range(6)) + # Diagonal of the reordered Hessian should be ascending + diag_vals = Sigma_X.diag() + assert (diag_vals[1:] >= diag_vals[:-1] - 1e-8).all() + + +# --------------------------------------------------------------------------- +# TestComputeOutputNmse +# --------------------------------------------------------------------------- + + +class TestComputeOutputNmse: + """Tests for :func:`compute_output_nmse`.""" + + def test_zero_error(self): + """Perfect reconstruction should give NMSE = 0.""" + W = torch.randn(4, 8) + A = torch.randn(8, 16) + assert compute_output_nmse(W, W, A) == pytest.approx(0.0, abs=1e-7) + + def test_positive_error(self): + """Perturbed reconstruction should give positive NMSE.""" + torch.manual_seed(0) + W = torch.randn(4, 8) + W_q = W + 0.1 * torch.randn_like(W) + A = torch.randn(8, 16) + nmse = compute_output_nmse(W, W_q, A) + assert nmse > 0.0 + + +# --------------------------------------------------------------------------- +# TestZsicQuantize +# --------------------------------------------------------------------------- + + +class TestZsicQuantize: + """Tests for :func:`zsic_quantize`.""" + + @pytest.fixture + def setup(self): + torch.manual_seed(123) + a, n, T = 16, 8, 64 + W = torch.randn(a, n, dtype=torch.float64) + A = torch.randn(n, T, dtype=torch.float64) + Sigma_X, L, _ = _compute_hessian_cholesky(A, sort_cols=False) + alpha = 0.5 / L.diag() + return W, A, alpha, Sigma_X, L + + def test_produces_valid_output(self, setup): + """Output should have correct shape, positive rate, and NMSE in (0, 1).""" + W, A, alpha, Sigma_X, L = setup + W_hat, rate, nmse, Z, gamma = zsic_quantize(W, A, alpha, Sigma_X, L, use_lmmse=False) + + assert W_hat.shape == W.shape + assert rate > 0.0 + assert 0.0 < nmse < 1.0 + assert Z.shape == W.shape + assert gamma.shape == (W.shape[1],) + + def test_lmmse_improves_nmse(self, setup): + """LMMSE correction should reduce (or at least not increase) the NMSE.""" + W, A, alpha, Sigma_X, L = setup + _, _, nmse_no, _, _ = zsic_quantize(W, A, alpha, Sigma_X, L, use_lmmse=False) + _, _, nmse_yes, _, _ = zsic_quantize(W, A, alpha, Sigma_X, L, use_lmmse=True) + + assert nmse_yes <= nmse_no + 1e-8 + + def test_rescaler_produces_valid_low_nmse(self, setup): + """Rescaler path with n_rescaler_iters=5 should produce valid output with low NMSE.""" + W, A, alpha, Sigma_X, L = setup + W_hat, rate, nmse, Z, gamma = zsic_quantize( + W, + A, + alpha, + Sigma_X, + L, + use_lmmse=True, + n_rescaler_iters=5, + ) + + assert W_hat.shape == W.shape + assert rate > 0.0 + # The rescaler path should still achieve reasonable NMSE (< 0.5). + assert 0.0 < nmse < 0.5 + # Reconstruction should be meaningfully close to original. + assert (W - W_hat).norm() < W.norm() + + +# --------------------------------------------------------------------------- +# TestWatersicQuantize +# --------------------------------------------------------------------------- + + +class TestWatersicQuantize: + """Tests for :func:`watersic_quantize`.""" + + @pytest.fixture + def data(self): + torch.manual_seed(7) + a, n, T = 32, 12, 100 + W = torch.randn(a, n, dtype=torch.float64) + A = torch.randn(n, T, dtype=torch.float64) + return W, A + + def test_basic_quantization(self, data): + """Should return valid W_hat, rate, and NMSE.""" + W, A = data + W_hat, rate, nmse, Z, gamma = watersic_quantize(W, A, c=0.5) + assert W_hat.shape == W.shape + assert rate > 0.0 + assert nmse > 0.0 + assert Z.shape == W.shape + assert gamma.shape == (W.shape[1],) + + def test_smaller_c_gives_higher_rate(self, data): + """Smaller c should produce finer quantization and a higher coding rate.""" + W, A = data + _, rate_large, _, _, _ = watersic_quantize(W, A, c=2.0) + _, rate_small, _, _, _ = watersic_quantize(W, A, c=0.1) + assert rate_small > rate_large + + def test_permutation_roundtrip(self, data): + """Columns should be correctly un-permuted so W_hat is in the original order.""" + W, A = data + W_hat, _, _, _, _ = watersic_quantize(W, A, c=0.5) + # The reconstruction error should be smaller than the weight norm + # (i.e. it's not just garbage / misaligned columns). + assert (W - W_hat).norm() < W.norm() + + +# --------------------------------------------------------------------------- +# TestBinarySearchC +# --------------------------------------------------------------------------- + + +class TestBinarySearchC: + """Tests for :func:`binary_search_c`.""" + + def test_achieves_target_rate(self): + """The returned *c* should achieve a rate within 1.0 bit of the target.""" + torch.manual_seed(99) + a, n, T = 32, 10, 80 + W = torch.randn(a, n, dtype=torch.float64) + A = torch.randn(n, T, dtype=torch.float64) + + target = 4.0 + # Use sample_frac=1.0 so the search operates on the same rows as the + # verification call (the matrix is small, so subsampling would cause + # a large mismatch between search and evaluation rates). + c = binary_search_c(W, A, target_rate=target, sample_frac=1.0) + + # Evaluate at full size to verify. + _, rate, _, _, _ = watersic_quantize(W, A, c) + assert abs(rate - target) < 1.0 + + +# --------------------------------------------------------------------------- +# KV Quantizer Helper tests +# --------------------------------------------------------------------------- + +from modelopt.torch.quantization.algorithms.watersic_kv.helper import ( + WaterSICKVState, + _compute_importance_weights, +) + +# --------------------------------------------------------------------------- +# TestComputeImportanceWeights +# --------------------------------------------------------------------------- + + +class TestComputeImportanceWeights: + """Tests for :func:`_compute_importance_weights`.""" + + def test_uniform_attention_gives_uniform_weights(self): + """Uniform attention matrix should produce equal importance weights.""" + N = 16 + P = torch.ones(8, N) / N # uniform over tokens + sqrt_w = _compute_importance_weights(P) + + assert sqrt_w.shape == (N, 1) + # All weights should be identical (since input is uniform). + assert torch.allclose(sqrt_w, sqrt_w[0].expand_as(sqrt_w)) + + def test_peaked_attention_gives_high_weight(self): + """When all attention is on token 0, token 0 should have the highest weight.""" + N = 16 + P = torch.zeros(8, N) + P[:, 0] = 1.0 # all attention on token 0 + sqrt_w = _compute_importance_weights(P) + + assert sqrt_w.shape == (N, 1) + # Token 0 should have the largest weight. + assert sqrt_w[0, 0] == sqrt_w.max() + + def test_clipping(self): + """Clipping should limit the maximum importance weight.""" + N = 16 + P = torch.zeros(8, N) + P[:, 0] = 1.0 # all attention on token 0 + clip = 10.0 + sqrt_w = _compute_importance_weights(P, importance_clip=clip) + + import math + + assert sqrt_w.max().item() <= math.sqrt(clip) + 1e-6 + + +# --------------------------------------------------------------------------- +# TestWaterSICKVState +# --------------------------------------------------------------------------- + + +class TestWaterSICKVState: + """Tests for :class:`WaterSICKVState`.""" + + def test_state_creation(self): + """State dataclass should store all fields correctly.""" + Z = torch.randint(0, 10, (4, 32, 16)) + alpha = torch.randn(4, 16) + gamma = torch.randn(4, 16) + state = WaterSICKVState(Z=Z, alpha=alpha, gamma=gamma, perm=None, rate=2.5) + + assert state.Z is Z + assert state.alpha is alpha + assert state.gamma is gamma + assert state.perm is None + assert state.rate == 2.5 + + +# --------------------------------------------------------------------------- +# TestWaterSICKVCalibConfig +# --------------------------------------------------------------------------- + + +class TestWaterSICKVCalibConfig: + def test_defaults(self): + from modelopt.torch.quantization.algorithms.watersic_kv.config import WaterSICKVCalibConfig + + cfg = WaterSICKVCalibConfig() + assert cfg.method == "watersic_kv" + assert cfg.target_rate == 2.0 + assert cfg.kl_aware is False + assert cfg.use_lmmse is True + assert cfg.use_sequential is False + + def test_custom_values(self): + from modelopt.torch.quantization.algorithms.watersic_kv.config import WaterSICKVCalibConfig + + cfg = WaterSICKVCalibConfig(target_rate=4.0, kl_aware=True, importance_clip=20.0) + assert cfg.target_rate == 4.0 + assert cfg.kl_aware is True + assert cfg.importance_clip == 20.0 + + def test_serialization_roundtrip(self): + from modelopt.torch.quantization.algorithms.watersic_kv.config import WaterSICKVCalibConfig + + cfg = WaterSICKVCalibConfig(target_rate=3.0, kl_aware=True) + data = cfg.model_dump() + cfg2 = WaterSICKVCalibConfig(**data) + assert cfg2.target_rate == 3.0 + assert cfg2.kl_aware is True