@@ -356,126 +356,79 @@ def _compute_rollout_loss(self, rollout: Rollout, advantage: float, scale: float
356356
357357 for step in valid :
358358 try :
359- image = Image .open (io .BytesIO (step .screenshot )). convert ( "RGB" )
359+ image = Image .open (io .BytesIO (step .screenshot ))
360360 except Exception :
361361 continue
362+ if image .mode != "RGB" :
363+ image = image .convert ("RGB" )
364+ image .format = "PNG"
362365 messages = build_agent_messages (rollout .instruction , include_image = True )
363366 action_text = step .raw_text or format_action_as_text (step .action , self ._config .screen_size )
364367
365368 if hasattr (self ._processor , "apply_chat_template" ):
366- text_input = self ._processor .apply_chat_template (
369+ prompt_text = self ._processor .apply_chat_template (
367370 messages , tokenize = False , add_generation_prompt = True )
368371 else :
369- text_input = messages [- 1 ]["content" ]
370-
371- prompt_inputs = self ._processor (text = [text_input ], images = [image ], return_tensors = "pt" )
372- prompt_len = prompt_inputs ["input_ids" ].shape [1 ]
373- inner_tok = getattr (self ._processor , "tokenizer" , self ._processor )
374- action_ids = inner_tok (action_text , return_tensors = "pt" , add_special_tokens = False )["input_ids" ]
375- if action_ids .shape [1 ] <= 0 :
376- continue
372+ prompt_text = messages [- 1 ]["content" ]
377373
378- full_ids = torch .cat ([prompt_inputs ["input_ids" ], action_ids .to (prompt_inputs ["input_ids" ].device )], dim = 1 )
379-
380- # --- Vision tensor handling during loss computation ---
381- # Current default ("exclude"): strips vision tensors so the
382- # forward pass only sees text embeddings. This avoids OOM on
383- # L40S-class GPUs (48 GB) because the vision encoder backward
384- # pass is very expensive and unnecessary — we only compute loss
385- # on *action* tokens (past prompt_len).
374+ # --- Vision-safe loss computation ---
375+ #
376+ # Process the FULL text (prompt + action) through the processor
377+ # as a single unit. This ensures the model's vision merge
378+ # operates on consistent input.
386379 #
387- # Proper fixes (future work):
388- # 1. "include" – keep vision tensors and let the full
389- # multimodal forward pass run. May OOM on < 80 GB VRAM
390- # without further optimisation.
391- # 2. "checkpoint" – use torch.utils.checkpoint on the vision
392- # encoder so activations are recomputed during backward
393- # instead of stored, dramatically cutting peak VRAM.
394- # 3. Cached KV – pre-compute and cache the vision encoder's
395- # key/value projections per screenshot so we never
396- # backpropagate through the encoder at all. Requires
397- # architecture-specific hooks (e.g. Qwen2-VL cross-attn).
380+ # WHY: The old approach processed prompt alone, then manually
381+ # concatenated action_ids onto input_ids. This created a
382+ # frankenstein input where pixel_values were sized for the
383+ # prompt but input_ids included action tokens. Qwen3's vision
384+ # merge changed internal sequence length, causing attention
385+ # mask mismatches (crash on step 5 intermittently).
386+ #
387+ # NOW: processor(prompt + action, image) produces consistent
388+ # input_ids + pixel_values + attention_mask. The model's
389+ # forward pass handles vision merge correctly.
390+
391+ vision_loss_mode = getattr (self ._config , "vision_loss_mode" , "exclude" )
398392 _VISION_KEYS = {"pixel_values" , "pixel_values_videos" ,
399393 "image_grid_thw" , "video_grid_thw" }
400394
401- vision_loss_mode = getattr (self ._config , "vision_loss_mode" , "exclude" )
395+ inner_tok = getattr (self ._processor , "tokenizer" , self ._processor )
396+ action_ids = inner_tok (action_text , add_special_tokens = False , return_tensors = "pt" )["input_ids" ]
397+ n_action = action_ids .shape [1 ]
398+ if n_action <= 0 :
399+ continue
400+
401+ full_text = prompt_text + action_text
402+ full_inputs = self ._processor (
403+ text = [full_text ], images = [image ], return_tensors = "pt" ,
404+ )
402405
403406 if vision_loss_mode == "exclude" :
404- excluded = _VISION_KEYS & set (prompt_inputs .keys ())
407+ excluded = _VISION_KEYS & set (full_inputs .keys ())
405408 if excluded and not getattr (self , "_vision_exclude_warned" , False ):
406409 logger .warning (
407- "vision_loss_mode='exclude': stripping vision tensors %s "
408- "from loss forward pass. Log-probs are TEXT-ONLY and do "
409- "not reflect visual grounding gradients. Set "
410- "vision_loss_mode='include' or 'checkpoint' once your "
411- "GPU VRAM allows it." ,
410+ "vision_loss_mode='exclude': stripping vision tensors %s" ,
412411 sorted (excluded ),
413412 )
414413 self ._vision_exclude_warned = True
415- full_inputs = {k : v for k , v in prompt_inputs .items ()
414+ full_inputs = {k : v for k , v in full_inputs .items ()
416415 if k not in _VISION_KEYS }
417- elif vision_loss_mode == "include" :
418- if not getattr (self , "_vision_include_warned" , False ):
419- logger .info ("vision_loss_mode='include': keeping all vision tensors (may OOM)." )
420- self ._vision_include_warned = True
421- full_inputs = dict (prompt_inputs )
422416 elif vision_loss_mode == "checkpoint" :
423417 if not getattr (self , "_vision_checkpoint_warned" , False ):
424- logger .info ("vision_loss_mode='checkpoint': enabling gradient checkpointing on vision encoder." )
418+ logger .info ("vision_loss_mode='checkpoint': gradient checkpointing on vision encoder." )
425419 self ._vision_checkpoint_warned = True
426420 if hasattr (self ._model , "visual" ) and hasattr (self ._model .visual , "gradient_checkpointing_enable" ):
427421 self ._model .visual .gradient_checkpointing_enable ()
428422 elif hasattr (self ._model , "vision_tower" ):
429423 self ._model .vision_tower .gradient_checkpointing_enable ()
430- else :
431- logger .warning ("Cannot find vision encoder for gradient checkpointing; falling back to 'include'." )
432- full_inputs = dict (prompt_inputs )
433- else :
434- raise ValueError (f"Unknown vision_loss_mode={ vision_loss_mode !r} . Use 'exclude', 'include', or 'checkpoint'." )
435-
436- full_inputs ["input_ids" ] = full_ids
437-
438- # Only set attention_mask for "exclude" mode (text-only forward).
439- # For "include" and "checkpoint" modes, vision tensors are present
440- # and Qwen3's vision-language merge changes the internal sequence
441- # length (e.g., 1305 input tokens → 1202 post-merge). An
442- # explicit attention_mask sized to input_ids will mismatch.
443- # Let the model construct its own mask internally.
444- if vision_loss_mode == "exclude" :
445- full_inputs ["attention_mask" ] = torch .ones_like (full_ids )
424+ # "include" mode: keep all tensors as-is
446425
447426 full_inputs = {k : v .to (device ) for k , v in full_inputs .items ()}
427+ outputs = self ._model (** full_inputs )
448428
449- n_action = action_ids .shape [1 ]
450-
451- # Forward pass with fallback: if include/checkpoint mode crashes
452- # due to Qwen3's vision merge changing sequence length (attention
453- # mask mismatch), retry with exclude mode for this step.
454- try :
455- outputs = self ._model (** full_inputs )
456- if vision_loss_mode == "exclude" :
457- al = outputs .logits [:, prompt_len - 1 : prompt_len - 1 + n_action , :]
458- else :
459- seq_len = outputs .logits .shape [1 ]
460- al = outputs .logits [:, seq_len - n_action - 1 : seq_len - 1 , :]
461- except (IndexError , RuntimeError ) as fwd_err :
462- if vision_loss_mode != "exclude" :
463- logger .warning (
464- "Vision forward pass failed (%s), retrying with "
465- "exclude mode for this step: %s" ,
466- vision_loss_mode , fwd_err ,
467- )
468- fallback_inputs = {
469- k : v for k , v in prompt_inputs .items ()
470- if k not in _VISION_KEYS
471- }
472- fallback_inputs ["input_ids" ] = full_ids
473- fallback_inputs ["attention_mask" ] = torch .ones_like (full_ids )
474- fallback_inputs = {k : v .to (device ) for k , v in fallback_inputs .items ()}
475- outputs = self ._model (** fallback_inputs )
476- al = outputs .logits [:, prompt_len - 1 : prompt_len - 1 + n_action , :]
477- else :
478- raise
429+ # Action logits are the last n_action positions in the output
430+ seq_len = outputs .logits .shape [1 ]
431+ al = outputs .logits [:, seq_len - n_action - 1 : seq_len - 1 , :]
479432
480433 lp = torch .nn .functional .log_softmax (al , dim = - 1 )
481434 action_token_ids = action_ids .to (device )
0 commit comments