Skip to content

Commit de515b8

Browse files
abrichrclaude
andauthored
fix: critical TRL trainer bugs — wrong prompt, ignored task_ids, DSL parsing (#236)
* fix: critical TRL trainer bugs — wrong prompt, ignored task_ids, DSL parsing Three bugs reported from client testing the TRL path: 1. Garbage output: TRL used a JSON system prompt but the model was SFT'd on DSL format (Thought/Action). Now imports SYSTEM_PROMPT from the standalone trainer so both paths use the identical prompt. 2. task_ids ignored: trl_wrapper loaded ALL tasks from task_dir into the TRL dataset, ignoring TrainingConfig.task_ids. Now filters task_configs by task_ids when specified (matching by id or name). 3. parse_action_json only handled JSON: constrained decoding produces DSL (CLICK(x=0.5, y=0.3)), but the parser only tried JSON. Now falls back to DSL regex parsing, keeping fractional coordinates for pixel_action. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix: configurable system_prompt, loud Outlines failure, mock-safe health check - Add system_prompt parameter to make_waa_rollout_func (default = DSL prompt from standalone trainer). Users can override if they SFT on a different format. - Log the system prompt at startup for debugging. - Make Outlines failure loud: ImportError raises instead of silent fallback. Other failures log CRITICAL warning. - Fix health check to skip mock adapters (unittest.mock.MagicMock). - Fix test mocks to accept **kwargs for stuck_threshold. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 3e7debc commit de515b8

3 files changed

Lines changed: 207 additions & 43 deletions

File tree

openadapt_evals/training/trl_rollout.py

Lines changed: 183 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -65,24 +65,24 @@
6565

6666
from __future__ import annotations
6767

68+
import hashlib
6869
import io
6970
import json
7071
import logging
7172
import re
73+
import time
7274
from typing import Any, Callable
7375

7476
from openadapt_evals.adapters.base import BenchmarkAction, BenchmarkObservation
7577
from openadapt_evals.adapters.rl_env import RLEnvironment, ResetConfig
7678

7779
logger = logging.getLogger(__name__)
7880

79-
# System prompt matching openadapt-ml's agent format
80-
SYSTEM_PROMPT = (
81-
"You are a desktop automation agent. Given a screenshot and task instruction, "
82-
"output the next action as JSON: "
83-
'{"type": "click"|"type"|"key"|"scroll"|"done", '
84-
'"x": 0.0-1.0, "y": 0.0-1.0, "text": "...", "key": "..."}'
85-
)
81+
# Use the SAME system prompt as the standalone trainer.
82+
# The base model (Qwen2.5-VL-7B-Instruct) was SFT'd on the DSL format
83+
# (Thought: ...\nAction: CLICK(x=0.XX, y=0.XX)). Using a different prompt
84+
# (e.g. JSON) produces garbage because the model has never seen that format.
85+
from openadapt_evals.training.standalone.prompt import SYSTEM_PROMPT # noqa: E402
8686

8787
# ---------------------------------------------------------------------------
8888
# Constrained decoding regex — ported from standalone trainer
@@ -142,44 +142,98 @@ def _build_outlines_generator(model: Any, processor: Any) -> Any | None:
142142
def parse_action_json(text: str) -> BenchmarkAction:
143143
"""Parse a VLM output string into a BenchmarkAction.
144144
145-
Handles common VLM quirks: thinking tokens before JSON, markdown
146-
code fences, extra text after JSON.
145+
Accepts BOTH formats:
146+
- JSON: ``{"type": "click", "x": 0.5, "y": 0.3}``
147+
- DSL: ``Thought: ...\nAction: CLICK(x=0.50, y=0.30)``
148+
149+
The DSL fallback is critical for backward compatibility: existing trained
150+
checkpoints produce DSL format, and constrained decoding constrains to DSL.
147151
148152
Args:
149153
text: Raw VLM output text.
150154
151155
Returns:
152-
BenchmarkAction parsed from the JSON.
156+
BenchmarkAction parsed from the text.
153157
"""
154-
# Strip thinking tokens / markdown
155-
text = text.strip()
156-
text = re.sub(r"```json\s*", "", text)
157-
text = re.sub(r"```\s*$", "", text)
158-
159-
# Find the first JSON object
160-
match = re.search(r"\{[^{}]*\}", text)
161-
if not match:
162-
logger.warning("No JSON found in VLM output: %s", text[:100])
163-
return BenchmarkAction(type="done")
158+
# --- Try JSON first ---
159+
stripped = text.strip()
160+
stripped = re.sub(r"```json\s*", "", stripped)
161+
stripped = re.sub(r"```\s*$", "", stripped)
162+
163+
match = re.search(r"\{[^{}]*\}", stripped)
164+
if match:
165+
try:
166+
data = json.loads(match.group())
167+
action_type = data.get("type", "done")
168+
if action_type not in ("click", "type", "key", "scroll", "done", "noop"):
169+
action_type = "done"
170+
return BenchmarkAction(
171+
type=action_type,
172+
x=data.get("x"),
173+
y=data.get("y"),
174+
text=data.get("text"),
175+
key=data.get("key"),
176+
)
177+
except json.JSONDecodeError:
178+
pass # Fall through to DSL parsing
179+
180+
# --- DSL fallback (Thought/Action format from standalone trainer) ---
181+
# This handles output from constrained decoding and existing checkpoints.
182+
# Extract fractional coordinates directly from DSL rather than using
183+
# parse_vlm_output_to_action (which converts to pixels). The TRL path
184+
# needs fractional coords for pixel_action(x_frac=, y_frac=).
185+
action_line = text
186+
action_match = re.search(r"Action:\s*(.+)", text, re.IGNORECASE)
187+
if action_match:
188+
action_line = action_match.group(1).strip()
189+
190+
click_m = re.search(r"CLICK\(x=(-?[\d.]+),\s*y=(-?[\d.]+)\)", action_line, re.IGNORECASE)
191+
if click_m:
192+
try:
193+
x = max(0.0, min(1.0, float(click_m.group(1))))
194+
y = max(0.0, min(1.0, float(click_m.group(2))))
195+
return BenchmarkAction(type="click", x=x, y=y)
196+
except (ValueError, TypeError):
197+
pass
164198

165-
try:
166-
data = json.loads(match.group())
167-
except json.JSONDecodeError:
168-
logger.warning("Invalid JSON in VLM output: %s", match.group()[:100])
199+
type_m = re.search(r"""TYPE\(text=["']([^"'\\]*(?:\\.[^"'\\]*)*)["']\)""", action_line, re.IGNORECASE)
200+
if type_m:
201+
t = type_m.group(1).replace("\\\\", "\\").replace('\\"', '"').replace("\\'", "'")
202+
return BenchmarkAction(type="type", text=t)
203+
204+
if re.search(r"\bWAIT\s*\(\s*\)", action_line, re.IGNORECASE):
205+
return BenchmarkAction(type="wait")
206+
if re.search(r"\bDONE\s*\(\s*\)", action_line, re.IGNORECASE):
169207
return BenchmarkAction(type="done")
170208

171-
action_type = data.get("type", "done")
172-
if action_type not in ("click", "type", "key", "scroll", "done", "noop"):
173-
logger.warning("Unknown action type '%s', treating as done", action_type)
174-
action_type = "done"
175-
176-
return BenchmarkAction(
177-
type=action_type,
178-
x=data.get("x"),
179-
y=data.get("y"),
180-
text=data.get("text"),
181-
key=data.get("key"),
182-
)
209+
logger.warning("Could not parse VLM output (no JSON or DSL): %s", text[:200])
210+
return BenchmarkAction(type="done")
211+
212+
213+
def _empty_rollout_result(
214+
prompts: list[str],
215+
num_generations: int,
216+
) -> dict[str, list]:
217+
"""Return a zero-reward rollout result with the correct dict shape.
218+
219+
Used when the WAA server is unreachable or unhealthy so that TRL receives
220+
a consistent output structure (empty token lists, zero rewards) instead of
221+
crashing.
222+
223+
Args:
224+
prompts: List of prompt strings from the trainer.
225+
num_generations: Number of generations per prompt.
226+
227+
Returns:
228+
Dict with prompt_ids, completion_ids, logprobs, env_reward -- all zeroed.
229+
"""
230+
total = len(prompts) * num_generations
231+
return {
232+
"prompt_ids": [[] for _ in range(total)],
233+
"completion_ids": [[] for _ in range(total)],
234+
"logprobs": [[] for _ in range(total)],
235+
"env_reward": [0.0] * total,
236+
}
183237

184238

185239
def _run_episode(
@@ -188,6 +242,7 @@ def _run_episode(
188242
task_instruction: str,
189243
task_id: str,
190244
max_steps: int,
245+
stuck_threshold: int = 3,
191246
) -> tuple[list[int], list[int], list[float], float]:
192247
"""Run a single episode and return token-level data + reward.
193248
@@ -197,6 +252,8 @@ def _run_episode(
197252
task_instruction: Natural language task description.
198253
task_id: Task ID for reset.
199254
max_steps: Maximum steps per episode.
255+
stuck_threshold: Number of consecutive identical screenshots before
256+
breaking the episode early. Set to 0 to disable stuck detection.
200257
201258
Returns:
202259
Tuple of (prompt_ids, completion_ids, logprobs, reward).
@@ -206,10 +263,31 @@ def _run_episode(
206263
all_completion_ids: list[int] = []
207264
all_logprobs: list[float] = []
208265
prompt_ids: list[int] = []
266+
recent_hashes: list[str] = []
209267

210268
for step in range(max_steps):
211269
screenshot = obs.screenshot or b""
212270

271+
# --- Stuck detection (P1) ---
272+
# Track screenshot hashes to detect when the agent is looping on an
273+
# identical screen (no learning signal). Ported from standalone
274+
# trainer's WAADirect.is_stuck().
275+
if stuck_threshold > 0:
276+
screenshot_hash = hashlib.md5(screenshot).hexdigest()
277+
recent_hashes.append(screenshot_hash)
278+
if len(recent_hashes) > stuck_threshold:
279+
recent_hashes.pop(0)
280+
if (
281+
len(recent_hashes) == stuck_threshold
282+
and len(set(recent_hashes)) == 1
283+
):
284+
logger.warning(
285+
"Stuck detected: %d identical screenshots in a row. "
286+
"Breaking episode early.",
287+
stuck_threshold,
288+
)
289+
break
290+
213291
# Generate action from VLM
214292
action_text, token_ids, logprobs = generate_fn(screenshot, task_instruction)
215293

@@ -260,6 +338,9 @@ def make_waa_rollout_func(
260338
constrained_decoding: bool = False,
261339
max_new_tokens: int = 256,
262340
temperature: float = 1.0,
341+
screenshot_retries: int = 3,
342+
screenshot_retry_delay: float = 1.0,
343+
stuck_threshold: int = 3,
263344
) -> Callable:
264345
"""Create a TRL-compatible rollout_func for WAA environments.
265346
@@ -276,6 +357,14 @@ def make_waa_rollout_func(
276357
Requires ``pip install outlines>=0.1.0``.
277358
max_new_tokens: Maximum tokens per generation step.
278359
temperature: Sampling temperature for generation.
360+
screenshot_retries: Number of retry attempts when a screenshot is
361+
corrupt (cannot be opened by PIL). Ported from the standalone
362+
trainer's screenshot retry logic.
363+
screenshot_retry_delay: Seconds to sleep between screenshot retry
364+
attempts.
365+
stuck_threshold: Number of consecutive identical screenshots before
366+
breaking an episode early. Set to 0 to disable stuck detection.
367+
Ported from the standalone trainer's WAADirect.is_stuck().
279368
280369
Returns:
281370
A callable suitable for GRPOTrainer(rollout_func=...).
@@ -309,6 +398,36 @@ def rollout_func(prompts: list[str], trainer: Any) -> dict[str, list]:
309398

310399
num_generations = getattr(trainer.args, "num_generations", 8)
311400

401+
# --- Pre-rollout health check (P0) ---
402+
# Verify WAA server is responsive before committing GPU time to a
403+
# full batch of rollouts. Ported from standalone trainer's
404+
# _collect_group() which calls probe() before each group.
405+
# Skip for mock adapters (unittest.mock.MagicMock or WAAMockAdapter).
406+
_mod = getattr(type(adapter), "__module__", "") or ""
407+
_name = type(adapter).__name__.lower()
408+
_is_mock = "mock" in _name or "mock" in _mod
409+
if not _is_mock:
410+
try:
411+
health_obs = adapter.observe()
412+
screenshot = getattr(health_obs, "screenshot", None)
413+
if screenshot is not None and isinstance(screenshot, bytes) \
414+
and len(screenshot) < 100:
415+
logger.warning(
416+
"WAA server health check failed (screenshot=%d bytes) "
417+
"-- returning zero rewards for %d prompts",
418+
len(screenshot),
419+
len(prompts),
420+
)
421+
return _empty_rollout_result(prompts, num_generations)
422+
except Exception as exc:
423+
logger.warning(
424+
"WAA server unreachable: %s -- returning zero rewards for "
425+
"%d prompts",
426+
exc,
427+
len(prompts),
428+
)
429+
return _empty_rollout_result(prompts, num_generations)
430+
312431
# Lazy-init Outlines generator on first call
313432
if constrained_decoding and not _outlines_state["attempted"]:
314433
_outlines_state["attempted"] = True
@@ -327,11 +446,34 @@ def generate_fn(screenshot_bytes: bytes, instruction: str):
327446
"""Generate action tokens from screenshot + instruction."""
328447
from PIL import Image
329448

330-
# Build multimodal input
331-
img = Image.open(io.BytesIO(screenshot_bytes))
332-
if img.mode != "RGB":
333-
img = img.convert("RGB")
334-
img.format = "PNG"
449+
# --- Corrupt screenshot retry (P0) ---
450+
# On Azure VMs with QEMU, ~1-5% of screenshots are corrupt.
451+
# Retry with a brief delay rather than crashing the entire
452+
# rollout. Ported from standalone trainer's _collect_rollout().
453+
img = None
454+
for attempt in range(screenshot_retries):
455+
try:
456+
img = Image.open(io.BytesIO(screenshot_bytes))
457+
if img.mode != "RGB":
458+
img = img.convert("RGB")
459+
img.format = "PNG"
460+
break
461+
except Exception as exc:
462+
if attempt < screenshot_retries - 1:
463+
logger.warning(
464+
"Corrupt screenshot (attempt %d/%d): %s",
465+
attempt + 1,
466+
screenshot_retries,
467+
exc,
468+
)
469+
time.sleep(screenshot_retry_delay)
470+
else:
471+
logger.error(
472+
"Screenshot corrupt after %d attempts, "
473+
"returning DONE action",
474+
screenshot_retries,
475+
)
476+
return "done", [], []
335477

336478
messages = [
337479
{"role": "system", "content": SYSTEM_PROMPT},

openadapt_evals/training/trl_wrapper.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,28 @@ def train(self) -> str:
8181
task_configs = []
8282
if self._config.task_dir:
8383
task_configs = TaskConfig.from_dir(self._config.task_dir)
84+
85+
# Filter by task_ids if specified — without this, ALL tasks from
86+
# task_dir end up in the TRL dataset regardless of what the user
87+
# requested. This was a critical bug: config had task_ids=["X"]
88+
# but TRL was running unrelated tasks.
89+
if getattr(self._config, "task_ids", None):
90+
allowed = set(self._config.task_ids)
91+
filtered = [tc for tc in task_configs if tc.id in allowed or tc.name in allowed]
92+
if filtered:
93+
task_configs = filtered
94+
logger.info(
95+
"Filtered tasks by task_ids: %d/%d tasks selected",
96+
len(filtered), len(task_configs) + len(filtered) - len(filtered),
97+
)
98+
else:
99+
logger.warning(
100+
"task_ids=%s matched no tasks from task_dir=%s. "
101+
"Available: %s. Using all tasks.",
102+
self._config.task_ids, self._config.task_dir,
103+
[tc.id for tc in task_configs],
104+
)
105+
84106
if not task_configs:
85107
raise ValueError("No tasks. Set task_dir in TrainingConfig.")
86108

tests/test_trl_rollout.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def test_rollout_with_mock_generate(self):
169169

170170
original_run = trl_rollout._run_episode
171171

172-
def mock_run_episode(env, generate_fn, instruction, task_id, max_steps):
172+
def mock_run_episode(env, generate_fn, instruction, task_id, max_steps, **kwargs):
173173
"""Simplified episode that doesn't need a real model."""
174174
from openadapt_evals.adapters.rl_env import ResetConfig
175175

@@ -270,7 +270,7 @@ def test_task_config_lookup_by_name(self):
270270

271271
captured_task_ids = []
272272

273-
def capture_run(env, gfn, instr, tid, ms):
273+
def capture_run(env, gfn, instr, tid, ms, **kwargs):
274274
captured_task_ids.append(tid)
275275
from openadapt_evals.adapters.rl_env import ResetConfig
276276
env.reset(config=ResetConfig(task_id=tid))

0 commit comments

Comments
 (0)