|
| 1 | +"""Bridges A + B + C head-to-head on real Qwen3-4B K. |
| 2 | +
|
| 3 | +Compares at matched bit budgets: |
| 4 | + - TurboQuant k8v4 (reference): 1024 bits/tok/head |
| 5 | + - Bridge A (Guth-Katz polynomial partitioning, degree 2 in JL-r space) |
| 6 | + - Bridge B (D4 nested lattice) |
| 7 | + - Bridge C (non-Gaussian shaping with empirical Lloyd-Max) |
| 8 | +
|
| 9 | +Metrics reported per bridge: |
| 10 | + - K rel-MSE on held-out K |
| 11 | + - Mean cosine(x, x̂) |
| 12 | + - Compression: bits / token / kv-head |
| 13 | + - Encode speed (ms / million vectors) |
| 14 | +
|
| 15 | +Also writes JSON snapshot per bridge for later audit. |
| 16 | +""" |
| 17 | +from __future__ import annotations |
| 18 | + |
| 19 | +import argparse |
| 20 | +import json |
| 21 | +import math |
| 22 | +import os |
| 23 | +import sys |
| 24 | +import time |
| 25 | +from pathlib import Path |
| 26 | + |
| 27 | +import numpy as np |
| 28 | +import torch |
| 29 | + |
| 30 | +os.environ.setdefault("VLLM_ENABLE_V1_MULTIPROCESSING", "0") |
| 31 | +os.environ.setdefault("KAKEYA_SNAPSHOT_QWEN3", "1") |
| 32 | + |
| 33 | + |
| 34 | +def capture_qwen3_k(model_path: str, n_passages: int, ctx_len: int, gpu_mem_util: float): |
| 35 | + from vllm import LLM, SamplingParams |
| 36 | + from vllm.inputs import TokensPrompt |
| 37 | + from transformers import AutoTokenizer |
| 38 | + from datasets import load_dataset |
| 39 | + |
| 40 | + tok = AutoTokenizer.from_pretrained(model_path) |
| 41 | + ds = load_dataset("wikitext", "wikitext-103-raw-v1", split="test") |
| 42 | + joined = "\n\n".join(ds["text"]) |
| 43 | + full_ids = tok(joined, return_tensors="pt").input_ids[0].tolist() |
| 44 | + passages = [ |
| 45 | + full_ids[i * ctx_len : (i + 1) * ctx_len] |
| 46 | + for i in range(n_passages) |
| 47 | + if (i + 1) * ctx_len <= len(full_ids) |
| 48 | + ] |
| 49 | + assert len(passages) == n_passages |
| 50 | + |
| 51 | + llm = LLM( |
| 52 | + model=model_path, max_model_len=ctx_len + 1, |
| 53 | + gpu_memory_utilization=gpu_mem_util, |
| 54 | + enforce_eager=True, enable_prefix_caching=False, |
| 55 | + ) |
| 56 | + from kakeya_v1_3_ppl.snapshot_hook import HookState |
| 57 | + HookState.phase = "capture" |
| 58 | + |
| 59 | + accum = {} |
| 60 | + for p_idx, ids in enumerate(passages): |
| 61 | + HookState.captured.clear() |
| 62 | + _ = llm.generate( |
| 63 | + [TokensPrompt(prompt_token_ids=ids)], |
| 64 | + SamplingParams(max_tokens=1, temperature=0.0, prompt_logprobs=1), |
| 65 | + ) |
| 66 | + for lid, kv in HookState.captured.items(): |
| 67 | + accum.setdefault(lid, []).append(np.asarray(kv["K"], dtype=np.float32)) |
| 68 | + return {lid: np.concatenate(arrs, axis=0) for lid, arrs in accum.items()} |
| 69 | + |
| 70 | + |
| 71 | +def tq_k8v4_roundtrip(K_unit: torch.Tensor, bits: int = 8) -> torch.Tensor: |
| 72 | + """Reference TurboQuant k8v4 algorithm: Hadamard rotate + per-coord |
| 73 | + uniform b-bit quantisation + un-rotate. |
| 74 | + """ |
| 75 | + D = K_unit.shape[-1] |
| 76 | + device = K_unit.device |
| 77 | + H = torch.tensor([[1.0]], device=device, dtype=torch.float32) |
| 78 | + while H.shape[0] < D: |
| 79 | + H = torch.cat([torch.cat([H, H], 1), torch.cat([H, -H], 1)], 0) |
| 80 | + H = H / math.sqrt(D) |
| 81 | + flat = K_unit.reshape(-1, D) |
| 82 | + norms = flat.norm(dim=1, keepdim=True).clamp(min=1e-12) |
| 83 | + unit = flat / norms |
| 84 | + y = unit @ H |
| 85 | + qmax = y.abs().max(dim=1, keepdim=True).values.clamp(min=1e-6) |
| 86 | + qs = (1 << (bits - 1)) - 1 |
| 87 | + scale = qmax / qs |
| 88 | + q = torch.round(y / scale).clamp(-qs, qs) * scale |
| 89 | + unit_hat = q @ H |
| 90 | + return (unit_hat * norms).reshape(K_unit.shape) |
| 91 | + |
| 92 | + |
| 93 | +def evaluate_bridge(name: str, K_test: torch.Tensor, K_hat: torch.Tensor, bits: int) -> dict: |
| 94 | + """Compare ground-truth K to reconstructed K_hat.""" |
| 95 | + err = K_test - K_hat |
| 96 | + rel_mse = float((err * err).sum(dim=-1).mean() / (K_test * K_test).sum(dim=-1).mean()) |
| 97 | + abs_mse = float((err * err).sum(dim=-1).mean()) |
| 98 | + cos = (K_test * K_hat).sum(dim=-1) / ( |
| 99 | + K_test.norm(dim=-1) * K_hat.norm(dim=-1).clamp(min=1e-12) |
| 100 | + ) |
| 101 | + return { |
| 102 | + "name": name, |
| 103 | + "bits_per_token_per_head": bits, |
| 104 | + "rel_mse": rel_mse, |
| 105 | + "abs_mse": abs_mse, |
| 106 | + "cos_mean": float(cos.mean().item()), |
| 107 | + "cos_min": float(cos.min().item()), |
| 108 | + } |
| 109 | + |
| 110 | + |
| 111 | +def main() -> None: |
| 112 | + ap = argparse.ArgumentParser(description=__doc__) |
| 113 | + ap.add_argument("--model-path", default="Qwen/Qwen3-4B") |
| 114 | + ap.add_argument("--n-passages", type=int, default=4) |
| 115 | + ap.add_argument("--ctx-len", type=int, default=2048) |
| 116 | + ap.add_argument("--gpu-mem-util", type=float, default=0.40) |
| 117 | + ap.add_argument("--n-train", type=int, default=200_000, |
| 118 | + help="Training samples for bridges A and C") |
| 119 | + ap.add_argument("--n-test", type=int, default=100_000, |
| 120 | + help="Held-out test samples") |
| 121 | + ap.add_argument("--boundary-skip-layers", type=int, nargs="*", |
| 122 | + default=[0, 1, 2, 3, 4, 5, 6, 29, 30, 31, 32, 33, 34, 35]) |
| 123 | + ap.add_argument("--out-dir", type=Path, required=True) |
| 124 | + args = ap.parse_args() |
| 125 | + args.out_dir.mkdir(parents=True, exist_ok=True) |
| 126 | + skip = set(args.boundary_skip_layers) |
| 127 | + |
| 128 | + print(f"[capture] {args.n_passages} × {args.ctx_len} passages …") |
| 129 | + t0 = time.perf_counter() |
| 130 | + captured = capture_qwen3_k( |
| 131 | + args.model_path, args.n_passages, args.ctx_len, args.gpu_mem_util, |
| 132 | + ) |
| 133 | + print(f"[capture] {time.perf_counter() - t0:.1f}s — " |
| 134 | + f"{len(captured)} layers") |
| 135 | + |
| 136 | + # Pool all non-boundary K. |
| 137 | + D = 128 |
| 138 | + pool = [] |
| 139 | + for lid, arr in captured.items(): |
| 140 | + if lid in skip: |
| 141 | + continue |
| 142 | + pool.append(torch.from_numpy(arr).reshape(-1, D)) |
| 143 | + K_all = torch.cat(pool, dim=0).cuda().float() |
| 144 | + N_total = K_all.shape[0] |
| 145 | + print(f"[data] {N_total:,} K vectors from {len(captured) - len(skip)} " |
| 146 | + f"non-boundary layers") |
| 147 | + |
| 148 | + # Train / test split. |
| 149 | + torch.manual_seed(42) |
| 150 | + perm = torch.randperm(N_total, device=K_all.device) |
| 151 | + K_train = K_all[perm[:args.n_train]].contiguous() |
| 152 | + K_test = K_all[perm[args.n_train:args.n_train + args.n_test]].contiguous() |
| 153 | + print(f"[split] train={K_train.shape[0]:,} test={K_test.shape[0]:,}") |
| 154 | + |
| 155 | + # Unit-normalise for bridges that operate on S^(D-1). |
| 156 | + eps = 1e-12 |
| 157 | + test_unit = K_test / K_test.norm(dim=1, keepdim=True).clamp(min=eps) |
| 158 | + |
| 159 | + results = [] |
| 160 | + |
| 161 | + # --- TurboQuant k8v4 reference --- |
| 162 | + print("\n[bridge 0] TurboQuant k8v4 (1024 bits)") |
| 163 | + t0 = time.perf_counter() |
| 164 | + K_hat = tq_k8v4_roundtrip(K_test, bits=8) |
| 165 | + dt = (time.perf_counter() - t0) * 1000 |
| 166 | + res = evaluate_bridge("TQ-k8v4", K_test, K_hat, bits=1024) |
| 167 | + res["encode_ms_per_M_vec"] = dt * 1_000_000 / K_test.shape[0] |
| 168 | + print(f" rel-MSE={res['rel_mse']:.6f} cos={res['cos_mean']:.4f} " |
| 169 | + f"bits={res['bits_per_token_per_head']}") |
| 170 | + results.append(res) |
| 171 | + |
| 172 | + # --- Bridge A: Guth-Katz polynomial partitioning --- |
| 173 | + print("\n[bridge A] Guth-Katz polynomial partitioning") |
| 174 | + from kakeyaturbo_py.bridge_a_guth_katz import GuthKatzPolynomialCodebook |
| 175 | + # Sweep n_polys ∈ {8, 12, 16, 20} for Pareto curve. |
| 176 | + for n_polys in [8, 12, 16, 20]: |
| 177 | + t0 = time.perf_counter() |
| 178 | + cb = GuthKatzPolynomialCodebook( |
| 179 | + K_train, D=D, n_polys=n_polys, seed=0xDEAD + n_polys, |
| 180 | + ) |
| 181 | + t_build = time.perf_counter() - t0 |
| 182 | + t0 = time.perf_counter() |
| 183 | + seg, t = cb.encode(test_unit) |
| 184 | + xhat_unit = cb.decode(seg, t) |
| 185 | + xhat = xhat_unit * K_test.norm(dim=1, keepdim=True) |
| 186 | + dt = (time.perf_counter() - t0) * 1000 |
| 187 | + res = evaluate_bridge( |
| 188 | + f"GuthKatz-polys{n_polys}", K_test, xhat, |
| 189 | + bits=n_polys, |
| 190 | + ) |
| 191 | + res["n_cells_occupied"] = cb.n_occupied |
| 192 | + res["max_cell_count"] = cb.max_cell_count |
| 193 | + res["mean_cell_count"] = cb.mean_cell_count |
| 194 | + res["build_time_s"] = t_build |
| 195 | + res["encode_ms_per_M_vec"] = dt * 1_000_000 / K_test.shape[0] |
| 196 | + print(f" n_polys={n_polys} occ={cb.n_occupied}/{cb.n_cells} " |
| 197 | + f"rel-MSE={res['rel_mse']:.4f} cos={res['cos_mean']:.4f} " |
| 198 | + f"bits={res['bits_per_token_per_head']} " |
| 199 | + f"build={t_build:.1f}s encode={res['encode_ms_per_M_vec']:.1f}ms/M") |
| 200 | + results.append(res) |
| 201 | + |
| 202 | + # --- Bridge B2: D4 + full TurboQuant engineering stack --- |
| 203 | + print("\n[bridge B2] D4 nested lattice + Hadamard + per-vector qmax + fp16 norms") |
| 204 | + from kakeyaturbo_py.bridge_b2_d4_tq_style import D4TQStyleCodebook |
| 205 | + # q_range=152 → 32 bits / D4 block × 32 blocks = 1024 lattice bits + 32 overhead = 1056 |
| 206 | + # total bits, exactly matching TQ k8v4's 1024 + 32 fp16 scalars. |
| 207 | + for q_range in [16, 64, 152]: |
| 208 | + t0 = time.perf_counter() |
| 209 | + cb = D4TQStyleCodebook(D=D, q_range=q_range) |
| 210 | + t_build = time.perf_counter() - t0 |
| 211 | + t0 = time.perf_counter() |
| 212 | + K_hat = cb.roundtrip(K_test) |
| 213 | + dt = (time.perf_counter() - t0) * 1000 |
| 214 | + res = evaluate_bridge( |
| 215 | + f"D4-TQ-Q{q_range}", K_test, K_hat, |
| 216 | + bits=cb.bits_per_token_per_head, |
| 217 | + ) |
| 218 | + res["build_time_s"] = t_build |
| 219 | + res["encode_ms_per_M_vec"] = dt * 1_000_000 / K_test.shape[0] |
| 220 | + print(f" q_range={q_range} rel-MSE={res['rel_mse']:.6f} " |
| 221 | + f"cos={res['cos_mean']:.4f} bits={res['bits_per_token_per_head']} " |
| 222 | + f"build={t_build:.1f}s encode={res['encode_ms_per_M_vec']:.1f}ms/M") |
| 223 | + results.append(res) |
| 224 | + |
| 225 | + # --- Bridge B: D4 nested lattice (naive, for contrast) --- |
| 226 | + print("\n[bridge B] D4 nested lattice (naive, no Hadamard / no per-vector scale)") |
| 227 | + from kakeyaturbo_py.bridge_b_nested_lattice import D4NestedLatticeCodebook |
| 228 | + for q_range in [1, 2, 4, 8, 16]: |
| 229 | + t0 = time.perf_counter() |
| 230 | + cb = D4NestedLatticeCodebook(K_train, D=D, q_range=q_range) |
| 231 | + t_build = time.perf_counter() - t0 |
| 232 | + t0 = time.perf_counter() |
| 233 | + K_hat = cb.roundtrip(K_test) |
| 234 | + dt = (time.perf_counter() - t0) * 1000 |
| 235 | + res = evaluate_bridge( |
| 236 | + f"D4-Q{q_range}", K_test, K_hat, |
| 237 | + bits=cb.bits_per_token_per_head, |
| 238 | + ) |
| 239 | + res["build_time_s"] = t_build |
| 240 | + res["encode_ms_per_M_vec"] = dt * 1_000_000 / K_test.shape[0] |
| 241 | + print(f" q_range={q_range} rel-MSE={res['rel_mse']:.4f} " |
| 242 | + f"cos={res['cos_mean']:.4f} bits={res['bits_per_token_per_head']} " |
| 243 | + f"build={t_build:.1f}s encode={res['encode_ms_per_M_vec']:.1f}ms/M") |
| 244 | + results.append(res) |
| 245 | + |
| 246 | + # --- Bridge C: non-Gaussian shaping --- |
| 247 | + print("\n[bridge C] non-Gaussian shaping via empirical Lloyd-Max") |
| 248 | + from kakeyaturbo_py.bridge_c_non_gaussian import NonGaussianShapingCodebook |
| 249 | + for bits_per_coord in [2, 3, 4, 6, 8]: |
| 250 | + t0 = time.perf_counter() |
| 251 | + cb = NonGaussianShapingCodebook( |
| 252 | + K_train, D=D, bits_per_coord=bits_per_coord, |
| 253 | + ) |
| 254 | + t_build = time.perf_counter() - t0 |
| 255 | + t0 = time.perf_counter() |
| 256 | + K_hat = cb.roundtrip(K_test) |
| 257 | + dt = (time.perf_counter() - t0) * 1000 |
| 258 | + res = evaluate_bridge( |
| 259 | + f"NonGauss-b{bits_per_coord}", K_test, K_hat, |
| 260 | + bits=cb.bits_per_token_per_head, |
| 261 | + ) |
| 262 | + res["build_time_s"] = t_build |
| 263 | + res["encode_ms_per_M_vec"] = dt * 1_000_000 / K_test.shape[0] |
| 264 | + print(f" bits/coord={bits_per_coord} rel-MSE={res['rel_mse']:.6f} " |
| 265 | + f"cos={res['cos_mean']:.4f} bits={res['bits_per_token_per_head']} " |
| 266 | + f"build={t_build:.1f}s encode={res['encode_ms_per_M_vec']:.1f}ms/M") |
| 267 | + results.append(res) |
| 268 | + |
| 269 | + # Pareto summary. |
| 270 | + print("\n=== Pareto summary (rel-MSE @ matched bits) ===") |
| 271 | + # Sort by bit count. |
| 272 | + bits_sorted = sorted({r["bits_per_token_per_head"] for r in results}) |
| 273 | + print(f"{'bits':>6} {'TQ':>8} {'GK':>8} {'D4':>8} {'NG':>8}") |
| 274 | + # Report closest config per family per bit level. |
| 275 | + families = { |
| 276 | + "TQ-k8v4": [r for r in results if r["name"].startswith("TQ")], |
| 277 | + "GuthKatz": [r for r in results if r["name"].startswith("GuthKatz")], |
| 278 | + "D4-": [r for r in results if r["name"].startswith("D4-")], |
| 279 | + "NonGauss": [r for r in results if r["name"].startswith("NonGauss")], |
| 280 | + } |
| 281 | + print("\nAll results:") |
| 282 | + for r in sorted(results, key=lambda r: r["bits_per_token_per_head"]): |
| 283 | + print(f" {r['bits_per_token_per_head']:>5} bits " |
| 284 | + f"rel-MSE={r['rel_mse']:.6f} cos={r['cos_mean']:.4f} " |
| 285 | + f"{r['name']}") |
| 286 | + |
| 287 | + out = { |
| 288 | + "model": args.model_path, |
| 289 | + "n_passages": args.n_passages, |
| 290 | + "n_train": args.n_train, |
| 291 | + "n_test": args.n_test, |
| 292 | + "D": D, |
| 293 | + "boundary_skip_layers": sorted(skip), |
| 294 | + "results": results, |
| 295 | + } |
| 296 | + out_path = args.out_dir / "bridges_abc_head_to_head.json" |
| 297 | + out_path.write_text(json.dumps(out, indent=2, default=float)) |
| 298 | + print(f"\n[done] written → {out_path}") |
| 299 | + |
| 300 | + |
| 301 | +if __name__ == "__main__": |
| 302 | + main() |
0 commit comments