|
| 1 | +"""TRL GRPOTrainer rollout function for WAA desktop environments. |
| 2 | +
|
| 3 | +Wraps WAADesktopEnv into TRL's experimental ``rollout_func`` API, enabling |
| 4 | +GRPO training of VLM agents against live (or mock) Windows VMs. |
| 5 | +
|
| 6 | +The rollout_func receives prompts (task instructions) from the trainer, |
| 7 | +runs multi-step episodes against the environment, collects action tokens |
| 8 | +and logprobs, computes dense rewards via milestones, and returns everything |
| 9 | +in the format TRL expects. |
| 10 | +
|
| 11 | +Usage with TRL: |
| 12 | + from trl import GRPOConfig, GRPOTrainer |
| 13 | + from openadapt_evals.training.trl_rollout import make_waa_rollout_func |
| 14 | +
|
| 15 | + rollout_func = make_waa_rollout_func( |
| 16 | + adapter=WAALiveAdapter(WAALiveConfig(server_url="http://localhost:5001")), |
| 17 | + task_configs=TaskConfig.from_dir("./tasks/"), |
| 18 | + max_steps=15, |
| 19 | + ) |
| 20 | +
|
| 21 | + trainer = GRPOTrainer( |
| 22 | + model=model, |
| 23 | + processing_class=processor, |
| 24 | + args=GRPOConfig(...), |
| 25 | + train_dataset=dataset, |
| 26 | + rollout_func=rollout_func, |
| 27 | + ) |
| 28 | + trainer.train() |
| 29 | +
|
| 30 | +Usage with mock adapter (no VM): |
| 31 | + from openadapt_evals.training.trl_rollout import make_waa_rollout_func |
| 32 | + from openadapt_evals.adapters.waa.mock import WAAMockAdapter |
| 33 | +
|
| 34 | + rollout_func = make_waa_rollout_func( |
| 35 | + adapter=WAAMockAdapter(), |
| 36 | + task_configs=task_configs, |
| 37 | + ) |
| 38 | +""" |
| 39 | + |
| 40 | +from __future__ import annotations |
| 41 | + |
| 42 | +import io |
| 43 | +import json |
| 44 | +import logging |
| 45 | +import re |
| 46 | +from typing import Any, Callable |
| 47 | + |
| 48 | +from openadapt_evals.adapters.base import BenchmarkAction, BenchmarkObservation |
| 49 | +from openadapt_evals.adapters.rl_env import RLEnvironment, ResetConfig |
| 50 | + |
| 51 | +logger = logging.getLogger(__name__) |
| 52 | + |
| 53 | +# System prompt matching openadapt-ml's agent format |
| 54 | +SYSTEM_PROMPT = ( |
| 55 | + "You are a desktop automation agent. Given a screenshot and task instruction, " |
| 56 | + "output the next action as JSON: " |
| 57 | + '{"type": "click"|"type"|"key"|"scroll"|"done", ' |
| 58 | + '"x": 0.0-1.0, "y": 0.0-1.0, "text": "...", "key": "..."}' |
| 59 | +) |
| 60 | + |
| 61 | + |
| 62 | +def parse_action_json(text: str) -> BenchmarkAction: |
| 63 | + """Parse a VLM output string into a BenchmarkAction. |
| 64 | +
|
| 65 | + Handles common VLM quirks: thinking tokens before JSON, markdown |
| 66 | + code fences, extra text after JSON. |
| 67 | +
|
| 68 | + Args: |
| 69 | + text: Raw VLM output text. |
| 70 | +
|
| 71 | + Returns: |
| 72 | + BenchmarkAction parsed from the JSON. |
| 73 | + """ |
| 74 | + # Strip thinking tokens / markdown |
| 75 | + text = text.strip() |
| 76 | + text = re.sub(r"```json\s*", "", text) |
| 77 | + text = re.sub(r"```\s*$", "", text) |
| 78 | + |
| 79 | + # Find the first JSON object |
| 80 | + match = re.search(r"\{[^{}]*\}", text) |
| 81 | + if not match: |
| 82 | + logger.warning("No JSON found in VLM output: %s", text[:100]) |
| 83 | + return BenchmarkAction(type="done") |
| 84 | + |
| 85 | + try: |
| 86 | + data = json.loads(match.group()) |
| 87 | + except json.JSONDecodeError: |
| 88 | + logger.warning("Invalid JSON in VLM output: %s", match.group()[:100]) |
| 89 | + return BenchmarkAction(type="done") |
| 90 | + |
| 91 | + action_type = data.get("type", "done") |
| 92 | + if action_type not in ("click", "type", "key", "scroll", "done", "noop"): |
| 93 | + logger.warning("Unknown action type '%s', treating as done", action_type) |
| 94 | + action_type = "done" |
| 95 | + |
| 96 | + return BenchmarkAction( |
| 97 | + type=action_type, |
| 98 | + x=data.get("x"), |
| 99 | + y=data.get("y"), |
| 100 | + text=data.get("text"), |
| 101 | + key=data.get("key"), |
| 102 | + ) |
| 103 | + |
| 104 | + |
| 105 | +def _run_episode( |
| 106 | + env: RLEnvironment, |
| 107 | + generate_fn: Callable[[bytes, str], tuple[str, list[int], list[float]]], |
| 108 | + task_instruction: str, |
| 109 | + task_id: str, |
| 110 | + max_steps: int, |
| 111 | +) -> tuple[list[int], list[int], list[float], float]: |
| 112 | + """Run a single episode and return token-level data + reward. |
| 113 | +
|
| 114 | + Args: |
| 115 | + env: The RL environment (already has task_config loaded). |
| 116 | + generate_fn: Function(screenshot_bytes, instruction) -> (text, token_ids, logprobs). |
| 117 | + task_instruction: Natural language task description. |
| 118 | + task_id: Task ID for reset. |
| 119 | + max_steps: Maximum steps per episode. |
| 120 | +
|
| 121 | + Returns: |
| 122 | + Tuple of (prompt_ids, completion_ids, logprobs, reward). |
| 123 | + """ |
| 124 | + obs = env.reset(config=ResetConfig(task_id=task_id)) |
| 125 | + |
| 126 | + all_completion_ids: list[int] = [] |
| 127 | + all_logprobs: list[float] = [] |
| 128 | + prompt_ids: list[int] = [] |
| 129 | + |
| 130 | + for step in range(max_steps): |
| 131 | + screenshot = obs.screenshot or b"" |
| 132 | + |
| 133 | + # Generate action from VLM |
| 134 | + action_text, token_ids, logprobs = generate_fn(screenshot, task_instruction) |
| 135 | + |
| 136 | + # Track token-level data |
| 137 | + if step == 0: |
| 138 | + # First generation includes the prompt encoding |
| 139 | + # In practice, the generate_fn should separate prompt from completion |
| 140 | + pass |
| 141 | + all_completion_ids.extend(token_ids) |
| 142 | + all_logprobs.extend(logprobs) |
| 143 | + |
| 144 | + # Parse and execute action |
| 145 | + action = parse_action_json(action_text) |
| 146 | + if action.type == "done": |
| 147 | + break |
| 148 | + |
| 149 | + # Handle fractional coordinates |
| 150 | + if action.x is not None and action.y is not None: |
| 151 | + if 0 <= action.x <= 1 and 0 <= action.y <= 1: |
| 152 | + step_result = env.pixel_action( |
| 153 | + x_frac=action.x, y_frac=action.y, |
| 154 | + action_type=action.type, text=action.text, key=action.key, |
| 155 | + ) |
| 156 | + else: |
| 157 | + step_result = env.pixel_action( |
| 158 | + x=int(action.x), y=int(action.y), |
| 159 | + action_type=action.type, text=action.text, key=action.key, |
| 160 | + ) |
| 161 | + elif action.type in ("type", "key"): |
| 162 | + step_result = env.step(action) |
| 163 | + else: |
| 164 | + step_result = env.step(action) |
| 165 | + |
| 166 | + obs = step_result.observation |
| 167 | + if step_result.done: |
| 168 | + break |
| 169 | + |
| 170 | + # Evaluate — dense rewards if milestones, binary otherwise |
| 171 | + reward = env.evaluate_dense() |
| 172 | + |
| 173 | + return prompt_ids, all_completion_ids, all_logprobs, reward |
| 174 | + |
| 175 | + |
| 176 | +def make_waa_rollout_func( |
| 177 | + adapter: Any, |
| 178 | + task_configs: list | None = None, |
| 179 | + max_steps: int = 15, |
| 180 | +) -> Callable: |
| 181 | + """Create a TRL-compatible rollout_func for WAA environments. |
| 182 | +
|
| 183 | + The returned function has signature: |
| 184 | + rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list] |
| 185 | +
|
| 186 | + Args: |
| 187 | + adapter: A BenchmarkAdapter (WAALiveAdapter or WAAMockAdapter). |
| 188 | + task_configs: List of TaskConfig objects. Each prompt in the training |
| 189 | + dataset should have a matching task_config by name or index. |
| 190 | + max_steps: Maximum steps per episode. |
| 191 | +
|
| 192 | + Returns: |
| 193 | + A callable suitable for GRPOTrainer(rollout_func=...). |
| 194 | + """ |
| 195 | + # Index task configs by name for lookup |
| 196 | + config_map: dict[str, Any] = {} |
| 197 | + if task_configs: |
| 198 | + from openadapt_evals.task_config import TaskConfig |
| 199 | + |
| 200 | + for tc in task_configs: |
| 201 | + config_map[tc.name] = tc |
| 202 | + config_map[tc.id] = tc |
| 203 | + |
| 204 | + def rollout_func(prompts: list[str], trainer: Any) -> dict[str, list]: |
| 205 | + """TRL GRPOTrainer rollout function. |
| 206 | +
|
| 207 | + Args: |
| 208 | + prompts: Task instructions from the training dataset. |
| 209 | + trainer: Active GRPOTrainer instance (provides model + processor). |
| 210 | +
|
| 211 | + Returns: |
| 212 | + Dict with prompt_ids, completion_ids, logprobs, env_reward. |
| 213 | + """ |
| 214 | + processor = trainer.processing_class |
| 215 | + model = trainer.model |
| 216 | + device = next(model.parameters()).device |
| 217 | + |
| 218 | + num_generations = getattr(trainer.args, "num_generations", 8) |
| 219 | + |
| 220 | + all_prompt_ids = [] |
| 221 | + all_completion_ids = [] |
| 222 | + all_logprobs = [] |
| 223 | + all_rewards = [] |
| 224 | + |
| 225 | + def generate_fn(screenshot_bytes: bytes, instruction: str): |
| 226 | + """Generate action tokens from screenshot + instruction.""" |
| 227 | + from PIL import Image |
| 228 | + |
| 229 | + # Build multimodal input |
| 230 | + img = Image.open(io.BytesIO(screenshot_bytes)) |
| 231 | + messages = [ |
| 232 | + {"role": "system", "content": SYSTEM_PROMPT}, |
| 233 | + {"role": "user", "content": [ |
| 234 | + {"type": "image", "image": img}, |
| 235 | + {"type": "text", "text": instruction}, |
| 236 | + ]}, |
| 237 | + ] |
| 238 | + |
| 239 | + # Tokenize with processor |
| 240 | + import torch |
| 241 | + |
| 242 | + text_input = processor.apply_chat_template( |
| 243 | + messages, tokenize=False, add_generation_prompt=True |
| 244 | + ) |
| 245 | + inputs = processor( |
| 246 | + text=[text_input], images=[img], |
| 247 | + return_tensors="pt", padding=True, |
| 248 | + ).to(device) |
| 249 | + |
| 250 | + # Generate |
| 251 | + with torch.no_grad(): |
| 252 | + outputs = model.generate( |
| 253 | + **inputs, |
| 254 | + max_new_tokens=256, |
| 255 | + do_sample=True, |
| 256 | + temperature=1.0, |
| 257 | + return_dict_in_generate=True, |
| 258 | + output_scores=True, |
| 259 | + ) |
| 260 | + |
| 261 | + # Extract completion tokens (everything after prompt) |
| 262 | + prompt_len = inputs["input_ids"].shape[1] |
| 263 | + completion_ids = outputs.sequences[0][prompt_len:].tolist() |
| 264 | + |
| 265 | + # Compute per-token logprobs from scores |
| 266 | + logprobs = [] |
| 267 | + if hasattr(outputs, "scores") and outputs.scores: |
| 268 | + for i, score in enumerate(outputs.scores): |
| 269 | + probs = torch.nn.functional.log_softmax(score[0], dim=-1) |
| 270 | + if i < len(completion_ids): |
| 271 | + logprobs.append(probs[completion_ids[i]].item()) |
| 272 | + |
| 273 | + # Decode text |
| 274 | + text = processor.decode(completion_ids, skip_special_tokens=True) |
| 275 | + |
| 276 | + return text, completion_ids, logprobs |
| 277 | + |
| 278 | + for prompt in prompts: |
| 279 | + # Find matching task config |
| 280 | + tc = config_map.get(prompt) |
| 281 | + |
| 282 | + for gen_idx in range(num_generations): |
| 283 | + env = RLEnvironment(adapter, task_config=tc) |
| 284 | + |
| 285 | + task_id = tc.id if tc else "default" |
| 286 | + |
| 287 | + try: |
| 288 | + p_ids, c_ids, lps, reward = _run_episode( |
| 289 | + env, generate_fn, prompt, task_id, max_steps, |
| 290 | + ) |
| 291 | + except Exception as exc: |
| 292 | + logger.error( |
| 293 | + "Rollout failed for prompt=%s gen=%d: %s", |
| 294 | + prompt[:50], gen_idx, exc, |
| 295 | + ) |
| 296 | + p_ids, c_ids, lps, reward = [], [], [], 0.0 |
| 297 | + |
| 298 | + all_prompt_ids.append(p_ids) |
| 299 | + all_completion_ids.append(c_ids) |
| 300 | + all_logprobs.append(lps) |
| 301 | + all_rewards.append(reward) |
| 302 | + |
| 303 | + return { |
| 304 | + "prompt_ids": all_prompt_ids, |
| 305 | + "completion_ids": all_completion_ids, |
| 306 | + "logprobs": all_logprobs, |
| 307 | + "env_reward": all_rewards, |
| 308 | + } |
| 309 | + |
| 310 | + return rollout_func |
0 commit comments