@@ -217,7 +217,94 @@ def test_different_sequence_lengths_same_result(self):
217217
218218
219219# ---------------------------------------------------------------------------
220- # Test 3: _compute_rollout_loss integration
220+ # Test 3: Synthetic vision-merging model (reproduces Qwen crash)
221+ # ---------------------------------------------------------------------------
222+
223+
224+ class TestVisionMergeCrash :
225+ """Reproduce the Qwen vision merge crash with a synthetic model.
226+
227+ Qwen2.5-VL and Qwen3.5-VL replace image placeholder tokens with
228+ visual features of a DIFFERENT count, changing internal sequence
229+ length. If attention_mask is sized for pre-merge input_ids, crash.
230+ """
231+
232+ @staticmethod
233+ def _make_vision_merge_model (vocab_size = 200 , placeholder_id = 50 , n_visual_features = 7 ):
234+ import torch .nn as nn
235+
236+ class VisionMergeModel (nn .Module ):
237+ def __init__ (self ):
238+ super ().__init__ ()
239+ self .embed = nn .Embedding (vocab_size , 16 )
240+ self .visual_embed = nn .Parameter (torch .randn (n_visual_features , 16 ))
241+ self .head = nn .Linear (16 , vocab_size )
242+ self ._ph = placeholder_id
243+ self ._nv = n_visual_features
244+
245+ def forward (self , input_ids , attention_mask = None , pixel_values = None , ** kw ):
246+ h = self .embed (input_ids )
247+ if pixel_values is not None :
248+ mask = input_ids [0 ] == self ._ph
249+ n_ph = mask .sum ().item ()
250+ if n_ph > 0 :
251+ keep = h [:, ~ mask , :]
252+ vis = self .visual_embed .unsqueeze (0 )
253+ idx = mask .nonzero (as_tuple = True )[0 ][0 ].item ()
254+ h = torch .cat ([keep [:, :idx , :], vis , keep [:, idx :, :]], dim = 1 )
255+ if attention_mask is not None and attention_mask .shape [1 ] != h .shape [1 ]:
256+ raise IndexError (
257+ f"The shape of the mask [{ attention_mask .shape [1 ]} ] at index 0 "
258+ f"does not match the shape of the indexed tensor [{ h .shape [1 ]} ] at index 0"
259+ )
260+
261+ class Out :
262+ pass
263+ out = Out ()
264+ out .logits = self .head (h )
265+ return out
266+
267+ return VisionMergeModel ()
268+
269+ def test_manual_concat_crashes (self ):
270+ """OLD approach: cat(prompt_ids, action_ids) + mask → crashes."""
271+ model = self ._make_vision_merge_model ()
272+ prompt_ids = torch .tensor ([[10 , 50 , 50 , 50 , 20 , 30 ]])
273+ action_ids = torch .tensor ([[40 , 41 ]])
274+ full_ids = torch .cat ([prompt_ids , action_ids ], dim = 1 )
275+ mask = torch .ones_like (full_ids )
276+ pv = torch .randn (1 , 3 , 10 , 10 )
277+
278+ with pytest .raises (IndexError , match = "shape of the mask" ):
279+ model (input_ids = full_ids , attention_mask = mask , pixel_values = pv )
280+
281+ def test_unified_processor_works (self ):
282+ """NEW approach: no explicit mask → model handles merge."""
283+ model = self ._make_vision_merge_model ()
284+ full_ids = torch .tensor ([[10 , 50 , 50 , 50 , 20 , 30 , 40 , 41 ]])
285+ pv = torch .randn (1 , 3 , 10 , 10 )
286+
287+ out = model (input_ids = full_ids , pixel_values = pv )
288+ # 8 - 3 placeholders + 7 features = 12
289+ assert out .logits .shape [1 ] == 12
290+
291+ def test_no_vision_no_merge (self ):
292+ """Without pixel_values, no merge, mask matches."""
293+ model = self ._make_vision_merge_model ()
294+ ids = torch .tensor ([[10 , 50 , 50 , 50 , 20 ]])
295+ out = model (input_ids = ids , attention_mask = torch .ones_like (ids ))
296+ assert out .logits .shape [1 ] == 5
297+
298+ def test_exclude_strips_vision (self ):
299+ """Exclude mode: no pixel_values passed, mask is safe."""
300+ model = self ._make_vision_merge_model ()
301+ ids = torch .tensor ([[10 , 50 , 50 , 50 , 20 , 30 , 40 , 41 ]])
302+ out = model (input_ids = ids , attention_mask = torch .ones_like (ids ))
303+ assert out .logits .shape [1 ] == 8
304+
305+
306+ # ---------------------------------------------------------------------------
307+ # Test 4: _compute_rollout_loss integration
221308# ---------------------------------------------------------------------------
222309
223310
0 commit comments