Skip to content

Commit 852e89b

Browse files
unamedkrclaude
andcommitted
CLI: add model spec + speed to KV Cache Analysis label
KV Cache Analysis now shows: Model: Qwen3.5-0.8B │ 6 attn layers │ 2 KV heads │ dim 256 Speed: 80 tokens in 96.6s (0.8 tok/s) │ prompt 18 tokens Removed duplicate stats line (was shown twice). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 77ff00b commit 852e89b

1 file changed

Lines changed: 28 additions & 14 deletions

File tree

tools/tq_chat.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,21 @@ def print_header():
4949
print(f"{C.CYAN}{C.BOLD}╚══════════════════════════════════════════════════════════╝{C.NC}")
5050
print()
5151

52-
def print_kv_analysis(cache, prompt_len):
52+
def print_kv_analysis(cache, prompt_len, gen_tokens=0, elapsed=0):
5353
"""Analyze and visualize KV cache compression."""
5454
import torch
5555

5656
total_fp16 = 0
5757
layers = 0
58+
head_dim = 0
59+
kv_heads = 0
5860
for i in range(len(cache.key_cache)):
5961
k = cache.key_cache[i]
6062
if k is not None and isinstance(k, torch.Tensor) and k.dim() >= 3:
6163
total_fp16 += k.nelement() * 2 * 2 # K+V, fp16
64+
if head_dim == 0:
65+
kv_heads = k.shape[1]
66+
head_dim = k.shape[-1]
6267
layers += 1
6368

6469
tq_4b = int(total_fp16 * 4.2 / 16)
@@ -67,11 +72,26 @@ def print_kv_analysis(cache, prompt_len):
6772

6873
print()
6974
print(f" {C.BOLD}📊 KV Cache Analysis{C.NC}")
70-
print(f" {C.DIM}{'─' * 52}{C.NC}")
71-
print(f" Attention Layers: {C.BOLD}{layers}{C.NC} | Prompt Tokens: {C.BOLD}{prompt_len}{C.NC}")
75+
print(f" {C.DIM}{'─' * 56}{C.NC}")
76+
77+
# Model spec line
78+
print(f" {C.BOLD}Model:{C.NC} Qwen3.5-0.8B {C.DIM}{C.NC} "
79+
f"{C.BOLD}{layers}{C.NC} attn layers {C.DIM}{C.NC} "
80+
f"{C.BOLD}{kv_heads}{C.NC} KV heads {C.DIM}{C.NC} "
81+
f"dim {C.BOLD}{head_dim}{C.NC}")
82+
83+
# Performance line
84+
if gen_tokens > 0 and elapsed > 0:
85+
tps = gen_tokens / elapsed
86+
print(f" {C.BOLD}Speed:{C.NC} {gen_tokens} tokens in {elapsed:.1f}s "
87+
f"({C.CYAN}{C.BOLD}{tps:.1f} tok/s{C.NC}) {C.DIM}{C.NC} "
88+
f"prompt {C.BOLD}{prompt_len}{C.NC} tokens")
89+
else:
90+
print(f" {C.BOLD}Tokens:{C.NC} {prompt_len} prompt")
91+
7292
print()
73-
print(f" {C.BOLD}{'Method':<22} {'Size':>10} {'Compression':>12} Bar{C.NC}")
74-
print(f" {'─' * 22} {'─' * 10} {'─' * 12} {'─' * 30}")
93+
print(f" {C.BOLD}{'Method':<22} {'Size':>10} {'Compress':>9} Bar{C.NC}")
94+
print(f" {'─' * 22} {'─' * 10} {'─' * 9} {'─' * 30}")
7595

7696
configs = [
7797
("FP16 (baseline)", total_fp16, 1.0, C.RED),
@@ -81,7 +101,7 @@ def print_kv_analysis(cache, prompt_len):
81101
]
82102

83103
for name, size, comp, color in configs:
84-
print(f" {name:<22} {size_str(size):>10} {comp:>10.1f}x {bar(size, total_fp16, 30, color)}")
104+
print(f" {name:<22} {size_str(size):>10} {comp:>7.1f}x {bar(size, total_fp16, 30, color)}")
85105

86106
saved = total_fp16 - k4v2
87107
print()
@@ -158,18 +178,12 @@ def spinner():
158178
subsequent_indent=" ")
159179
print(wrapped)
160180

161-
# Stats
162-
print()
163-
print(f" {C.DIM}{'─' * 52}{C.NC}")
164-
tps = gen_tokens / elapsed if elapsed > 0 else 0
165-
print(f" {C.DIM}{gen_tokens} tokens in {elapsed:.1f}s ({tps:.1f} tok/s) | prompt: {prompt_len} tokens{C.NC}")
166-
167-
# KV cache analysis
181+
# KV cache analysis (with timing info)
168182
with torch.no_grad():
169183
out2 = model(**inputs, use_cache=True)
170184
cache = out2.past_key_values
171185

172-
print_kv_analysis(cache, prompt_len)
186+
print_kv_analysis(cache, prompt_len, gen_tokens, elapsed)
173187

174188

175189
def main():

0 commit comments

Comments
 (0)