1313
1414from transformer_lens .model_bridge import TransformerBridge
1515
16- # Non-zero IDs to dodge macOS-arm64 KV-cache NaN with all-zero input.
1716_PROMPT_TOKENS = torch .tensor ([[15496 , 11 , 314 , 1101 , 257 ]], dtype = torch .long )
1817
1918
@@ -30,6 +29,45 @@ def test_generate_without_tokenizer_stop_at_eos_false_kv_cache(tokenizer_free_br
3029 assert bridge .tokenizer is None
3130
3231 tokens = _PROMPT_TOKENS .clone ()
32+
33+ # === TEMP DEBUG: localize CI-only NaN; remove after diagnosing ===
34+ import sys
35+
36+ def _diag (label : str , t : torch .Tensor ) -> None :
37+ print (
38+ f"[DIAG] { label } : nan={ torch .isnan (t ).any ().item ()} "
39+ f"inf={ torch .isinf (t ).any ().item ()} "
40+ f"shape={ tuple (t .shape )} dtype={ t .dtype } "
41+ f"sample={ t .flatten ()[:4 ].tolist ()} " ,
42+ file = sys .stderr ,
43+ flush = True ,
44+ )
45+
46+ with torch .no_grad ():
47+ bl = bridge (tokens , return_type = "logits" )
48+ _diag ("bridge_fwd_no_cache" , bl )
49+
50+ with torch .no_grad ():
51+ ho = bridge .original_model (tokens )
52+ _diag ("hf_fwd_no_cache" , ho .logits )
53+
54+ with torch .no_grad ():
55+ ho_cache = bridge .original_model (tokens , use_cache = True )
56+ _diag ("hf_fwd_step0_use_cache" , ho_cache .logits )
57+ print (
58+ f"[DIAG] cache_type={ type (ho_cache .past_key_values ).__name__ } " ,
59+ file = sys .stderr ,
60+ flush = True ,
61+ )
62+
63+ next_id = ho_cache .logits [:, - 1 , :].argmax (- 1 , keepdim = True )
64+ with torch .no_grad ():
65+ ho_step1 = bridge .original_model (
66+ next_id , past_key_values = ho_cache .past_key_values , use_cache = True
67+ )
68+ _diag ("hf_fwd_step1_with_cache" , ho_step1 .logits )
69+ # === END TEMP DEBUG ===
70+
3371 output = bridge .generate (
3472 tokens ,
3573 max_new_tokens = 3 ,
0 commit comments