Skip to content

Commit 5413864

Browse files
abrichrclaude
andauthored
fix: proper vision-safe loss — process full text as one unit (#224)
Root cause: manually concatenating action_ids onto prompt input_ids created inconsistent input (pixel_values sized for prompt, input_ids includes action tokens). Qwen3's vision merge changes internal sequence length, crashing with attention mask mismatches. Fix: process prompt_text + action_text as a SINGLE string through the processor. Produces consistent input_ids, pixel_values, attention_mask. The model handles vision merge correctly on processor output. Replaces the silent fallback from PR #223 with a proper solution that gives correct vision-aware gradients for ALL steps in ALL modes. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 5e42bd6 commit 5413864

1 file changed

Lines changed: 42 additions & 89 deletions

File tree

openadapt_evals/training/standalone/trainer.py

Lines changed: 42 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -356,126 +356,79 @@ def _compute_rollout_loss(self, rollout: Rollout, advantage: float, scale: float
356356

357357
for step in valid:
358358
try:
359-
image = Image.open(io.BytesIO(step.screenshot)).convert("RGB")
359+
image = Image.open(io.BytesIO(step.screenshot))
360360
except Exception:
361361
continue
362+
if image.mode != "RGB":
363+
image = image.convert("RGB")
364+
image.format = "PNG"
362365
messages = build_agent_messages(rollout.instruction, include_image=True)
363366
action_text = step.raw_text or format_action_as_text(step.action, self._config.screen_size)
364367

365368
if hasattr(self._processor, "apply_chat_template"):
366-
text_input = self._processor.apply_chat_template(
369+
prompt_text = self._processor.apply_chat_template(
367370
messages, tokenize=False, add_generation_prompt=True)
368371
else:
369-
text_input = messages[-1]["content"]
370-
371-
prompt_inputs = self._processor(text=[text_input], images=[image], return_tensors="pt")
372-
prompt_len = prompt_inputs["input_ids"].shape[1]
373-
inner_tok = getattr(self._processor, "tokenizer", self._processor)
374-
action_ids = inner_tok(action_text, return_tensors="pt", add_special_tokens=False)["input_ids"]
375-
if action_ids.shape[1] <= 0:
376-
continue
372+
prompt_text = messages[-1]["content"]
377373

378-
full_ids = torch.cat([prompt_inputs["input_ids"], action_ids.to(prompt_inputs["input_ids"].device)], dim=1)
379-
380-
# --- Vision tensor handling during loss computation ---
381-
# Current default ("exclude"): strips vision tensors so the
382-
# forward pass only sees text embeddings. This avoids OOM on
383-
# L40S-class GPUs (48 GB) because the vision encoder backward
384-
# pass is very expensive and unnecessary — we only compute loss
385-
# on *action* tokens (past prompt_len).
374+
# --- Vision-safe loss computation ---
375+
#
376+
# Process the FULL text (prompt + action) through the processor
377+
# as a single unit. This ensures the model's vision merge
378+
# operates on consistent input.
386379
#
387-
# Proper fixes (future work):
388-
# 1. "include" – keep vision tensors and let the full
389-
# multimodal forward pass run. May OOM on < 80 GB VRAM
390-
# without further optimisation.
391-
# 2. "checkpoint" – use torch.utils.checkpoint on the vision
392-
# encoder so activations are recomputed during backward
393-
# instead of stored, dramatically cutting peak VRAM.
394-
# 3. Cached KV – pre-compute and cache the vision encoder's
395-
# key/value projections per screenshot so we never
396-
# backpropagate through the encoder at all. Requires
397-
# architecture-specific hooks (e.g. Qwen2-VL cross-attn).
380+
# WHY: The old approach processed prompt alone, then manually
381+
# concatenated action_ids onto input_ids. This created a
382+
# frankenstein input where pixel_values were sized for the
383+
# prompt but input_ids included action tokens. Qwen3's vision
384+
# merge changed internal sequence length, causing attention
385+
# mask mismatches (crash on step 5 intermittently).
386+
#
387+
# NOW: processor(prompt + action, image) produces consistent
388+
# input_ids + pixel_values + attention_mask. The model's
389+
# forward pass handles vision merge correctly.
390+
391+
vision_loss_mode = getattr(self._config, "vision_loss_mode", "exclude")
398392
_VISION_KEYS = {"pixel_values", "pixel_values_videos",
399393
"image_grid_thw", "video_grid_thw"}
400394

401-
vision_loss_mode = getattr(self._config, "vision_loss_mode", "exclude")
395+
inner_tok = getattr(self._processor, "tokenizer", self._processor)
396+
action_ids = inner_tok(action_text, add_special_tokens=False, return_tensors="pt")["input_ids"]
397+
n_action = action_ids.shape[1]
398+
if n_action <= 0:
399+
continue
400+
401+
full_text = prompt_text + action_text
402+
full_inputs = self._processor(
403+
text=[full_text], images=[image], return_tensors="pt",
404+
)
402405

403406
if vision_loss_mode == "exclude":
404-
excluded = _VISION_KEYS & set(prompt_inputs.keys())
407+
excluded = _VISION_KEYS & set(full_inputs.keys())
405408
if excluded and not getattr(self, "_vision_exclude_warned", False):
406409
logger.warning(
407-
"vision_loss_mode='exclude': stripping vision tensors %s "
408-
"from loss forward pass. Log-probs are TEXT-ONLY and do "
409-
"not reflect visual grounding gradients. Set "
410-
"vision_loss_mode='include' or 'checkpoint' once your "
411-
"GPU VRAM allows it.",
410+
"vision_loss_mode='exclude': stripping vision tensors %s",
412411
sorted(excluded),
413412
)
414413
self._vision_exclude_warned = True
415-
full_inputs = {k: v for k, v in prompt_inputs.items()
414+
full_inputs = {k: v for k, v in full_inputs.items()
416415
if k not in _VISION_KEYS}
417-
elif vision_loss_mode == "include":
418-
if not getattr(self, "_vision_include_warned", False):
419-
logger.info("vision_loss_mode='include': keeping all vision tensors (may OOM).")
420-
self._vision_include_warned = True
421-
full_inputs = dict(prompt_inputs)
422416
elif vision_loss_mode == "checkpoint":
423417
if not getattr(self, "_vision_checkpoint_warned", False):
424-
logger.info("vision_loss_mode='checkpoint': enabling gradient checkpointing on vision encoder.")
418+
logger.info("vision_loss_mode='checkpoint': gradient checkpointing on vision encoder.")
425419
self._vision_checkpoint_warned = True
426420
if hasattr(self._model, "visual") and hasattr(self._model.visual, "gradient_checkpointing_enable"):
427421
self._model.visual.gradient_checkpointing_enable()
428422
elif hasattr(self._model, "vision_tower"):
429423
self._model.vision_tower.gradient_checkpointing_enable()
430-
else:
431-
logger.warning("Cannot find vision encoder for gradient checkpointing; falling back to 'include'.")
432-
full_inputs = dict(prompt_inputs)
433-
else:
434-
raise ValueError(f"Unknown vision_loss_mode={vision_loss_mode!r}. Use 'exclude', 'include', or 'checkpoint'.")
435-
436-
full_inputs["input_ids"] = full_ids
437-
438-
# Only set attention_mask for "exclude" mode (text-only forward).
439-
# For "include" and "checkpoint" modes, vision tensors are present
440-
# and Qwen3's vision-language merge changes the internal sequence
441-
# length (e.g., 1305 input tokens → 1202 post-merge). An
442-
# explicit attention_mask sized to input_ids will mismatch.
443-
# Let the model construct its own mask internally.
444-
if vision_loss_mode == "exclude":
445-
full_inputs["attention_mask"] = torch.ones_like(full_ids)
424+
# "include" mode: keep all tensors as-is
446425

447426
full_inputs = {k: v.to(device) for k, v in full_inputs.items()}
427+
outputs = self._model(**full_inputs)
448428

449-
n_action = action_ids.shape[1]
450-
451-
# Forward pass with fallback: if include/checkpoint mode crashes
452-
# due to Qwen3's vision merge changing sequence length (attention
453-
# mask mismatch), retry with exclude mode for this step.
454-
try:
455-
outputs = self._model(**full_inputs)
456-
if vision_loss_mode == "exclude":
457-
al = outputs.logits[:, prompt_len - 1: prompt_len - 1 + n_action, :]
458-
else:
459-
seq_len = outputs.logits.shape[1]
460-
al = outputs.logits[:, seq_len - n_action - 1: seq_len - 1, :]
461-
except (IndexError, RuntimeError) as fwd_err:
462-
if vision_loss_mode != "exclude":
463-
logger.warning(
464-
"Vision forward pass failed (%s), retrying with "
465-
"exclude mode for this step: %s",
466-
vision_loss_mode, fwd_err,
467-
)
468-
fallback_inputs = {
469-
k: v for k, v in prompt_inputs.items()
470-
if k not in _VISION_KEYS
471-
}
472-
fallback_inputs["input_ids"] = full_ids
473-
fallback_inputs["attention_mask"] = torch.ones_like(full_ids)
474-
fallback_inputs = {k: v.to(device) for k, v in fallback_inputs.items()}
475-
outputs = self._model(**fallback_inputs)
476-
al = outputs.logits[:, prompt_len - 1: prompt_len - 1 + n_action, :]
477-
else:
478-
raise
429+
# Action logits are the last n_action positions in the output
430+
seq_len = outputs.logits.shape[1]
431+
al = outputs.logits[:, seq_len - n_action - 1: seq_len - 1, :]
479432

480433
lp = torch.nn.functional.log_softmax(al, dim=-1)
481434
action_token_ids = action_ids.to(device)

0 commit comments

Comments
 (0)