Skip to content

Commit 296ac08

Browse files
committed
Fix pre-commit issues, circular import, and sample_frac handling
- Fix circular import in mode.py by using lazy import for WaterSICKVCalibConfig - Fix sample_frac=None propagation to binary_search_c - Fix ruff lint (en-dash, missing docstring) and formatting - All pre-commit hooks pass (ruff, mypy, bandit, license) Signed-off-by: Kai Xu <kaix@nvidia.com>
1 parent c4f20b2 commit 296ac08

File tree

6 files changed

+84
-68
lines changed

6 files changed

+84
-68
lines changed

modelopt/torch/quantization/algorithms/watersic_kv/kv_quantizer.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,7 @@
3030

3131
from modelopt.torch.utils import print_rank_0
3232

33-
from .zsic import (
34-
_compute_hessian_cholesky,
35-
binary_search_c,
36-
damp_for_rate,
37-
watersic_quantize,
38-
)
33+
from .zsic import _compute_hessian_cholesky, binary_search_c, damp_for_rate, watersic_quantize
3934

4035
# ---------------------------------------------------------------------------
4136
# Data structures
@@ -76,7 +71,7 @@ def _compute_importance_weights(P: Tensor, importance_clip: float = 50.0) -> Ten
7671
Clamp the normalised weights to ``[1/clip, clip]`` to prevent
7772
extreme outliers.
7873
79-
Returns
74+
Returns:
8075
-------
8176
sqrt_w : Tensor (N, 1)
8277
Square-root importance weights, suitable for left-multiplying the
@@ -121,7 +116,7 @@ def kl_divergence_logits(
121116
K_q : Tensor (..., N, D)
122117
temperature : float
123118
124-
Returns
119+
Returns:
125120
-------
126121
kl : float
127122
Mean KL divergence in **bits** (i.e. divided by ln 2).
@@ -130,10 +125,10 @@ def kl_divergence_logits(
130125
K64 = K.double()
131126
Kq64 = K_q.double()
132127

133-
s = Q64 @ K64.transpose(-2, -1) / temperature # (..., S, N)
128+
s = Q64 @ K64.transpose(-2, -1) / temperature # (..., S, N)
134129
s_q = Q64 @ Kq64.transpose(-2, -1) / temperature # (..., S, N)
135130

136-
log_Z = torch.logsumexp(s, dim=-1) # (..., S)
131+
log_Z = torch.logsumexp(s, dim=-1) # (..., S)
137132
log_Z_q = torch.logsumexp(s_q, dim=-1) # (..., S)
138133

139134
P = torch.softmax(s, dim=-1) # (..., S, N)
@@ -172,6 +167,7 @@ def __init__(
172167
kl_aware: bool = False,
173168
importance_clip: float = 50.0,
174169
):
170+
"""Initialize helper for a single attention module."""
175171
self.module = module
176172
self.name = name
177173
self.kl_aware = kl_aware
@@ -186,7 +182,7 @@ def __init__(
186182

187183
def setup(self):
188184
"""Patch ``_quantized_attention`` on the module instance to capture Q/K."""
189-
# The original is a @staticmethod on the class grab the underlying function.
185+
# The original is a @staticmethod on the class - grab the underlying function.
190186
original_fn = type(self.module)._quantized_attention
191187
self._original_fn = original_fn
192188

@@ -231,7 +227,7 @@ def quantize(
231227
target_rate: float = 4.0,
232228
use_lmmse: bool = True,
233229
n_rescaler_iters: int = 0,
234-
sample_frac: float = 0.1,
230+
sample_frac: float | None = None,
235231
) -> WaterSICKVState:
236232
"""Run WaterSIC quantisation on the collected key activations.
237233
@@ -246,7 +242,7 @@ def quantize(
246242
sample_frac : float
247243
Fraction of rows used by :func:`binary_search_c`.
248244
249-
Returns
245+
Returns:
250246
-------
251247
WaterSICKVState
252248
"""
@@ -291,14 +287,16 @@ def quantize(
291287
_, L, perm = precomputed
292288

293289
# Binary search for the scale factor c.
290+
n_tokens = K_h.shape[0]
291+
sf = sample_frac if sample_frac is not None else min(0.1, 1000.0 / max(n_tokens, 1))
294292
c = binary_search_c(
295293
K_h,
296294
A,
297295
target_rate=target_rate,
298296
damp_pct=damp_pct,
299297
use_lmmse=use_lmmse,
300298
n_rescaler_iters=n_rescaler_iters,
301-
sample_frac=sample_frac,
299+
sample_frac=sf,
302300
_precomputed=precomputed,
303301
)
304302

@@ -317,9 +315,7 @@ def quantize(
317315
if sqrt_w is not None:
318316
W_hat = W_hat / sqrt_w
319317

320-
print_rank_0(
321-
f" [{self.name}] head {h}: rate={rate:.2f} bpe, nmse={nmse:.4f}"
322-
)
318+
print_rank_0(f" [{self.name}] head {h}: rate={rate:.2f} bpe, nmse={nmse:.4f}")
323319

324320
# Recover per-head state.
325321
# alpha = c / L.diag() (same as inside watersic_quantize).
@@ -334,9 +330,9 @@ def quantize(
334330
mean_rate = sum(rates) / len(rates) if rates else 0.0
335331

336332
state = WaterSICKVState(
337-
Z=torch.stack(Z_heads), # (H, B*S_k, D)
338-
alpha=torch.stack(alpha_heads), # (H, D)
339-
gamma=torch.stack(gamma_heads), # (H, D)
333+
Z=torch.stack(Z_heads), # (H, B*S_k, D)
334+
alpha=torch.stack(alpha_heads), # (H, D)
335+
gamma=torch.stack(gamma_heads), # (H, D)
340336
perm=torch.stack(perm_heads) if perm_heads and perm_heads[0] is not None else None,
341337
rate=mean_rate,
342338
)

modelopt/torch/quantization/algorithms/watersic_kv/zsic.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,13 @@ def watersic_quantize(
283283
alpha = c / L.diag()
284284

285285
W_hat, rate, nmse, Z, gamma = zsic_quantize(
286-
W, A, alpha, Sigma_X, L, use_lmmse=use_lmmse, n_rescaler_iters=n_rescaler_iters
286+
W,
287+
A,
288+
alpha,
289+
Sigma_X,
290+
L,
291+
use_lmmse=use_lmmse,
292+
n_rescaler_iters=n_rescaler_iters,
287293
)
288294

289295
# Undo permutation.

modelopt/torch/quantization/mode.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from modelopt.torch.opt.searcher import ForwardLoop
3333

3434
from .compress import compress_convert, compress_restore, update_compress_metadata
35-
from .algorithms.watersic_kv.config import WaterSICKVCalibConfig
3635
from .config import (
3736
AWQClipCalibConfig,
3837
AWQFullCalibConfig,
@@ -513,6 +512,8 @@ class WaterSICKVModeDescriptor(BaseCalibrateModeDescriptor):
513512
@property
514513
def config_class(self) -> type[QuantizeAlgorithmConfig]:
515514
"""Specifies the config class for the mode."""
515+
from .algorithms.watersic_kv.config import WaterSICKVCalibConfig
516+
516517
return WaterSICKVCalibConfig
517518

518519
_calib_func = watersic_kv

0 commit comments

Comments
 (0)