|
| 1 | +#!/usr/bin/env python3 |
| 2 | +""" |
| 3 | +tq — TurboQuant CLI |
| 4 | +
|
| 5 | +Unified command-line interface for KV cache compression. |
| 6 | +Designed for both humans and AI agents (JSON output by default). |
| 7 | +
|
| 8 | +Usage: |
| 9 | + tq quantize <input> [--type TYPE] [--output FILE] |
| 10 | + tq bench [--seq-len N] [--head-dim N] [--json] |
| 11 | + tq info [--type TYPE] |
| 12 | + tq demo [--question TEXT] |
| 13 | + tq +compare # A/B test helper |
| 14 | + tq +memory <model> <context> # Memory savings calculator |
| 15 | +
|
| 16 | +Google CLI design principles applied: |
| 17 | + - JSON-first output (--json flag, default for scripts) |
| 18 | + - Structured exit codes |
| 19 | + - Help-driven discovery |
| 20 | + - + prefix for high-level helpers |
| 21 | +""" |
| 22 | + |
| 23 | +import sys |
| 24 | +import os |
| 25 | +import json |
| 26 | +import argparse |
| 27 | +import time |
| 28 | +import struct |
| 29 | +import numpy as np |
| 30 | + |
| 31 | +# Add bindings to path |
| 32 | +sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "../bindings/python")) |
| 33 | + |
| 34 | +# ═══════════════════════════════════════════════════════════ |
| 35 | +# Colors (disabled when piped) |
| 36 | +# ═══════════════════════════════════════════════════════════ |
| 37 | +IS_TTY = sys.stdout.isatty() |
| 38 | + |
| 39 | +class C: |
| 40 | + if IS_TTY: |
| 41 | + BOLD = "\033[1m"; DIM = "\033[2m"; NC = "\033[0m" |
| 42 | + CYAN = "\033[36m"; GREEN = "\033[32m"; YELLOW = "\033[33m" |
| 43 | + RED = "\033[31m"; MAGENTA = "\033[35m"; BLUE = "\033[34m" |
| 44 | + BAR = "█"; BAR_E = "░" |
| 45 | + else: |
| 46 | + BOLD = DIM = NC = CYAN = GREEN = YELLOW = RED = MAGENTA = BLUE = "" |
| 47 | + BAR = "#"; BAR_E = "-" |
| 48 | + |
| 49 | +def bar(val, mx, w=25): |
| 50 | + f = int(val / mx * w) if mx > 0 else 0 |
| 51 | + f = min(f, w) |
| 52 | + return f"{C.GREEN}{C.BAR * f}{C.DIM}{C.BAR_E * (w - f)}{C.NC}" |
| 53 | + |
| 54 | +def sz(b): |
| 55 | + if b >= 1e9: return f"{b/1e9:.2f} GB" |
| 56 | + if b >= 1e6: return f"{b/1e6:.1f} MB" |
| 57 | + if b >= 1e3: return f"{b/1e3:.1f} KB" |
| 58 | + return f"{b} B" |
| 59 | + |
| 60 | +# ═══════════════════════════════════════════════════════════ |
| 61 | +# EXIT CODES (structured, Google CLI pattern) |
| 62 | +# ═══════════════════════════════════════════════════════════ |
| 63 | +EXIT_OK = 0 |
| 64 | +EXIT_USAGE = 1 |
| 65 | +EXIT_LIB_MISSING = 2 |
| 66 | +EXIT_MODEL_ERROR = 3 |
| 67 | +EXIT_IO_ERROR = 4 |
| 68 | + |
| 69 | +# ═══════════════════════════════════════════════════════════ |
| 70 | +# COMMANDS |
| 71 | +# ═══════════════════════════════════════════════════════════ |
| 72 | + |
| 73 | +def cmd_info(args): |
| 74 | + """Show quantization type information.""" |
| 75 | + try: |
| 76 | + from turboquant import TurboQuant |
| 77 | + tq = TurboQuant("cpu") |
| 78 | + except Exception as e: |
| 79 | + print(json.dumps({"error": "TurboQuant library not found", "detail": str(e)})) |
| 80 | + return EXIT_LIB_MISSING |
| 81 | + |
| 82 | + types = [ |
| 83 | + {"name": "uniform_4b", "id": 5, "bits": 4.2, "compression": 7.5, "grade": "A+", "recommended": True}, |
| 84 | + {"name": "mixed_4b8", "id": 7, "bits": 5.0, "compression": 6.4, "grade": "A+", "recommended": True}, |
| 85 | + {"name": "uniform_2b", "id": 6, "bits": 2.2, "compression": 14.2, "grade": "A", "recommended": False}, |
| 86 | + {"name": "turbo_3b", "id": 3, "bits": 5.8, "compression": 4.6, "grade": "B+", "recommended": False}, |
| 87 | + {"name": "polar_4b", "id": 1, "bits": 4.5, "compression": 7.1, "grade": "B", "recommended": False}, |
| 88 | + {"name": "qjl_1b", "id": 2, "bits": 1.2, "compression": 25.6,"grade": "C", "recommended": False}, |
| 89 | + ] |
| 90 | + |
| 91 | + if args.json_output: |
| 92 | + print(json.dumps({"types": types}, indent=2)) |
| 93 | + else: |
| 94 | + print(f"\n {C.BOLD}TurboQuant Quantization Types{C.NC}") |
| 95 | + print(f" Ranked by real Qwen3.5-0.8B A/B test results\n") |
| 96 | + print(f" {C.BOLD}{'Type':<14} {'Bits':>5} {'Compress':>9} {'Grade':>6} {'Note':<20}{C.NC}") |
| 97 | + print(f" {'─'*14} {'─'*5} {'─'*9} {'─'*6} {'─'*20}") |
| 98 | + for t in types: |
| 99 | + star = f"{C.GREEN}★{C.NC}" if t["recommended"] else " " |
| 100 | + note = "← recommended" if t["recommended"] else "" |
| 101 | + print(f" {star} {t['name']:<12} {t['bits']:>5.1f} {t['compression']:>8.1f}x {t['grade']:>5} {note}") |
| 102 | + print() |
| 103 | + return EXIT_OK |
| 104 | + |
| 105 | + |
| 106 | +def cmd_bench(args): |
| 107 | + """Run performance benchmark.""" |
| 108 | + try: |
| 109 | + from turboquant import TurboQuant |
| 110 | + tq = TurboQuant("cpu") |
| 111 | + except Exception as e: |
| 112 | + print(json.dumps({"error": str(e)})) |
| 113 | + return EXIT_LIB_MISSING |
| 114 | + |
| 115 | + seq_len = args.seq_len or 512 |
| 116 | + head_dim = args.head_dim or 128 |
| 117 | + reps = 500 |
| 118 | + |
| 119 | + np.random.seed(42) |
| 120 | + keys = np.random.randn(seq_len, head_dim).astype(np.float32) * 0.15 |
| 121 | + query = np.random.randn(head_dim).astype(np.float32) * 0.15 |
| 122 | + |
| 123 | + results = [] |
| 124 | + for qtype, name in [(5, "uniform_4b"), (7, "mixed_4b8"), (6, "uniform_2b")]: |
| 125 | + t0 = time.time() |
| 126 | + for _ in range(reps): |
| 127 | + q = tq.quantize_keys(keys, qtype) |
| 128 | + quant_time = (time.time() - t0) / reps |
| 129 | + |
| 130 | + deq = tq.dequantize_keys(q, seq_len, head_dim, qtype) |
| 131 | + mse = float(np.mean((keys - deq) ** 2)) |
| 132 | + |
| 133 | + fp32_scores = keys @ query |
| 134 | + scores = tq.attention(query, q, seq_len, head_dim, qtype) |
| 135 | + cos = float(np.dot(scores, fp32_scores) / (np.linalg.norm(scores) * np.linalg.norm(fp32_scores) + 1e-10)) |
| 136 | + |
| 137 | + results.append({ |
| 138 | + "type": name, "seq_len": seq_len, "head_dim": head_dim, |
| 139 | + "mse": round(mse, 6), "cosine": round(cos, 4), |
| 140 | + "quant_ms": round(quant_time * 1000, 3), |
| 141 | + "compression": round(keys.nbytes / len(q), 1), |
| 142 | + }) |
| 143 | + |
| 144 | + if args.json_output: |
| 145 | + print(json.dumps({"benchmark": results}, indent=2)) |
| 146 | + else: |
| 147 | + print(f"\n {C.BOLD}TurboQuant Benchmark{C.NC} (seq={seq_len}, dim={head_dim})\n") |
| 148 | + print(f" {C.BOLD}{'Type':<14} {'MSE':>10} {'Cosine':>8} {'Time':>8} {'Compress':>9}{C.NC}") |
| 149 | + print(f" {'─'*14} {'─'*10} {'─'*8} {'─'*8} {'─'*9}") |
| 150 | + for r in results: |
| 151 | + g = C.GREEN if r["cosine"] > 0.99 else C.YELLOW if r["cosine"] > 0.95 else C.RED |
| 152 | + print(f" {r['type']:<14} {r['mse']:>10.6f} {g}{r['cosine']:>8.4f}{C.NC} {r['quant_ms']:>6.1f}ms {r['compression']:>7.1f}x") |
| 153 | + print() |
| 154 | + return EXIT_OK |
| 155 | + |
| 156 | + |
| 157 | +def cmd_memory(args): |
| 158 | + """Calculate memory savings for a model+context combination.""" |
| 159 | + models = { |
| 160 | + "qwen3.5-0.8b": {"layers": 6, "kv_heads": 2, "head_dim": 256, "params": 0.8}, |
| 161 | + "llama-3.2-1b": {"layers": 16, "kv_heads": 8, "head_dim": 64, "params": 1.2}, |
| 162 | + "llama-3.2-3b": {"layers": 28, "kv_heads": 8, "head_dim": 128, "params": 3.2}, |
| 163 | + "phi-3-mini": {"layers": 32, "kv_heads": 32, "head_dim": 96, "params": 3.8}, |
| 164 | + } |
| 165 | + |
| 166 | + model_key = args.model.lower().replace(" ", "-") |
| 167 | + if model_key not in models: |
| 168 | + avail = ", ".join(models.keys()) |
| 169 | + if args.json_output: |
| 170 | + print(json.dumps({"error": f"Unknown model: {args.model}", "available": list(models.keys())})) |
| 171 | + else: |
| 172 | + print(f" {C.RED}Unknown model: {args.model}{C.NC}") |
| 173 | + print(f" Available: {avail}") |
| 174 | + return EXIT_USAGE |
| 175 | + |
| 176 | + m = models[model_key] |
| 177 | + ctx = args.context |
| 178 | + |
| 179 | + fp16 = m["layers"] * m["kv_heads"] * m["head_dim"] * ctx * 2 * 2 |
| 180 | + tq4b = fp16 * 4.2 / 16 |
| 181 | + k4v2 = fp16 * (4.2 + 2.2) / 2 / 16 |
| 182 | + tq2b = fp16 * 2.2 / 16 |
| 183 | + |
| 184 | + result = { |
| 185 | + "model": args.model, "context": ctx, |
| 186 | + "fp16_bytes": int(fp16), |
| 187 | + "uniform_4b_bytes": int(tq4b), |
| 188 | + "k4v2_bytes": int(k4v2), |
| 189 | + "uniform_2b_bytes": int(tq2b), |
| 190 | + "saved_k4v2_bytes": int(fp16 - k4v2), |
| 191 | + "saved_pct": round((1 - k4v2 / fp16) * 100, 1), |
| 192 | + } |
| 193 | + |
| 194 | + if args.json_output: |
| 195 | + print(json.dumps(result, indent=2)) |
| 196 | + else: |
| 197 | + ctx_str = f"{ctx//1024}K" if ctx >= 1024 else str(ctx) |
| 198 | + print(f"\n {C.BOLD}Memory: {args.model} @ {ctx_str} context{C.NC}\n") |
| 199 | + configs = [ |
| 200 | + ("FP16 (baseline)", fp16, C.RED), |
| 201 | + ("TQ uniform_4b", tq4b, C.GREEN), |
| 202 | + ("TQ K4V2", k4v2, C.GREEN), |
| 203 | + ("TQ uniform_2b", tq2b, C.YELLOW), |
| 204 | + ] |
| 205 | + for name, size, color in configs: |
| 206 | + comp = fp16 / size if size > 0 else 1 |
| 207 | + print(f" {name:<20} {sz(size):>10} {comp:>5.1f}x {bar(size, fp16)}") |
| 208 | + print(f"\n {C.GREEN}{C.BOLD}Best balance (K4V2): saves {sz(fp16 - k4v2)} ({(1-k4v2/fp16)*100:.0f}%){C.NC}\n") |
| 209 | + return EXIT_OK |
| 210 | + |
| 211 | + |
| 212 | +def cmd_compare(args): |
| 213 | + """A/B comparison helper.""" |
| 214 | + os.execvp(sys.executable, [sys.executable, "-c", |
| 215 | + "import subprocess; subprocess.run(['./build/ab_test'])"]) |
| 216 | + |
| 217 | + |
| 218 | +# ═══════════════════════════════════════════════════════════ |
| 219 | +# MAIN |
| 220 | +# ═══════════════════════════════════════════════════════════ |
| 221 | + |
| 222 | +def main(): |
| 223 | + parser = argparse.ArgumentParser( |
| 224 | + prog="tq", |
| 225 | + description="TurboQuant CLI — KV cache compression for LLM inference", |
| 226 | + formatter_class=argparse.RawDescriptionHelpFormatter, |
| 227 | + epilog=""" |
| 228 | +commands: |
| 229 | + info Show quantization types and recommendations |
| 230 | + bench Run performance benchmark |
| 231 | + +memory MODEL CTX Calculate memory savings |
| 232 | + +compare Run A/B comparison (requires build) |
| 233 | + demo Interactive chat with Qwen3.5-0.8B |
| 234 | +
|
| 235 | +examples: |
| 236 | + tq info |
| 237 | + tq info --json |
| 238 | + tq bench --seq-len 2048 --head-dim 256 |
| 239 | + tq +memory llama-3.2-3b 65536 |
| 240 | + tq +memory qwen3.5-0.8b 131072 --json |
| 241 | + tq demo "What is quantization?" |
| 242 | +""") |
| 243 | + parser.add_argument("--json", dest="json_output", action="store_true", help="JSON output (for AI agents)") |
| 244 | + sub = parser.add_subparsers(dest="command") |
| 245 | + |
| 246 | + # info |
| 247 | + p_info = sub.add_parser("info", help="Quantization type information") |
| 248 | + p_info.add_argument("--json", dest="json_output", action="store_true") |
| 249 | + |
| 250 | + # bench |
| 251 | + p_bench = sub.add_parser("bench", help="Performance benchmark") |
| 252 | + p_bench.add_argument("--seq-len", type=int) |
| 253 | + p_bench.add_argument("--head-dim", type=int) |
| 254 | + p_bench.add_argument("--json", dest="json_output", action="store_true") |
| 255 | + |
| 256 | + # +memory |
| 257 | + p_mem = sub.add_parser("+memory", help="Memory savings calculator") |
| 258 | + p_mem.add_argument("model", help="Model name (e.g., llama-3.2-3b)") |
| 259 | + p_mem.add_argument("context", type=int, help="Context length in tokens") |
| 260 | + p_mem.add_argument("--json", dest="json_output", action="store_true") |
| 261 | + |
| 262 | + # +compare |
| 263 | + sub.add_parser("+compare", help="Run A/B comparison") |
| 264 | + |
| 265 | + # demo |
| 266 | + p_demo = sub.add_parser("demo", help="Chat with Qwen3.5-0.8B") |
| 267 | + p_demo.add_argument("question", nargs="?", help="Question (interactive if omitted)") |
| 268 | + |
| 269 | + args = parser.parse_args() |
| 270 | + |
| 271 | + if not args.command: |
| 272 | + parser.print_help() |
| 273 | + return EXIT_USAGE |
| 274 | + |
| 275 | + if args.command == "info": |
| 276 | + return cmd_info(args) |
| 277 | + elif args.command == "bench": |
| 278 | + return cmd_bench(args) |
| 279 | + elif args.command == "+memory": |
| 280 | + return cmd_memory(args) |
| 281 | + elif args.command == "+compare": |
| 282 | + return cmd_compare(args) |
| 283 | + elif args.command == "demo": |
| 284 | + os.execvp(sys.executable, [sys.executable, |
| 285 | + os.path.join(os.path.dirname(__file__), "tq_chat.py"), |
| 286 | + *([] if not args.question else [args.question])]) |
| 287 | + return EXIT_OK |
| 288 | + |
| 289 | + |
| 290 | +if __name__ == "__main__": |
| 291 | + sys.exit(main()) |
0 commit comments