Skip to content

Commit 97c144b

Browse files
abrichrclaude
andauthored
feat: add self-contained GRPO training example script (#81)
* feat: add self-contained GRPO training example script 250-line example showing the full RL training loop: model loading → rollout collection → GRPO loss → weight update → checkpoint. No openadapt-ml dependency — all GRPO math, action parsing, and log-prob computation are inline. Uses RLEnvironment from openadapt-evals. Includes --mock flag for testing without a VM. Usage: python scripts/train_grpo_example.py --mock --num-steps 3 python scripts/train_grpo_example.py --server http://localhost:5001 --task-id <UUID> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: align GRPO training example with openadapt-ml trainer - Align SYSTEM_PROMPT with openadapt_ml.datasets.next_action.SYSTEM_PROMPT - Use chat template for prompt construction (not raw string concatenation) - Fix screen height default: 1080 (was 1200) - Fix LoRA target_modules: 4 projections (was 2) matching ml trainer - Fix coordinate fallback: use format_action_as_text with normalized fractions (was using raw pixel coords like x=960) - Add WAIT() handler in parse_action (was falling through to DONE) - Fix TYPE regex to handle escaped quotes and backslashes - Fix loss scaling: divide by (n_valid * num_steps) matching ml trainer - Rename grpo_loss to policy_gradient_loss with honest docstring - Add build_agent_messages and format_action_as_text helper functions Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 840f9ef commit 97c144b

1 file changed

Lines changed: 366 additions & 0 deletions

File tree

scripts/train_grpo_example.py

Lines changed: 366 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,366 @@
1+
#!/usr/bin/env python3
2+
"""Self-contained GRPO training example for GUI agents.
3+
4+
Demonstrates the full RL loop: connect to WAA, collect rollouts with a VLM
5+
policy, compute group-relative advantages, update LoRA weights via
6+
REINFORCE with group-relative advantages (equivalent to single-epoch GRPO).
7+
No openadapt-ml dependency -- all math and parsing are inline.
8+
9+
Requirements:
10+
pip install torch transformers peft pillow
11+
# openadapt-evals must be installed (provides RLEnvironment, adapters)
12+
13+
Usage:
14+
# Mock mode (no VM, random rewards -- for testing the loop):
15+
python scripts/train_grpo_example.py --mock --num-steps 2 --group-size 3
16+
17+
# Live WAA server:
18+
python scripts/train_grpo_example.py \\
19+
--server http://localhost:5001 \\
20+
--task-id <WAA_UUID> \\
21+
--num-steps 10 \\
22+
--group-size 4
23+
24+
# Custom model and learning rate:
25+
python scripts/train_grpo_example.py \\
26+
--mock \\
27+
--model-name Qwen/Qwen2.5-VL-7B-Instruct \\
28+
--lr 1e-5 \\
29+
--num-steps 5
30+
"""
31+
32+
from __future__ import annotations
33+
34+
import io
35+
import re
36+
import time
37+
38+
import fire
39+
import torch
40+
from peft import LoraConfig, get_peft_model
41+
from PIL import Image
42+
from transformers import AutoModelForVision2Seq, AutoProcessor
43+
44+
from openadapt_evals.adapters.base import BenchmarkAction, BenchmarkObservation
45+
from openadapt_evals.adapters.rl_env import RLEnvironment
46+
47+
48+
# -- Policy gradient loss ------------------------------------------------------
49+
50+
51+
def policy_gradient_loss(
52+
current_logps: torch.Tensor,
53+
old_logps: torch.Tensor,
54+
advantages: torch.Tensor,
55+
epsilon: float = 0.2,
56+
) -> torch.Tensor:
57+
"""Policy gradient loss with optional PPO-style clipping.
58+
59+
When old_logps == current_logps (single-epoch), reduces to REINFORCE.
60+
"""
61+
ratio = torch.exp(current_logps - old_logps)
62+
clipped = torch.clamp(ratio, 1.0 - epsilon, 1.0 + epsilon)
63+
return -torch.min(ratio * advantages, clipped * advantages).mean()
64+
65+
66+
def compute_advantages(rewards: list[float]) -> list[float]:
67+
"""Group-relative advantage: (r - mean) / (std + eps)."""
68+
n = len(rewards)
69+
if n == 0:
70+
return []
71+
mean = sum(rewards) / n
72+
std = (sum((r - mean) ** 2 for r in rewards) / n) ** 0.5
73+
if std < 1e-8:
74+
return [0.0] * n
75+
return [(r - mean) / (std + 1e-8) for r in rewards]
76+
77+
78+
# -- Helpers -------------------------------------------------------------------
79+
80+
# Aligned with openadapt_ml.datasets.next_action.SYSTEM_PROMPT
81+
SYSTEM_PROMPT = (
82+
"You are a GUI automation agent. Given a screenshot and a user goal, "
83+
"predict the single next action.\n\n"
84+
"COORDINATE SYSTEM:\n"
85+
"- x=0.0 is the LEFT edge, x=1.0 is the RIGHT edge\n"
86+
"- y=0.0 is the TOP edge, y=1.0 is the BOTTOM edge\n"
87+
"- To click the CENTER of an element, estimate its center position "
88+
"as a fraction of screen width/height\n"
89+
"- Example: An element in the middle of the screen would be "
90+
"approximately x=0.5, y=0.5\n\n"
91+
"ALLOWED ACTIONS (use exactly this format):\n"
92+
"- CLICK(x=0.XX, y=0.XX) → click at normalized coordinates\n"
93+
'- TYPE(text="...") → type text into the currently focused field\n'
94+
"- WAIT() → wait for UI to update\n"
95+
"- DONE() → task is complete\n\n"
96+
"RESPONSE FORMAT (required):\n"
97+
"Thought: [Brief reasoning: what element to interact with and why]\n"
98+
"Action: [Exactly one action, e.g., CLICK(x=0.35, y=0.42)]\n\n"
99+
"IMPORTANT: Output coordinates with 2 decimal places. "
100+
"Estimate the center of target elements."
101+
)
102+
103+
DEFAULT_SCREEN_SIZE = (1920, 1080) # Aligned with openadapt-ml
104+
105+
106+
def parse_action(
107+
text: str,
108+
width: int = 1920,
109+
height: int = 1080,
110+
) -> BenchmarkAction:
111+
"""Parse VLM text output into a BenchmarkAction.
112+
113+
Supports: CLICK(x=0.XX, y=0.XX), TYPE(text="..."), WAIT(), DONE().
114+
Aligned with openadapt_ml.training.grpo.trainer.parse_vlm_output_to_action.
115+
"""
116+
text = text.strip()
117+
118+
# CLICK
119+
m = re.search(r"CLICK\(x=(-?[\d.]+),\s*y=(-?[\d.]+)\)", text, re.IGNORECASE)
120+
if m:
121+
x_frac = max(0.0, min(1.0, float(m.group(1))))
122+
y_frac = max(0.0, min(1.0, float(m.group(2))))
123+
return BenchmarkAction(
124+
type="click", x=float(int(x_frac * width)), y=float(int(y_frac * height))
125+
)
126+
127+
# TYPE (handles escaped quotes)
128+
m = re.search(r"""TYPE\(text=["']([^"'\\]*(?:\\.[^"'\\]*)*)["']\)""", text, re.IGNORECASE)
129+
if m:
130+
typed_text = m.group(1).replace("\\\\", "\\").replace('\\"', '"').replace("\\'", "'")
131+
return BenchmarkAction(type="type", text=typed_text)
132+
133+
# WAIT
134+
if re.search(r"\bWAIT\s*\(\s*\)", text, re.IGNORECASE):
135+
return BenchmarkAction(type="wait")
136+
137+
# DONE
138+
if re.search(r"\bDONE\s*\(\s*\)", text, re.IGNORECASE):
139+
return BenchmarkAction(type="done")
140+
141+
return BenchmarkAction(type="done")
142+
143+
144+
def build_agent_messages(instruction: str) -> list[dict[str, str]]:
145+
"""Build chat messages — aligned with openadapt-ml _build_agent_messages."""
146+
user_content = (
147+
f"Goal: {instruction}\n\n"
148+
"Look at the screenshot and determine the NEXT action.\n\n"
149+
'Action: [CLICK(x=..., y=...) or TYPE(text="...") or WAIT() or DONE()]'
150+
)
151+
return [
152+
{"role": "system", "content": SYSTEM_PROMPT},
153+
{"role": "user", "content": user_content},
154+
]
155+
156+
157+
def format_action_as_text(
158+
action: BenchmarkAction,
159+
width: int = 1920,
160+
height: int = 1080,
161+
) -> str:
162+
"""Convert BenchmarkAction back to DSL text for log-prob computation."""
163+
action_type = getattr(action, "type", "done")
164+
if action_type == "click":
165+
x_px = getattr(action, "x", 0) or 0
166+
y_px = getattr(action, "y", 0) or 0
167+
x_frac = x_px / width if width > 0 else 0.0
168+
y_frac = y_px / height if height > 0 else 0.0
169+
return f"CLICK(x={x_frac:.2f}, y={y_frac:.2f})"
170+
if action_type == "type":
171+
text = getattr(action, "text", "") or ""
172+
escaped = text.replace("\\", "\\\\").replace('"', '\\"')
173+
return f'TYPE(text="{escaped}")'
174+
if action_type == "wait":
175+
return "WAIT()"
176+
return "DONE()"
177+
178+
179+
def trajectory_logprob(
180+
model, processor, screenshot_bytes: bytes, instruction: str, action_text: str
181+
) -> torch.Tensor:
182+
"""Forward pass: compute log-prob of action_text given screenshot + prompt."""
183+
image = Image.open(io.BytesIO(screenshot_bytes)).convert("RGB")
184+
185+
# Use chat template if available (aligned with openadapt-ml trainer)
186+
messages = build_agent_messages(instruction)
187+
if hasattr(processor, "apply_chat_template"):
188+
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
189+
else:
190+
prompt = messages[-1]["content"]
191+
192+
inputs = processor(prompt, images=[image], return_tensors="pt")
193+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
194+
195+
tokenizer = getattr(processor, "tokenizer", processor)
196+
action_ids = tokenizer(action_text, return_tensors="pt", add_special_tokens=False)[
197+
"input_ids"
198+
].to(model.device)
199+
prompt_len = inputs["input_ids"].shape[1]
200+
201+
full_ids = torch.cat([inputs["input_ids"], action_ids], dim=1)
202+
inputs["input_ids"] = full_ids
203+
inputs["attention_mask"] = torch.ones_like(full_ids)
204+
205+
logits = model(**inputs).logits
206+
action_logits = logits[:, prompt_len - 1 : prompt_len - 1 + action_ids.shape[1], :]
207+
log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1)
208+
token_lps = log_probs.gather(2, action_ids.unsqueeze(-1)).squeeze(-1)
209+
return token_lps.sum()
210+
211+
212+
# -- Main training loop -------------------------------------------------------
213+
214+
215+
def main(
216+
server: str = "http://localhost:5001",
217+
task_id: str | None = None,
218+
num_steps: int = 5,
219+
group_size: int = 4,
220+
max_episode_steps: int = 15,
221+
model_name: str = "Qwen/Qwen2.5-VL-7B-Instruct",
222+
lr: float = 1e-5,
223+
checkpoint_dir: str = "grpo_checkpoint",
224+
mock: bool = False,
225+
) -> None:
226+
"""Run GRPO training: rollouts -> advantages -> policy gradient -> update."""
227+
# 1. Load model with LoRA
228+
print(f"Loading {model_name} ...")
229+
processor = AutoProcessor.from_pretrained(model_name)
230+
model = AutoModelForVision2Seq.from_pretrained(
231+
model_name, torch_dtype=torch.bfloat16, device_map="auto"
232+
)
233+
lora = LoraConfig(
234+
r=16,
235+
lora_alpha=32,
236+
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
237+
task_type="CAUSAL_LM",
238+
)
239+
model = get_peft_model(model, lora)
240+
model.print_trainable_parameters()
241+
optimizer = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=lr)
242+
243+
# 2. Create RL environment
244+
if mock:
245+
from openadapt_evals.adapters.waa.mock import WAAMockAdapter
246+
247+
adapter = WAAMockAdapter(num_tasks=20)
248+
else:
249+
from openadapt_evals.adapters.waa.live import WAALiveAdapter, WAALiveConfig
250+
251+
adapter = WAALiveAdapter(WAALiveConfig(server_url=server))
252+
253+
if task_id is None:
254+
tasks = adapter.list_tasks()
255+
task_id = tasks[0].task_id
256+
print(f"Auto-selected task: {task_id}")
257+
258+
env = RLEnvironment(adapter=adapter, default_task_id=task_id)
259+
w, h = env.screen_size
260+
261+
# 3. Agent function: screenshot -> VLM -> action
262+
def agent_fn(obs: BenchmarkObservation) -> BenchmarkAction:
263+
if not obs.screenshot:
264+
return BenchmarkAction(type="done")
265+
image = Image.open(io.BytesIO(obs.screenshot)).convert("RGB")
266+
task = getattr(env, "_current_task", None)
267+
goal = task.instruction if task else "Complete the task."
268+
269+
messages = build_agent_messages(goal)
270+
if hasattr(processor, "apply_chat_template"):
271+
prompt = processor.apply_chat_template(
272+
messages, tokenize=False, add_generation_prompt=True
273+
)
274+
else:
275+
prompt = messages[-1]["content"]
276+
277+
inputs = processor(prompt, images=[image], return_tensors="pt")
278+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
279+
with torch.no_grad():
280+
out = model.generate(
281+
**inputs,
282+
max_new_tokens=100,
283+
do_sample=True,
284+
temperature=0.7,
285+
)
286+
text = processor.decode(
287+
out[0][inputs["input_ids"].shape[1] :],
288+
skip_special_tokens=True,
289+
)
290+
action = parse_action(text, w, h)
291+
action._raw_text = text # stash for log-prob recomputation
292+
return action
293+
294+
# 4. Training loop
295+
for step in range(num_steps):
296+
t0 = time.time()
297+
print(f"\n{'=' * 50}\nStep {step + 1}/{num_steps}: collecting {group_size} rollouts ...")
298+
299+
# -- Collect rollouts --
300+
model.eval()
301+
rollouts, rewards = [], []
302+
for g in range(group_size):
303+
trajectory = env.collect_rollout(agent_fn=agent_fn, max_steps=max_episode_steps)
304+
reward = trajectory[-1].reward if trajectory else 0.0
305+
rollouts.append(trajectory)
306+
rewards.append(reward)
307+
print(f" rollout {g + 1}: {len(trajectory)} steps, reward={reward:.2f}")
308+
309+
# -- Compute advantages --
310+
advantages = compute_advantages(rewards)
311+
if all(a == 0.0 for a in advantages):
312+
print(" No variance in rewards, skipping gradient step.")
313+
continue
314+
315+
# -- Policy gradient update --
316+
model.train()
317+
optimizer.zero_grad()
318+
total_loss = 0.0
319+
n_terms = 0
320+
321+
task = getattr(env, "_current_task", None)
322+
instruction = task.instruction if task else ""
323+
324+
n_valid = sum(1 for _, a in zip(rollouts, advantages) if abs(a) >= 1e-8)
325+
326+
for traj, adv in zip(rollouts, advantages):
327+
if abs(adv) < 1e-8:
328+
continue
329+
num_steps = max(len(traj), 1)
330+
for s in traj:
331+
if not s.observation.screenshot:
332+
continue
333+
# Use raw VLM text if available, else reconstruct from action
334+
action_text = getattr(s.action, "_raw_text", None)
335+
if not action_text:
336+
action_text = format_action_as_text(s.action, w, h)
337+
logp = trajectory_logprob(
338+
model, processor, s.observation.screenshot, instruction, action_text
339+
)
340+
adv_t = torch.tensor(adv, device=logp.device, dtype=logp.dtype)
341+
loss = policy_gradient_loss(
342+
logp.unsqueeze(0),
343+
logp.detach().unsqueeze(0),
344+
adv_t.unsqueeze(0),
345+
)
346+
# Scale by 1/(num_valid_rollouts * num_steps) to match ml trainer
347+
scaled = loss / (n_valid * num_steps)
348+
scaled.backward()
349+
total_loss += loss.item()
350+
n_terms += 1
351+
352+
torch.nn.utils.clip_grad_norm_([p for p in model.parameters() if p.requires_grad], 1.0)
353+
optimizer.step()
354+
355+
avg_loss = total_loss / max(n_terms, 1)
356+
avg_reward = sum(rewards) / len(rewards)
357+
print(f" loss={avg_loss:.4f} mean_reward={avg_reward:.3f} time={time.time() - t0:.1f}s")
358+
359+
# 5. Save checkpoint
360+
model.save_pretrained(checkpoint_dir)
361+
print(f"\nCheckpoint saved to {checkpoint_dir}/")
362+
print("Done.")
363+
364+
365+
if __name__ == "__main__":
366+
fire.Fire(main)

0 commit comments

Comments
 (0)