|
| 1 | +#!/usr/bin/env python3 |
| 2 | +"""Self-contained GRPO training example for GUI agents. |
| 3 | +
|
| 4 | +Demonstrates the full RL loop: connect to WAA, collect rollouts with a VLM |
| 5 | +policy, compute group-relative advantages, update LoRA weights via |
| 6 | +REINFORCE with group-relative advantages (equivalent to single-epoch GRPO). |
| 7 | +No openadapt-ml dependency -- all math and parsing are inline. |
| 8 | +
|
| 9 | +Requirements: |
| 10 | + pip install torch transformers peft pillow |
| 11 | + # openadapt-evals must be installed (provides RLEnvironment, adapters) |
| 12 | +
|
| 13 | +Usage: |
| 14 | + # Mock mode (no VM, random rewards -- for testing the loop): |
| 15 | + python scripts/train_grpo_example.py --mock --num-steps 2 --group-size 3 |
| 16 | +
|
| 17 | + # Live WAA server: |
| 18 | + python scripts/train_grpo_example.py \\ |
| 19 | + --server http://localhost:5001 \\ |
| 20 | + --task-id <WAA_UUID> \\ |
| 21 | + --num-steps 10 \\ |
| 22 | + --group-size 4 |
| 23 | +
|
| 24 | + # Custom model and learning rate: |
| 25 | + python scripts/train_grpo_example.py \\ |
| 26 | + --mock \\ |
| 27 | + --model-name Qwen/Qwen2.5-VL-7B-Instruct \\ |
| 28 | + --lr 1e-5 \\ |
| 29 | + --num-steps 5 |
| 30 | +""" |
| 31 | + |
| 32 | +from __future__ import annotations |
| 33 | + |
| 34 | +import io |
| 35 | +import re |
| 36 | +import time |
| 37 | + |
| 38 | +import fire |
| 39 | +import torch |
| 40 | +from peft import LoraConfig, get_peft_model |
| 41 | +from PIL import Image |
| 42 | +from transformers import AutoModelForVision2Seq, AutoProcessor |
| 43 | + |
| 44 | +from openadapt_evals.adapters.base import BenchmarkAction, BenchmarkObservation |
| 45 | +from openadapt_evals.adapters.rl_env import RLEnvironment |
| 46 | + |
| 47 | + |
| 48 | +# -- Policy gradient loss ------------------------------------------------------ |
| 49 | + |
| 50 | + |
| 51 | +def policy_gradient_loss( |
| 52 | + current_logps: torch.Tensor, |
| 53 | + old_logps: torch.Tensor, |
| 54 | + advantages: torch.Tensor, |
| 55 | + epsilon: float = 0.2, |
| 56 | +) -> torch.Tensor: |
| 57 | + """Policy gradient loss with optional PPO-style clipping. |
| 58 | +
|
| 59 | + When old_logps == current_logps (single-epoch), reduces to REINFORCE. |
| 60 | + """ |
| 61 | + ratio = torch.exp(current_logps - old_logps) |
| 62 | + clipped = torch.clamp(ratio, 1.0 - epsilon, 1.0 + epsilon) |
| 63 | + return -torch.min(ratio * advantages, clipped * advantages).mean() |
| 64 | + |
| 65 | + |
| 66 | +def compute_advantages(rewards: list[float]) -> list[float]: |
| 67 | + """Group-relative advantage: (r - mean) / (std + eps).""" |
| 68 | + n = len(rewards) |
| 69 | + if n == 0: |
| 70 | + return [] |
| 71 | + mean = sum(rewards) / n |
| 72 | + std = (sum((r - mean) ** 2 for r in rewards) / n) ** 0.5 |
| 73 | + if std < 1e-8: |
| 74 | + return [0.0] * n |
| 75 | + return [(r - mean) / (std + 1e-8) for r in rewards] |
| 76 | + |
| 77 | + |
| 78 | +# -- Helpers ------------------------------------------------------------------- |
| 79 | + |
| 80 | +# Aligned with openadapt_ml.datasets.next_action.SYSTEM_PROMPT |
| 81 | +SYSTEM_PROMPT = ( |
| 82 | + "You are a GUI automation agent. Given a screenshot and a user goal, " |
| 83 | + "predict the single next action.\n\n" |
| 84 | + "COORDINATE SYSTEM:\n" |
| 85 | + "- x=0.0 is the LEFT edge, x=1.0 is the RIGHT edge\n" |
| 86 | + "- y=0.0 is the TOP edge, y=1.0 is the BOTTOM edge\n" |
| 87 | + "- To click the CENTER of an element, estimate its center position " |
| 88 | + "as a fraction of screen width/height\n" |
| 89 | + "- Example: An element in the middle of the screen would be " |
| 90 | + "approximately x=0.5, y=0.5\n\n" |
| 91 | + "ALLOWED ACTIONS (use exactly this format):\n" |
| 92 | + "- CLICK(x=0.XX, y=0.XX) → click at normalized coordinates\n" |
| 93 | + '- TYPE(text="...") → type text into the currently focused field\n' |
| 94 | + "- WAIT() → wait for UI to update\n" |
| 95 | + "- DONE() → task is complete\n\n" |
| 96 | + "RESPONSE FORMAT (required):\n" |
| 97 | + "Thought: [Brief reasoning: what element to interact with and why]\n" |
| 98 | + "Action: [Exactly one action, e.g., CLICK(x=0.35, y=0.42)]\n\n" |
| 99 | + "IMPORTANT: Output coordinates with 2 decimal places. " |
| 100 | + "Estimate the center of target elements." |
| 101 | +) |
| 102 | + |
| 103 | +DEFAULT_SCREEN_SIZE = (1920, 1080) # Aligned with openadapt-ml |
| 104 | + |
| 105 | + |
| 106 | +def parse_action( |
| 107 | + text: str, |
| 108 | + width: int = 1920, |
| 109 | + height: int = 1080, |
| 110 | +) -> BenchmarkAction: |
| 111 | + """Parse VLM text output into a BenchmarkAction. |
| 112 | +
|
| 113 | + Supports: CLICK(x=0.XX, y=0.XX), TYPE(text="..."), WAIT(), DONE(). |
| 114 | + Aligned with openadapt_ml.training.grpo.trainer.parse_vlm_output_to_action. |
| 115 | + """ |
| 116 | + text = text.strip() |
| 117 | + |
| 118 | + # CLICK |
| 119 | + m = re.search(r"CLICK\(x=(-?[\d.]+),\s*y=(-?[\d.]+)\)", text, re.IGNORECASE) |
| 120 | + if m: |
| 121 | + x_frac = max(0.0, min(1.0, float(m.group(1)))) |
| 122 | + y_frac = max(0.0, min(1.0, float(m.group(2)))) |
| 123 | + return BenchmarkAction( |
| 124 | + type="click", x=float(int(x_frac * width)), y=float(int(y_frac * height)) |
| 125 | + ) |
| 126 | + |
| 127 | + # TYPE (handles escaped quotes) |
| 128 | + m = re.search(r"""TYPE\(text=["']([^"'\\]*(?:\\.[^"'\\]*)*)["']\)""", text, re.IGNORECASE) |
| 129 | + if m: |
| 130 | + typed_text = m.group(1).replace("\\\\", "\\").replace('\\"', '"').replace("\\'", "'") |
| 131 | + return BenchmarkAction(type="type", text=typed_text) |
| 132 | + |
| 133 | + # WAIT |
| 134 | + if re.search(r"\bWAIT\s*\(\s*\)", text, re.IGNORECASE): |
| 135 | + return BenchmarkAction(type="wait") |
| 136 | + |
| 137 | + # DONE |
| 138 | + if re.search(r"\bDONE\s*\(\s*\)", text, re.IGNORECASE): |
| 139 | + return BenchmarkAction(type="done") |
| 140 | + |
| 141 | + return BenchmarkAction(type="done") |
| 142 | + |
| 143 | + |
| 144 | +def build_agent_messages(instruction: str) -> list[dict[str, str]]: |
| 145 | + """Build chat messages — aligned with openadapt-ml _build_agent_messages.""" |
| 146 | + user_content = ( |
| 147 | + f"Goal: {instruction}\n\n" |
| 148 | + "Look at the screenshot and determine the NEXT action.\n\n" |
| 149 | + 'Action: [CLICK(x=..., y=...) or TYPE(text="...") or WAIT() or DONE()]' |
| 150 | + ) |
| 151 | + return [ |
| 152 | + {"role": "system", "content": SYSTEM_PROMPT}, |
| 153 | + {"role": "user", "content": user_content}, |
| 154 | + ] |
| 155 | + |
| 156 | + |
| 157 | +def format_action_as_text( |
| 158 | + action: BenchmarkAction, |
| 159 | + width: int = 1920, |
| 160 | + height: int = 1080, |
| 161 | +) -> str: |
| 162 | + """Convert BenchmarkAction back to DSL text for log-prob computation.""" |
| 163 | + action_type = getattr(action, "type", "done") |
| 164 | + if action_type == "click": |
| 165 | + x_px = getattr(action, "x", 0) or 0 |
| 166 | + y_px = getattr(action, "y", 0) or 0 |
| 167 | + x_frac = x_px / width if width > 0 else 0.0 |
| 168 | + y_frac = y_px / height if height > 0 else 0.0 |
| 169 | + return f"CLICK(x={x_frac:.2f}, y={y_frac:.2f})" |
| 170 | + if action_type == "type": |
| 171 | + text = getattr(action, "text", "") or "" |
| 172 | + escaped = text.replace("\\", "\\\\").replace('"', '\\"') |
| 173 | + return f'TYPE(text="{escaped}")' |
| 174 | + if action_type == "wait": |
| 175 | + return "WAIT()" |
| 176 | + return "DONE()" |
| 177 | + |
| 178 | + |
| 179 | +def trajectory_logprob( |
| 180 | + model, processor, screenshot_bytes: bytes, instruction: str, action_text: str |
| 181 | +) -> torch.Tensor: |
| 182 | + """Forward pass: compute log-prob of action_text given screenshot + prompt.""" |
| 183 | + image = Image.open(io.BytesIO(screenshot_bytes)).convert("RGB") |
| 184 | + |
| 185 | + # Use chat template if available (aligned with openadapt-ml trainer) |
| 186 | + messages = build_agent_messages(instruction) |
| 187 | + if hasattr(processor, "apply_chat_template"): |
| 188 | + prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| 189 | + else: |
| 190 | + prompt = messages[-1]["content"] |
| 191 | + |
| 192 | + inputs = processor(prompt, images=[image], return_tensors="pt") |
| 193 | + inputs = {k: v.to(model.device) for k, v in inputs.items()} |
| 194 | + |
| 195 | + tokenizer = getattr(processor, "tokenizer", processor) |
| 196 | + action_ids = tokenizer(action_text, return_tensors="pt", add_special_tokens=False)[ |
| 197 | + "input_ids" |
| 198 | + ].to(model.device) |
| 199 | + prompt_len = inputs["input_ids"].shape[1] |
| 200 | + |
| 201 | + full_ids = torch.cat([inputs["input_ids"], action_ids], dim=1) |
| 202 | + inputs["input_ids"] = full_ids |
| 203 | + inputs["attention_mask"] = torch.ones_like(full_ids) |
| 204 | + |
| 205 | + logits = model(**inputs).logits |
| 206 | + action_logits = logits[:, prompt_len - 1 : prompt_len - 1 + action_ids.shape[1], :] |
| 207 | + log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1) |
| 208 | + token_lps = log_probs.gather(2, action_ids.unsqueeze(-1)).squeeze(-1) |
| 209 | + return token_lps.sum() |
| 210 | + |
| 211 | + |
| 212 | +# -- Main training loop ------------------------------------------------------- |
| 213 | + |
| 214 | + |
| 215 | +def main( |
| 216 | + server: str = "http://localhost:5001", |
| 217 | + task_id: str | None = None, |
| 218 | + num_steps: int = 5, |
| 219 | + group_size: int = 4, |
| 220 | + max_episode_steps: int = 15, |
| 221 | + model_name: str = "Qwen/Qwen2.5-VL-7B-Instruct", |
| 222 | + lr: float = 1e-5, |
| 223 | + checkpoint_dir: str = "grpo_checkpoint", |
| 224 | + mock: bool = False, |
| 225 | +) -> None: |
| 226 | + """Run GRPO training: rollouts -> advantages -> policy gradient -> update.""" |
| 227 | + # 1. Load model with LoRA |
| 228 | + print(f"Loading {model_name} ...") |
| 229 | + processor = AutoProcessor.from_pretrained(model_name) |
| 230 | + model = AutoModelForVision2Seq.from_pretrained( |
| 231 | + model_name, torch_dtype=torch.bfloat16, device_map="auto" |
| 232 | + ) |
| 233 | + lora = LoraConfig( |
| 234 | + r=16, |
| 235 | + lora_alpha=32, |
| 236 | + target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], |
| 237 | + task_type="CAUSAL_LM", |
| 238 | + ) |
| 239 | + model = get_peft_model(model, lora) |
| 240 | + model.print_trainable_parameters() |
| 241 | + optimizer = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=lr) |
| 242 | + |
| 243 | + # 2. Create RL environment |
| 244 | + if mock: |
| 245 | + from openadapt_evals.adapters.waa.mock import WAAMockAdapter |
| 246 | + |
| 247 | + adapter = WAAMockAdapter(num_tasks=20) |
| 248 | + else: |
| 249 | + from openadapt_evals.adapters.waa.live import WAALiveAdapter, WAALiveConfig |
| 250 | + |
| 251 | + adapter = WAALiveAdapter(WAALiveConfig(server_url=server)) |
| 252 | + |
| 253 | + if task_id is None: |
| 254 | + tasks = adapter.list_tasks() |
| 255 | + task_id = tasks[0].task_id |
| 256 | + print(f"Auto-selected task: {task_id}") |
| 257 | + |
| 258 | + env = RLEnvironment(adapter=adapter, default_task_id=task_id) |
| 259 | + w, h = env.screen_size |
| 260 | + |
| 261 | + # 3. Agent function: screenshot -> VLM -> action |
| 262 | + def agent_fn(obs: BenchmarkObservation) -> BenchmarkAction: |
| 263 | + if not obs.screenshot: |
| 264 | + return BenchmarkAction(type="done") |
| 265 | + image = Image.open(io.BytesIO(obs.screenshot)).convert("RGB") |
| 266 | + task = getattr(env, "_current_task", None) |
| 267 | + goal = task.instruction if task else "Complete the task." |
| 268 | + |
| 269 | + messages = build_agent_messages(goal) |
| 270 | + if hasattr(processor, "apply_chat_template"): |
| 271 | + prompt = processor.apply_chat_template( |
| 272 | + messages, tokenize=False, add_generation_prompt=True |
| 273 | + ) |
| 274 | + else: |
| 275 | + prompt = messages[-1]["content"] |
| 276 | + |
| 277 | + inputs = processor(prompt, images=[image], return_tensors="pt") |
| 278 | + inputs = {k: v.to(model.device) for k, v in inputs.items()} |
| 279 | + with torch.no_grad(): |
| 280 | + out = model.generate( |
| 281 | + **inputs, |
| 282 | + max_new_tokens=100, |
| 283 | + do_sample=True, |
| 284 | + temperature=0.7, |
| 285 | + ) |
| 286 | + text = processor.decode( |
| 287 | + out[0][inputs["input_ids"].shape[1] :], |
| 288 | + skip_special_tokens=True, |
| 289 | + ) |
| 290 | + action = parse_action(text, w, h) |
| 291 | + action._raw_text = text # stash for log-prob recomputation |
| 292 | + return action |
| 293 | + |
| 294 | + # 4. Training loop |
| 295 | + for step in range(num_steps): |
| 296 | + t0 = time.time() |
| 297 | + print(f"\n{'=' * 50}\nStep {step + 1}/{num_steps}: collecting {group_size} rollouts ...") |
| 298 | + |
| 299 | + # -- Collect rollouts -- |
| 300 | + model.eval() |
| 301 | + rollouts, rewards = [], [] |
| 302 | + for g in range(group_size): |
| 303 | + trajectory = env.collect_rollout(agent_fn=agent_fn, max_steps=max_episode_steps) |
| 304 | + reward = trajectory[-1].reward if trajectory else 0.0 |
| 305 | + rollouts.append(trajectory) |
| 306 | + rewards.append(reward) |
| 307 | + print(f" rollout {g + 1}: {len(trajectory)} steps, reward={reward:.2f}") |
| 308 | + |
| 309 | + # -- Compute advantages -- |
| 310 | + advantages = compute_advantages(rewards) |
| 311 | + if all(a == 0.0 for a in advantages): |
| 312 | + print(" No variance in rewards, skipping gradient step.") |
| 313 | + continue |
| 314 | + |
| 315 | + # -- Policy gradient update -- |
| 316 | + model.train() |
| 317 | + optimizer.zero_grad() |
| 318 | + total_loss = 0.0 |
| 319 | + n_terms = 0 |
| 320 | + |
| 321 | + task = getattr(env, "_current_task", None) |
| 322 | + instruction = task.instruction if task else "" |
| 323 | + |
| 324 | + n_valid = sum(1 for _, a in zip(rollouts, advantages) if abs(a) >= 1e-8) |
| 325 | + |
| 326 | + for traj, adv in zip(rollouts, advantages): |
| 327 | + if abs(adv) < 1e-8: |
| 328 | + continue |
| 329 | + num_steps = max(len(traj), 1) |
| 330 | + for s in traj: |
| 331 | + if not s.observation.screenshot: |
| 332 | + continue |
| 333 | + # Use raw VLM text if available, else reconstruct from action |
| 334 | + action_text = getattr(s.action, "_raw_text", None) |
| 335 | + if not action_text: |
| 336 | + action_text = format_action_as_text(s.action, w, h) |
| 337 | + logp = trajectory_logprob( |
| 338 | + model, processor, s.observation.screenshot, instruction, action_text |
| 339 | + ) |
| 340 | + adv_t = torch.tensor(adv, device=logp.device, dtype=logp.dtype) |
| 341 | + loss = policy_gradient_loss( |
| 342 | + logp.unsqueeze(0), |
| 343 | + logp.detach().unsqueeze(0), |
| 344 | + adv_t.unsqueeze(0), |
| 345 | + ) |
| 346 | + # Scale by 1/(num_valid_rollouts * num_steps) to match ml trainer |
| 347 | + scaled = loss / (n_valid * num_steps) |
| 348 | + scaled.backward() |
| 349 | + total_loss += loss.item() |
| 350 | + n_terms += 1 |
| 351 | + |
| 352 | + torch.nn.utils.clip_grad_norm_([p for p in model.parameters() if p.requires_grad], 1.0) |
| 353 | + optimizer.step() |
| 354 | + |
| 355 | + avg_loss = total_loss / max(n_terms, 1) |
| 356 | + avg_reward = sum(rewards) / len(rewards) |
| 357 | + print(f" loss={avg_loss:.4f} mean_reward={avg_reward:.3f} time={time.time() - t0:.1f}s") |
| 358 | + |
| 359 | + # 5. Save checkpoint |
| 360 | + model.save_pretrained(checkpoint_dir) |
| 361 | + print(f"\nCheckpoint saved to {checkpoint_dir}/") |
| 362 | + print("Done.") |
| 363 | + |
| 364 | + |
| 365 | +if __name__ == "__main__": |
| 366 | + fire.Fire(main) |
0 commit comments