Skip to content

Commit abcafe8

Browse files
abrichrclaude
andauthored
feat: add end-to-end GRPO training script with TRL + Unsloth (#129)
One command to train: task YAMLs → dense rewards → GRPO. Features: - --mock mode for pipeline validation (no VM/GPU) - --use-unsloth for VRAM efficiency (4bit + LoRA) - --task-dir loads YAML task configs with milestones - Dense rewards via milestones (reward = passed/total) - Configurable: model, group size, loss type, vLLM, learning rate - LoRA checkpoint loading (--lora-checkpoint) Usage: python scripts/train_trl_grpo.py --task-dir ./example_tasks --mock python scripts/train_trl_grpo.py --task-dir ./tasks --server-url http://vm:5001 Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent ebae5a6 commit abcafe8

1 file changed

Lines changed: 376 additions & 0 deletions

File tree

scripts/train_trl_grpo.py

Lines changed: 376 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,376 @@
1+
#!/usr/bin/env python3
2+
"""End-to-end GRPO training script using TRL + Unsloth + WAA.
3+
4+
One command to train a VLM desktop agent with dense milestone rewards:
5+
6+
# With real WAA VM:
7+
python scripts/train_trl_grpo.py \
8+
--task-dir ./example_tasks \
9+
--server-url http://localhost:5001 \
10+
--model Qwen/Qwen2.5-VL-3B-Instruct \
11+
--output ./grpo_output
12+
13+
# Mock mode (no VM, no GPU — validates pipeline):
14+
python scripts/train_trl_grpo.py \
15+
--task-dir ./example_tasks \
16+
--mock \
17+
--output ./grpo_output_mock
18+
19+
# With Unsloth (recommended for GPU training):
20+
python scripts/train_trl_grpo.py \
21+
--task-dir ./example_tasks \
22+
--server-url http://localhost:5001 \
23+
--model Qwen/Qwen2.5-VL-7B-Instruct \
24+
--use-unsloth \
25+
--output ./grpo_output
26+
27+
Requirements:
28+
pip install openadapt-evals trl>=0.17
29+
pip install unsloth # optional, for VRAM efficiency
30+
"""
31+
32+
from __future__ import annotations
33+
34+
import argparse
35+
import json
36+
import logging
37+
import sys
38+
from pathlib import Path
39+
40+
logging.basicConfig(
41+
level=logging.INFO,
42+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
43+
)
44+
logger = logging.getLogger("train_trl_grpo")
45+
46+
47+
def load_task_dataset(task_configs):
48+
"""Create a HuggingFace Dataset from TaskConfig objects.
49+
50+
Each row has a 'prompt' field (task instruction) that TRL's
51+
GRPOTrainer samples from during training.
52+
"""
53+
from datasets import Dataset
54+
55+
return Dataset.from_dict({
56+
"prompt": [tc.name for tc in task_configs],
57+
"task_id": [tc.id for tc in task_configs],
58+
})
59+
60+
61+
def load_model_unsloth(model_name, max_seq_length=4096, lora_r=16):
62+
"""Load model with Unsloth for VRAM efficiency."""
63+
from unsloth import FastVisionModel
64+
65+
logger.info("Loading model with Unsloth: %s", model_name)
66+
model, processor = FastVisionModel.from_pretrained(
67+
model_name,
68+
max_seq_length=max_seq_length,
69+
load_in_4bit=True,
70+
fast_inference=True,
71+
gpu_memory_utilization=0.6,
72+
float8_kv_cache=True,
73+
)
74+
model = FastVisionModel.get_peft_model(
75+
model,
76+
r=lora_r,
77+
lora_alpha=lora_r,
78+
target_modules=[
79+
"q_proj", "k_proj", "v_proj", "o_proj",
80+
"gate_proj", "up_proj", "down_proj",
81+
],
82+
)
83+
logger.info("Model loaded with Unsloth (4bit + LoRA r=%d)", lora_r)
84+
return model, processor
85+
86+
87+
def load_model_standard(model_name, lora_r=16):
88+
"""Load model with standard HuggingFace + PEFT."""
89+
import torch
90+
from peft import LoraConfig, get_peft_model
91+
from transformers import AutoModelForVision2Seq, AutoProcessor
92+
93+
logger.info("Loading model (standard): %s", model_name)
94+
processor = AutoProcessor.from_pretrained(model_name)
95+
model = AutoModelForVision2Seq.from_pretrained(
96+
model_name,
97+
torch_dtype=torch.bfloat16,
98+
device_map="auto",
99+
)
100+
lora_config = LoraConfig(
101+
r=lora_r,
102+
lora_alpha=lora_r,
103+
target_modules=[
104+
"q_proj", "k_proj", "v_proj", "o_proj",
105+
"gate_proj", "up_proj", "down_proj",
106+
],
107+
task_type="CAUSAL_LM",
108+
)
109+
model = get_peft_model(model, lora_config)
110+
model.print_trainable_parameters()
111+
return model, processor
112+
113+
114+
def create_mock_rollout_func(task_configs):
115+
"""Create a mock rollout_func for pipeline validation without VM/GPU.
116+
117+
Returns synthetic rewards matching milestone fractions so GRPO
118+
can compute advantages.
119+
"""
120+
import random
121+
122+
config_map = {tc.name: tc for tc in task_configs}
123+
124+
def mock_rollout_func(prompts, trainer):
125+
num_generations = getattr(trainer.args, "num_generations", 8)
126+
all_prompt_ids = []
127+
all_completion_ids = []
128+
all_logprobs = []
129+
all_rewards = []
130+
131+
for prompt in prompts:
132+
tc = config_map.get(prompt)
133+
n_milestones = len(tc.milestones) if tc else 3
134+
135+
for _ in range(num_generations):
136+
# Simulate varying milestone completion
137+
passed = random.randint(0, n_milestones)
138+
reward = passed / max(n_milestones, 1)
139+
140+
all_prompt_ids.append([1, 2, 3])
141+
all_completion_ids.append([4, 5, 6, 7])
142+
all_logprobs.append([-0.5, -0.3, -0.2, -0.1])
143+
all_rewards.append(reward)
144+
145+
return {
146+
"prompt_ids": all_prompt_ids,
147+
"completion_ids": all_completion_ids,
148+
"logprobs": all_logprobs,
149+
"env_reward": all_rewards,
150+
}
151+
152+
return mock_rollout_func
153+
154+
155+
def main():
156+
parser = argparse.ArgumentParser(
157+
description="Train a VLM desktop agent with TRL GRPO + dense rewards"
158+
)
159+
160+
# Task configuration
161+
parser.add_argument(
162+
"--task-dir", required=True,
163+
help="Directory of YAML task configs (e.g., ./example_tasks)",
164+
)
165+
166+
# Environment
167+
parser.add_argument(
168+
"--server-url", default="http://localhost:5001",
169+
help="WAA server URL (default: localhost:5001)",
170+
)
171+
parser.add_argument(
172+
"--evaluate-url", default=None,
173+
help="Separate evaluate server URL (default: same as server-url)",
174+
)
175+
parser.add_argument(
176+
"--max-steps", type=int, default=15,
177+
help="Max steps per episode (default: 15)",
178+
)
179+
parser.add_argument(
180+
"--mock", action="store_true",
181+
help="Use mock adapter (no VM/GPU needed, validates pipeline)",
182+
)
183+
184+
# Model
185+
parser.add_argument(
186+
"--model", default="Qwen/Qwen2.5-VL-3B-Instruct",
187+
help="Model name or path (default: Qwen2.5-VL-3B)",
188+
)
189+
parser.add_argument(
190+
"--use-unsloth", action="store_true",
191+
help="Use Unsloth for VRAM efficiency (recommended for GPU)",
192+
)
193+
parser.add_argument(
194+
"--lora-r", type=int, default=16,
195+
help="LoRA rank (default: 16)",
196+
)
197+
parser.add_argument(
198+
"--lora-checkpoint", default=None,
199+
help="Path to LoRA checkpoint to resume from",
200+
)
201+
202+
# Training
203+
parser.add_argument("--output", default="./grpo_output", help="Output directory")
204+
parser.add_argument("--num-generations", type=int, default=4, help="GRPO group size")
205+
parser.add_argument("--max-completion-length", type=int, default=256)
206+
parser.add_argument("--batch-size", type=int, default=1)
207+
parser.add_argument("--gradient-accumulation", type=int, default=4)
208+
parser.add_argument("--num-epochs", type=int, default=1)
209+
parser.add_argument("--learning-rate", type=float, default=5e-6)
210+
parser.add_argument("--save-steps", type=int, default=50)
211+
parser.add_argument("--logging-steps", type=int, default=1)
212+
parser.add_argument("--loss-type", default="grpo", choices=["grpo", "dapo", "dr_grpo"])
213+
parser.add_argument(
214+
"--use-vllm", action="store_true",
215+
help="Use vLLM for generation (faster, requires vllm installed)",
216+
)
217+
218+
# Reward
219+
parser.add_argument(
220+
"--reward-fn", default="env",
221+
choices=["env", "env+length"],
222+
help="Reward function: env (milestone rewards only) or env+length (penalize long episodes)",
223+
)
224+
225+
args = parser.parse_args()
226+
227+
# --- Load task configs ---
228+
from openadapt_evals.task_config import TaskConfig
229+
230+
task_configs = TaskConfig.from_dir(args.task_dir)
231+
if not task_configs:
232+
logger.error("No task configs found in %s", args.task_dir)
233+
sys.exit(1)
234+
235+
logger.info(
236+
"Loaded %d task configs from %s",
237+
len(task_configs), args.task_dir,
238+
)
239+
for tc in task_configs:
240+
logger.info(
241+
" %s (%s): %d milestones, max_steps=%d",
242+
tc.id, tc.name[:40], len(tc.milestones), tc.max_steps,
243+
)
244+
245+
dataset = load_task_dataset(task_configs)
246+
logger.info("Training dataset: %d tasks", len(dataset))
247+
248+
# --- Mock mode ---
249+
if args.mock:
250+
logger.info("=== MOCK MODE — validating pipeline without VM/GPU ===")
251+
252+
rollout_func = create_mock_rollout_func(task_configs)
253+
254+
# Verify rollout_func output shape
255+
mock_trainer = type("MockTrainer", (), {"args": type("Args", (), {"num_generations": args.num_generations})()})()
256+
result = rollout_func(dataset["prompt"][:2], mock_trainer)
257+
258+
logger.info("Mock rollout output keys: %s", list(result.keys()))
259+
logger.info("Rewards: %s", result["env_reward"])
260+
logger.info(
261+
"Reward variance: %.4f (need >0 for GRPO)",
262+
max(result["env_reward"]) - min(result["env_reward"]),
263+
)
264+
265+
# Save mock results
266+
output_dir = Path(args.output)
267+
output_dir.mkdir(parents=True, exist_ok=True)
268+
with open(output_dir / "mock_results.json", "w") as f:
269+
json.dump({
270+
"mode": "mock",
271+
"tasks": len(task_configs),
272+
"num_generations": args.num_generations,
273+
"rewards": result["env_reward"],
274+
"reward_variance": max(result["env_reward"]) - min(result["env_reward"]),
275+
}, f, indent=2)
276+
logger.info("Mock results saved to %s", output_dir / "mock_results.json")
277+
logger.info("=== Mock pipeline validation PASSED ===")
278+
return
279+
280+
# --- Real training ---
281+
logger.info("=== Setting up GRPO training ===")
282+
283+
# Load model
284+
if args.use_unsloth:
285+
model, processor = load_model_unsloth(args.model, lora_r=args.lora_r)
286+
else:
287+
model, processor = load_model_standard(args.model, lora_r=args.lora_r)
288+
289+
# Load LoRA checkpoint if provided
290+
if args.lora_checkpoint:
291+
from peft import PeftModel
292+
293+
logger.info("Loading LoRA checkpoint: %s", args.lora_checkpoint)
294+
model = PeftModel.from_pretrained(model, args.lora_checkpoint)
295+
296+
# Create rollout function
297+
if args.mock:
298+
rollout_func = create_mock_rollout_func(task_configs)
299+
else:
300+
from openadapt_evals.adapters.waa.live import WAALiveAdapter, WAALiveConfig
301+
from openadapt_evals.training.trl_rollout import make_waa_rollout_func
302+
303+
adapter = WAALiveAdapter(
304+
WAALiveConfig(
305+
server_url=args.server_url,
306+
evaluate_url=args.evaluate_url,
307+
)
308+
)
309+
rollout_func = make_waa_rollout_func(
310+
adapter=adapter,
311+
task_configs=task_configs,
312+
max_steps=args.max_steps,
313+
)
314+
315+
# Create reward function
316+
def env_reward_fn(completions, **kwargs):
317+
"""Extract environment rewards from rollout_func output."""
318+
return kwargs.get("env_reward", [0.0] * len(completions))
319+
320+
reward_funcs = [env_reward_fn]
321+
322+
if args.reward_fn == "env+length":
323+
def length_penalty(completions, **kwargs):
324+
"""Penalize very long completions (encourage efficiency)."""
325+
max_len = args.max_completion_length
326+
return [-0.1 * (len(c) / max_len) for c in completions]
327+
reward_funcs.append(length_penalty)
328+
329+
# Configure training
330+
from trl import GRPOConfig, GRPOTrainer
331+
332+
config = GRPOConfig(
333+
output_dir=args.output,
334+
num_generations=args.num_generations,
335+
max_completion_length=args.max_completion_length,
336+
per_device_train_batch_size=args.batch_size,
337+
gradient_accumulation_steps=args.gradient_accumulation,
338+
num_train_epochs=args.num_epochs,
339+
learning_rate=args.learning_rate,
340+
loss_type=args.loss_type,
341+
logging_steps=args.logging_steps,
342+
save_steps=args.save_steps,
343+
bf16=True,
344+
report_to="none", # set to "wandb" for W&B logging
345+
)
346+
347+
if args.use_vllm:
348+
config.use_vllm = True
349+
config.vllm_mode = "colocate"
350+
config.vllm_gpu_memory_utilization = 0.3
351+
352+
trainer = GRPOTrainer(
353+
model=model,
354+
processing_class=processor,
355+
args=config,
356+
train_dataset=dataset,
357+
reward_funcs=reward_funcs,
358+
rollout_func=rollout_func,
359+
)
360+
361+
logger.info("=== Starting GRPO training ===")
362+
logger.info(" Model: %s", args.model)
363+
logger.info(" Tasks: %d", len(task_configs))
364+
logger.info(" Group size: %d", args.num_generations)
365+
logger.info(" Loss type: %s", args.loss_type)
366+
logger.info(" Output: %s", args.output)
367+
368+
trainer.train()
369+
370+
# Save final checkpoint
371+
trainer.save_model(args.output)
372+
logger.info("=== Training complete. Model saved to %s ===", args.output)
373+
374+
375+
if __name__ == "__main__":
376+
main()

0 commit comments

Comments
 (0)