Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions openadapt_evals/training/trl_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ def make_waa_rollout_func(
# Outlines generator is created lazily on first rollout call
# (needs the trainer's model and processor which aren't available yet).
_outlines_state: dict[str, Any] = {"generator": None, "attempted": False}
_prompt_logged: list[bool] = [False] # log the prompt once for diagnostics

def rollout_func(prompts: list[str], trainer: Any) -> dict[str, list]:
"""TRL GRPOTrainer rollout function.
Expand Down Expand Up @@ -501,20 +502,30 @@ def generate_fn(screenshot_bytes: bytes, instruction: str):
)
return "done", [], []

messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": [
{"type": "image", "image": img},
{"type": "text", "text": instruction},
]},
]
# Use the SAME message construction as the standalone trainer.
# This includes the "Goal:" prefix, format guidance, and the
# correct {"type": "image"} tag format that Qwen processors expect.
# Without this, the model sees just the raw instruction text and
# produces degenerate output (e.g., "# # # # # # #").
from openadapt_evals.training.standalone.prompt import build_agent_messages

messages = build_agent_messages(instruction, include_image=True)

import torch

text_input = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)

# Log the prompt on first call so operators can verify
# the correct format is being used (DSL, not JSON).
if not _prompt_logged[0]:
_prompt_logged[0] = True
logger.info(
"TRL rollout prompt (first 300 chars of text_input): %.300s",
text_input,
)

# --- Constrained decoding path (Outlines) ---
if outlines_gen is not None:
import outlines
Expand Down
57 changes: 19 additions & 38 deletions openadapt_evals/training/trl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,29 +219,22 @@ def on_step_end(self, args, state, control, **kwargs):
pass

# --- TRL config: use provided or build sensible defaults ---
# TRL constraints on batch sizing:
# 1. per_device_train_batch_size must be <= len(dataset)
# 2. generation_batch_size must be divisible by num_generations
# 3. generation_batch_size defaults to per_device_train_batch_size
# TRL constraints:
# - generation_batch_size must be divisible by num_generations
# - per_device_train_batch_size must be <= len(dataset)
#
# Therefore: per_device_train_batch_size must be a MULTIPLE of
# num_generations AND <= len(dataset). The minimum valid value is
# num_generations itself. If the dataset is smaller, we pad it
# by repeating tasks to reach at least that size.
# For RL with few tasks: set batch_size=1 (one unique prompt per
# step) and generation_batch_size=num_generations (satisfies the
# divisibility requirement). This produces exactly num_generations
# rollouts per step — matching the standalone trainer.
#
# Previous approach (batch_size=num_gen, padded dataset) caused
# 4× over-generation: 4 identical prompts × 4 generations = 16
# rollouts when only 4 were needed.
num_gen = self._config.num_rollouts_per_step
n_tasks = len(task_configs)

if self._trl_config is not None:
trl_config = self._trl_config
bs = getattr(trl_config, "per_device_train_batch_size", 8)
ng = getattr(trl_config, "num_generations", num_gen)
if bs % ng != 0:
logger.warning(
"per_device_train_batch_size=%d is not divisible by "
"num_generations=%d. TRL will reject this. "
"Set per_device_train_batch_size=%d.",
bs, ng, ng,
)
else:
trl_config = GRPOConfig(
output_dir=self._config.output_dir,
Expand All @@ -254,28 +247,16 @@ def on_step_end(self, args, state, control, **kwargs):
bf16=True,
loss_type="grpo",
num_train_epochs=1,
# batch_size = num_generations: TRL requires
# batch_size % num_generations == 0. This is the
# minimum valid value. Each step processes
# batch_size prompts × num_generations rollouts each.
per_device_train_batch_size=num_gen,
per_device_train_batch_size=1,
# generation_batch_size must be divisible by num_generations.
# Setting it to num_generations satisfies the constraint
# while keeping batch_size=1 (one unique prompt per step).
generation_batch_size=num_gen,
)

# Pad dataset if needed: TRL needs len(dataset) >= batch_size.
# With 1 task and batch_size=4, we repeat the task 4 times.
# Each row triggers the same rollout_func, so repeats are fine
# for RL (same task, many rollouts = more learning signal).
bs = getattr(trl_config, "per_device_train_batch_size", num_gen)
if len(dataset) < bs:
import math
repeats = math.ceil(bs / len(dataset))
logger.info(
"Padding dataset from %d to %d rows (repeating tasks %dx) "
"to meet per_device_train_batch_size=%d",
len(dataset), len(dataset) * repeats, repeats, bs,
)
padded = {k: v * repeats for k, v in dataset.to_dict().items()}
dataset = Dataset.from_dict(padded)
# No dataset padding needed: with batch_size=1, even a single-task
# dataset works. TRL iterates one prompt per step, each getting
# num_generations rollouts via rollout_func.

# --- Train ---
trainer = _TRLTrainer(
Expand Down
Loading