@@ -60,7 +60,7 @@ def print_kv_analysis(cache, prompt_len, gen_tokens=0, elapsed=0):
6060 for i in range (len (cache .key_cache )):
6161 k = cache .key_cache [i ]
6262 if k is not None and isinstance (k , torch .Tensor ) and k .dim () >= 3 :
63- total_fp16 += k .nelement () * 2 * 2 # K+V, fp16
63+ total_fp16 += k .cpu (). nelement () * 2 * 2 # K+V, fp16
6464 if head_dim == 0 :
6565 kv_heads = k .shape [1 ]
6666 head_dim = k .shape [- 1 ]
@@ -131,11 +131,16 @@ def run_chat(question, model, tokenizer):
131131 add_generation_prompt = True ,
132132 enable_thinking = False )
133133 inputs = tokenizer (text , return_tensors = "pt" )
134+ # Move to same device as model
135+ inputs = {k : v .to (model .device ) for k , v in inputs .items ()}
134136 prompt_len = inputs ["input_ids" ].shape [1 ]
135137
136- max_tokens = 80 # ~80 tokens ≈ 2 paragraphs, ~100s on CPU
138+ is_gpu = str (model .device ) != "cpu"
139+ max_tokens = 150 if is_gpu else 80
140+ est_time = max_tokens * 0.1 if is_gpu else max_tokens * 1.3
141+ dev_name = "GPU" if is_gpu else "CPU"
137142
138- print (f" { C .BOLD } { C .GREEN } A:{ C .NC } { C .DIM } (generating ~{ max_tokens } tokens, ~{ max_tokens * 1.3 :.0f} s on CPU ){ C .NC } " )
143+ print (f" { C .BOLD } { C .GREEN } A:{ C .NC } { C .DIM } (generating ~{ max_tokens } tokens, ~{ est_time :.0f} s on { dev_name } ){ C .NC } " )
139144 print (f" " , end = "" , flush = True )
140145
141146 import contextlib , io , threading
@@ -196,6 +201,7 @@ def main():
196201
197202 # Load model (suppress noisy warnings)
198203 print (f" { C .DIM } Loading Qwen3.5-0.8B...{ C .NC } " , end = "" , flush = True )
204+ # Note: device_label is set after torch import below
199205
200206 import warnings
201207 import logging
@@ -210,19 +216,30 @@ def main():
210216 from transformers import AutoModelForCausalLM , AutoTokenizer
211217
212218 model_name = "Qwen/Qwen3.5-0.8B"
219+
220+ # Auto-detect best device: MPS (Apple GPU) > CPU
221+ if torch .backends .mps .is_available ():
222+ device = "mps"
223+ dtype = torch .float16
224+ device_label = "MPS (Apple GPU)"
225+ else :
226+ device = "cpu"
227+ dtype = torch .float32
228+ device_label = "CPU"
229+
213230 with contextlib .redirect_stderr (io .StringIO ()):
214231 tokenizer = AutoTokenizer .from_pretrained (model_name , trust_remote_code = True )
215232 model = AutoModelForCausalLM .from_pretrained (
216- model_name , trust_remote_code = True , dtype = torch . float32
217- )
233+ model_name , trust_remote_code = True , dtype = dtype
234+ ). to ( device )
218235 model .eval ()
219236
220237 # Pre-set pad_token_id to suppress "Setting pad_token_id" message
221238 if tokenizer .pad_token_id is None :
222239 tokenizer .pad_token_id = tokenizer .eos_token_id
223240 model .generation_config .pad_token_id = tokenizer .eos_token_id
224241
225- print (f" { C .GREEN } ✓{ C .NC } " )
242+ print (f" { C .GREEN } ✓{ C .NC } { C . DIM } ( { device_label } ) { C . NC } " )
226243 print ()
227244
228245 if args .benchmark :
0 commit comments