Skip to content

Commit 1d94899

Browse files
abrichrclaude
andauthored
fix: use build_agent_messages for TRL prompt + fix 4x over-generation (#247)
Two critical fixes: 1. Garbage output root cause: TRL constructed user messages differently from the standalone trainer. Standalone wraps instruction with "Goal:" prefix, format guidance, and {"type": "image"} placeholder. TRL passed raw instruction text. Now imports build_agent_messages from standalone.prompt so both paths produce identical messages. 2. 4x over-generation: batch_size=num_gen with padded dataset caused 4 identical prompts × 4 generations = 16 rollouts (standalone does 4). Now: batch_size=1, generation_batch_size=num_gen. One unique prompt per step with num_gen rollouts. No dataset padding needed. Also adds one-time prompt logging for operator verification. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 682b581 commit 1d94899

2 files changed

Lines changed: 37 additions & 45 deletions

File tree

openadapt_evals/training/trl_rollout.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,7 @@ def make_waa_rollout_func(
415415
# Outlines generator is created lazily on first rollout call
416416
# (needs the trainer's model and processor which aren't available yet).
417417
_outlines_state: dict[str, Any] = {"generator": None, "attempted": False}
418+
_prompt_logged: list[bool] = [False] # log the prompt once for diagnostics
418419

419420
def rollout_func(prompts: list[str], trainer: Any) -> dict[str, list]:
420421
"""TRL GRPOTrainer rollout function.
@@ -501,20 +502,30 @@ def generate_fn(screenshot_bytes: bytes, instruction: str):
501502
)
502503
return "done", [], []
503504

504-
messages = [
505-
{"role": "system", "content": SYSTEM_PROMPT},
506-
{"role": "user", "content": [
507-
{"type": "image", "image": img},
508-
{"type": "text", "text": instruction},
509-
]},
510-
]
505+
# Use the SAME message construction as the standalone trainer.
506+
# This includes the "Goal:" prefix, format guidance, and the
507+
# correct {"type": "image"} tag format that Qwen processors expect.
508+
# Without this, the model sees just the raw instruction text and
509+
# produces degenerate output (e.g., "# # # # # # #").
510+
from openadapt_evals.training.standalone.prompt import build_agent_messages
511+
512+
messages = build_agent_messages(instruction, include_image=True)
511513

512514
import torch
513515

514516
text_input = processor.apply_chat_template(
515517
messages, tokenize=False, add_generation_prompt=True
516518
)
517519

520+
# Log the prompt on first call so operators can verify
521+
# the correct format is being used (DSL, not JSON).
522+
if not _prompt_logged[0]:
523+
_prompt_logged[0] = True
524+
logger.info(
525+
"TRL rollout prompt (first 300 chars of text_input): %.300s",
526+
text_input,
527+
)
528+
518529
# --- Constrained decoding path (Outlines) ---
519530
if outlines_gen is not None:
520531
import outlines

openadapt_evals/training/trl_wrapper.py

Lines changed: 19 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -219,29 +219,22 @@ def on_step_end(self, args, state, control, **kwargs):
219219
pass
220220

221221
# --- TRL config: use provided or build sensible defaults ---
222-
# TRL constraints on batch sizing:
223-
# 1. per_device_train_batch_size must be <= len(dataset)
224-
# 2. generation_batch_size must be divisible by num_generations
225-
# 3. generation_batch_size defaults to per_device_train_batch_size
222+
# TRL constraints:
223+
# - generation_batch_size must be divisible by num_generations
224+
# - per_device_train_batch_size must be <= len(dataset)
226225
#
227-
# Therefore: per_device_train_batch_size must be a MULTIPLE of
228-
# num_generations AND <= len(dataset). The minimum valid value is
229-
# num_generations itself. If the dataset is smaller, we pad it
230-
# by repeating tasks to reach at least that size.
226+
# For RL with few tasks: set batch_size=1 (one unique prompt per
227+
# step) and generation_batch_size=num_generations (satisfies the
228+
# divisibility requirement). This produces exactly num_generations
229+
# rollouts per step — matching the standalone trainer.
230+
#
231+
# Previous approach (batch_size=num_gen, padded dataset) caused
232+
# 4× over-generation: 4 identical prompts × 4 generations = 16
233+
# rollouts when only 4 were needed.
231234
num_gen = self._config.num_rollouts_per_step
232-
n_tasks = len(task_configs)
233235

234236
if self._trl_config is not None:
235237
trl_config = self._trl_config
236-
bs = getattr(trl_config, "per_device_train_batch_size", 8)
237-
ng = getattr(trl_config, "num_generations", num_gen)
238-
if bs % ng != 0:
239-
logger.warning(
240-
"per_device_train_batch_size=%d is not divisible by "
241-
"num_generations=%d. TRL will reject this. "
242-
"Set per_device_train_batch_size=%d.",
243-
bs, ng, ng,
244-
)
245238
else:
246239
trl_config = GRPOConfig(
247240
output_dir=self._config.output_dir,
@@ -254,28 +247,16 @@ def on_step_end(self, args, state, control, **kwargs):
254247
bf16=True,
255248
loss_type="grpo",
256249
num_train_epochs=1,
257-
# batch_size = num_generations: TRL requires
258-
# batch_size % num_generations == 0. This is the
259-
# minimum valid value. Each step processes
260-
# batch_size prompts × num_generations rollouts each.
261-
per_device_train_batch_size=num_gen,
250+
per_device_train_batch_size=1,
251+
# generation_batch_size must be divisible by num_generations.
252+
# Setting it to num_generations satisfies the constraint
253+
# while keeping batch_size=1 (one unique prompt per step).
254+
generation_batch_size=num_gen,
262255
)
263256

264-
# Pad dataset if needed: TRL needs len(dataset) >= batch_size.
265-
# With 1 task and batch_size=4, we repeat the task 4 times.
266-
# Each row triggers the same rollout_func, so repeats are fine
267-
# for RL (same task, many rollouts = more learning signal).
268-
bs = getattr(trl_config, "per_device_train_batch_size", num_gen)
269-
if len(dataset) < bs:
270-
import math
271-
repeats = math.ceil(bs / len(dataset))
272-
logger.info(
273-
"Padding dataset from %d to %d rows (repeating tasks %dx) "
274-
"to meet per_device_train_batch_size=%d",
275-
len(dataset), len(dataset) * repeats, repeats, bs,
276-
)
277-
padded = {k: v * repeats for k, v in dataset.to_dict().items()}
278-
dataset = Dataset.from_dict(padded)
257+
# No dataset padding needed: with batch_size=1, even a single-task
258+
# dataset works. TRL iterates one prompt per step, each getting
259+
# num_generations rollouts via rollout_func.
279260

280261
# --- Train ---
281262
trainer = _TRLTrainer(

0 commit comments

Comments
 (0)