88generation path (algorithmic/custom-tokenized use cases).
99"""
1010
11+ import platform
12+
1113import pytest
1214import torch
1315
1416from transformer_lens .model_bridge import TransformerBridge
1517
1618_PROMPT_TOKENS = torch .tensor ([[15496 , 11 , 314 , 1101 , 257 ]], dtype = torch .long )
1719
20+ _MACOS_ARM64 = platform .system () == "Darwin" and platform .machine () == "arm64"
21+
1822
1923@pytest .fixture (scope = "module" )
2024def tokenizer_free_bridge ():
@@ -23,70 +27,13 @@ def tokenizer_free_bridge():
2327 return bridge
2428
2529
30+ @pytest .mark .skipif (_MACOS_ARM64 , reason = "Upstream macOS-arm64 KV-cache NaN; see linked issue." )
2631def test_generate_without_tokenizer_stop_at_eos_false_kv_cache (tokenizer_free_bridge ):
2732 """generate() with no tokenizer, stop_at_eos=False, use_past_kv_cache=True."""
2833 bridge = tokenizer_free_bridge
2934 assert bridge .tokenizer is None
3035
3136 tokens = _PROMPT_TOKENS .clone ()
32-
33- # === TEMP DEBUG: localize where NaN originates on CI ===
34- import sys
35-
36- def _diag (label : str , t : torch .Tensor ) -> None :
37- print (
38- f"[DIAG] { label } : nan={ torch .isnan (t ).any ().item ()} "
39- f"inf={ torch .isinf (t ).any ().item ()} shape={ tuple (t .shape )} " ,
40- file = sys .stderr ,
41- flush = True ,
42- )
43-
44- with torch .no_grad ():
45- o0 = bridge .original_model (tokens , use_cache = True )
46- _diag ("step0_logits" , o0 .logits )
47- cache = o0 .past_key_values
48- print (
49- f"[DIAG] cache_type={ type (cache ).__name__ } "
50- f"seq_len={ cache .get_seq_length () if hasattr (cache , 'get_seq_length' ) else 'n/a' } "
51- f"layers={ len (cache .layers ) if hasattr (cache , 'layers' ) else 'n/a' } " ,
52- file = sys .stderr ,
53- flush = True ,
54- )
55- if hasattr (cache , "layers" ):
56- for li , layer in enumerate (cache .layers ):
57- k = getattr (layer , "keys" , None )
58- v = getattr (layer , "values" , None )
59- if k is not None and v is not None :
60- print (
61- f"[DIAG] cache_layer_{ li } : K_nan={ torch .isnan (k ).any ().item ()} "
62- f"V_nan={ torch .isnan (v ).any ().item ()} K_shape={ tuple (k .shape )} " ,
63- file = sys .stderr ,
64- flush = True ,
65- )
66- break # one layer is enough to spot corruption
67-
68- next_id = o0 .logits [:, - 1 , :].argmax (- 1 , keepdim = True )
69- attn_mask = torch .ones ((1 , tokens .shape [1 ] + 1 ), dtype = torch .long )
70- pos_ids = torch .tensor ([[tokens .shape [1 ]]], dtype = torch .long )
71-
72- # Variant A: bridge-fix kwargs (mask + position_ids + cache)
73- with torch .no_grad ():
74- oA = bridge .original_model (
75- next_id ,
76- past_key_values = o0 .past_key_values ,
77- use_cache = True ,
78- attention_mask = attn_mask ,
79- position_ids = pos_ids ,
80- )
81- _diag ("step1_with_mask_and_pos" , oA .logits )
82-
83- # Variant B: no cache — feed full 6-token sequence fresh
84- full_tokens = torch .cat ([tokens , next_id ], dim = 1 )
85- with torch .no_grad ():
86- oB = bridge .original_model (full_tokens )
87- _diag ("step1_full_no_cache" , oB .logits )
88- # === END TEMP DEBUG ===
89-
9037 output = bridge .generate (
9138 tokens ,
9239 max_new_tokens = 3 ,
@@ -178,6 +125,7 @@ def test_generate_string_input_without_tokenizer_errors(tokenizer_free_bridge):
178125 bridge .generate ("hello" , max_new_tokens = 3 , verbose = False )
179126
180127
128+ @pytest .mark .skipif (_MACOS_ARM64 , reason = "Upstream macOS-arm64 KV-cache NaN; see linked issue." )
181129def test_generate_return_type_str_without_tokenizer_errors (tokenizer_free_bridge ):
182130 """generate(return_type='str') must error when no tokenizer is set.
183131
0 commit comments