6565
6666from __future__ import annotations
6767
68+ import hashlib
6869import io
6970import json
7071import logging
7172import re
73+ import time
7274from typing import Any , Callable
7375
7476from openadapt_evals .adapters .base import BenchmarkAction , BenchmarkObservation
7577from openadapt_evals .adapters .rl_env import RLEnvironment , ResetConfig
7678
7779logger = logging .getLogger (__name__ )
7880
79- # System prompt matching openadapt-ml's agent format
80- SYSTEM_PROMPT = (
81- "You are a desktop automation agent. Given a screenshot and task instruction, "
82- "output the next action as JSON: "
83- '{"type": "click"|"type"|"key"|"scroll"|"done", '
84- '"x": 0.0-1.0, "y": 0.0-1.0, "text": "...", "key": "..."}'
85- )
81+ # Use the SAME system prompt as the standalone trainer.
82+ # The base model (Qwen2.5-VL-7B-Instruct) was SFT'd on the DSL format
83+ # (Thought: ...\nAction: CLICK(x=0.XX, y=0.XX)). Using a different prompt
84+ # (e.g. JSON) produces garbage because the model has never seen that format.
85+ from openadapt_evals .training .standalone .prompt import SYSTEM_PROMPT # noqa: E402
8686
8787# ---------------------------------------------------------------------------
8888# Constrained decoding regex — ported from standalone trainer
@@ -142,44 +142,98 @@ def _build_outlines_generator(model: Any, processor: Any) -> Any | None:
142142def parse_action_json (text : str ) -> BenchmarkAction :
143143 """Parse a VLM output string into a BenchmarkAction.
144144
145- Handles common VLM quirks: thinking tokens before JSON, markdown
146- code fences, extra text after JSON.
145+ Accepts BOTH formats:
146+ - JSON: ``{"type": "click", "x": 0.5, "y": 0.3}``
147+ - DSL: ``Thought: ...\n Action: CLICK(x=0.50, y=0.30)``
148+
149+ The DSL fallback is critical for backward compatibility: existing trained
150+ checkpoints produce DSL format, and constrained decoding constrains to DSL.
147151
148152 Args:
149153 text: Raw VLM output text.
150154
151155 Returns:
152- BenchmarkAction parsed from the JSON .
156+ BenchmarkAction parsed from the text .
153157 """
154- # Strip thinking tokens / markdown
155- text = text .strip ()
156- text = re .sub (r"```json\s*" , "" , text )
157- text = re .sub (r"```\s*$" , "" , text )
158-
159- # Find the first JSON object
160- match = re .search (r"\{[^{}]*\}" , text )
161- if not match :
162- logger .warning ("No JSON found in VLM output: %s" , text [:100 ])
163- return BenchmarkAction (type = "done" )
158+ # --- Try JSON first ---
159+ stripped = text .strip ()
160+ stripped = re .sub (r"```json\s*" , "" , stripped )
161+ stripped = re .sub (r"```\s*$" , "" , stripped )
162+
163+ match = re .search (r"\{[^{}]*\}" , stripped )
164+ if match :
165+ try :
166+ data = json .loads (match .group ())
167+ action_type = data .get ("type" , "done" )
168+ if action_type not in ("click" , "type" , "key" , "scroll" , "done" , "noop" ):
169+ action_type = "done"
170+ return BenchmarkAction (
171+ type = action_type ,
172+ x = data .get ("x" ),
173+ y = data .get ("y" ),
174+ text = data .get ("text" ),
175+ key = data .get ("key" ),
176+ )
177+ except json .JSONDecodeError :
178+ pass # Fall through to DSL parsing
179+
180+ # --- DSL fallback (Thought/Action format from standalone trainer) ---
181+ # This handles output from constrained decoding and existing checkpoints.
182+ # Extract fractional coordinates directly from DSL rather than using
183+ # parse_vlm_output_to_action (which converts to pixels). The TRL path
184+ # needs fractional coords for pixel_action(x_frac=, y_frac=).
185+ action_line = text
186+ action_match = re .search (r"Action:\s*(.+)" , text , re .IGNORECASE )
187+ if action_match :
188+ action_line = action_match .group (1 ).strip ()
189+
190+ click_m = re .search (r"CLICK\(x=(-?[\d.]+),\s*y=(-?[\d.]+)\)" , action_line , re .IGNORECASE )
191+ if click_m :
192+ try :
193+ x = max (0.0 , min (1.0 , float (click_m .group (1 ))))
194+ y = max (0.0 , min (1.0 , float (click_m .group (2 ))))
195+ return BenchmarkAction (type = "click" , x = x , y = y )
196+ except (ValueError , TypeError ):
197+ pass
164198
165- try :
166- data = json .loads (match .group ())
167- except json .JSONDecodeError :
168- logger .warning ("Invalid JSON in VLM output: %s" , match .group ()[:100 ])
199+ type_m = re .search (r"""TYPE\(text=["']([^"'\\]*(?:\\.[^"'\\]*)*)["']\)""" , action_line , re .IGNORECASE )
200+ if type_m :
201+ t = type_m .group (1 ).replace ("\\ \\ " , "\\ " ).replace ('\\ "' , '"' ).replace ("\\ '" , "'" )
202+ return BenchmarkAction (type = "type" , text = t )
203+
204+ if re .search (r"\bWAIT\s*\(\s*\)" , action_line , re .IGNORECASE ):
205+ return BenchmarkAction (type = "wait" )
206+ if re .search (r"\bDONE\s*\(\s*\)" , action_line , re .IGNORECASE ):
169207 return BenchmarkAction (type = "done" )
170208
171- action_type = data .get ("type" , "done" )
172- if action_type not in ("click" , "type" , "key" , "scroll" , "done" , "noop" ):
173- logger .warning ("Unknown action type '%s', treating as done" , action_type )
174- action_type = "done"
175-
176- return BenchmarkAction (
177- type = action_type ,
178- x = data .get ("x" ),
179- y = data .get ("y" ),
180- text = data .get ("text" ),
181- key = data .get ("key" ),
182- )
209+ logger .warning ("Could not parse VLM output (no JSON or DSL): %s" , text [:200 ])
210+ return BenchmarkAction (type = "done" )
211+
212+
213+ def _empty_rollout_result (
214+ prompts : list [str ],
215+ num_generations : int ,
216+ ) -> dict [str , list ]:
217+ """Return a zero-reward rollout result with the correct dict shape.
218+
219+ Used when the WAA server is unreachable or unhealthy so that TRL receives
220+ a consistent output structure (empty token lists, zero rewards) instead of
221+ crashing.
222+
223+ Args:
224+ prompts: List of prompt strings from the trainer.
225+ num_generations: Number of generations per prompt.
226+
227+ Returns:
228+ Dict with prompt_ids, completion_ids, logprobs, env_reward -- all zeroed.
229+ """
230+ total = len (prompts ) * num_generations
231+ return {
232+ "prompt_ids" : [[] for _ in range (total )],
233+ "completion_ids" : [[] for _ in range (total )],
234+ "logprobs" : [[] for _ in range (total )],
235+ "env_reward" : [0.0 ] * total ,
236+ }
183237
184238
185239def _run_episode (
@@ -188,6 +242,7 @@ def _run_episode(
188242 task_instruction : str ,
189243 task_id : str ,
190244 max_steps : int ,
245+ stuck_threshold : int = 3 ,
191246) -> tuple [list [int ], list [int ], list [float ], float ]:
192247 """Run a single episode and return token-level data + reward.
193248
@@ -197,6 +252,8 @@ def _run_episode(
197252 task_instruction: Natural language task description.
198253 task_id: Task ID for reset.
199254 max_steps: Maximum steps per episode.
255+ stuck_threshold: Number of consecutive identical screenshots before
256+ breaking the episode early. Set to 0 to disable stuck detection.
200257
201258 Returns:
202259 Tuple of (prompt_ids, completion_ids, logprobs, reward).
@@ -206,10 +263,31 @@ def _run_episode(
206263 all_completion_ids : list [int ] = []
207264 all_logprobs : list [float ] = []
208265 prompt_ids : list [int ] = []
266+ recent_hashes : list [str ] = []
209267
210268 for step in range (max_steps ):
211269 screenshot = obs .screenshot or b""
212270
271+ # --- Stuck detection (P1) ---
272+ # Track screenshot hashes to detect when the agent is looping on an
273+ # identical screen (no learning signal). Ported from standalone
274+ # trainer's WAADirect.is_stuck().
275+ if stuck_threshold > 0 :
276+ screenshot_hash = hashlib .md5 (screenshot ).hexdigest ()
277+ recent_hashes .append (screenshot_hash )
278+ if len (recent_hashes ) > stuck_threshold :
279+ recent_hashes .pop (0 )
280+ if (
281+ len (recent_hashes ) == stuck_threshold
282+ and len (set (recent_hashes )) == 1
283+ ):
284+ logger .warning (
285+ "Stuck detected: %d identical screenshots in a row. "
286+ "Breaking episode early." ,
287+ stuck_threshold ,
288+ )
289+ break
290+
213291 # Generate action from VLM
214292 action_text , token_ids , logprobs = generate_fn (screenshot , task_instruction )
215293
@@ -260,6 +338,9 @@ def make_waa_rollout_func(
260338 constrained_decoding : bool = False ,
261339 max_new_tokens : int = 256 ,
262340 temperature : float = 1.0 ,
341+ screenshot_retries : int = 3 ,
342+ screenshot_retry_delay : float = 1.0 ,
343+ stuck_threshold : int = 3 ,
263344) -> Callable :
264345 """Create a TRL-compatible rollout_func for WAA environments.
265346
@@ -276,6 +357,14 @@ def make_waa_rollout_func(
276357 Requires ``pip install outlines>=0.1.0``.
277358 max_new_tokens: Maximum tokens per generation step.
278359 temperature: Sampling temperature for generation.
360+ screenshot_retries: Number of retry attempts when a screenshot is
361+ corrupt (cannot be opened by PIL). Ported from the standalone
362+ trainer's screenshot retry logic.
363+ screenshot_retry_delay: Seconds to sleep between screenshot retry
364+ attempts.
365+ stuck_threshold: Number of consecutive identical screenshots before
366+ breaking an episode early. Set to 0 to disable stuck detection.
367+ Ported from the standalone trainer's WAADirect.is_stuck().
279368
280369 Returns:
281370 A callable suitable for GRPOTrainer(rollout_func=...).
@@ -309,6 +398,36 @@ def rollout_func(prompts: list[str], trainer: Any) -> dict[str, list]:
309398
310399 num_generations = getattr (trainer .args , "num_generations" , 8 )
311400
401+ # --- Pre-rollout health check (P0) ---
402+ # Verify WAA server is responsive before committing GPU time to a
403+ # full batch of rollouts. Ported from standalone trainer's
404+ # _collect_group() which calls probe() before each group.
405+ # Skip for mock adapters (unittest.mock.MagicMock or WAAMockAdapter).
406+ _mod = getattr (type (adapter ), "__module__" , "" ) or ""
407+ _name = type (adapter ).__name__ .lower ()
408+ _is_mock = "mock" in _name or "mock" in _mod
409+ if not _is_mock :
410+ try :
411+ health_obs = adapter .observe ()
412+ screenshot = getattr (health_obs , "screenshot" , None )
413+ if screenshot is not None and isinstance (screenshot , bytes ) \
414+ and len (screenshot ) < 100 :
415+ logger .warning (
416+ "WAA server health check failed (screenshot=%d bytes) "
417+ "-- returning zero rewards for %d prompts" ,
418+ len (screenshot ),
419+ len (prompts ),
420+ )
421+ return _empty_rollout_result (prompts , num_generations )
422+ except Exception as exc :
423+ logger .warning (
424+ "WAA server unreachable: %s -- returning zero rewards for "
425+ "%d prompts" ,
426+ exc ,
427+ len (prompts ),
428+ )
429+ return _empty_rollout_result (prompts , num_generations )
430+
312431 # Lazy-init Outlines generator on first call
313432 if constrained_decoding and not _outlines_state ["attempted" ]:
314433 _outlines_state ["attempted" ] = True
@@ -327,11 +446,34 @@ def generate_fn(screenshot_bytes: bytes, instruction: str):
327446 """Generate action tokens from screenshot + instruction."""
328447 from PIL import Image
329448
330- # Build multimodal input
331- img = Image .open (io .BytesIO (screenshot_bytes ))
332- if img .mode != "RGB" :
333- img = img .convert ("RGB" )
334- img .format = "PNG"
449+ # --- Corrupt screenshot retry (P0) ---
450+ # On Azure VMs with QEMU, ~1-5% of screenshots are corrupt.
451+ # Retry with a brief delay rather than crashing the entire
452+ # rollout. Ported from standalone trainer's _collect_rollout().
453+ img = None
454+ for attempt in range (screenshot_retries ):
455+ try :
456+ img = Image .open (io .BytesIO (screenshot_bytes ))
457+ if img .mode != "RGB" :
458+ img = img .convert ("RGB" )
459+ img .format = "PNG"
460+ break
461+ except Exception as exc :
462+ if attempt < screenshot_retries - 1 :
463+ logger .warning (
464+ "Corrupt screenshot (attempt %d/%d): %s" ,
465+ attempt + 1 ,
466+ screenshot_retries ,
467+ exc ,
468+ )
469+ time .sleep (screenshot_retry_delay )
470+ else :
471+ logger .error (
472+ "Screenshot corrupt after %d attempts, "
473+ "returning DONE action" ,
474+ screenshot_retries ,
475+ )
476+ return "done" , [], []
335477
336478 messages = [
337479 {"role" : "system" , "content" : SYSTEM_PROMPT },
0 commit comments