|
| 1 | +"""LMGauss vs production codebook on REAL Qwen3-0.6B post-WHT K. |
| 2 | +
|
| 3 | +Tests Theorem 7 (Basat 2026) prediction on the actual case: 3-RHT + a |
| 4 | +Lloyd-Max-Gaussian codebook should match Gaussian-input quantization error |
| 5 | +(up to a vanishing additive term). The paper does not pin a specific |
| 6 | +codebook; we construct LM-Gauss for N(0, 1/d) so it matches the per-row |
| 7 | +L2-normalized post-WHT K scale. |
| 8 | +
|
| 9 | +Cross-product to find the actual best cell: |
| 10 | +
|
| 11 | + turbo4 (sub-Gaussian fit, ±0.174) LM-Gauss (Gaussian fit, ±0.241) |
| 12 | + 1-RHT (prod) [production stack] ? |
| 13 | + 2-RHT ? ? |
| 14 | + 3-RHT ? [paper's theoretical optimum] |
| 15 | +
|
| 16 | +Data: /Users/tom/dev/mse_scripts/eden-investigation/kv_dump/ |
| 17 | +Layers 0, 4, 8, 12, 16, 20, 23 from Qwen3-0.6B, 512 tokens, 8 heads, d=128. |
| 18 | +""" |
| 19 | + |
| 20 | +from __future__ import annotations |
| 21 | + |
| 22 | +from pathlib import Path |
| 23 | + |
| 24 | +import numpy as np |
| 25 | +from scipy import linalg |
| 26 | +from scipy.stats import norm |
| 27 | + |
| 28 | +D = 128 |
| 29 | +KV_DIR = Path("/Users/tom/dev/mse_scripts/eden-investigation/kv_dump") |
| 30 | +LAYERS = [0, 4, 8, 12, 16, 20, 23] |
| 31 | + |
| 32 | +# Production turbo4 codebook from ggml/src/ggml-cuda/turbo-quant.cuh. |
| 33 | +# Header comment claims Lloyd-Max for N(0, 1/128), but the actual centroid |
| 34 | +# spacing is tighter (outermost 0.174 vs 0.241 for true Lloyd-Max-Gaussian), |
| 35 | +# so it is empirically sub-Gaussian fitted to the post-WHT K distribution. |
| 36 | +PROD_CENTROIDS_4BIT = np.array([ |
| 37 | + -0.173926, -0.117195, -0.089527, -0.068756, |
| 38 | + -0.051262, -0.035597, -0.020989, -0.006938, |
| 39 | + 0.006938, 0.020989, 0.035597, 0.051262, |
| 40 | + 0.068756, 0.089527, 0.117195, 0.173926, |
| 41 | +]) |
| 42 | + |
| 43 | + |
| 44 | +def lloyd_max_gaussian(sigma: float, n_levels: int = 16, n_iter: int = 500) -> np.ndarray: |
| 45 | + """Lloyd-Max-optimal scalar quantizer for N(0, sigma^2).""" |
| 46 | + pts = norm.ppf(np.linspace(0.5 / n_levels, 1 - 0.5 / n_levels, n_levels), 0, sigma) |
| 47 | + for _ in range(n_iter): |
| 48 | + bnd = (pts[:-1] + pts[1:]) / 2 |
| 49 | + edges = np.concatenate([[-np.inf], bnd, [np.inf]]) |
| 50 | + new = np.empty_like(pts) |
| 51 | + for i in range(n_levels): |
| 52 | + a, b = edges[i], edges[i + 1] |
| 53 | + pa = norm.pdf(a, 0, sigma) if np.isfinite(a) else 0.0 |
| 54 | + pb = norm.pdf(b, 0, sigma) if np.isfinite(b) else 0.0 |
| 55 | + mass = norm.cdf(b, 0, sigma) - norm.cdf(a, 0, sigma) |
| 56 | + new[i] = -sigma * sigma * (pb - pa) / mass if mass > 1e-15 else pts[i] |
| 57 | + if np.max(np.abs(new - pts)) < 1e-12: |
| 58 | + pts = new |
| 59 | + break |
| 60 | + pts = new |
| 61 | + return pts |
| 62 | + |
| 63 | + |
| 64 | +SIGMA = 1.0 / np.sqrt(D) |
| 65 | +LMG_CENTROIDS = lloyd_max_gaussian(SIGMA) |
| 66 | + |
| 67 | +# Walsh-Hadamard basis (orthonormal) |
| 68 | +H = linalg.hadamard(D) / np.sqrt(D) |
| 69 | + |
| 70 | + |
| 71 | +def apply_rht(X: np.ndarray, k: int, seed: int) -> np.ndarray: |
| 72 | + """Apply k random Hadamard transforms (random signs + Hadamard) to X (..., d).""" |
| 73 | + rng = np.random.default_rng(seed) |
| 74 | + Y = X.copy() |
| 75 | + for _ in range(k): |
| 76 | + signs = rng.choice([-1.0, 1.0], size=D).astype(np.float64) |
| 77 | + Y = (Y * signs) @ H.T |
| 78 | + return Y |
| 79 | + |
| 80 | + |
| 81 | +def quantize_per_row(X: np.ndarray, codebook: np.ndarray) -> np.ndarray: |
| 82 | + """Production scheme: per-row L2 normalize, nearest-centroid lookup, dequantize.""" |
| 83 | + norms = np.linalg.norm(X, axis=-1, keepdims=True) |
| 84 | + norms = np.maximum(norms, 1e-12) |
| 85 | + Xn = X / norms |
| 86 | + flat = Xn.reshape(-1) |
| 87 | + idx = np.argmin(np.abs(flat[:, None] - codebook[None, :]), axis=1) |
| 88 | + return codebook[idx].reshape(Xn.shape) * norms |
| 89 | + |
| 90 | + |
| 91 | +def load_layer_k(layer: int) -> np.ndarray: |
| 92 | + """Load real K for layer, flatten (1, H, T, D) -> (H*T, D).""" |
| 93 | + k = np.load(KV_DIR / f"layer{layer:02d}_k.npy").astype(np.float64) |
| 94 | + return k.reshape(-1, D) |
| 95 | + |
| 96 | + |
| 97 | +# --------------------------------------------------------------------------- |
| 98 | +# Print codebook overview |
| 99 | +# --------------------------------------------------------------------------- |
| 100 | + |
| 101 | +print("=== LMGauss vs Production codebook on REAL Qwen3-0.6B K ===") |
| 102 | +print(f" d={D}, layers={LAYERS}, ~{8 * 512} rows per layer") |
| 103 | +print(f" PROD turbo4 outermost: ±{PROD_CENTROIDS_4BIT[-1]:.4f}") |
| 104 | +print(f" LM-Gauss(sigma=1/sqrt(d)) outer: ±{LMG_CENTROIDS[-1]:.4f} (= {LMG_CENTROIDS[-1]/SIGMA:.3f} sigma)") |
| 105 | +print() |
| 106 | + |
| 107 | +# --------------------------------------------------------------------------- |
| 108 | +# Sweep |
| 109 | +# --------------------------------------------------------------------------- |
| 110 | + |
| 111 | +results: dict[int, dict[tuple[int, str], float]] = {} |
| 112 | +post_wht_stats: dict[int, dict[str, float]] = {} |
| 113 | + |
| 114 | +for layer in LAYERS: |
| 115 | + K_raw = load_layer_k(layer) |
| 116 | + # Production: one fixed-seed WHT-with-signs. This is the "1-RHT" baseline. |
| 117 | + K_post_wht = apply_rht(K_raw, 1, seed=layer * 1000) |
| 118 | + post_wht_stats[layer] = { |
| 119 | + "std": float(K_post_wht.std()), |
| 120 | + "abs_max": float(np.abs(K_post_wht).max()), |
| 121 | + "frac_outside_turbo4": float(np.mean(np.abs(K_post_wht) > PROD_CENTROIDS_4BIT[-1] * np.linalg.norm(K_post_wht, axis=-1, keepdims=True).mean())), |
| 122 | + } |
| 123 | + |
| 124 | + layer_res = {} |
| 125 | + for k_extra in (0, 1, 2): |
| 126 | + if k_extra == 0: |
| 127 | + K_rot = K_post_wht |
| 128 | + else: |
| 129 | + K_rot = apply_rht(K_post_wht, k_extra, seed=42 + layer * 10 + k_extra) |
| 130 | + for cb_name, cb in (("turbo4", PROD_CENTROIDS_4BIT), ("LMGauss", LMG_CENTROIDS)): |
| 131 | + K_q = quantize_per_row(K_rot, cb) |
| 132 | + layer_res[(k_extra, cb_name)] = float(np.mean((K_q - K_rot) ** 2)) |
| 133 | + results[layer] = layer_res |
| 134 | + |
| 135 | +# --------------------------------------------------------------------------- |
| 136 | +# Per-layer table |
| 137 | +# --------------------------------------------------------------------------- |
| 138 | + |
| 139 | +print(f"{'layer':<8}{'RHTs':<6}{'turbo4 MSE':<18}{'LMGauss MSE':<18}" |
| 140 | + f"{'turbo4 % vs 1-RHT':<22}{'LMGauss % vs 1-RHT':<22}") |
| 141 | +print("-" * 110) |
| 142 | +for layer in LAYERS: |
| 143 | + base_t = results[layer][(0, "turbo4")] |
| 144 | + base_l = results[layer][(0, "LMGauss")] |
| 145 | + for k_extra in (0, 1, 2): |
| 146 | + t = results[layer][(k_extra, "turbo4")] |
| 147 | + l = results[layer][(k_extra, "LMGauss")] |
| 148 | + dt = (t / base_t - 1) * 100 if k_extra > 0 else 0.0 |
| 149 | + dl = (l / base_l - 1) * 100 if k_extra > 0 else 0.0 |
| 150 | + rht_label = f"{k_extra + 1}-RHT" |
| 151 | + print(f"{layer:<8}{rht_label:<6}{t:<18.6e}{l:<18.6e}{dt:<+22.2f}{dl:<+22.2f}") |
| 152 | + print() |
| 153 | + |
| 154 | +# --------------------------------------------------------------------------- |
| 155 | +# Aggregate ranking |
| 156 | +# --------------------------------------------------------------------------- |
| 157 | + |
| 158 | +print("=== Mean MSE across layers (lowest is best) ===") |
| 159 | +cells = [] |
| 160 | +for k_extra in (0, 1, 2): |
| 161 | + for cb_name in ("turbo4", "LMGauss"): |
| 162 | + vals = [results[l][(k_extra, cb_name)] for l in LAYERS] |
| 163 | + mean_mse = float(np.mean(vals)) |
| 164 | + cells.append((k_extra + 1, cb_name, mean_mse)) |
| 165 | + |
| 166 | +cells.sort(key=lambda x: x[2]) |
| 167 | +prod_baseline = next(m for n, c, m in cells if n == 1 and c == "turbo4") |
| 168 | + |
| 169 | +print(f"{'rank':<6}{'cell':<25}{'mean MSE':<18}{'vs prod (1-RHT × turbo4)':<26}") |
| 170 | +print("-" * 80) |
| 171 | +for i, (n_rht, cb, mse) in enumerate(cells): |
| 172 | + delta = (mse / prod_baseline - 1) * 100 |
| 173 | + label = f"{n_rht}-RHT × {cb}" |
| 174 | + print(f"{i + 1:<6}{label:<25}{mse:<18.6e}{delta:<+26.2f}") |
| 175 | + |
| 176 | +# --------------------------------------------------------------------------- |
| 177 | +# Headline interpretation |
| 178 | +# --------------------------------------------------------------------------- |
| 179 | + |
| 180 | +best = cells[0] |
| 181 | +prod = (1, "turbo4", prod_baseline) |
| 182 | +gain = (prod_baseline / best[2] - 1) * 100 |
| 183 | +print() |
| 184 | +print(f"Best: {best[0]}-RHT × {best[1]} MSE={best[2]:.6e}") |
| 185 | +print(f"Prod: 1-RHT × turbo4 MSE={prod_baseline:.6e}") |
| 186 | +if best[0] == 1 and best[1] == "turbo4": |
| 187 | + print("=> Production stack is optimal. Paper's recommendation loses on real K.") |
| 188 | +elif best[0] == 3 and best[1] == "LMGauss": |
| 189 | + print(f"=> Paper's recommended stack (3-RHT + Gaussian-fit codebook) wins by {gain:.1f}% over production.") |
| 190 | +else: |
| 191 | + print(f"=> Best cell is {best[0]}-RHT × {best[1]}, gain {gain:.1f}% over production.") |
0 commit comments