Skip to content

Commit 92d62c9

Browse files
committed
Fix TurboQuant KV zeroed by low-mem export (993cff5): _is_kv_buffer only frees genuinely all-zero kv_cache.* buffers (count_nonzero==0); preserves TQ4 centroids/boundaries/rotation/rotation_T
1 parent 993cff5 commit 92d62c9

1 file changed

Lines changed: 23 additions & 5 deletions

File tree

backends/cuda/cuda_backend.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,29 @@ def _codegen_device_target_aware(self, device):
166166

167167

168168
def _is_kv_buffer(name, v) -> bool:
169-
return (
170-
isinstance(v, torch.Tensor)
171-
and not isinstance(v, torch.nn.Parameter)
172-
and "kv_cache" in name
173-
)
169+
"""True only for an actual KV-cache *content* buffer that is safe to free.
170+
171+
The low-memory path (``_move_to_device_resize_kv``) frees every buffer this
172+
matches and re-synthesizes it as ZEROS in both the lifted graph and the
173+
serialized ``.ptd`` (see ``_full_zeros_preserving_strides`` /
174+
``_get_const_synthesize_zeros``). That is only valid for genuine KV *content*,
175+
which is all-zeros at export time (caches start empty).
176+
177+
It must NOT match the non-zero constants that some KV-cache modules register
178+
alongside the cache — e.g. TurboQuant registers its codebook/rotation
179+
(``centroids``/``boundaries``/``rotation``/``rotation_T``) as buffers on the
180+
``kv_cache`` module, so their FQNs also contain ``kv_cache``. Freeing+zeroing
181+
those silently corrupts the serialized model (TQ4 dequant -> 0 -> garbage).
182+
Gate on the buffer actually being all-zeros so only empty KV content is freed;
183+
this is robust to any future constant name (a non-zero buffer is never freed).
184+
"""
185+
if not isinstance(v, torch.Tensor) or isinstance(v, torch.nn.Parameter):
186+
return False
187+
if "kv_cache" not in name or v.numel() == 0 or v.is_meta:
188+
return False
189+
# Only the genuinely all-zero KV content may be freed + re-zeroed; non-zero
190+
# constants (TurboQuant centroids/rotation/...) must be preserved as-is.
191+
return bool(torch.count_nonzero(v) == 0)
174192

175193

176194
def _empty_strided_on_device(v, location):

0 commit comments

Comments
 (0)