Skip to content

Commit 10ef822

Browse files
committed
Final fix for MacOS CI failure
1 parent 0591090 commit 10ef822

2 files changed

Lines changed: 8 additions & 39 deletions

File tree

tests/unit/model_bridge/test_bridge_generate_no_tokenizer.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -29,45 +29,6 @@ def test_generate_without_tokenizer_stop_at_eos_false_kv_cache(tokenizer_free_br
2929
assert bridge.tokenizer is None
3030

3131
tokens = _PROMPT_TOKENS.clone()
32-
33-
# === TEMP DEBUG: localize CI-only NaN; remove after diagnosing ===
34-
import sys
35-
36-
def _diag(label: str, t: torch.Tensor) -> None:
37-
print(
38-
f"[DIAG] {label}: nan={torch.isnan(t).any().item()} "
39-
f"inf={torch.isinf(t).any().item()} "
40-
f"shape={tuple(t.shape)} dtype={t.dtype} "
41-
f"sample={t.flatten()[:4].tolist()}",
42-
file=sys.stderr,
43-
flush=True,
44-
)
45-
46-
with torch.no_grad():
47-
bl = bridge(tokens, return_type="logits")
48-
_diag("bridge_fwd_no_cache", bl)
49-
50-
with torch.no_grad():
51-
ho = bridge.original_model(tokens)
52-
_diag("hf_fwd_no_cache", ho.logits)
53-
54-
with torch.no_grad():
55-
ho_cache = bridge.original_model(tokens, use_cache=True)
56-
_diag("hf_fwd_step0_use_cache", ho_cache.logits)
57-
print(
58-
f"[DIAG] cache_type={type(ho_cache.past_key_values).__name__}",
59-
file=sys.stderr,
60-
flush=True,
61-
)
62-
63-
next_id = ho_cache.logits[:, -1, :].argmax(-1, keepdim=True)
64-
with torch.no_grad():
65-
ho_step1 = bridge.original_model(
66-
next_id, past_key_values=ho_cache.past_key_values, use_cache=True
67-
)
68-
_diag("hf_fwd_step1_with_cache", ho_step1.logits)
69-
# === END TEMP DEBUG ===
70-
7132
output = bridge.generate(
7233
tokens,
7334
max_new_tokens=3,

transformer_lens/model_bridge/bridge.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2386,6 +2386,14 @@ def _generate_tokens(
23862386
forward_kwargs["position_ids"] = forward_kwargs["position_ids"][
23872387
:, -1:
23882388
]
2389+
# HF v5 + macOS-arm64 NaNs when inferring the mask
2390+
# from past_key_values + 1-token input. Pass it.
2391+
if "attention_mask" not in forward_kwargs:
2392+
forward_kwargs["attention_mask"] = torch.ones(
2393+
(current_tokens.shape[0], current_tokens.shape[1]),
2394+
dtype=torch.long,
2395+
device=current_tokens.device,
2396+
)
23892397
logits = self(
23902398
current_tokens[:, -1:],
23912399
return_type="logits",

0 commit comments

Comments
 (0)