Skip to content

Commit f6192fa

Browse files
committed
Final fix for MacOS CI failure
1 parent 0591090 commit f6192fa

2 files changed

Lines changed: 59 additions & 20 deletions

File tree

tests/unit/model_bridge/test_bridge_generate_no_tokenizer.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,42 +30,61 @@ def test_generate_without_tokenizer_stop_at_eos_false_kv_cache(tokenizer_free_br
3030

3131
tokens = _PROMPT_TOKENS.clone()
3232

33-
# === TEMP DEBUG: localize CI-only NaN; remove after diagnosing ===
33+
# === TEMP DEBUG: localize where NaN originates on CI ===
3434
import sys
3535

3636
def _diag(label: str, t: torch.Tensor) -> None:
3737
print(
3838
f"[DIAG] {label}: nan={torch.isnan(t).any().item()} "
39-
f"inf={torch.isinf(t).any().item()} "
40-
f"shape={tuple(t.shape)} dtype={t.dtype} "
41-
f"sample={t.flatten()[:4].tolist()}",
39+
f"inf={torch.isinf(t).any().item()} shape={tuple(t.shape)}",
4240
file=sys.stderr,
4341
flush=True,
4442
)
4543

4644
with torch.no_grad():
47-
bl = bridge(tokens, return_type="logits")
48-
_diag("bridge_fwd_no_cache", bl)
49-
50-
with torch.no_grad():
51-
ho = bridge.original_model(tokens)
52-
_diag("hf_fwd_no_cache", ho.logits)
53-
54-
with torch.no_grad():
55-
ho_cache = bridge.original_model(tokens, use_cache=True)
56-
_diag("hf_fwd_step0_use_cache", ho_cache.logits)
45+
o0 = bridge.original_model(tokens, use_cache=True)
46+
_diag("step0_logits", o0.logits)
47+
cache = o0.past_key_values
5748
print(
58-
f"[DIAG] cache_type={type(ho_cache.past_key_values).__name__}",
49+
f"[DIAG] cache_type={type(cache).__name__} "
50+
f"seq_len={cache.get_seq_length() if hasattr(cache, 'get_seq_length') else 'n/a'} "
51+
f"layers={len(cache.layers) if hasattr(cache, 'layers') else 'n/a'}",
5952
file=sys.stderr,
6053
flush=True,
6154
)
62-
63-
next_id = ho_cache.logits[:, -1, :].argmax(-1, keepdim=True)
55+
if hasattr(cache, "layers"):
56+
for li, layer in enumerate(cache.layers):
57+
k = getattr(layer, "keys", None)
58+
v = getattr(layer, "values", None)
59+
if k is not None and v is not None:
60+
print(
61+
f"[DIAG] cache_layer_{li}: K_nan={torch.isnan(k).any().item()} "
62+
f"V_nan={torch.isnan(v).any().item()} K_shape={tuple(k.shape)}",
63+
file=sys.stderr,
64+
flush=True,
65+
)
66+
break # one layer is enough to spot corruption
67+
68+
next_id = o0.logits[:, -1, :].argmax(-1, keepdim=True)
69+
attn_mask = torch.ones((1, tokens.shape[1] + 1), dtype=torch.long)
70+
pos_ids = torch.tensor([[tokens.shape[1]]], dtype=torch.long)
71+
72+
# Variant A: bridge-fix kwargs (mask + position_ids + cache)
6473
with torch.no_grad():
65-
ho_step1 = bridge.original_model(
66-
next_id, past_key_values=ho_cache.past_key_values, use_cache=True
74+
oA = bridge.original_model(
75+
next_id,
76+
past_key_values=o0.past_key_values,
77+
use_cache=True,
78+
attention_mask=attn_mask,
79+
position_ids=pos_ids,
6780
)
68-
_diag("hf_fwd_step1_with_cache", ho_step1.logits)
81+
_diag("step1_with_mask_and_pos", oA.logits)
82+
83+
# Variant B: no cache — feed full 6-token sequence fresh
84+
full_tokens = torch.cat([tokens, next_id], dim=1)
85+
with torch.no_grad():
86+
oB = bridge.original_model(full_tokens)
87+
_diag("step1_full_no_cache", oB.logits)
6988
# === END TEMP DEBUG ===
7089

7190
output = bridge.generate(

transformer_lens/model_bridge/bridge.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2382,10 +2382,30 @@ def _generate_tokens(
23822382
forward_kwargs["use_cache"] = True
23832383
if _hf_kv_cache is not None:
23842384
forward_kwargs["past_key_values"] = _hf_kv_cache
2385+
# HF v5 + macOS-arm64 NaNs when these are inferred
2386+
# from cache state alone. Mirror HF generate(): pass
2387+
# both an (batch, total_len) attention_mask and a
2388+
# (batch, 1) position_ids for the new token.
2389+
batch_size = current_tokens.shape[0]
2390+
total_len = current_tokens.shape[1]
2391+
device = current_tokens.device
2392+
if "attention_mask" not in forward_kwargs:
2393+
forward_kwargs["attention_mask"] = torch.ones(
2394+
(batch_size, total_len),
2395+
dtype=torch.long,
2396+
device=device,
2397+
)
23852398
if "position_ids" in forward_kwargs:
23862399
forward_kwargs["position_ids"] = forward_kwargs["position_ids"][
23872400
:, -1:
23882401
]
2402+
else:
2403+
forward_kwargs["position_ids"] = torch.full(
2404+
(batch_size, 1),
2405+
total_len - 1,
2406+
dtype=torch.long,
2407+
device=device,
2408+
)
23892409
logits = self(
23902410
current_tokens[:, -1:],
23912411
return_type="logits",

0 commit comments

Comments
 (0)