@@ -30,42 +30,61 @@ def test_generate_without_tokenizer_stop_at_eos_false_kv_cache(tokenizer_free_br
3030
3131 tokens = _PROMPT_TOKENS .clone ()
3232
33- # === TEMP DEBUG: localize CI-only NaN; remove after diagnosing ===
33+ # === TEMP DEBUG: localize where NaN originates on CI ===
3434 import sys
3535
3636 def _diag (label : str , t : torch .Tensor ) -> None :
3737 print (
3838 f"[DIAG] { label } : nan={ torch .isnan (t ).any ().item ()} "
39- f"inf={ torch .isinf (t ).any ().item ()} "
40- f"shape={ tuple (t .shape )} dtype={ t .dtype } "
41- f"sample={ t .flatten ()[:4 ].tolist ()} " ,
39+ f"inf={ torch .isinf (t ).any ().item ()} shape={ tuple (t .shape )} " ,
4240 file = sys .stderr ,
4341 flush = True ,
4442 )
4543
4644 with torch .no_grad ():
47- bl = bridge (tokens , return_type = "logits" )
48- _diag ("bridge_fwd_no_cache" , bl )
49-
50- with torch .no_grad ():
51- ho = bridge .original_model (tokens )
52- _diag ("hf_fwd_no_cache" , ho .logits )
53-
54- with torch .no_grad ():
55- ho_cache = bridge .original_model (tokens , use_cache = True )
56- _diag ("hf_fwd_step0_use_cache" , ho_cache .logits )
45+ o0 = bridge .original_model (tokens , use_cache = True )
46+ _diag ("step0_logits" , o0 .logits )
47+ cache = o0 .past_key_values
5748 print (
58- f"[DIAG] cache_type={ type (ho_cache .past_key_values ).__name__ } " ,
49+ f"[DIAG] cache_type={ type (cache ).__name__ } "
50+ f"seq_len={ cache .get_seq_length () if hasattr (cache , 'get_seq_length' ) else 'n/a' } "
51+ f"layers={ len (cache .layers ) if hasattr (cache , 'layers' ) else 'n/a' } " ,
5952 file = sys .stderr ,
6053 flush = True ,
6154 )
62-
63- next_id = ho_cache .logits [:, - 1 , :].argmax (- 1 , keepdim = True )
55+ if hasattr (cache , "layers" ):
56+ for li , layer in enumerate (cache .layers ):
57+ k = getattr (layer , "keys" , None )
58+ v = getattr (layer , "values" , None )
59+ if k is not None and v is not None :
60+ print (
61+ f"[DIAG] cache_layer_{ li } : K_nan={ torch .isnan (k ).any ().item ()} "
62+ f"V_nan={ torch .isnan (v ).any ().item ()} K_shape={ tuple (k .shape )} " ,
63+ file = sys .stderr ,
64+ flush = True ,
65+ )
66+ break # one layer is enough to spot corruption
67+
68+ next_id = o0 .logits [:, - 1 , :].argmax (- 1 , keepdim = True )
69+ attn_mask = torch .ones ((1 , tokens .shape [1 ] + 1 ), dtype = torch .long )
70+ pos_ids = torch .tensor ([[tokens .shape [1 ]]], dtype = torch .long )
71+
72+ # Variant A: bridge-fix kwargs (mask + position_ids + cache)
6473 with torch .no_grad ():
65- ho_step1 = bridge .original_model (
66- next_id , past_key_values = ho_cache .past_key_values , use_cache = True
74+ oA = bridge .original_model (
75+ next_id ,
76+ past_key_values = o0 .past_key_values ,
77+ use_cache = True ,
78+ attention_mask = attn_mask ,
79+ position_ids = pos_ids ,
6780 )
68- _diag ("hf_fwd_step1_with_cache" , ho_step1 .logits )
81+ _diag ("step1_with_mask_and_pos" , oA .logits )
82+
83+ # Variant B: no cache — feed full 6-token sequence fresh
84+ full_tokens = torch .cat ([tokens , next_id ], dim = 1 )
85+ with torch .no_grad ():
86+ oB = bridge .original_model (full_tokens )
87+ _diag ("step1_full_no_cache" , oB .logits )
6988 # === END TEMP DEBUG ===
7089
7190 output = bridge .generate (
0 commit comments