@@ -49,16 +49,21 @@ def print_header():
4949 print (f"{ C .CYAN } { C .BOLD } ╚══════════════════════════════════════════════════════════╝{ C .NC } " )
5050 print ()
5151
52- def print_kv_analysis (cache , prompt_len ):
52+ def print_kv_analysis (cache , prompt_len , gen_tokens = 0 , elapsed = 0 ):
5353 """Analyze and visualize KV cache compression."""
5454 import torch
5555
5656 total_fp16 = 0
5757 layers = 0
58+ head_dim = 0
59+ kv_heads = 0
5860 for i in range (len (cache .key_cache )):
5961 k = cache .key_cache [i ]
6062 if k is not None and isinstance (k , torch .Tensor ) and k .dim () >= 3 :
6163 total_fp16 += k .nelement () * 2 * 2 # K+V, fp16
64+ if head_dim == 0 :
65+ kv_heads = k .shape [1 ]
66+ head_dim = k .shape [- 1 ]
6267 layers += 1
6368
6469 tq_4b = int (total_fp16 * 4.2 / 16 )
@@ -67,11 +72,26 @@ def print_kv_analysis(cache, prompt_len):
6772
6873 print ()
6974 print (f" { C .BOLD } 📊 KV Cache Analysis{ C .NC } " )
70- print (f" { C .DIM } { '─' * 52 } { C .NC } " )
71- print (f" Attention Layers: { C .BOLD } { layers } { C .NC } | Prompt Tokens: { C .BOLD } { prompt_len } { C .NC } " )
75+ print (f" { C .DIM } { '─' * 56 } { C .NC } " )
76+
77+ # Model spec line
78+ print (f" { C .BOLD } Model:{ C .NC } Qwen3.5-0.8B { C .DIM } │{ C .NC } "
79+ f"{ C .BOLD } { layers } { C .NC } attn layers { C .DIM } │{ C .NC } "
80+ f"{ C .BOLD } { kv_heads } { C .NC } KV heads { C .DIM } │{ C .NC } "
81+ f"dim { C .BOLD } { head_dim } { C .NC } " )
82+
83+ # Performance line
84+ if gen_tokens > 0 and elapsed > 0 :
85+ tps = gen_tokens / elapsed
86+ print (f" { C .BOLD } Speed:{ C .NC } { gen_tokens } tokens in { elapsed :.1f} s "
87+ f"({ C .CYAN } { C .BOLD } { tps :.1f} tok/s{ C .NC } ) { C .DIM } │{ C .NC } "
88+ f"prompt { C .BOLD } { prompt_len } { C .NC } tokens" )
89+ else :
90+ print (f" { C .BOLD } Tokens:{ C .NC } { prompt_len } prompt" )
91+
7292 print ()
73- print (f" { C .BOLD } { 'Method' :<22} { 'Size' :>10} { 'Compression ' :>12 } Bar{ C .NC } " )
74- print (f" { '─' * 22 } { '─' * 10 } { '─' * 12 } { '─' * 30 } " )
93+ print (f" { C .BOLD } { 'Method' :<22} { 'Size' :>10} { 'Compress ' :>9 } Bar{ C .NC } " )
94+ print (f" { '─' * 22 } { '─' * 10 } { '─' * 9 } { '─' * 30 } " )
7595
7696 configs = [
7797 ("FP16 (baseline)" , total_fp16 , 1.0 , C .RED ),
@@ -81,7 +101,7 @@ def print_kv_analysis(cache, prompt_len):
81101 ]
82102
83103 for name , size , comp , color in configs :
84- print (f" { name :<22} { size_str (size ):>10} { comp :>10 .1f} x { bar (size , total_fp16 , 30 , color )} " )
104+ print (f" { name :<22} { size_str (size ):>10} { comp :>7 .1f} x { bar (size , total_fp16 , 30 , color )} " )
85105
86106 saved = total_fp16 - k4v2
87107 print ()
@@ -158,18 +178,12 @@ def spinner():
158178 subsequent_indent = " " )
159179 print (wrapped )
160180
161- # Stats
162- print ()
163- print (f" { C .DIM } { '─' * 52 } { C .NC } " )
164- tps = gen_tokens / elapsed if elapsed > 0 else 0
165- print (f" { C .DIM } ⏱ { gen_tokens } tokens in { elapsed :.1f} s ({ tps :.1f} tok/s) | prompt: { prompt_len } tokens{ C .NC } " )
166-
167- # KV cache analysis
181+ # KV cache analysis (with timing info)
168182 with torch .no_grad ():
169183 out2 = model (** inputs , use_cache = True )
170184 cache = out2 .past_key_values
171185
172- print_kv_analysis (cache , prompt_len )
186+ print_kv_analysis (cache , prompt_len , gen_tokens , elapsed )
173187
174188
175189def main ():
0 commit comments