Skip to content

Commit d348f1b

Browse files
abrichrclaude
andauthored
fix: vision loss forward pass falls back to exclude on crash (#223)
Qwen3's vision-language merge changes internal sequence length unpredictably. Both include and checkpoint modes crash intermittently with attention mask mismatches (mask too large OR too small depending on generated sequence length). Fix: catch IndexError/RuntimeError from the vision forward pass and retry with exclude mode (text-only, no vision tensors) for that step. Training never crashes — some steps get vision-aware gradients, some get text-only gradients, but all steps contribute to learning. This is the pragmatic fix. The proper fix (capturing logits during generation to avoid re-forward entirely) is future work. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 02e8216 commit d348f1b

1 file changed

Lines changed: 29 additions & 15 deletions

File tree

openadapt_evals/training/standalone/trainer.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -446,24 +446,38 @@ def _compute_rollout_loss(self, rollout: Rollout, advantage: float, scale: float
446446

447447
full_inputs = {k: v.to(device) for k, v in full_inputs.items()}
448448

449-
outputs = self._model(**full_inputs)
450-
451-
# For "exclude" mode, logits shape matches input_ids (no vision merge).
452-
# For "include"/"checkpoint", Qwen3's vision merge changes the
453-
# sequence length. Slice action logits from the END of the
454-
# output sequence (action tokens are always last).
455449
n_action = action_ids.shape[1]
456-
if vision_loss_mode == "exclude":
457-
al = outputs.logits[:, prompt_len - 1: prompt_len - 1 + n_action, :]
458-
else:
459-
# Post-merge: total output length differs from input_ids length.
460-
# Action tokens are the last n_action tokens in the sequence.
461-
seq_len = outputs.logits.shape[1]
462-
al = outputs.logits[:, seq_len - n_action - 1: seq_len - 1, :]
463450

464-
lp = torch.nn.functional.log_softmax(al, dim=-1)
451+
# Forward pass with fallback: if include/checkpoint mode crashes
452+
# due to Qwen3's vision merge changing sequence length (attention
453+
# mask mismatch), retry with exclude mode for this step.
454+
try:
455+
outputs = self._model(**full_inputs)
456+
if vision_loss_mode == "exclude":
457+
al = outputs.logits[:, prompt_len - 1: prompt_len - 1 + n_action, :]
458+
else:
459+
seq_len = outputs.logits.shape[1]
460+
al = outputs.logits[:, seq_len - n_action - 1: seq_len - 1, :]
461+
except (IndexError, RuntimeError) as fwd_err:
462+
if vision_loss_mode != "exclude":
463+
logger.warning(
464+
"Vision forward pass failed (%s), retrying with "
465+
"exclude mode for this step: %s",
466+
vision_loss_mode, fwd_err,
467+
)
468+
fallback_inputs = {
469+
k: v for k, v in prompt_inputs.items()
470+
if k not in _VISION_KEYS
471+
}
472+
fallback_inputs["input_ids"] = full_ids
473+
fallback_inputs["attention_mask"] = torch.ones_like(full_ids)
474+
fallback_inputs = {k: v.to(device) for k, v in fallback_inputs.items()}
475+
outputs = self._model(**fallback_inputs)
476+
al = outputs.logits[:, prompt_len - 1: prompt_len - 1 + n_action, :]
477+
else:
478+
raise
465479

466-
# Gather log-probs for the actual action token IDs
480+
lp = torch.nn.functional.log_softmax(al, dim=-1)
467481
action_token_ids = action_ids.to(device)
468482
tlp = lp.gather(2, action_token_ids.unsqueeze(-1)).squeeze(-1)
469483
slp = tlp.sum()

0 commit comments

Comments
 (0)