Skip to content

Commit 2355d53

Browse files
abrichrclaude
andauthored
test: synthetic vision-merge model proves fix correctness (#226)
VisionMergeModel mimics Qwen2.5/3.5-VL: replaces placeholder tokens with N visual features, changing sequence length. 4 new tests: - test_manual_concat_crashes: OLD approach → IndexError (mask mismatch) - test_unified_processor_works: NEW approach → correct post-merge shape - test_no_vision_no_merge: no pixel_values → no merge → mask safe - test_exclude_strips_vision: exclude mode → no pixel_values → safe Architecture-agnostic. 12/12 pass in 0.05s. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent b488794 commit 2355d53

1 file changed

Lines changed: 88 additions & 1 deletion

File tree

tests/test_vision_loss.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,94 @@ def test_different_sequence_lengths_same_result(self):
217217

218218

219219
# ---------------------------------------------------------------------------
220-
# Test 3: _compute_rollout_loss integration
220+
# Test 3: Synthetic vision-merging model (reproduces Qwen crash)
221+
# ---------------------------------------------------------------------------
222+
223+
224+
class TestVisionMergeCrash:
225+
"""Reproduce the Qwen vision merge crash with a synthetic model.
226+
227+
Qwen2.5-VL and Qwen3.5-VL replace image placeholder tokens with
228+
visual features of a DIFFERENT count, changing internal sequence
229+
length. If attention_mask is sized for pre-merge input_ids, crash.
230+
"""
231+
232+
@staticmethod
233+
def _make_vision_merge_model(vocab_size=200, placeholder_id=50, n_visual_features=7):
234+
import torch.nn as nn
235+
236+
class VisionMergeModel(nn.Module):
237+
def __init__(self):
238+
super().__init__()
239+
self.embed = nn.Embedding(vocab_size, 16)
240+
self.visual_embed = nn.Parameter(torch.randn(n_visual_features, 16))
241+
self.head = nn.Linear(16, vocab_size)
242+
self._ph = placeholder_id
243+
self._nv = n_visual_features
244+
245+
def forward(self, input_ids, attention_mask=None, pixel_values=None, **kw):
246+
h = self.embed(input_ids)
247+
if pixel_values is not None:
248+
mask = input_ids[0] == self._ph
249+
n_ph = mask.sum().item()
250+
if n_ph > 0:
251+
keep = h[:, ~mask, :]
252+
vis = self.visual_embed.unsqueeze(0)
253+
idx = mask.nonzero(as_tuple=True)[0][0].item()
254+
h = torch.cat([keep[:, :idx, :], vis, keep[:, idx:, :]], dim=1)
255+
if attention_mask is not None and attention_mask.shape[1] != h.shape[1]:
256+
raise IndexError(
257+
f"The shape of the mask [{attention_mask.shape[1]}] at index 0 "
258+
f"does not match the shape of the indexed tensor [{h.shape[1]}] at index 0"
259+
)
260+
261+
class Out:
262+
pass
263+
out = Out()
264+
out.logits = self.head(h)
265+
return out
266+
267+
return VisionMergeModel()
268+
269+
def test_manual_concat_crashes(self):
270+
"""OLD approach: cat(prompt_ids, action_ids) + mask → crashes."""
271+
model = self._make_vision_merge_model()
272+
prompt_ids = torch.tensor([[10, 50, 50, 50, 20, 30]])
273+
action_ids = torch.tensor([[40, 41]])
274+
full_ids = torch.cat([prompt_ids, action_ids], dim=1)
275+
mask = torch.ones_like(full_ids)
276+
pv = torch.randn(1, 3, 10, 10)
277+
278+
with pytest.raises(IndexError, match="shape of the mask"):
279+
model(input_ids=full_ids, attention_mask=mask, pixel_values=pv)
280+
281+
def test_unified_processor_works(self):
282+
"""NEW approach: no explicit mask → model handles merge."""
283+
model = self._make_vision_merge_model()
284+
full_ids = torch.tensor([[10, 50, 50, 50, 20, 30, 40, 41]])
285+
pv = torch.randn(1, 3, 10, 10)
286+
287+
out = model(input_ids=full_ids, pixel_values=pv)
288+
# 8 - 3 placeholders + 7 features = 12
289+
assert out.logits.shape[1] == 12
290+
291+
def test_no_vision_no_merge(self):
292+
"""Without pixel_values, no merge, mask matches."""
293+
model = self._make_vision_merge_model()
294+
ids = torch.tensor([[10, 50, 50, 50, 20]])
295+
out = model(input_ids=ids, attention_mask=torch.ones_like(ids))
296+
assert out.logits.shape[1] == 5
297+
298+
def test_exclude_strips_vision(self):
299+
"""Exclude mode: no pixel_values passed, mask is safe."""
300+
model = self._make_vision_merge_model()
301+
ids = torch.tensor([[10, 50, 50, 50, 20, 30, 40, 41]])
302+
out = model(input_ids=ids, attention_mask=torch.ones_like(ids))
303+
assert out.logits.shape[1] == 8
304+
305+
306+
# ---------------------------------------------------------------------------
307+
# Test 4: _compute_rollout_loss integration
221308
# ---------------------------------------------------------------------------
222309

223310

0 commit comments

Comments
 (0)