Skip to content

Commit 6617e02

Browse files
abrichrclaude
andauthored
fix: include image placeholder in chat template for VLM GRPO (#59)
Qwen2.5-VL requires <|image_pad|> tokens in the input. These are inserted by apply_chat_template only when messages include {"type": "image"} content blocks. Fixed both agent_fn and _compute_rollout_loss. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 19c79c2 commit 6617e02

File tree

1 file changed

+19
-4
lines changed

1 file changed

+19
-4
lines changed

openadapt_ml/training/grpo/trainer.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,20 +97,35 @@ def policy_gradient_loss(
9797
# ---------------------------------------------------------------------------
9898

9999

100-
def _build_agent_messages(instruction: str) -> list[dict[str, str]]:
100+
def _build_agent_messages(
101+
instruction: str, *, include_image: bool = False
102+
) -> list[dict]:
101103
"""Build chat messages for the GRPO agent.
102104
103105
Uses the same SYSTEM_PROMPT as SFT training so GRPO operates in
104106
the same prompt distribution the model was warm-started on.
105107
106108
This is the **single source of truth** for prompt construction
107109
during both rollout collection and loss computation.
110+
111+
Args:
112+
instruction: Task instruction text.
113+
include_image: If True, include an image placeholder in the user
114+
message so ``apply_chat_template`` inserts ``<|image_pad|>``
115+
tokens required by Qwen2.5-VL and similar VLMs.
108116
"""
109-
user_content = (
117+
text_content = (
110118
f"Goal: {instruction}\n\n"
111119
"Look at the screenshot and determine the NEXT action.\n\n"
112120
'Action: [CLICK(x=..., y=...) or TYPE(text="...") or WAIT() or DONE()]'
113121
)
122+
if include_image:
123+
user_content = [
124+
{"type": "image"},
125+
{"type": "text", "text": text_content},
126+
]
127+
else:
128+
user_content = text_content
114129
return [
115130
{"role": "system", "content": SYSTEM_PROMPT},
116131
{"role": "user", "content": user_content},
@@ -318,7 +333,7 @@ def agent_fn(obs: Any) -> BenchmarkAction:
318333
if raw_obs and isinstance(raw_obs, dict):
319334
instruction = raw_obs.get("instruction", "")
320335

321-
messages = _build_agent_messages(instruction)
336+
messages = _build_agent_messages(instruction, include_image=True)
322337

323338
if hasattr(processor, "apply_chat_template"):
324339
text_input = processor.apply_chat_template(
@@ -505,7 +520,7 @@ def _compute_rollout_loss(
505520
except Exception:
506521
continue
507522

508-
messages = _build_agent_messages(instruction)
523+
messages = _build_agent_messages(instruction, include_image=True)
509524

510525
# Raw text from rollout or reconstruct from DSL
511526
raw_text = getattr(action, "_grpo_raw_text", None)

0 commit comments

Comments
 (0)