|
| 1 | +#!/usr/bin/env python3 |
| 2 | +""" |
| 3 | +TurboQuant CLI — Interactive chat with KV cache compression analysis. |
| 4 | +
|
| 5 | +Usage: |
| 6 | + python3 tools/tq_chat.py # Interactive mode |
| 7 | + python3 tools/tq_chat.py "Your question here" # Single question |
| 8 | + python3 tools/tq_chat.py --benchmark # Run benchmark suite |
| 9 | +""" |
| 10 | + |
| 11 | +import sys |
| 12 | +import os |
| 13 | +import time |
| 14 | +import argparse |
| 15 | + |
| 16 | +# Colors |
| 17 | +class C: |
| 18 | + BOLD = "\033[1m" |
| 19 | + DIM = "\033[2m" |
| 20 | + CYAN = "\033[36m" |
| 21 | + GREEN = "\033[32m" |
| 22 | + YELLOW = "\033[33m" |
| 23 | + RED = "\033[31m" |
| 24 | + MAGENTA = "\033[35m" |
| 25 | + BLUE = "\033[34m" |
| 26 | + NC = "\033[0m" |
| 27 | + BAR = "█" |
| 28 | + BAR_EMPTY = "░" |
| 29 | + |
| 30 | +def bar(value, max_val, width=30, color=C.GREEN): |
| 31 | + filled = int(value / max_val * width) if max_val > 0 else 0 |
| 32 | + filled = min(filled, width) |
| 33 | + return f"{color}{C.BAR * filled}{C.DIM}{C.BAR_EMPTY * (width - filled)}{C.NC}" |
| 34 | + |
| 35 | +def size_str(bytes_val): |
| 36 | + if bytes_val >= 1024 * 1024 * 1024: |
| 37 | + return f"{bytes_val / 1024**3:.2f} GB" |
| 38 | + elif bytes_val >= 1024 * 1024: |
| 39 | + return f"{bytes_val / 1024**2:.1f} MB" |
| 40 | + elif bytes_val >= 1024: |
| 41 | + return f"{bytes_val / 1024:.1f} KB" |
| 42 | + return f"{bytes_val} B" |
| 43 | + |
| 44 | +def print_header(): |
| 45 | + print() |
| 46 | + print(f"{C.CYAN}{C.BOLD}╔══════════════════════════════════════════════════════════╗{C.NC}") |
| 47 | + print(f"{C.CYAN}{C.BOLD}║ 🚀 TurboQuant CLI — KV Cache Compression for LLMs ║{C.NC}") |
| 48 | + print(f"{C.CYAN}{C.BOLD}║ Model: Qwen3.5-0.8B | Powered by QuantumAI Inc. ║{C.NC}") |
| 49 | + print(f"{C.CYAN}{C.BOLD}╚══════════════════════════════════════════════════════════╝{C.NC}") |
| 50 | + print() |
| 51 | + |
| 52 | +def print_kv_analysis(cache, prompt_len): |
| 53 | + """Analyze and visualize KV cache compression.""" |
| 54 | + import torch |
| 55 | + |
| 56 | + total_fp16 = 0 |
| 57 | + layers = 0 |
| 58 | + for i in range(len(cache.key_cache)): |
| 59 | + k = cache.key_cache[i] |
| 60 | + if k is not None and isinstance(k, torch.Tensor) and k.dim() >= 3: |
| 61 | + total_fp16 += k.nelement() * 2 * 2 # K+V, fp16 |
| 62 | + layers += 1 |
| 63 | + |
| 64 | + tq_4b = int(total_fp16 * 4.2 / 16) |
| 65 | + tq_2b = int(total_fp16 * 2.2 / 16) |
| 66 | + k4v2 = int(total_fp16 * (4.2 + 2.2) / 2 / 16) |
| 67 | + |
| 68 | + print() |
| 69 | + print(f" {C.BOLD}📊 KV Cache Analysis{C.NC}") |
| 70 | + print(f" {C.DIM}{'─' * 52}{C.NC}") |
| 71 | + print(f" Attention Layers: {C.BOLD}{layers}{C.NC} | Prompt Tokens: {C.BOLD}{prompt_len}{C.NC}") |
| 72 | + print() |
| 73 | + print(f" {C.BOLD}{'Method':<22} {'Size':>10} {'Compression':>12} Bar{C.NC}") |
| 74 | + print(f" {'─' * 22} {'─' * 10} {'─' * 12} {'─' * 30}") |
| 75 | + |
| 76 | + configs = [ |
| 77 | + ("FP16 (baseline)", total_fp16, 1.0, C.RED), |
| 78 | + ("TQ uniform_4b", tq_4b, total_fp16 / tq_4b, C.GREEN), |
| 79 | + ("TQ K4V2 asymmetric", k4v2, total_fp16 / k4v2, C.GREEN), |
| 80 | + ("TQ uniform_2b", tq_2b, total_fp16 / tq_2b, C.YELLOW), |
| 81 | + ] |
| 82 | + |
| 83 | + for name, size, comp, color in configs: |
| 84 | + print(f" {name:<22} {size_str(size):>10} {comp:>10.1f}x {bar(size, total_fp16, 30, color)}") |
| 85 | + |
| 86 | + saved = total_fp16 - k4v2 |
| 87 | + print() |
| 88 | + print(f" {C.GREEN}{C.BOLD}💾 Best balance (K4V2): saves {size_str(saved)} ({saved*100//total_fp16}%){C.NC}") |
| 89 | + |
| 90 | + # Scale projections |
| 91 | + print() |
| 92 | + print(f" {C.BOLD}📈 Projected at longer contexts:{C.NC}") |
| 93 | + per_token = total_fp16 / prompt_len |
| 94 | + for ctx in [4096, 16384, 65536, 131072]: |
| 95 | + fp16 = per_token * ctx |
| 96 | + k4v2_proj = fp16 * (4.2 + 2.2) / 2 / 16 |
| 97 | + saved_proj = fp16 - k4v2_proj |
| 98 | + ctx_str = f"{ctx // 1024}K" |
| 99 | + print(f" {ctx_str:>6}: FP16 {size_str(fp16):>10} → TQ {size_str(k4v2_proj):>10} {bar(k4v2_proj, fp16, 20, C.GREEN)} save {size_str(saved_proj)}") |
| 100 | + |
| 101 | + |
| 102 | +def run_chat(question, model, tokenizer): |
| 103 | + """Run a single question through the model with analysis.""" |
| 104 | + import torch |
| 105 | + |
| 106 | + print(f" {C.BOLD}{C.BLUE}Q:{C.NC} {question}") |
| 107 | + print() |
| 108 | + |
| 109 | + messages = [{"role": "user", "content": question}] |
| 110 | + text = tokenizer.apply_chat_template(messages, tokenize=False, |
| 111 | + add_generation_prompt=True, |
| 112 | + enable_thinking=False) |
| 113 | + inputs = tokenizer(text, return_tensors="pt") |
| 114 | + prompt_len = inputs["input_ids"].shape[1] |
| 115 | + |
| 116 | + print(f" {C.BOLD}{C.GREEN}A:{C.NC} ", end="", flush=True) |
| 117 | + |
| 118 | + t0 = time.time() |
| 119 | + with torch.no_grad(): |
| 120 | + out = model.generate( |
| 121 | + **inputs, |
| 122 | + max_new_tokens=300, |
| 123 | + do_sample=True, |
| 124 | + temperature=0.7, |
| 125 | + top_p=0.9, |
| 126 | + ) |
| 127 | + elapsed = time.time() - t0 |
| 128 | + |
| 129 | + answer = tokenizer.decode(out[0][prompt_len:], skip_special_tokens=True) |
| 130 | + gen_tokens = out.shape[1] - prompt_len |
| 131 | + |
| 132 | + # Print answer with wrapping |
| 133 | + import textwrap |
| 134 | + for line in answer.split("\n"): |
| 135 | + wrapped = textwrap.fill(line, width=72, initial_indent=" ", |
| 136 | + subsequent_indent=" ") |
| 137 | + print(wrapped) |
| 138 | + |
| 139 | + # Stats |
| 140 | + print() |
| 141 | + print(f" {C.DIM}{'─' * 52}{C.NC}") |
| 142 | + tps = gen_tokens / elapsed if elapsed > 0 else 0 |
| 143 | + print(f" {C.DIM}⏱ {gen_tokens} tokens in {elapsed:.1f}s ({tps:.1f} tok/s) | prompt: {prompt_len} tokens{C.NC}") |
| 144 | + |
| 145 | + # KV cache analysis |
| 146 | + with torch.no_grad(): |
| 147 | + out2 = model(**inputs, use_cache=True) |
| 148 | + cache = out2.past_key_values |
| 149 | + |
| 150 | + print_kv_analysis(cache, prompt_len) |
| 151 | + |
| 152 | + |
| 153 | +def main(): |
| 154 | + parser = argparse.ArgumentParser(description="TurboQuant CLI — Chat with KV cache analysis") |
| 155 | + parser.add_argument("question", nargs="?", help="Question to ask (interactive if omitted)") |
| 156 | + parser.add_argument("--benchmark", action="store_true", help="Run benchmark suite") |
| 157 | + args = parser.parse_args() |
| 158 | + |
| 159 | + print_header() |
| 160 | + |
| 161 | + # Load model |
| 162 | + print(f" {C.DIM}Loading Qwen3.5-0.8B...{C.NC}", end="", flush=True) |
| 163 | + import torch |
| 164 | + from transformers import AutoModelForCausalLM, AutoTokenizer |
| 165 | + |
| 166 | + model_name = "Qwen/Qwen3.5-0.8B" |
| 167 | + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
| 168 | + model = AutoModelForCausalLM.from_pretrained( |
| 169 | + model_name, trust_remote_code=True, dtype=torch.float32 |
| 170 | + ) |
| 171 | + model.eval() |
| 172 | + print(f" {C.GREEN}✓{C.NC}") |
| 173 | + print() |
| 174 | + |
| 175 | + if args.benchmark: |
| 176 | + questions = [ |
| 177 | + "What is 2+2?", |
| 178 | + "Explain KV cache quantization in one paragraph.", |
| 179 | + "Write a Python function that computes fibonacci numbers.", |
| 180 | + ] |
| 181 | + for q in questions: |
| 182 | + run_chat(q, model, tokenizer) |
| 183 | + print() |
| 184 | + print(f" {C.DIM}{'═' * 52}{C.NC}") |
| 185 | + print() |
| 186 | + elif args.question: |
| 187 | + run_chat(args.question, model, tokenizer) |
| 188 | + else: |
| 189 | + # Interactive mode |
| 190 | + print(f" {C.YELLOW}Interactive mode. Type your question (or 'quit' to exit).{C.NC}") |
| 191 | + print() |
| 192 | + while True: |
| 193 | + try: |
| 194 | + q = input(f" {C.BOLD}You:{C.NC} ").strip() |
| 195 | + if not q or q.lower() in ("quit", "exit", "q"): |
| 196 | + print(f"\n {C.DIM}Goodbye!{C.NC}\n") |
| 197 | + break |
| 198 | + print() |
| 199 | + run_chat(q, model, tokenizer) |
| 200 | + print() |
| 201 | + print(f" {C.DIM}{'═' * 52}{C.NC}") |
| 202 | + print() |
| 203 | + except (KeyboardInterrupt, EOFError): |
| 204 | + print(f"\n {C.DIM}Goodbye!{C.NC}\n") |
| 205 | + break |
| 206 | + |
| 207 | + |
| 208 | +if __name__ == "__main__": |
| 209 | + main() |
0 commit comments