Skip to content

Commit 21c1fc7

Browse files
TheTomclaude
andcommitted
experiments/rht-k-sweep: add production codebook + 4-shape sensitivity tests
Two follow-ups to the v3 synthetic refutation of Basat 2026: 1. sensitivity_check.py — 4 sub-Gaussian source shapes at kurt ≈ -1.5 confirm the v3 finding is robust across the distribution family: bernoulli+noise (bimodal): +20.7% MSE, +22.1% KL, cat 26%→76% uniform on [-a,+a] (flat): +29.9% MSE, +30.6% KL, cat 25%→87% uniform/gaussian mix: +29.9% MSE, +30.1% KL, cat 24%→86% truncated gaussian: +193.6% MSE, +195% KL, cat 27%→100% Direction is identical (more RHTs hurts). Magnitude varies with shape. 2. production_codebook.py — the *actual* production turbo4 codebook from ggml/src/ggml-cuda/turbo-quant.cuh (16 centroids, extremes ±0.174, per-128-element L2 normalization matching QK_TURBO4=128). This is what the fork ships, not Lloyd-Max-on-Gaussian. At §3 layer-0 conditions (kurt -1.52, KS 0.14): k_extra=0: MSE 6.40e-5, KL 2.52e-7, catastrophic 25.0% k_extra=1: MSE 1.55e-4 (+141.7%), KL 5.91e-7 (+134.5%), catastrophic 100.0% k_extra=2: MSE 1.51e-4 (+136.5%), KL 5.90e-7 (+134.1%), catastrophic 100.0% The production result is 6x larger than the Lloyd-Max-Gaussian result because production centroids are fit tighter than ±2σ (sub-Gaussian-fitted), so when +1 RHT pushes K to standard Gaussian, the new ±3σ tails clip at the ±0.174 codebook extreme → saturation error → every query becomes catastrophic. This is the production-fork-grade refutation of "more RHTs help KV-cache quantization." Theorem still holds (kurt drifts to 0 as proven); application to TurboQuant prescribed by Basat 2026 strictly degrades downstream quality on real-shipping production codebook. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 0ba7022 commit 21c1fc7

2 files changed

Lines changed: 474 additions & 0 deletions

