Skip to content

Commit 6ecd136

Browse files
abrichrclaude
andcommitted
feat: add --task-ids, --max-steps-per-episode, --max-new-tokens to standalone GRPO CLI
Without --task-ids, the trainer cycles through ALL tasks in --task-dir including hard ones (calc-formula) that base models can't complete. Now you can filter: --task-ids custom-notepad-hello Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 25eb2f2 commit 6ecd136

1 file changed

Lines changed: 10 additions & 1 deletion

File tree

openadapt_evals/training/standalone/trainer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,11 +281,14 @@ def main() -> None:
281281
"""CLI entry point."""
282282
p = argparse.ArgumentParser(description="Standalone GRPO trainer for WAA")
283283
p.add_argument("--task-dir", required=True, help="Directory of TaskConfig YAMLs")
284+
p.add_argument("--task-ids", nargs="+", default=None, help="Specific task IDs to train on (default: all from task-dir)")
284285
p.add_argument("--server-url", default="http://localhost:5001")
285286
p.add_argument("--model", default="Qwen/Qwen2.5-VL-7B-Instruct")
286287
p.add_argument("--lora-checkpoint", default=None)
287288
p.add_argument("--num-steps", type=int, default=10)
288289
p.add_argument("--num-rollouts", type=int, default=8)
290+
p.add_argument("--max-steps-per-episode", type=int, default=15)
291+
p.add_argument("--max-new-tokens", type=int, default=2048)
289292
p.add_argument("--output", default="checkpoints/grpo")
290293
p.add_argument("--no-4bit", action="store_true")
291294
p.add_argument("--eval-model", default="gpt-4.1-mini")
@@ -295,7 +298,13 @@ def main() -> None:
295298
config = TrainingConfig(
296299
model_name=a.model, load_in_4bit=not a.no_4bit, lora_checkpoint=a.lora_checkpoint,
297300
server_url=a.server_url, task_dir=a.task_dir, num_training_steps=a.num_steps,
298-
num_rollouts_per_step=a.num_rollouts, output_dir=a.output, eval_model=a.eval_model)
301+
num_rollouts_per_step=a.num_rollouts, max_steps_per_episode=a.max_steps_per_episode,
302+
max_new_tokens=a.max_new_tokens, output_dir=a.output, eval_model=a.eval_model)
303+
304+
# Filter to specific tasks if requested
305+
if a.task_ids:
306+
config.task_ids = a.task_ids
307+
299308
GRPOTrainer(config).train()
300309

301310

0 commit comments

Comments
 (0)