3030
3131from modelopt .torch .utils import print_rank_0
3232
33- from .zsic import (
34- _compute_hessian_cholesky ,
35- binary_search_c ,
36- damp_for_rate ,
37- watersic_quantize ,
38- )
33+ from .zsic import _compute_hessian_cholesky , binary_search_c , damp_for_rate , watersic_quantize
3934
4035# ---------------------------------------------------------------------------
4136# Data structures
@@ -76,7 +71,7 @@ def _compute_importance_weights(P: Tensor, importance_clip: float = 50.0) -> Ten
7671 Clamp the normalised weights to ``[1/clip, clip]`` to prevent
7772 extreme outliers.
7873
79- Returns
74+ Returns:
8075 -------
8176 sqrt_w : Tensor (N, 1)
8277 Square-root importance weights, suitable for left-multiplying the
@@ -121,7 +116,7 @@ def kl_divergence_logits(
121116 K_q : Tensor (..., N, D)
122117 temperature : float
123118
124- Returns
119+ Returns:
125120 -------
126121 kl : float
127122 Mean KL divergence in **bits** (i.e. divided by ln 2).
@@ -130,10 +125,10 @@ def kl_divergence_logits(
130125 K64 = K .double ()
131126 Kq64 = K_q .double ()
132127
133- s = Q64 @ K64 .transpose (- 2 , - 1 ) / temperature # (..., S, N)
128+ s = Q64 @ K64 .transpose (- 2 , - 1 ) / temperature # (..., S, N)
134129 s_q = Q64 @ Kq64 .transpose (- 2 , - 1 ) / temperature # (..., S, N)
135130
136- log_Z = torch .logsumexp (s , dim = - 1 ) # (..., S)
131+ log_Z = torch .logsumexp (s , dim = - 1 ) # (..., S)
137132 log_Z_q = torch .logsumexp (s_q , dim = - 1 ) # (..., S)
138133
139134 P = torch .softmax (s , dim = - 1 ) # (..., S, N)
@@ -172,6 +167,7 @@ def __init__(
172167 kl_aware : bool = False ,
173168 importance_clip : float = 50.0 ,
174169 ):
170+ """Initialize helper for a single attention module."""
175171 self .module = module
176172 self .name = name
177173 self .kl_aware = kl_aware
@@ -186,7 +182,7 @@ def __init__(
186182
187183 def setup (self ):
188184 """Patch ``_quantized_attention`` on the module instance to capture Q/K."""
189- # The original is a @staticmethod on the class – grab the underlying function.
185+ # The original is a @staticmethod on the class - grab the underlying function.
190186 original_fn = type (self .module )._quantized_attention
191187 self ._original_fn = original_fn
192188
@@ -231,7 +227,7 @@ def quantize(
231227 target_rate : float = 4.0 ,
232228 use_lmmse : bool = True ,
233229 n_rescaler_iters : int = 0 ,
234- sample_frac : float = 0.1 ,
230+ sample_frac : float | None = None ,
235231 ) -> WaterSICKVState :
236232 """Run WaterSIC quantisation on the collected key activations.
237233
@@ -246,7 +242,7 @@ def quantize(
246242 sample_frac : float
247243 Fraction of rows used by :func:`binary_search_c`.
248244
249- Returns
245+ Returns:
250246 -------
251247 WaterSICKVState
252248 """
@@ -291,14 +287,16 @@ def quantize(
291287 _ , L , perm = precomputed
292288
293289 # Binary search for the scale factor c.
290+ n_tokens = K_h .shape [0 ]
291+ sf = sample_frac if sample_frac is not None else min (0.1 , 1000.0 / max (n_tokens , 1 ))
294292 c = binary_search_c (
295293 K_h ,
296294 A ,
297295 target_rate = target_rate ,
298296 damp_pct = damp_pct ,
299297 use_lmmse = use_lmmse ,
300298 n_rescaler_iters = n_rescaler_iters ,
301- sample_frac = sample_frac ,
299+ sample_frac = sf ,
302300 _precomputed = precomputed ,
303301 )
304302
@@ -317,9 +315,7 @@ def quantize(
317315 if sqrt_w is not None :
318316 W_hat = W_hat / sqrt_w
319317
320- print_rank_0 (
321- f" [{ self .name } ] head { h } : rate={ rate :.2f} bpe, nmse={ nmse :.4f} "
322- )
318+ print_rank_0 (f" [{ self .name } ] head { h } : rate={ rate :.2f} bpe, nmse={ nmse :.4f} " )
323319
324320 # Recover per-head state.
325321 # alpha = c / L.diag() (same as inside watersic_quantize).
@@ -334,9 +330,9 @@ def quantize(
334330 mean_rate = sum (rates ) / len (rates ) if rates else 0.0
335331
336332 state = WaterSICKVState (
337- Z = torch .stack (Z_heads ), # (H, B*S_k, D)
338- alpha = torch .stack (alpha_heads ), # (H, D)
339- gamma = torch .stack (gamma_heads ), # (H, D)
333+ Z = torch .stack (Z_heads ), # (H, B*S_k, D)
334+ alpha = torch .stack (alpha_heads ), # (H, D)
335+ gamma = torch .stack (gamma_heads ), # (H, D)
340336 perm = torch .stack (perm_heads ) if perm_heads and perm_heads [0 ] is not None else None ,
341337 rate = mean_rate ,
342338 )
0 commit comments