File tree

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
"""RHT-count sweep using the ACTUAL production turbo4 codebook.
2+
3+
Source: ggml/src/ggml-cuda/turbo-quant.cuh TURBO_CENTROIDS_4BIT (16 levels).
4+
Block normalization: per-128-element L2 (QK_TURBO4 = 128 = head_dim per
5+
ggml-common.h), so each row of K (one head's worth of values for one token)
6+
is one normalization block.
7+
8+
This is the codebook the fork actually ships. v3 used Lloyd-Max-on-Gaussian
9+
(±2.7σ extremes). Production is ±2σ-ish, sub-Gaussian-fitted. Result should
10+
hold (same mechanism: production codebook lucky-aligns with sub-Gaussian K,
11+
+RHT breaks alignment), but the magnitude could differ.
12+
"""
13+
14+
from __future__ import annotations
15+
16+
import numpy as np
17+
from scipy import linalg, stats
18+
19+
d = 128 # = QK_TURBO4 = head_dim, also the L2-norm block size
20+
T = 512
21+
n_trials = 5
22+
n_queries = 64
23+
sigma = 1.0 / np.sqrt(d)
24+
TARGET_KURT = -1.56
25+
26+
27+
# Production turbo4 codebook (16 levels, 4-bit)
28+
# From ggml/src/ggml-cuda/turbo-quant.cuh:
29+
PROD_CENTROIDS_4BIT = np.array([
30+
-0.173926, -0.117195, -0.089527, -0.068756,
31+
-0.051262, -0.035597, -0.020989, -0.006938,
32+
0.006938, 0.020989, 0.035597, 0.051262,
33+
0.068756, 0.089527, 0.117195, 0.173926
34+
])
35+
36+
37+
# ---------------------------------------------------------------------------
38+
# Source K matching §3 layer-0 stats
39+
# ---------------------------------------------------------------------------
40+
41+
def gen_postwht_K(T, d, rng, noise_frac):
42+
bern = rng.choice([-1.0, 1.0], size=(T, d))
43+
noise = rng.normal(0.0, noise_frac, size=(T, d))
44+
raw = bern + noise
45+
return raw * (sigma / np.sqrt(np.var(raw)))
46+
47+
48+
H = linalg.hadamard(d) / np.sqrt(d)
49+
50+
51+
def apply_rht(X, k, seed):
52+
rng = np.random.default_rng(seed)
53+
Y = X.copy()
54+
sign_seqs = []
55+
for _ in range(k):
56+
signs = rng.choice([-1.0, 1.0], size=d).astype(np.float64)
57+
sign_seqs.append(signs)
58+
Y = (Y * signs) @ H.T
59+
return Y, sign_seqs
60+
61+
62+
def invert_rht(Y, sign_seqs):
63+
X = Y.copy()
64+
for signs in reversed(sign_seqs):
65+
X = (X @ H) * signs
66+
return X
67+
68+
69+
# ---------------------------------------------------------------------------
70+
# Production-style block quantization
71+
#
72+
# Per row (one head_dim block of size 128):
73+
# norm = ||K_row||_2 (then divide by sqrt(d) to get rms-scaled centroids
74+
# relative to row magnitude — see kernel: kv = c * norm)
75+
# Then nearest-centroid lookup.
76+
# Dequantize: K_row[i] ≈ centroid[idx[i]] * norm
77+
# ---------------------------------------------------------------------------
78+
79+
def quantize_prod(X, codebook):
80+
"""X: (T, d). Per-row L2 normalize, nearest-centroid, dequantize."""
81+
norms = np.linalg.norm(X, axis=1, keepdims=True) # (T, 1)
82+
norms = np.maximum(norms, 1e-12)
83+
# The kernel does kv = centroid * norm. So centroids in the table are
84+
# scaled so that centroid * norm reconstructs the value. Therefore the
85+
# nearest centroid match is performed against (X / norm).
86+
X_n = X / norms # (T, d), normalized rows
87+
flat = X_n.flatten()
88+
idx = np.argmin(np.abs(flat[:, None] - codebook[None, :]), axis=1)
89+
out_n = codebook[idx].reshape(X_n.shape)
90+
return out_n * norms # dequantize
91+
92+
93+
def softmax(x, axis=-1):
94+
x = x - x.max(axis=axis, keepdims=True)
95+
e = np.exp(x)
96+
return e / e.sum(axis=axis, keepdims=True)
97+
98+
99+
def attn_kl(K_ref, K_test, Q):
100+
sd = np.sqrt(d)
101+
p = softmax(Q @ K_ref.T / sd)
102+
q = softmax(Q @ K_test.T / sd)
103+
eps = 1e-12
104+
return (p * (np.log(p + eps) - np.log(q + eps))).sum(axis=-1)
105+
106+
107+
# ---------------------------------------------------------------------------
108+
# Find noise_frac that matches §3 layer-0 stats with production codebook
109+
# ---------------------------------------------------------------------------
110+
111+
print(f"=== Production turbo4 codebook ===")
112+
print(f" centroids (16 levels): extremes ±{PROD_CENTROIDS_4BIT[-1]:.4f}")
113+
print(f" L2-normalized extreme: ±{PROD_CENTROIDS_4BIT[-1]:.4f} of row norm")
114+
print(f" vs Lloyd-Max-Gaussian extremes: ±0.240 (my v3 codebook)")
115+
print()
116+
117+
# Sweep noise_frac to find §3 layer-0 match
118+
print("Tuning noise_frac to match §3 layer-0 stats...")
119+
print(f" target: kurt = -1.56, KS = 0.155")
120+
print()
121+
chosen = None
122+
for nf in [0.30, 0.32, 0.34, 0.35, 0.36, 0.38, 0.40]:
123+
K_try = gen_postwht_K(T, d, np.random.default_rng(42), nf)
124+
kurt = stats.kurtosis(K_try.flatten())
125+
ks, _ = stats.kstest(K_try.flatten(), 'norm', args=(0, sigma))
126+
print(f" noise_frac={nf:.2f} kurt={kurt:+.3f} KS={ks:.4f}")
127+
if abs(kurt - TARGET_KURT) < 0.05:
128+
chosen = (nf, kurt, ks, K_try)
129+
nf, kurt, ks, K_orig = chosen if chosen else (0.35, *stats.kurtosis(gen_postwht_K(T, d, np.random.default_rng(42), 0.35).flatten()), 0, gen_postwht_K(T, d, np.random.default_rng(42), 0.35))
130+
print()
131+
print(f" selected: noise_frac={nf:.3f}, kurt={kurt:+.3f}, KS={ks:.4f}")
132+
print()
133+
134+
135+
# ---------------------------------------------------------------------------
136+
# Sweep k_extra with production codebook
137+
# ---------------------------------------------------------------------------
138+
139+
print("=== Sweep k_extra on production codebook ===")
140+
print(f"{'k_extra':<10}{'post-kurt':<14}{'post-KS':<12}"
141+
f"{'MSE':<18}{'KL mean':<16}{'KL p99':<16}{'Cat rate':<12}")
142+
print("-" * 100)
143+
144+
results = {}
145+
all_kl_arrays = {}
146+
147+
for k_extra in [0, 1, 2]:
148+
mse_runs = []
149+
kl_per_query: list[np.ndarray] = []
150+
kurt_post = None
151+
ks_post = None
152+
for trial in range(n_trials):
153+
seed = 2000 + 100 * (k_extra + 1) + trial
154+
if k_extra == 0:
155+
K_rot, sign_seqs = K_orig.copy(), []
156+
else:
157+
K_rot, sign_seqs = apply_rht(K_orig, k_extra, seed)
158+
if trial == 0:
159+
kurt_post = stats.kurtosis(K_rot.flatten())
160+
ks_post, _ = stats.kstest(K_rot.flatten(), 'norm', args=(0, sigma))
161+
K_rot_q = quantize_prod(K_rot, PROD_CENTROIDS_4BIT)
162+
K_recon = invert_rht(K_rot_q, sign_seqs) if sign_seqs else K_rot_q
163+
mse_runs.append(np.mean((K_recon - K_orig) ** 2))
164+
q_rng = np.random.default_rng(seed + 50000)
165+
Q = q_rng.normal(0.0, sigma, size=(n_queries, d))
166+
kl_per_query.append(attn_kl(K_orig, K_recon, Q))
167+
168+
mse_mean = float(np.mean(mse_runs))
169+
kl_concat = np.concatenate(kl_per_query)
170+
all_kl_arrays[k_extra] = kl_concat
171+
results[k_extra] = {
172+
"post_kurt": kurt_post,
173+
"post_ks": ks_post,
174+
"mse_mean": mse_mean,
175+
"kl_mean": float(kl_concat.mean()),
176+
"kl_p99": float(np.percentile(kl_concat, 99)),
177+
}
178+
179+
base_median = float(np.median(all_kl_arrays[0]))
180+
for k_extra in [0, 1, 2]:
181+
r = results[k_extra]
182+
cat = float(np.mean(all_kl_arrays[k_extra] >= 1.10 * base_median))
183+
print(
184+
f"{k_extra:<10}{r['post_kurt']:<+14.4f}{r['post_ks']:<12.4f}"
185+
f"{r['mse_mean']:<18.4e}{r['kl_mean']:<16.4e}"
186+
f"{r['kl_p99']:<16.4e}{cat:<12.1%}"
187+
)
188+
189+
print()
190+
print("=== Delta vs k_extra=0 (production baseline) ===")
191+
print(f" k_extra=1 MSE: {(results[1]['mse_mean']/results[0]['mse_mean'] - 1)*100:+.2f}% "
192+
f"KL mean: {(results[1]['kl_mean']/results[0]['kl_mean'] - 1)*100:+.2f}% "
193+
f"KL p99: {(results[1]['kl_p99']/results[0]['kl_p99'] - 1)*100:+.2f}%")
194+
print(f" k_extra=2 MSE: {(results[2]['mse_mean']/results[0]['mse_mean'] - 1)*100:+.2f}% "
195+
f"KL mean: {(results[2]['kl_mean']/results[0]['kl_mean'] - 1)*100:+.2f}% "
196+
f"KL p99: {(results[2]['kl_p99']/results[0]['kl_p99'] - 1)*100:+.2f}%")
197+
198+
# ---------------------------------------------------------------------------
199+
# Side-by-side: production vs Lloyd-Max-on-Gaussian codebook
200+
# ---------------------------------------------------------------------------
201+
202+
print()
203+
print("=== Side-by-side: PROD codebook vs Lloyd-Max-Gaussian (v3) codebook ===")
204+
205+
from scipy.stats import norm
206+
def lloyd_max_gaussian(sig, n_levels=16, n_iter=100):
207+
pts = norm.ppf(np.linspace(0.5/n_levels, 1-0.5/n_levels, n_levels), 0, sig)
208+
for _ in range(n_iter):
209+
bnd = (pts[:-1] + pts[1:]) / 2
210+
edges = np.concatenate([[-np.inf], bnd, [np.inf]])
211+
new = []
212+
for i in range(n_levels):
213+
a, b = edges[i], edges[i+1]
214+
pa = norm.pdf(a, 0, sig) if np.isfinite(a) else 0
215+
pb = norm.pdf(b, 0, sig) if np.isfinite(b) else 0
216+
phi_a, phi_b = norm.cdf(a, 0, sig), norm.cdf(b, 0, sig)
217+
mass = phi_b - phi_a
218+
new.append(-sig*sig*(pb-pa)/mass if mass > 1e-15 else pts[i])
219+
new = np.array(new)
220+
if np.max(np.abs(new - pts)) < 1e-9: break
221+
pts = new
222+
return pts
223+
224+
LM_CENTROIDS = lloyd_max_gaussian(sigma)
225+
226+
227+
def quantize_per_coord(X, codebook):
228+
"""Per-coord global codebook (the v3 method, no per-block norm)."""
229+
flat = X.flatten()
230+
idx = np.argmin(np.abs(flat[:, None] - codebook[None, :]), axis=1)
231+
return codebook[idx].reshape(X.shape)
232+
233+
234+
print(f"{'codebook':<20}{'quant scheme':<20}{'k_extra=0 MSE':<16}"
235+
f"{'k_extra=1 MSE':<16}{'Δ MSE':<12}{'Δ KL':<12}{'cat 0→1':<14}")
236+
print("-" * 110)
237+
238+
for cb_name, codebook, scheme_name, quant_fn in [
239+
("PROD turbo4", PROD_CENTROIDS_4BIT, "per-row L2 norm", quantize_prod),
240+
("Lloyd-Max Gauss", LM_CENTROIDS, "per-coord global", quantize_per_coord),
241+
]:
242+
base_mse, base_kl, k1_mse, k1_kl, base_kls, k1_kls = (None,) * 6
243+
for k_extra in [0, 1]:
244+
mse_runs, kls_runs = [], []
245+
for trial in range(n_trials):
246+
seed = 3000 + 100*(k_extra+1) + trial
247+
if k_extra == 0:
248+
K_rot, sign_seqs = K_orig.copy(), []
249+
else:
250+
K_rot, sign_seqs = apply_rht(K_orig, k_extra, seed)
251+
K_rot_q = quant_fn(K_rot, codebook)
252+
K_recon = invert_rht(K_rot_q, sign_seqs) if sign_seqs else K_rot_q
253+
mse_runs.append(np.mean((K_recon - K_orig)**2))
254+
q_rng = np.random.default_rng(seed + 50000)
255+
Q = q_rng.normal(0.0, sigma, size=(n_queries, d))
256+
kls_runs.append(attn_kl(K_orig, K_recon, Q))
257+
mse = float(np.mean(mse_runs))
258+
kls = np.concatenate(kls_runs)
259+
if k_extra == 0:
260+
base_mse, base_kl, base_kls = mse, kls.mean(), kls
261+
else:
262+
k1_mse, k1_kl, k1_kls = mse, kls.mean(), kls
263+
base_median = float(np.median(base_kls))
264+
cat0 = float(np.mean(base_kls >= 1.10 * base_median))
265+
cat1 = float(np.mean(k1_kls >= 1.10 * base_median))
266+
d_mse = (k1_mse/base_mse - 1) * 100
267+
d_kl = (k1_kl/base_kl - 1) * 100
268+
print(
269+
f"{cb_name:<20}{scheme_name:<20}{base_mse:<16.3e}"
270+
f"{k1_mse:<16.3e}{d_mse:<+12.1f}{d_kl:<+12.1f}{cat0:.1%}{cat1:.1%}"
271+
)

0 commit comments

Comments
 (0)