diff --git a/openadapt_evals/training/trl_rollout.py b/openadapt_evals/training/trl_rollout.py index 1fd7e53..b2fe74c 100644 --- a/openadapt_evals/training/trl_rollout.py +++ b/openadapt_evals/training/trl_rollout.py @@ -415,6 +415,7 @@ def make_waa_rollout_func( # Outlines generator is created lazily on first rollout call # (needs the trainer's model and processor which aren't available yet). _outlines_state: dict[str, Any] = {"generator": None, "attempted": False} + _prompt_logged: list[bool] = [False] # log the prompt once for diagnostics def rollout_func(prompts: list[str], trainer: Any) -> dict[str, list]: """TRL GRPOTrainer rollout function. @@ -501,13 +502,14 @@ def generate_fn(screenshot_bytes: bytes, instruction: str): ) return "done", [], [] - messages = [ - {"role": "system", "content": SYSTEM_PROMPT}, - {"role": "user", "content": [ - {"type": "image", "image": img}, - {"type": "text", "text": instruction}, - ]}, - ] + # Use the SAME message construction as the standalone trainer. + # This includes the "Goal:" prefix, format guidance, and the + # correct {"type": "image"} tag format that Qwen processors expect. + # Without this, the model sees just the raw instruction text and + # produces degenerate output (e.g., "# # # # # # #"). + from openadapt_evals.training.standalone.prompt import build_agent_messages + + messages = build_agent_messages(instruction, include_image=True) import torch @@ -515,6 +517,15 @@ def generate_fn(screenshot_bytes: bytes, instruction: str): messages, tokenize=False, add_generation_prompt=True ) + # Log the prompt on first call so operators can verify + # the correct format is being used (DSL, not JSON). + if not _prompt_logged[0]: + _prompt_logged[0] = True + logger.info( + "TRL rollout prompt (first 300 chars of text_input): %.300s", + text_input, + ) + # --- Constrained decoding path (Outlines) --- if outlines_gen is not None: import outlines diff --git a/openadapt_evals/training/trl_wrapper.py b/openadapt_evals/training/trl_wrapper.py index ab1d69f..b4c48ee 100644 --- a/openadapt_evals/training/trl_wrapper.py +++ b/openadapt_evals/training/trl_wrapper.py @@ -219,29 +219,22 @@ def on_step_end(self, args, state, control, **kwargs): pass # --- TRL config: use provided or build sensible defaults --- - # TRL constraints on batch sizing: - # 1. per_device_train_batch_size must be <= len(dataset) - # 2. generation_batch_size must be divisible by num_generations - # 3. generation_batch_size defaults to per_device_train_batch_size + # TRL constraints: + # - generation_batch_size must be divisible by num_generations + # - per_device_train_batch_size must be <= len(dataset) # - # Therefore: per_device_train_batch_size must be a MULTIPLE of - # num_generations AND <= len(dataset). The minimum valid value is - # num_generations itself. If the dataset is smaller, we pad it - # by repeating tasks to reach at least that size. + # For RL with few tasks: set batch_size=1 (one unique prompt per + # step) and generation_batch_size=num_generations (satisfies the + # divisibility requirement). This produces exactly num_generations + # rollouts per step — matching the standalone trainer. + # + # Previous approach (batch_size=num_gen, padded dataset) caused + # 4× over-generation: 4 identical prompts × 4 generations = 16 + # rollouts when only 4 were needed. num_gen = self._config.num_rollouts_per_step - n_tasks = len(task_configs) if self._trl_config is not None: trl_config = self._trl_config - bs = getattr(trl_config, "per_device_train_batch_size", 8) - ng = getattr(trl_config, "num_generations", num_gen) - if bs % ng != 0: - logger.warning( - "per_device_train_batch_size=%d is not divisible by " - "num_generations=%d. TRL will reject this. " - "Set per_device_train_batch_size=%d.", - bs, ng, ng, - ) else: trl_config = GRPOConfig( output_dir=self._config.output_dir, @@ -254,28 +247,16 @@ def on_step_end(self, args, state, control, **kwargs): bf16=True, loss_type="grpo", num_train_epochs=1, - # batch_size = num_generations: TRL requires - # batch_size % num_generations == 0. This is the - # minimum valid value. Each step processes - # batch_size prompts × num_generations rollouts each. - per_device_train_batch_size=num_gen, + per_device_train_batch_size=1, + # generation_batch_size must be divisible by num_generations. + # Setting it to num_generations satisfies the constraint + # while keeping batch_size=1 (one unique prompt per step). + generation_batch_size=num_gen, ) - # Pad dataset if needed: TRL needs len(dataset) >= batch_size. - # With 1 task and batch_size=4, we repeat the task 4 times. - # Each row triggers the same rollout_func, so repeats are fine - # for RL (same task, many rollouts = more learning signal). - bs = getattr(trl_config, "per_device_train_batch_size", num_gen) - if len(dataset) < bs: - import math - repeats = math.ceil(bs / len(dataset)) - logger.info( - "Padding dataset from %d to %d rows (repeating tasks %dx) " - "to meet per_device_train_batch_size=%d", - len(dataset), len(dataset) * repeats, repeats, bs, - ) - padded = {k: v * repeats for k, v in dataset.to_dict().items()} - dataset = Dataset.from_dict(padded) + # No dataset padding needed: with batch_size=1, even a single-task + # dataset works. TRL iterates one prompt per step, each getting + # num_generations rollouts via rollout_func. # --- Train --- trainer = _TRLTrainer(