3232
3333from .zsic import _compute_hessian_cholesky , binary_search_c , damp_for_rate , watersic_quantize
3434
35- # ---------------------------------------------------------------------------
36- # Data structures
37- # ---------------------------------------------------------------------------
38-
3935
4036@dataclass
4137class WaterSICKVState :
@@ -53,11 +49,6 @@ class WaterSICKVState:
5349 """Achieved coding rate (bits per element)."""
5450
5551
56- # ---------------------------------------------------------------------------
57- # Importance weighting
58- # ---------------------------------------------------------------------------
59-
60-
6152def _compute_importance_weights (P : Tensor , importance_clip : float = 50.0 ) -> Tensor :
6253 """Derive per-token importance weights from an attention probability matrix.
6354
@@ -90,63 +81,6 @@ def _compute_importance_weights(P: Tensor, importance_clip: float = 50.0) -> Ten
9081 return w .sqrt ().unsqueeze (1 ) # (N, 1)
9182
9283
93- # ---------------------------------------------------------------------------
94- # KL divergence in logit space
95- # ---------------------------------------------------------------------------
96-
97-
98- def kl_divergence_logits (
99- Q : Tensor ,
100- K : Tensor ,
101- K_q : Tensor ,
102- temperature : float = 1.0 ,
103- ) -> float :
104- """Compute the KL divergence between attention distributions induced by *K* and *K_q*.
105-
106- Uses the logit identity to avoid materialising the full attention matrix:
107-
108- KL(P || P_q) = E_x[ P^T (s - s_q) + logsumexp(s_q) - logsumexp(s) ]
109-
110- where ``s = Q K^T / temperature`` and ``s_q = Q K_q^T / temperature``.
111-
112- Parameters
113- ----------
114- Q : Tensor (..., S, D)
115- K : Tensor (..., N, D)
116- K_q : Tensor (..., N, D)
117- temperature : float
118-
119- Returns:
120- -------
121- kl : float
122- Mean KL divergence in **bits** (i.e. divided by ln 2).
123- """
124- Q64 = Q .double ()
125- K64 = K .double ()
126- Kq64 = K_q .double ()
127-
128- s = Q64 @ K64 .transpose (- 2 , - 1 ) / temperature # (..., S, N)
129- s_q = Q64 @ Kq64 .transpose (- 2 , - 1 ) / temperature # (..., S, N)
130-
131- log_Z = torch .logsumexp (s , dim = - 1 ) # (..., S)
132- log_Z_q = torch .logsumexp (s_q , dim = - 1 ) # (..., S)
133-
134- P = torch .softmax (s , dim = - 1 ) # (..., S, N)
135-
136- # KL per query position: sum_n P_n (s_n - s_q_n) + log_Z_q - log_Z
137- kl_per_query = (P * (s - s_q )).sum (dim = - 1 ) + log_Z_q - log_Z # (..., S)
138-
139- # Convert nats to bits and return mean.
140- import math
141-
142- return (kl_per_query .mean () / math .log (2 )).item ()
143-
144-
145- # ---------------------------------------------------------------------------
146- # WaterSICKVHelper
147- # ---------------------------------------------------------------------------
148-
149-
15084class WaterSICKVHelper :
15185 """Hook-based helper that captures Q/K activations and runs WaterSIC quantisation.
15286
@@ -178,8 +112,6 @@ def __init__(
178112
179113 self ._original_fn = None
180114
181- # ----- patching --------------------------------------------------
182-
183115 def setup (self ):
184116 """Patch ``_quantized_attention`` on the module instance to capture Q/K."""
185117 # The original is a @staticmethod on the class - grab the underlying function.
@@ -220,8 +152,6 @@ def cleanup(self):
220152 if "_quantized_attention" in vars (self .module ):
221153 delattr (self .module , "_quantized_attention" )
222154
223- # ----- quantisation -----------------------------------------------
224-
225155 def quantize (
226156 self ,
227157 target_rate : float = 4.0 ,
@@ -246,6 +176,13 @@ def quantize(
246176 -------
247177 WaterSICKVState
248178 """
179+ if not self .collected_Q or not self .collected_K :
180+ raise RuntimeError (
181+ f"[{ self .name } ] No Q/K activations were collected during the calibration "
182+ f"forward pass. Ensure setup() was called before the forward loop and that "
183+ f"the forward loop passes data through this attention layer."
184+ )
185+
249186 # Concatenate collected activations across calibration batches.
250187 # Each tensor is (batch, n_heads, seq, d_head).
251188 Q_all = torch .cat (self .collected_Q , dim = 0 ) # (B_total, H, S_q, D)
@@ -262,14 +199,17 @@ def quantize(
262199
263200 damp_pct = damp_for_rate (target_rate )
264201
202+ # Run quantization on GPU if available (much faster for real models).
203+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
204+
265205 for h in range (H ):
266206 # K_h shape: (B, S_k, D) → treat as weight matrix (a, n) where
267207 # a = B * S_k (token-batch dimension) and n = D (head dimension).
268- K_h = K_all [:, h , :, :].reshape (- 1 , D ).double () # (B*S_k, D )
208+ K_h = K_all [:, h , :, :].reshape (- 1 , D ).to ( device = device , dtype = torch . float64 )
269209
270210 # Activation matrix: use Q_h^T so the Hessian reflects query-key
271211 # interaction. A shape: (D, B*S_q).
272- Q_h = Q_all [:, h , :, :].reshape (- 1 , D ).double () # (B*S_q, D )
212+ Q_h = Q_all [:, h , :, :].reshape (- 1 , D ).to ( device = device , dtype = torch . float64 )
273213 A = Q_h .T # (D, B*S_q)
274214
275215 # Optional importance weighting — scale K rows (not A) so that
@@ -320,19 +260,26 @@ def quantize(
320260 # Recover per-head state.
321261 # alpha = c / L.diag() (same as inside watersic_quantize).
322262 alpha_h = (c / L .diag ()).float ()
323-
324- Z_heads .append (Z_h )
325- alpha_heads .append (alpha_h )
326- gamma_heads .append (gamma_h .float ())
327- perm_heads .append (perm )
263+ if perm is not None :
264+ inv_perm = torch .argsort (perm )
265+ alpha_h = alpha_h [inv_perm ]
266+
267+ # Move results to CPU to free GPU memory for next head.
268+ Z_heads .append (Z_h .cpu ())
269+ alpha_heads .append (alpha_h .cpu ())
270+ gamma_heads .append (gamma_h .float ().cpu ())
271+ perm_heads .append (perm .cpu () if perm is not None else None )
328272 rates .append (rate )
329273
274+ if torch .cuda .is_available ():
275+ torch .cuda .empty_cache ()
276+
330277 mean_rate = sum (rates ) / len (rates ) if rates else 0.0
331278
332279 state = WaterSICKVState (
333- Z = torch .stack (Z_heads ), # (H, B*S_k, D)
334- alpha = torch .stack (alpha_heads ), # (H, D)
335- gamma = torch .stack (gamma_heads ), # (H, D)
280+ Z = torch .stack (Z_heads ),
281+ alpha = torch .stack (alpha_heads ),
282+ gamma = torch .stack (gamma_heads ),
336283 perm = torch .stack (perm_heads ) if perm_heads and perm_heads [0 ] is not None else None ,
337284 rate = mean_rate ,
338285 )
@@ -342,8 +289,6 @@ def quantize(
342289
343290 return state
344291
345- # ----- cleanup -----------------------------------------------------
346-
347292 def free (self ):
348293 """Release collected calibration data."""
349294 self .collected_Q .clear ()
0 commit comments