|
| 1 | +"""RHT-count sweep using the ACTUAL production turbo4 codebook. |
| 2 | +
|
| 3 | +Source: ggml/src/ggml-cuda/turbo-quant.cuh TURBO_CENTROIDS_4BIT (16 levels). |
| 4 | +Block normalization: per-128-element L2 (QK_TURBO4 = 128 = head_dim per |
| 5 | +ggml-common.h), so each row of K (one head's worth of values for one token) |
| 6 | +is one normalization block. |
| 7 | +
|
| 8 | +This is the codebook the fork actually ships. v3 used Lloyd-Max-on-Gaussian |
| 9 | +(±2.7σ extremes). Production is ±2σ-ish, sub-Gaussian-fitted. Result should |
| 10 | +hold (same mechanism: production codebook lucky-aligns with sub-Gaussian K, |
| 11 | ++RHT breaks alignment), but the magnitude could differ. |
| 12 | +""" |
| 13 | + |
| 14 | +from __future__ import annotations |
| 15 | + |
| 16 | +import numpy as np |
| 17 | +from scipy import linalg, stats |
| 18 | + |
| 19 | +d = 128 # = QK_TURBO4 = head_dim, also the L2-norm block size |
| 20 | +T = 512 |
| 21 | +n_trials = 5 |
| 22 | +n_queries = 64 |
| 23 | +sigma = 1.0 / np.sqrt(d) |
| 24 | +TARGET_KURT = -1.56 |
| 25 | + |
| 26 | + |
| 27 | +# Production turbo4 codebook (16 levels, 4-bit) |
| 28 | +# From ggml/src/ggml-cuda/turbo-quant.cuh: |
| 29 | +PROD_CENTROIDS_4BIT = np.array([ |
| 30 | + -0.173926, -0.117195, -0.089527, -0.068756, |
| 31 | + -0.051262, -0.035597, -0.020989, -0.006938, |
| 32 | + 0.006938, 0.020989, 0.035597, 0.051262, |
| 33 | + 0.068756, 0.089527, 0.117195, 0.173926 |
| 34 | +]) |
| 35 | + |
| 36 | + |
| 37 | +# --------------------------------------------------------------------------- |
| 38 | +# Source K matching §3 layer-0 stats |
| 39 | +# --------------------------------------------------------------------------- |
| 40 | + |
| 41 | +def gen_postwht_K(T, d, rng, noise_frac): |
| 42 | + bern = rng.choice([-1.0, 1.0], size=(T, d)) |
| 43 | + noise = rng.normal(0.0, noise_frac, size=(T, d)) |
| 44 | + raw = bern + noise |
| 45 | + return raw * (sigma / np.sqrt(np.var(raw))) |
| 46 | + |
| 47 | + |
| 48 | +H = linalg.hadamard(d) / np.sqrt(d) |
| 49 | + |
| 50 | + |
| 51 | +def apply_rht(X, k, seed): |
| 52 | + rng = np.random.default_rng(seed) |
| 53 | + Y = X.copy() |
| 54 | + sign_seqs = [] |
| 55 | + for _ in range(k): |
| 56 | + signs = rng.choice([-1.0, 1.0], size=d).astype(np.float64) |
| 57 | + sign_seqs.append(signs) |
| 58 | + Y = (Y * signs) @ H.T |
| 59 | + return Y, sign_seqs |
| 60 | + |
| 61 | + |
| 62 | +def invert_rht(Y, sign_seqs): |
| 63 | + X = Y.copy() |
| 64 | + for signs in reversed(sign_seqs): |
| 65 | + X = (X @ H) * signs |
| 66 | + return X |
| 67 | + |
| 68 | + |
| 69 | +# --------------------------------------------------------------------------- |
| 70 | +# Production-style block quantization |
| 71 | +# |
| 72 | +# Per row (one head_dim block of size 128): |
| 73 | +# norm = ||K_row||_2 (then divide by sqrt(d) to get rms-scaled centroids |
| 74 | +# relative to row magnitude — see kernel: kv = c * norm) |
| 75 | +# Then nearest-centroid lookup. |
| 76 | +# Dequantize: K_row[i] ≈ centroid[idx[i]] * norm |
| 77 | +# --------------------------------------------------------------------------- |
| 78 | + |
| 79 | +def quantize_prod(X, codebook): |
| 80 | + """X: (T, d). Per-row L2 normalize, nearest-centroid, dequantize.""" |
| 81 | + norms = np.linalg.norm(X, axis=1, keepdims=True) # (T, 1) |
| 82 | + norms = np.maximum(norms, 1e-12) |
| 83 | + # The kernel does kv = centroid * norm. So centroids in the table are |
| 84 | + # scaled so that centroid * norm reconstructs the value. Therefore the |
| 85 | + # nearest centroid match is performed against (X / norm). |
| 86 | + X_n = X / norms # (T, d), normalized rows |
| 87 | + flat = X_n.flatten() |
| 88 | + idx = np.argmin(np.abs(flat[:, None] - codebook[None, :]), axis=1) |
| 89 | + out_n = codebook[idx].reshape(X_n.shape) |
| 90 | + return out_n * norms # dequantize |
| 91 | + |
| 92 | + |
| 93 | +def softmax(x, axis=-1): |
| 94 | + x = x - x.max(axis=axis, keepdims=True) |
| 95 | + e = np.exp(x) |
| 96 | + return e / e.sum(axis=axis, keepdims=True) |
| 97 | + |
| 98 | + |
| 99 | +def attn_kl(K_ref, K_test, Q): |
| 100 | + sd = np.sqrt(d) |
| 101 | + p = softmax(Q @ K_ref.T / sd) |
| 102 | + q = softmax(Q @ K_test.T / sd) |
| 103 | + eps = 1e-12 |
| 104 | + return (p * (np.log(p + eps) - np.log(q + eps))).sum(axis=-1) |
| 105 | + |
| 106 | + |
| 107 | +# --------------------------------------------------------------------------- |
| 108 | +# Find noise_frac that matches §3 layer-0 stats with production codebook |
| 109 | +# --------------------------------------------------------------------------- |
| 110 | + |
| 111 | +print(f"=== Production turbo4 codebook ===") |
| 112 | +print(f" centroids (16 levels): extremes ±{PROD_CENTROIDS_4BIT[-1]:.4f}") |
| 113 | +print(f" L2-normalized extreme: ±{PROD_CENTROIDS_4BIT[-1]:.4f} of row norm") |
| 114 | +print(f" vs Lloyd-Max-Gaussian extremes: ±0.240 (my v3 codebook)") |
| 115 | +print() |
| 116 | + |
| 117 | +# Sweep noise_frac to find §3 layer-0 match |
| 118 | +print("Tuning noise_frac to match §3 layer-0 stats...") |
| 119 | +print(f" target: kurt = -1.56, KS = 0.155") |
| 120 | +print() |
| 121 | +chosen = None |
| 122 | +for nf in [0.30, 0.32, 0.34, 0.35, 0.36, 0.38, 0.40]: |
| 123 | + K_try = gen_postwht_K(T, d, np.random.default_rng(42), nf) |
| 124 | + kurt = stats.kurtosis(K_try.flatten()) |
| 125 | + ks, _ = stats.kstest(K_try.flatten(), 'norm', args=(0, sigma)) |
| 126 | + print(f" noise_frac={nf:.2f} kurt={kurt:+.3f} KS={ks:.4f}") |
| 127 | + if abs(kurt - TARGET_KURT) < 0.05: |
| 128 | + chosen = (nf, kurt, ks, K_try) |
| 129 | +nf, kurt, ks, K_orig = chosen if chosen else (0.35, *stats.kurtosis(gen_postwht_K(T, d, np.random.default_rng(42), 0.35).flatten()), 0, gen_postwht_K(T, d, np.random.default_rng(42), 0.35)) |
| 130 | +print() |
| 131 | +print(f" selected: noise_frac={nf:.3f}, kurt={kurt:+.3f}, KS={ks:.4f}") |
| 132 | +print() |
| 133 | + |
| 134 | + |
| 135 | +# --------------------------------------------------------------------------- |
| 136 | +# Sweep k_extra with production codebook |
| 137 | +# --------------------------------------------------------------------------- |
| 138 | + |
| 139 | +print("=== Sweep k_extra on production codebook ===") |
| 140 | +print(f"{'k_extra':<10}{'post-kurt':<14}{'post-KS':<12}" |
| 141 | + f"{'MSE':<18}{'KL mean':<16}{'KL p99':<16}{'Cat rate':<12}") |
| 142 | +print("-" * 100) |
| 143 | + |
| 144 | +results = {} |
| 145 | +all_kl_arrays = {} |
| 146 | + |
| 147 | +for k_extra in [0, 1, 2]: |
| 148 | + mse_runs = [] |
| 149 | + kl_per_query: list[np.ndarray] = [] |
| 150 | + kurt_post = None |
| 151 | + ks_post = None |
| 152 | + for trial in range(n_trials): |
| 153 | + seed = 2000 + 100 * (k_extra + 1) + trial |
| 154 | + if k_extra == 0: |
| 155 | + K_rot, sign_seqs = K_orig.copy(), [] |
| 156 | + else: |
| 157 | + K_rot, sign_seqs = apply_rht(K_orig, k_extra, seed) |
| 158 | + if trial == 0: |
| 159 | + kurt_post = stats.kurtosis(K_rot.flatten()) |
| 160 | + ks_post, _ = stats.kstest(K_rot.flatten(), 'norm', args=(0, sigma)) |
| 161 | + K_rot_q = quantize_prod(K_rot, PROD_CENTROIDS_4BIT) |
| 162 | + K_recon = invert_rht(K_rot_q, sign_seqs) if sign_seqs else K_rot_q |
| 163 | + mse_runs.append(np.mean((K_recon - K_orig) ** 2)) |
| 164 | + q_rng = np.random.default_rng(seed + 50000) |
| 165 | + Q = q_rng.normal(0.0, sigma, size=(n_queries, d)) |
| 166 | + kl_per_query.append(attn_kl(K_orig, K_recon, Q)) |
| 167 | + |
| 168 | + mse_mean = float(np.mean(mse_runs)) |
| 169 | + kl_concat = np.concatenate(kl_per_query) |
| 170 | + all_kl_arrays[k_extra] = kl_concat |
| 171 | + results[k_extra] = { |
| 172 | + "post_kurt": kurt_post, |
| 173 | + "post_ks": ks_post, |
| 174 | + "mse_mean": mse_mean, |
| 175 | + "kl_mean": float(kl_concat.mean()), |
| 176 | + "kl_p99": float(np.percentile(kl_concat, 99)), |
| 177 | + } |
| 178 | + |
| 179 | +base_median = float(np.median(all_kl_arrays[0])) |
| 180 | +for k_extra in [0, 1, 2]: |
| 181 | + r = results[k_extra] |
| 182 | + cat = float(np.mean(all_kl_arrays[k_extra] >= 1.10 * base_median)) |
| 183 | + print( |
| 184 | + f"{k_extra:<10}{r['post_kurt']:<+14.4f}{r['post_ks']:<12.4f}" |
| 185 | + f"{r['mse_mean']:<18.4e}{r['kl_mean']:<16.4e}" |
| 186 | + f"{r['kl_p99']:<16.4e}{cat:<12.1%}" |
| 187 | + ) |
| 188 | + |
| 189 | +print() |
| 190 | +print("=== Delta vs k_extra=0 (production baseline) ===") |
| 191 | +print(f" k_extra=1 MSE: {(results[1]['mse_mean']/results[0]['mse_mean'] - 1)*100:+.2f}% " |
| 192 | + f"KL mean: {(results[1]['kl_mean']/results[0]['kl_mean'] - 1)*100:+.2f}% " |
| 193 | + f"KL p99: {(results[1]['kl_p99']/results[0]['kl_p99'] - 1)*100:+.2f}%") |
| 194 | +print(f" k_extra=2 MSE: {(results[2]['mse_mean']/results[0]['mse_mean'] - 1)*100:+.2f}% " |
| 195 | + f"KL mean: {(results[2]['kl_mean']/results[0]['kl_mean'] - 1)*100:+.2f}% " |
| 196 | + f"KL p99: {(results[2]['kl_p99']/results[0]['kl_p99'] - 1)*100:+.2f}%") |
| 197 | + |
| 198 | +# --------------------------------------------------------------------------- |
| 199 | +# Side-by-side: production vs Lloyd-Max-on-Gaussian codebook |
| 200 | +# --------------------------------------------------------------------------- |
| 201 | + |
| 202 | +print() |
| 203 | +print("=== Side-by-side: PROD codebook vs Lloyd-Max-Gaussian (v3) codebook ===") |
| 204 | + |
| 205 | +from scipy.stats import norm |
| 206 | +def lloyd_max_gaussian(sig, n_levels=16, n_iter=100): |
| 207 | + pts = norm.ppf(np.linspace(0.5/n_levels, 1-0.5/n_levels, n_levels), 0, sig) |
| 208 | + for _ in range(n_iter): |
| 209 | + bnd = (pts[:-1] + pts[1:]) / 2 |
| 210 | + edges = np.concatenate([[-np.inf], bnd, [np.inf]]) |
| 211 | + new = [] |
| 212 | + for i in range(n_levels): |
| 213 | + a, b = edges[i], edges[i+1] |
| 214 | + pa = norm.pdf(a, 0, sig) if np.isfinite(a) else 0 |
| 215 | + pb = norm.pdf(b, 0, sig) if np.isfinite(b) else 0 |
| 216 | + phi_a, phi_b = norm.cdf(a, 0, sig), norm.cdf(b, 0, sig) |
| 217 | + mass = phi_b - phi_a |
| 218 | + new.append(-sig*sig*(pb-pa)/mass if mass > 1e-15 else pts[i]) |
| 219 | + new = np.array(new) |
| 220 | + if np.max(np.abs(new - pts)) < 1e-9: break |
| 221 | + pts = new |
| 222 | + return pts |
| 223 | + |
| 224 | +LM_CENTROIDS = lloyd_max_gaussian(sigma) |
| 225 | + |
| 226 | + |
| 227 | +def quantize_per_coord(X, codebook): |
| 228 | + """Per-coord global codebook (the v3 method, no per-block norm).""" |
| 229 | + flat = X.flatten() |
| 230 | + idx = np.argmin(np.abs(flat[:, None] - codebook[None, :]), axis=1) |
| 231 | + return codebook[idx].reshape(X.shape) |
| 232 | + |
| 233 | + |
| 234 | +print(f"{'codebook':<20}{'quant scheme':<20}{'k_extra=0 MSE':<16}" |
| 235 | + f"{'k_extra=1 MSE':<16}{'Δ MSE':<12}{'Δ KL':<12}{'cat 0→1':<14}") |
| 236 | +print("-" * 110) |
| 237 | + |
| 238 | +for cb_name, codebook, scheme_name, quant_fn in [ |
| 239 | + ("PROD turbo4", PROD_CENTROIDS_4BIT, "per-row L2 norm", quantize_prod), |
| 240 | + ("Lloyd-Max Gauss", LM_CENTROIDS, "per-coord global", quantize_per_coord), |
| 241 | +]: |
| 242 | + base_mse, base_kl, k1_mse, k1_kl, base_kls, k1_kls = (None,) * 6 |
| 243 | + for k_extra in [0, 1]: |
| 244 | + mse_runs, kls_runs = [], [] |
| 245 | + for trial in range(n_trials): |
| 246 | + seed = 3000 + 100*(k_extra+1) + trial |
| 247 | + if k_extra == 0: |
| 248 | + K_rot, sign_seqs = K_orig.copy(), [] |
| 249 | + else: |
| 250 | + K_rot, sign_seqs = apply_rht(K_orig, k_extra, seed) |
| 251 | + K_rot_q = quant_fn(K_rot, codebook) |
| 252 | + K_recon = invert_rht(K_rot_q, sign_seqs) if sign_seqs else K_rot_q |
| 253 | + mse_runs.append(np.mean((K_recon - K_orig)**2)) |
| 254 | + q_rng = np.random.default_rng(seed + 50000) |
| 255 | + Q = q_rng.normal(0.0, sigma, size=(n_queries, d)) |
| 256 | + kls_runs.append(attn_kl(K_orig, K_recon, Q)) |
| 257 | + mse = float(np.mean(mse_runs)) |
| 258 | + kls = np.concatenate(kls_runs) |
| 259 | + if k_extra == 0: |
| 260 | + base_mse, base_kl, base_kls = mse, kls.mean(), kls |
| 261 | + else: |
| 262 | + k1_mse, k1_kl, k1_kls = mse, kls.mean(), kls |
| 263 | + base_median = float(np.median(base_kls)) |
| 264 | + cat0 = float(np.mean(base_kls >= 1.10 * base_median)) |
| 265 | + cat1 = float(np.mean(k1_kls >= 1.10 * base_median)) |
| 266 | + d_mse = (k1_mse/base_mse - 1) * 100 |
| 267 | + d_kl = (k1_kl/base_kl - 1) * 100 |
| 268 | + print( |
| 269 | + f"{cb_name:<20}{scheme_name:<20}{base_mse:<16.3e}" |
| 270 | + f"{k1_mse:<16.3e}{d_mse:<+12.1f}{d_kl:<+12.1f}{cat0:.1%} → {cat1:.1%}" |
| 271 | + ) |
0 commit comments