Skip to content

Commit b75813b

Browse files
brettTheTom
authored andcommitted
fix: V-norm in memory_stats, SeedSequence PRNG, MSE compressed_size_bits
Subset of @brosequist's #90 commit 0fd5de9 — keeping the actual fixes, deferring the streaming + serialization API surface until a production caller exists. Included: - KVCacheCompressor.memory_stats() was omitting the float32 norm stored per V vector, inflating reported compression ratio. Adds v_bits_total += n_vectors * 32. - TurboQuantMSE.compressed_size_bits() — was missing (TurboQuant already had it). - Replaces seed + 1000 magic offset with np.random.SeedSequence(seed).spawn(2) for true PRNG independence between PolarQuant and QJL stages, and between K and V quantizers. Deferred (not in this commit): - compress_token() / get_compressed_cache() streaming API - CompressedVector.to_bytes() / from_bytes() binary serialization - CompressedKVCache.save() / load() npz serialization
1 parent 1224fef commit b75813b

3 files changed

Lines changed: 39 additions & 17 deletions

File tree

tests/test_kv_cache.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,9 @@ def test_memory_stats(self):
102102
compressor = KVCacheCompressor(head_dim=128, k_bits=3, v_bits=3)
103103
stats = compressor.memory_stats(seq_len=1024, num_layers=32, num_heads=32)
104104

105-
# K: 3 bits/val + norm overhead, V: 3 bits/val
106-
# Ratio vs fp16 (16 bits): 16 / ((3+3)/2 + overhead) ≈ 2.5-3x
105+
# K: 3 bits/val + 32-bit norm, V: 3 bits/val + 32-bit norm
106+
# Both K and V include per-vector norm (float32) for rescaling.
107+
# Ratio vs fp16 (16 bits/val): 16*128 / (128*3 + 32 + 128*3 + 32) / 2 ≈ 2.46x
107108
assert stats["compression_ratio"] > 2.0
108109
assert stats["compressed_mb"] < stats["original_mb"]
109110

@@ -125,6 +126,7 @@ def test_metadata_stored(self):
125126
assert compressed.v_bit_width == 3
126127

127128

129+
128130
def _softmax(x):
129131
"""Simple softmax for testing."""
130132
e = np.exp(x - np.max(x, axis=-1, keepdims=True))

