Skip to content

Commit c470158

Browse files
experiments/rht-k-sweep: LM-Gauss vs production codebook on real Qwen3-0.6B K
Tests Basat 2026 Theorem 7 (3-RHT + URR-optimized codebook matches Gaussian- input quantization) on real post-WHT K extracted from Qwen3-0.6B, layers 0/4/8/12/16/20/23. Constructs Lloyd-Max-optimal codebook for N(0, 1/d) and compares against the production turbo4 codebook (sub-Gaussian fit) across 1/2/3 RHTs. Result: 1-RHT × LMGauss is best by 2% over production; every 2-RHT and 3-RHT cell is worse, including 3-RHT × LMGauss (paper's recommended stack) at +23% vs production. The O(sqrt(log d / d)) additive term in Theorem 7 does not vanish at d=128. Co-Authored-By: tturney <tturney@psyguard.ai>
1 parent 05afaeb commit c470158

1 file changed

Lines changed: 191 additions & 0 deletions

File tree

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
"""LMGauss vs production codebook on REAL Qwen3-0.6B post-WHT K.
2+
3+
Tests Theorem 7 (Basat 2026) prediction on the actual case: 3-RHT + a
4+
Lloyd-Max-Gaussian codebook should match Gaussian-input quantization error
5+
(up to a vanishing additive term). The paper does not pin a specific
6+
codebook; we construct LM-Gauss for N(0, 1/d) so it matches the per-row
7+
L2-normalized post-WHT K scale.
8+
9+
Cross-product to find the actual best cell:
10+
11+
turbo4 (sub-Gaussian fit, ±0.174) LM-Gauss (Gaussian fit, ±0.241)
12+
1-RHT (prod) [production stack] ?
13+
2-RHT ? ?
14+
3-RHT ? [paper's theoretical optimum]
15+
16+
Data: /Users/tom/dev/mse_scripts/eden-investigation/kv_dump/
17+
Layers 0, 4, 8, 12, 16, 20, 23 from Qwen3-0.6B, 512 tokens, 8 heads, d=128.
18+
"""
19+
20+
from __future__ import annotations
21+
22+
from pathlib import Path
23+
24+
import numpy as np
25+
from scipy import linalg
26+
from scipy.stats import norm
27+
28+
D = 128
29+
KV_DIR = Path("/Users/tom/dev/mse_scripts/eden-investigation/kv_dump")
30+
LAYERS = [0, 4, 8, 12, 16, 20, 23]
31+
32+
# Production turbo4 codebook from ggml/src/ggml-cuda/turbo-quant.cuh.
33+
# Header comment claims Lloyd-Max for N(0, 1/128), but the actual centroid
34+
# spacing is tighter (outermost 0.174 vs 0.241 for true Lloyd-Max-Gaussian),
35+
# so it is empirically sub-Gaussian fitted to the post-WHT K distribution.
36+
PROD_CENTROIDS_4BIT = np.array([
37+
-0.173926, -0.117195, -0.089527, -0.068756,
38+
-0.051262, -0.035597, -0.020989, -0.006938,
39+
0.006938, 0.020989, 0.035597, 0.051262,
40+
0.068756, 0.089527, 0.117195, 0.173926,
41+
])
42+
43+
44+
def lloyd_max_gaussian(sigma: float, n_levels: int = 16, n_iter: int = 500) -> np.ndarray:
45+
"""Lloyd-Max-optimal scalar quantizer for N(0, sigma^2)."""
46+
pts = norm.ppf(np.linspace(0.5 / n_levels, 1 - 0.5 / n_levels, n_levels), 0, sigma)
47+
for _ in range(n_iter):
48+
bnd = (pts[:-1] + pts[1:]) / 2
49+
edges = np.concatenate([[-np.inf], bnd, [np.inf]])
50+
new = np.empty_like(pts)
51+
for i in range(n_levels):
52+
a, b = edges[i], edges[i + 1]
53+
pa = norm.pdf(a, 0, sigma) if np.isfinite(a) else 0.0
54+
pb = norm.pdf(b, 0, sigma) if np.isfinite(b) else 0.0
55+
mass = norm.cdf(b, 0, sigma) - norm.cdf(a, 0, sigma)
56+
new[i] = -sigma * sigma * (pb - pa) / mass if mass > 1e-15 else pts[i]
57+
if np.max(np.abs(new - pts)) < 1e-12:
58+
pts = new
59+
break
60+
pts = new
61+
return pts
62+
63+
64+
SIGMA = 1.0 / np.sqrt(D)
65+
LMG_CENTROIDS = lloyd_max_gaussian(SIGMA)
66+
67+
# Walsh-Hadamard basis (orthonormal)
68+
H = linalg.hadamard(D) / np.sqrt(D)
69+
70+
71+
def apply_rht(X: np.ndarray, k: int, seed: int) -> np.ndarray:
72+
"""Apply k random Hadamard transforms (random signs + Hadamard) to X (..., d)."""
73+
rng = np.random.default_rng(seed)
74+
Y = X.copy()
75+
for _ in range(k):
76+
signs = rng.choice([-1.0, 1.0], size=D).astype(np.float64)
77+
Y = (Y * signs) @ H.T
78+
return Y
79+
80+
81+
def quantize_per_row(X: np.ndarray, codebook: np.ndarray) -> np.ndarray:
82+
"""Production scheme: per-row L2 normalize, nearest-centroid lookup, dequantize."""
83+
norms = np.linalg.norm(X, axis=-1, keepdims=True)
84+
norms = np.maximum(norms, 1e-12)
85+
Xn = X / norms
86+
flat = Xn.reshape(-1)
87+
idx = np.argmin(np.abs(flat[:, None] - codebook[None, :]), axis=1)
88+
return codebook[idx].reshape(Xn.shape) * norms
89+
90+
91+
def load_layer_k(layer: int) -> np.ndarray:
92+
"""Load real K for layer, flatten (1, H, T, D) -> (H*T, D)."""
93+
k = np.load(KV_DIR / f"layer{layer:02d}_k.npy").astype(np.float64)
94+
return k.reshape(-1, D)
95+
96+
97+
# ---------------------------------------------------------------------------
98+
# Print codebook overview
99+
# ---------------------------------------------------------------------------
100+
101+
print("=== LMGauss vs Production codebook on REAL Qwen3-0.6B K ===")
102+
print(f" d={D}, layers={LAYERS}, ~{8 * 512} rows per layer")
103+
print(f" PROD turbo4 outermost: ±{PROD_CENTROIDS_4BIT[-1]:.4f}")
104+
print(f" LM-Gauss(sigma=1/sqrt(d)) outer: ±{LMG_CENTROIDS[-1]:.4f} (= {LMG_CENTROIDS[-1]/SIGMA:.3f} sigma)")
105+
print()
106+
107+
# ---------------------------------------------------------------------------
108+
# Sweep
109+
# ---------------------------------------------------------------------------
110+
111+
results: dict[int, dict[tuple[int, str], float]] = {}
112+
post_wht_stats: dict[int, dict[str, float]] = {}
113+
114+
for layer in LAYERS:
115+
K_raw = load_layer_k(layer)
116+
# Production: one fixed-seed WHT-with-signs. This is the "1-RHT" baseline.
117+
K_post_wht = apply_rht(K_raw, 1, seed=layer * 1000)
118+
post_wht_stats[layer] = {
119+
"std": float(K_post_wht.std()),
120+
"abs_max": float(np.abs(K_post_wht).max()),
121+
"frac_outside_turbo4": float(np.mean(np.abs(K_post_wht) > PROD_CENTROIDS_4BIT[-1] * np.linalg.norm(K_post_wht, axis=-1, keepdims=True).mean())),
122+
}
123+
124+
layer_res = {}
125+
for k_extra in (0, 1, 2):
126+
if k_extra == 0:
127+
K_rot = K_post_wht
128+
else:
129+
K_rot = apply_rht(K_post_wht, k_extra, seed=42 + layer * 10 + k_extra)
130+
for cb_name, cb in (("turbo4", PROD_CENTROIDS_4BIT), ("LMGauss", LMG_CENTROIDS)):
131+
K_q = quantize_per_row(K_rot, cb)
132+
layer_res[(k_extra, cb_name)] = float(np.mean((K_q - K_rot) ** 2))
133+
results[layer] = layer_res
134+
135+
# ---------------------------------------------------------------------------
136+
# Per-layer table
137+
# ---------------------------------------------------------------------------
138+
139+
print(f"{'layer':<8}{'RHTs':<6}{'turbo4 MSE':<18}{'LMGauss MSE':<18}"
140+
f"{'turbo4 % vs 1-RHT':<22}{'LMGauss % vs 1-RHT':<22}")
141+
print("-" * 110)
142+
for layer in LAYERS:
143+
base_t = results[layer][(0, "turbo4")]
144+
base_l = results[layer][(0, "LMGauss")]
145+
for k_extra in (0, 1, 2):
146+
t = results[layer][(k_extra, "turbo4")]
147+
l = results[layer][(k_extra, "LMGauss")]
148+
dt = (t / base_t - 1) * 100 if k_extra > 0 else 0.0
149+
dl = (l / base_l - 1) * 100 if k_extra > 0 else 0.0
150+
rht_label = f"{k_extra + 1}-RHT"
151+
print(f"{layer:<8}{rht_label:<6}{t:<18.6e}{l:<18.6e}{dt:<+22.2f}{dl:<+22.2f}")
152+
print()
153+
154+
# ---------------------------------------------------------------------------
155+
# Aggregate ranking
156+
# ---------------------------------------------------------------------------
157+
158+
print("=== Mean MSE across layers (lowest is best) ===")
159+
cells = []
160+
for k_extra in (0, 1, 2):
161+
for cb_name in ("turbo4", "LMGauss"):
162+
vals = [results[l][(k_extra, cb_name)] for l in LAYERS]
163+
mean_mse = float(np.mean(vals))
164+
cells.append((k_extra + 1, cb_name, mean_mse))
165+
166+
cells.sort(key=lambda x: x[2])
167+
prod_baseline = next(m for n, c, m in cells if n == 1 and c == "turbo4")
168+
169+
print(f"{'rank':<6}{'cell':<25}{'mean MSE':<18}{'vs prod (1-RHT × turbo4)':<26}")
170+
print("-" * 80)
171+
for i, (n_rht, cb, mse) in enumerate(cells):
172+
delta = (mse / prod_baseline - 1) * 100
173+
label = f"{n_rht}-RHT × {cb}"
174+
print(f"{i + 1:<6}{label:<25}{mse:<18.6e}{delta:<+26.2f}")
175+
176+
# ---------------------------------------------------------------------------
177+
# Headline interpretation
178+
# ---------------------------------------------------------------------------
179+
180+
best = cells[0]
181+
prod = (1, "turbo4", prod_baseline)
182+
gain = (prod_baseline / best[2] - 1) * 100
183+
print()
184+
print(f"Best: {best[0]}-RHT × {best[1]} MSE={best[2]:.6e}")
185+
print(f"Prod: 1-RHT × turbo4 MSE={prod_baseline:.6e}")
186+
if best[0] == 1 and best[1] == "turbo4":
187+
print("=> Production stack is optimal. Paper's recommendation loses on real K.")
188+
elif best[0] == 3 and best[1] == "LMGauss":
189+
print(f"=> Paper's recommended stack (3-RHT + Gaussian-fit codebook) wins by {gain:.1f}% over production.")
190+
else:
191+
print(f"=> Best cell is {best[0]}-RHT × {best[1]}, gain {gain:.1f}% over production.")

0 commit comments

Comments
 (0)