@@ -29,45 +29,6 @@ def test_generate_without_tokenizer_stop_at_eos_false_kv_cache(tokenizer_free_br
2929 assert bridge .tokenizer is None
3030
3131 tokens = _PROMPT_TOKENS .clone ()
32-
33- # === TEMP DEBUG: localize CI-only NaN; remove after diagnosing ===
34- import sys
35-
36- def _diag (label : str , t : torch .Tensor ) -> None :
37- print (
38- f"[DIAG] { label } : nan={ torch .isnan (t ).any ().item ()} "
39- f"inf={ torch .isinf (t ).any ().item ()} "
40- f"shape={ tuple (t .shape )} dtype={ t .dtype } "
41- f"sample={ t .flatten ()[:4 ].tolist ()} " ,
42- file = sys .stderr ,
43- flush = True ,
44- )
45-
46- with torch .no_grad ():
47- bl = bridge (tokens , return_type = "logits" )
48- _diag ("bridge_fwd_no_cache" , bl )
49-
50- with torch .no_grad ():
51- ho = bridge .original_model (tokens )
52- _diag ("hf_fwd_no_cache" , ho .logits )
53-
54- with torch .no_grad ():
55- ho_cache = bridge .original_model (tokens , use_cache = True )
56- _diag ("hf_fwd_step0_use_cache" , ho_cache .logits )
57- print (
58- f"[DIAG] cache_type={ type (ho_cache .past_key_values ).__name__ } " ,
59- file = sys .stderr ,
60- flush = True ,
61- )
62-
63- next_id = ho_cache .logits [:, - 1 , :].argmax (- 1 , keepdim = True )
64- with torch .no_grad ():
65- ho_step1 = bridge .original_model (
66- next_id , past_key_values = ho_cache .past_key_values , use_cache = True
67- )
68- _diag ("hf_fwd_step1_with_cache" , ho_step1 .logits )
69- # === END TEMP DEBUG ===
70-
7132 output = bridge .generate (
7233 tokens ,
7334 max_new_tokens = 3 ,
0 commit comments