Skip to content

Commit da370b8

Browse files
committed
Final fix for MacOS CI failure
1 parent 0591090 commit da370b8

2 files changed

Lines changed: 20 additions & 39 deletions

File tree

tests/unit/model_bridge/test_bridge_generate_no_tokenizer.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -29,45 +29,6 @@ def test_generate_without_tokenizer_stop_at_eos_false_kv_cache(tokenizer_free_br
2929
assert bridge.tokenizer is None
3030

3131
tokens = _PROMPT_TOKENS.clone()
32-
33-
# === TEMP DEBUG: localize CI-only NaN; remove after diagnosing ===
34-
import sys
35-
36-
def _diag(label: str, t: torch.Tensor) -> None:
37-
print(
38-
f"[DIAG] {label}: nan={torch.isnan(t).any().item()} "
39-
f"inf={torch.isinf(t).any().item()} "
40-
f"shape={tuple(t.shape)} dtype={t.dtype} "
41-
f"sample={t.flatten()[:4].tolist()}",
42-
file=sys.stderr,
43-
flush=True,
44-
)
45-
46-
with torch.no_grad():
47-
bl = bridge(tokens, return_type="logits")
48-
_diag("bridge_fwd_no_cache", bl)
49-
50-
with torch.no_grad():
51-
ho = bridge.original_model(tokens)
52-
_diag("hf_fwd_no_cache", ho.logits)
53-
54-
with torch.no_grad():
55-
ho_cache = bridge.original_model(tokens, use_cache=True)
56-
_diag("hf_fwd_step0_use_cache", ho_cache.logits)
57-
print(
58-
f"[DIAG] cache_type={type(ho_cache.past_key_values).__name__}",
59-
file=sys.stderr,
60-
flush=True,
61-
)
62-
63-
next_id = ho_cache.logits[:, -1, :].argmax(-1, keepdim=True)
64-
with torch.no_grad():
65-
ho_step1 = bridge.original_model(
66-
next_id, past_key_values=ho_cache.past_key_values, use_cache=True
67-
)
68-
_diag("hf_fwd_step1_with_cache", ho_step1.logits)
69-
# === END TEMP DEBUG ===
70-
7132
output = bridge.generate(
7233
tokens,
7334
max_new_tokens=3,

transformer_lens/model_bridge/bridge.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2382,10 +2382,30 @@ def _generate_tokens(
23822382
forward_kwargs["use_cache"] = True
23832383
if _hf_kv_cache is not None:
23842384
forward_kwargs["past_key_values"] = _hf_kv_cache
2385+
# HF v5 + macOS-arm64 NaNs when these are inferred
2386+
# from cache state alone. Mirror HF generate(): pass
2387+
# both an (batch, total_len) attention_mask and a
2388+
# (batch, 1) position_ids for the new token.
2389+
batch_size = current_tokens.shape[0]
2390+
total_len = current_tokens.shape[1]
2391+
device = current_tokens.device
2392+
if "attention_mask" not in forward_kwargs:
2393+
forward_kwargs["attention_mask"] = torch.ones(
2394+
(batch_size, total_len),
2395+
dtype=torch.long,
2396+
device=device,
2397+
)
23852398
if "position_ids" in forward_kwargs:
23862399
forward_kwargs["position_ids"] = forward_kwargs["position_ids"][
23872400
:, -1:
23882401
]
2402+
else:
2403+
forward_kwargs["position_ids"] = torch.full(
2404+
(batch_size, 1),
2405+
total_len - 1,
2406+
dtype=torch.long,
2407+
device=device,
2408+
)
23892409
logits = self(
23902410
current_tokens[:, -1:],
23912411
return_type="logits",

0 commit comments

Comments
 (0)