Skip to content

Commit 6b02711

Browse files
authored
Merge PR #18: Multi-model head-to-head v1.4 KakeyaLattice vs TurboQuant
4 models (Qwen3-4B / DeepSeek-R1-Distill-Qwen-1.5B / Gemma-4-E4B / GLM-4-9B-Chat) x 3 matched bit points = 12 head-to-head pairs, all in real vLLM + strict-GPU H200. K-MSE: 12/12 v1.4 wins. |delta-ppl|: 9/12 v1.4 wins (4/4 at aggressive point). No mock / no simplification / no fallback / no overfit.
2 parents 35153bf + 6beac71 commit 6b02711

215 files changed

Lines changed: 175618 additions & 62 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.gitignore

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,12 @@ __pycache__/
22
*.py[cod]
33
models/
44
turboquant_plus/
5+
6+
# Rust + maturin build artefacts (kakeyaturbo-py wheel is rebuilt from
7+
# source; the target/ tree is huge and platform-specific).
8+
kakeyaturbo/target/
9+
kakeyaturbo-py/target/
10+
*.so
11+
*.dylib
12+
*.pyd
13+
Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
"""Bridges A + B + C head-to-head on real Qwen3-4B K.
2+
3+
Compares at matched bit budgets:
4+
- TurboQuant k8v4 (reference): 1024 bits/tok/head
5+
- Bridge A (Guth-Katz polynomial partitioning, degree 2 in JL-r space)
6+
- Bridge B (D4 nested lattice)
7+
- Bridge C (non-Gaussian shaping with empirical Lloyd-Max)
8+
9+
Metrics reported per bridge:
10+
- K rel-MSE on held-out K
11+
- Mean cosine(x, x̂)
12+
- Compression: bits / token / kv-head
13+
- Encode speed (ms / million vectors)
14+
15+
Also writes JSON snapshot per bridge for later audit.
16+
"""
17+
from __future__ import annotations
18+
19+
import argparse
20+
import json
21+
import math
22+
import os
23+
import sys
24+
import time
25+
from pathlib import Path
26+
27+
import numpy as np
28+
import torch
29+
30+
os.environ.setdefault("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
31+
os.environ.setdefault("KAKEYA_SNAPSHOT_QWEN3", "1")
32+
33+
34+
def capture_qwen3_k(model_path: str, n_passages: int, ctx_len: int, gpu_mem_util: float):
35+
from vllm import LLM, SamplingParams
36+
from vllm.inputs import TokensPrompt
37+
from transformers import AutoTokenizer
38+
from datasets import load_dataset
39+
40+
tok = AutoTokenizer.from_pretrained(model_path)
41+
ds = load_dataset("wikitext", "wikitext-103-raw-v1", split="test")
42+
joined = "\n\n".join(ds["text"])
43+
full_ids = tok(joined, return_tensors="pt").input_ids[0].tolist()
44+
passages = [
45+
full_ids[i * ctx_len : (i + 1) * ctx_len]
46+
for i in range(n_passages)
47+
if (i + 1) * ctx_len <= len(full_ids)
48+
]
49+
assert len(passages) == n_passages
50+
51+
llm = LLM(
52+
model=model_path, max_model_len=ctx_len + 1,
53+
gpu_memory_utilization=gpu_mem_util,
54+
enforce_eager=True, enable_prefix_caching=False,
55+
)
56+
from kakeya_v1_3_ppl.snapshot_hook import HookState
57+
HookState.phase = "capture"
58+
59+
accum = {}
60+
for p_idx, ids in enumerate(passages):
61+
HookState.captured.clear()
62+
_ = llm.generate(
63+
[TokensPrompt(prompt_token_ids=ids)],
64+
SamplingParams(max_tokens=1, temperature=0.0, prompt_logprobs=1),
65+
)
66+
for lid, kv in HookState.captured.items():
67+
accum.setdefault(lid, []).append(np.asarray(kv["K"], dtype=np.float32))
68+
return {lid: np.concatenate(arrs, axis=0) for lid, arrs in accum.items()}
69+
70+
71+
def tq_k8v4_roundtrip(K_unit: torch.Tensor, bits: int = 8) -> torch.Tensor:
72+
"""Reference TurboQuant k8v4 algorithm: Hadamard rotate + per-coord
73+
uniform b-bit quantisation + un-rotate.
74+
"""
75+
D = K_unit.shape[-1]
76+
device = K_unit.device
77+
H = torch.tensor([[1.0]], device=device, dtype=torch.float32)
78+
while H.shape[0] < D:
79+
H = torch.cat([torch.cat([H, H], 1), torch.cat([H, -H], 1)], 0)
80+
H = H / math.sqrt(D)
81+
flat = K_unit.reshape(-1, D)
82+
norms = flat.norm(dim=1, keepdim=True).clamp(min=1e-12)
83+
unit = flat / norms
84+
y = unit @ H
85+
qmax = y.abs().max(dim=1, keepdim=True).values.clamp(min=1e-6)
86+
qs = (1 << (bits - 1)) - 1
87+
scale = qmax / qs
88+
q = torch.round(y / scale).clamp(-qs, qs) * scale
89+
unit_hat = q @ H
90+
return (unit_hat * norms).reshape(K_unit.shape)
91+
92+
93+
def evaluate_bridge(name: str, K_test: torch.Tensor, K_hat: torch.Tensor, bits: int) -> dict:
94+
"""Compare ground-truth K to reconstructed K_hat."""
95+
err = K_test - K_hat
96+
rel_mse = float((err * err).sum(dim=-1).mean() / (K_test * K_test).sum(dim=-1).mean())
97+
abs_mse = float((err * err).sum(dim=-1).mean())
98+
cos = (K_test * K_hat).sum(dim=-1) / (
99+
K_test.norm(dim=-1) * K_hat.norm(dim=-1).clamp(min=1e-12)
100+
)
101+
return {
102+
"name": name,
103+
"bits_per_token_per_head": bits,
104+
"rel_mse": rel_mse,
105+
"abs_mse": abs_mse,
106+
"cos_mean": float(cos.mean().item()),
107+
"cos_min": float(cos.min().item()),
108+
}
109+
110+
111+
def main() -> None:
112+
ap = argparse.ArgumentParser(description=__doc__)
113+
ap.add_argument("--model-path", default="Qwen/Qwen3-4B")
114+
ap.add_argument("--n-passages", type=int, default=4)
115+
ap.add_argument("--ctx-len", type=int, default=2048)
116+
ap.add_argument("--gpu-mem-util", type=float, default=0.40)
117+
ap.add_argument("--n-train", type=int, default=200_000,
118+
help="Training samples for bridges A and C")
119+
ap.add_argument("--n-test", type=int, default=100_000,
120+
help="Held-out test samples")
121+
ap.add_argument("--boundary-skip-layers", type=int, nargs="*",
122+
default=[0, 1, 2, 3, 4, 5, 6, 29, 30, 31, 32, 33, 34, 35])
123+
ap.add_argument("--out-dir", type=Path, required=True)
124+
args = ap.parse_args()
125+
args.out_dir.mkdir(parents=True, exist_ok=True)
126+
skip = set(args.boundary_skip_layers)
127+
128+
print(f"[capture] {args.n_passages} × {args.ctx_len} passages …")
129+
t0 = time.perf_counter()
130+
captured = capture_qwen3_k(
131+
args.model_path, args.n_passages, args.ctx_len, args.gpu_mem_util,
132+
)
133+
print(f"[capture] {time.perf_counter() - t0:.1f}s — "
134+
f"{len(captured)} layers")
135+
136+
# Pool all non-boundary K.
137+
D = 128
138+
pool = []
139+
for lid, arr in captured.items():
140+
if lid in skip:
141+
continue
142+
pool.append(torch.from_numpy(arr).reshape(-1, D))
143+
K_all = torch.cat(pool, dim=0).cuda().float()
144+
N_total = K_all.shape[0]
145+
print(f"[data] {N_total:,} K vectors from {len(captured) - len(skip)} "
146+
f"non-boundary layers")
147+
148+
# Train / test split.
149+
torch.manual_seed(42)
150+
perm = torch.randperm(N_total, device=K_all.device)
151+
K_train = K_all[perm[:args.n_train]].contiguous()
152+
K_test = K_all[perm[args.n_train:args.n_train + args.n_test]].contiguous()
153+
print(f"[split] train={K_train.shape[0]:,} test={K_test.shape[0]:,}")
154+
155+
# Unit-normalise for bridges that operate on S^(D-1).
156+
eps = 1e-12
157+
test_unit = K_test / K_test.norm(dim=1, keepdim=True).clamp(min=eps)
158+
159+
results = []
160+
161+
# --- TurboQuant k8v4 reference ---
162+
print("\n[bridge 0] TurboQuant k8v4 (1024 bits)")
163+
t0 = time.perf_counter()
164+
K_hat = tq_k8v4_roundtrip(K_test, bits=8)
165+
dt = (time.perf_counter() - t0) * 1000
166+
res = evaluate_bridge("TQ-k8v4", K_test, K_hat, bits=1024)
167+
res["encode_ms_per_M_vec"] = dt * 1_000_000 / K_test.shape[0]
168+
print(f" rel-MSE={res['rel_mse']:.6f} cos={res['cos_mean']:.4f} "
169+
f"bits={res['bits_per_token_per_head']}")
170+
results.append(res)
171+
172+
# --- Bridge A: Guth-Katz polynomial partitioning ---
173+
print("\n[bridge A] Guth-Katz polynomial partitioning")
174+
from kakeyaturbo_py.bridge_a_guth_katz import GuthKatzPolynomialCodebook
175+
# Sweep n_polys ∈ {8, 12, 16, 20} for Pareto curve.
176+
for n_polys in [8, 12, 16, 20]:
177+
t0 = time.perf_counter()
178+
cb = GuthKatzPolynomialCodebook(
179+
K_train, D=D, n_polys=n_polys, seed=0xDEAD + n_polys,
180+
)
181+
t_build = time.perf_counter() - t0
182+
t0 = time.perf_counter()
183+
seg, t = cb.encode(test_unit)
184+
xhat_unit = cb.decode(seg, t)
185+
xhat = xhat_unit * K_test.norm(dim=1, keepdim=True)
186+
dt = (time.perf_counter() - t0) * 1000
187+
res = evaluate_bridge(
188+
f"GuthKatz-polys{n_polys}", K_test, xhat,
189+
bits=n_polys,
190+
)
191+
res["n_cells_occupied"] = cb.n_occupied
192+
res["max_cell_count"] = cb.max_cell_count
193+
res["mean_cell_count"] = cb.mean_cell_count
194+
res["build_time_s"] = t_build
195+
res["encode_ms_per_M_vec"] = dt * 1_000_000 / K_test.shape[0]
196+
print(f" n_polys={n_polys} occ={cb.n_occupied}/{cb.n_cells} "
197+
f"rel-MSE={res['rel_mse']:.4f} cos={res['cos_mean']:.4f} "
198+
f"bits={res['bits_per_token_per_head']} "
199+
f"build={t_build:.1f}s encode={res['encode_ms_per_M_vec']:.1f}ms/M")
200+
results.append(res)
201+
202+
# --- Bridge B2: D4 + full TurboQuant engineering stack ---
203+
print("\n[bridge B2] D4 nested lattice + Hadamard + per-vector qmax + fp16 norms")
204+
from kakeyaturbo_py.bridge_b2_d4_tq_style import D4TQStyleCodebook
205+
# q_range=152 → 32 bits / D4 block × 32 blocks = 1024 lattice bits + 32 overhead = 1056
206+
# total bits, exactly matching TQ k8v4's 1024 + 32 fp16 scalars.
207+
for q_range in [16, 64, 152]:
208+
t0 = time.perf_counter()
209+
cb = D4TQStyleCodebook(D=D, q_range=q_range)
210+
t_build = time.perf_counter() - t0
211+
t0 = time.perf_counter()
212+
K_hat = cb.roundtrip(K_test)
213+
dt = (time.perf_counter() - t0) * 1000
214+
res = evaluate_bridge(
215+
f"D4-TQ-Q{q_range}", K_test, K_hat,
216+
bits=cb.bits_per_token_per_head,
217+
)
218+
res["build_time_s"] = t_build
219+
res["encode_ms_per_M_vec"] = dt * 1_000_000 / K_test.shape[0]
220+
print(f" q_range={q_range} rel-MSE={res['rel_mse']:.6f} "
221+
f"cos={res['cos_mean']:.4f} bits={res['bits_per_token_per_head']} "
222+
f"build={t_build:.1f}s encode={res['encode_ms_per_M_vec']:.1f}ms/M")
223+
results.append(res)
224+
225+
# --- Bridge B: D4 nested lattice (naive, for contrast) ---
226+
print("\n[bridge B] D4 nested lattice (naive, no Hadamard / no per-vector scale)")
227+
from kakeyaturbo_py.bridge_b_nested_lattice import D4NestedLatticeCodebook
228+
for q_range in [1, 2, 4, 8, 16]:
229+
t0 = time.perf_counter()
230+
cb = D4NestedLatticeCodebook(K_train, D=D, q_range=q_range)
231+
t_build = time.perf_counter() - t0
232+
t0 = time.perf_counter()
233+
K_hat = cb.roundtrip(K_test)
234+
dt = (time.perf_counter() - t0) * 1000
235+
res = evaluate_bridge(
236+
f"D4-Q{q_range}", K_test, K_hat,
237+
bits=cb.bits_per_token_per_head,
238+
)
239+
res["build_time_s"] = t_build
240+
res["encode_ms_per_M_vec"] = dt * 1_000_000 / K_test.shape[0]
241+
print(f" q_range={q_range} rel-MSE={res['rel_mse']:.4f} "
242+
f"cos={res['cos_mean']:.4f} bits={res['bits_per_token_per_head']} "
243+
f"build={t_build:.1f}s encode={res['encode_ms_per_M_vec']:.1f}ms/M")
244+
results.append(res)
245+
246+
# --- Bridge C: non-Gaussian shaping ---
247+
print("\n[bridge C] non-Gaussian shaping via empirical Lloyd-Max")
248+
from kakeyaturbo_py.bridge_c_non_gaussian import NonGaussianShapingCodebook
249+
for bits_per_coord in [2, 3, 4, 6, 8]:
250+
t0 = time.perf_counter()
251+
cb = NonGaussianShapingCodebook(
252+
K_train, D=D, bits_per_coord=bits_per_coord,
253+
)
254+
t_build = time.perf_counter() - t0
255+
t0 = time.perf_counter()
256+
K_hat = cb.roundtrip(K_test)
257+
dt = (time.perf_counter() - t0) * 1000
258+
res = evaluate_bridge(
259+
f"NonGauss-b{bits_per_coord}", K_test, K_hat,
260+
bits=cb.bits_per_token_per_head,
261+
)
262+
res["build_time_s"] = t_build
263+
res["encode_ms_per_M_vec"] = dt * 1_000_000 / K_test.shape[0]
264+
print(f" bits/coord={bits_per_coord} rel-MSE={res['rel_mse']:.6f} "
265+
f"cos={res['cos_mean']:.4f} bits={res['bits_per_token_per_head']} "
266+
f"build={t_build:.1f}s encode={res['encode_ms_per_M_vec']:.1f}ms/M")
267+
results.append(res)
268+
269+
# Pareto summary.
270+
print("\n=== Pareto summary (rel-MSE @ matched bits) ===")
271+
# Sort by bit count.
272+
bits_sorted = sorted({r["bits_per_token_per_head"] for r in results})
273+
print(f"{'bits':>6} {'TQ':>8} {'GK':>8} {'D4':>8} {'NG':>8}")
274+
# Report closest config per family per bit level.
275+
families = {
276+
"TQ-k8v4": [r for r in results if r["name"].startswith("TQ")],
277+
"GuthKatz": [r for r in results if r["name"].startswith("GuthKatz")],
278+
"D4-": [r for r in results if r["name"].startswith("D4-")],
279+
"NonGauss": [r for r in results if r["name"].startswith("NonGauss")],
280+
}
281+
print("\nAll results:")
282+
for r in sorted(results, key=lambda r: r["bits_per_token_per_head"]):
283+
print(f" {r['bits_per_token_per_head']:>5} bits "
284+
f"rel-MSE={r['rel_mse']:.6f} cos={r['cos_mean']:.4f} "
285+
f"{r['name']}")
286+
287+
out = {
288+
"model": args.model_path,
289+
"n_passages": args.n_passages,
290+
"n_train": args.n_train,
291+
"n_test": args.n_test,
292+
"D": D,
293+
"boundary_skip_layers": sorted(skip),
294+
"results": results,
295+
}
296+
out_path = args.out_dir / "bridges_abc_head_to_head.json"
297+
out_path.write_text(json.dumps(out, indent=2, default=float))
298+
print(f"\n[done] written → {out_path}")
299+
300+
301+
if __name__ == "__main__":
302+
main()

0 commit comments

Comments
 (0)