turboquant/kv_cache.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,6 @@ class KVCacheCompressor:
4747
4848
# Decompress
4949
k_hat, v_hat = compressor.decompress(compressed)
50-
51-
# Or compress streaming (one token at a time)
52-
compressor.compress_token(k_vec, v_vec, layer=0, head=0)
5350
"""
5451

5552
def __init__(
@@ -71,14 +68,20 @@ def __init__(
7168
self.k_bits = k_bits
7269
self.v_bits = v_bits
7370

71+
# Spawn independent child seeds so K and V quantizers use statistically
72+
# independent random streams without magic offset arithmetic.
73+
# Accept either an int or an already-created SeedSequence.
74+
ss = seed if isinstance(seed, np.random.SeedSequence) else np.random.SeedSequence(seed)
75+
k_child, v_child = ss.spawn(2)
76+
7477
# K cache uses full TurboQuant (inner product preservation)
7578
self.k_quantizer = TurboQuant(
76-
head_dim, bit_width=k_bits, seed=seed, norm_correction=norm_correction,
79+
head_dim, bit_width=k_bits, seed=k_child, norm_correction=norm_correction,
7780
)
7881

7982
# V cache uses MSE-only PolarQuant (value reconstruction)
8083
self.v_quantizer = TurboQuantMSE(
81-
head_dim, bit_width=v_bits, seed=seed + 500, norm_correction=norm_correction,
84+
head_dim, bit_width=v_bits, seed=v_child, norm_correction=norm_correction,
8285
)
8386

8487
def compress(self, k_cache: np.ndarray, v_cache: np.ndarray) -> CompressedKVCache:
@@ -160,8 +163,8 @@ def memory_stats(self, seq_len: int, num_layers: int, num_heads: int) -> dict:
160163

161164
# K: b bits per coord + 32-bit norm
162165
k_bits_total = n_vectors * (self.head_dim * self.k_bits + 32)
163-
# V: b bits per coord (no norm needed for MSE-only)
164-
v_bits_total = n_vectors * self.head_dim * self.v_bits
166+
# V: b bits per coord + 32-bit norm (PolarQuant stores per-vector norm for rescaling)
167+
v_bits_total = n_vectors * self.head_dim * self.v_bits + n_vectors * 32
165168

166169
compressed_bytes = (k_bits_total + v_bits_total) / 8
167170

turboquant/turboquant.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
@dataclass
2020
class CompressedVector:
2121
"""Container for a TurboQuant-compressed vector."""
22-
mse_indices: np.ndarray # (d,) or (batch, d) — PolarQuant indices, (b-1)-bit integers
23-
vector_norms: np.ndarray # scalar or (batch,) — original ||x||_2 for rescaling
24-
qjl_signs: np.ndarray # (d,) or (batch, d) — QJL sign bits, int8 {+1, -1}
25-
residual_norms: np.ndarray # scalar or (batch,) — ||residual||_2
26-
bit_width: int # total bits per coordinate
22+
mse_indices: np.ndarray # (d,) or (batch, d) — PolarQuant indices, (b-1)-bit integers
23+
vector_norms: np.ndarray # scalar or (batch,) — original ||x||_2 for rescaling
24+
qjl_signs: np.ndarray # (d,) or (batch, d) — QJL sign bits, int8 {+1, -1}
25+
residual_norms: np.ndarray # scalar or (batch,) — ||residual||_2
26+
bit_width: int # total bits per coordinate
2727

2828

2929
class TurboQuant:
@@ -54,13 +54,19 @@ def __init__(self, d: int, bit_width: int, seed: int = 42, norm_correction: bool
5454
self.d = d
5555
self.bit_width = bit_width
5656

57+
# Spawn independent child seeds from a SeedSequence so PolarQuant and QJL
58+
# use statistically independent random streams without magic offset arithmetic.
59+
# Accept either an int or an already-created SeedSequence (e.g. from a parent spawner).
60+
ss = seed if isinstance(seed, np.random.SeedSequence) else np.random.SeedSequence(seed)
61+
pq_child, qjl_child = ss.spawn(2)
62+
5763
# Stage 1: PolarQuant at (b-1) bits
5864
self.polar_quant = PolarQuant(
59-
d, bit_width=bit_width - 1, seed=seed, norm_correction=norm_correction,
65+
d, bit_width=bit_width - 1, seed=pq_child, norm_correction=norm_correction,
6066
)
6167

62-
# Stage 2: QJL for residual (uses different seed)
63-
self.qjl = QJL(d, seed=seed + 1000)
68+
# Stage 2: QJL for residual (independent seed stream)
69+
self.qjl = QJL(d, seed=qjl_child)
6470

6571
def quantize(self, x: np.ndarray) -> CompressedVector:
6672
"""Quantize a vector or batch.
@@ -148,3 +154,14 @@ def quantize(self, x: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
148154

149155
def dequantize(self, indices: np.ndarray, norms: np.ndarray) -> np.ndarray:
150156
return self.polar_quant.dequantize(indices, norms)
157+
158+
def compressed_size_bits(self, n_vectors: int) -> int:
159+
"""Compute total storage in bits for n_vectors compressed vectors.
160+
161+
Includes:
162+
- PolarQuant indices: b bits per coordinate per vector
163+
- Norms: 32 bits (float32) per vector (stored for per-vector rescaling)
164+
"""
165+
per_vector = self.d * self.bit_width
166+
norms = 32 # float32 per vector
167+
return n_vectors * (per_vector + norms)

0 commit comments

Comments
 (0)