Skip to content

Commit 43cac1c

Browse files
abrichrclaude
andauthored
fix: align standalone GRPO with WAA API format and add retry logic (#193)
The standalone GRPO trainer produced zero rewards due to two API format bugs in WAADirect: 1. screenshot() tried resp.json() expecting base64-encoded JSON, but WAA's /screenshot returns raw PNG bytes via Flask's send_file(). Fixed to use resp.content (matching WAALiveAdapter). 2. execute_action() wrapped commands in `python -c "..."`, but WAA's /execute_windows uses exec() directly -- the wrapper caused SyntaxError inside the VM. Fixed to send bare Python statements (matching WAALiveAdapter._build_pixel_command). Additional improvements: - Add probe() method for structured health checking - Add screenshot retry logic (3 attempts with 2s delay) - Add double_click, right_click, scroll action types - Fix type action to click target first then type (match WAALiveAdapter) - Add pre-rollout health check in trainer._collect_group() - Handle empty rollouts gracefully in training loop - Fix train script to bypass openadapt_evals/__init__.py eager imports (open_clip -> numpy ABI crash in minimal training environments) Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 47f1126 commit 43cac1c

3 files changed

Lines changed: 160 additions & 25 deletions

File tree

openadapt_evals/training/standalone/trainer.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,13 @@ def _collect_rollout(self, task_id: str, instruction: str) -> Rollout:
7575
recent: list[bytes] = []
7676

7777
for step_idx in range(self._config.max_steps_per_episode):
78+
# screenshot() already has built-in retry (3 attempts by default)
7879
try:
7980
screenshot = self._env.screenshot()
8081
except Exception as e:
81-
logger.warning("Screenshot failed at step %d: %s", step_idx, e)
82+
logger.warning(
83+
"Screenshot failed at step %d after retries: %s", step_idx, e,
84+
)
8285
break
8386
recent.append(screenshot)
8487
if self._env.is_stuck(recent, window=self._config.stuck_window):
@@ -124,6 +127,19 @@ def _collect_rollout(self, task_id: str, instruction: str) -> Rollout:
124127

125128
def _collect_group(self, task_id: str) -> list[Rollout]:
126129
"""Collect N rollouts for one GRPO gradient step."""
130+
assert self._env is not None
131+
132+
# Pre-rollout health check: verify WAA is responsive before committing
133+
# to a full group of rollouts (avoids wasting time on a dead server).
134+
probe = self._env.probe()
135+
if not probe.get("screenshot_ok"):
136+
logger.error(
137+
"Pre-rollout health check FAILED for task %s: %s — "
138+
"skipping group (returning empty rollouts)",
139+
task_id, probe,
140+
)
141+
return []
142+
127143
tc = self._task_configs.get(task_id)
128144
instruction = getattr(tc, "name", "") or task_id if tc else task_id
129145
if tc and self._env:
@@ -242,6 +258,12 @@ def train(self) -> str:
242258
self._model.eval()
243259
rollouts = self._collect_group(task_id)
244260
self._model.train()
261+
if not rollouts:
262+
logger.warning(
263+
"Step %d/%d: no rollouts collected (server may be down), skipping.",
264+
step + 1, self._config.num_training_steps,
265+
)
266+
continue
245267
m = self._training_step(rollouts)
246268
m.update({"step": step, "task_id": task_id, "elapsed": time.time() - t0, "step_time": time.time() - ts})
247269
logger.info("Step %d/%d: reward=%.2f loss=%.4f time=%.1fs",

openadapt_evals/training/standalone/waa_direct.py

Lines changed: 97 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from __future__ import annotations
44

5-
import base64
65
import hashlib
76
import logging
87
import time
@@ -15,6 +14,10 @@
1514

1615
logger = logging.getLogger(__name__)
1716

17+
# Default retry parameters for screenshot
18+
SCREENSHOT_MAX_RETRIES = 3
19+
SCREENSHOT_RETRY_DELAY = 2.0 # seconds
20+
1821

1922
@dataclass
2023
class RolloutStep:
@@ -37,42 +40,97 @@ class Rollout:
3740

3841

3942
class WAADirect:
40-
"""Direct HTTP client for WAA Flask server. Screenshot/click/type/key."""
43+
"""Direct HTTP client for WAA Flask server. Screenshot/click/type/key.
44+
45+
WAA API contract (from WAA Flask server main.py):
46+
GET /screenshot -> raw PNG bytes (Content-Type: image/png)
47+
POST /execute_windows -> exec(command, {'computer': computer, 'human': human})
48+
Payload: {"command": "<python code>"}
49+
The command is Python code executed via exec() with pyautogui available.
50+
Do NOT wrap in ``python -c "..."`` -- send bare Python statements.
51+
"""
4152

4253
def __init__(self, server_url: str = "http://localhost:5001",
4354
screen_size: tuple[int, int] = (1920, 1080)) -> None:
4455
self.server_url = server_url.rstrip("/")
4556
self.screen_size = screen_size
4657
self._session = requests.Session()
4758

48-
def screenshot(self) -> bytes:
49-
"""Take a fresh screenshot. Returns PNG bytes."""
50-
resp = self._session.get(f"{self.server_url}/screenshot", timeout=30)
51-
if resp.status_code != 200:
52-
raise RuntimeError(f"Screenshot failed: {resp.status_code}")
53-
data = resp.json()
54-
img_b64 = data.get("screenshot", data.get("image", ""))
55-
if not img_b64:
56-
raise RuntimeError("No screenshot data in response")
57-
return base64.b64decode(img_b64)
59+
def screenshot(self, max_retries: int = SCREENSHOT_MAX_RETRIES,
60+
retry_delay: float = SCREENSHOT_RETRY_DELAY) -> bytes:
61+
"""Take a fresh screenshot. Returns raw PNG bytes.
62+
63+
WAA's /screenshot endpoint returns raw PNG via Flask's send_file(),
64+
NOT base64-encoded JSON. Read resp.content, not resp.json().
65+
"""
66+
last_exc: Exception | None = None
67+
for attempt in range(1, max_retries + 1):
68+
try:
69+
resp = self._session.get(
70+
f"{self.server_url}/screenshot", timeout=30,
71+
)
72+
if resp.status_code != 200:
73+
raise RuntimeError(
74+
f"Screenshot HTTP {resp.status_code}: {resp.text[:200]}"
75+
)
76+
png_bytes = resp.content
77+
if len(png_bytes) < 100:
78+
raise RuntimeError(
79+
f"Screenshot too small ({len(png_bytes)} bytes) -- "
80+
"server may not be ready"
81+
)
82+
return png_bytes
83+
except Exception as e:
84+
last_exc = e
85+
logger.warning(
86+
"Screenshot attempt %d/%d failed: %s",
87+
attempt, max_retries, e,
88+
)
89+
if attempt < max_retries:
90+
time.sleep(retry_delay)
91+
raise RuntimeError(
92+
f"Screenshot failed after {max_retries} attempts"
93+
) from last_exc
5894

5995
def execute_action(self, action: SimpleAction) -> dict[str, Any]:
60-
"""Execute action on VM via /execute_windows."""
96+
"""Execute action on VM via /execute_windows.
97+
98+
WAA's /execute_windows does ``exec(command, {'computer': ..., 'human': ...})``.
99+
The command must be bare Python code -- NOT wrapped in ``python -c "..."``.
100+
pyautogui is available via import inside the exec'd code.
101+
"""
61102
if action.type == "click":
62103
x, y = int(action.x or 0), int(action.y or 0)
63-
cmd = f'python -c "import pyautogui; pyautogui.click({x}, {y})"'
104+
cmd = f"import pyautogui; pyautogui.click({x}, {y})"
105+
elif action.type == "double_click":
106+
x, y = int(action.x or 0), int(action.y or 0)
107+
cmd = f"import pyautogui; pyautogui.doubleClick({x}, {y})"
108+
elif action.type == "right_click":
109+
x, y = int(action.x or 0), int(action.y or 0)
110+
cmd = f"import pyautogui; pyautogui.rightClick({x}, {y})"
64111
elif action.type == "type":
65-
text = (action.text or "").replace('"', '\\"')
66-
cmd = f'python -c "import pyautogui; pyautogui.typewrite(\'{text}\', interval=0.05)"'
112+
text = (action.text or "").replace("\\", "\\\\").replace("'", "\\'")
113+
x, y = int(action.x or 0), int(action.y or 0)
114+
# Click target first, then type (matches WAALiveAdapter pattern)
115+
cmd = (
116+
f"import pyautogui; import time; "
117+
f"pyautogui.click({x}, {y}); "
118+
f"time.sleep(0.2); "
119+
f"pyautogui.typewrite('{text}', interval=0.05)"
120+
)
67121
elif action.type == "key":
68-
cmd = f'python -c "import pyautogui; pyautogui.press(\'{action.key or "enter"}\')"'
122+
key = action.key or "enter"
123+
cmd = f"import pyautogui; pyautogui.press('{key}')"
124+
elif action.type == "scroll":
125+
x, y = int(action.x or 0), int(action.y or 0)
126+
cmd = f"import pyautogui; pyautogui.scroll(-3, x={x}, y={y})"
69127
elif action.type == "wait":
70128
time.sleep(2)
71129
return {"status": "ok", "action": "wait"}
72130
elif action.type == "done":
73131
return {"status": "ok", "action": "done"}
74132
else:
75-
return {"status": "error", "message": f"Unknown: {action.type}"}
133+
return {"status": "error", "message": f"Unknown action type: {action.type}"}
76134

77135
resp = self._session.post(
78136
f"{self.server_url}/execute_windows", json={"command": cmd}, timeout=30,
@@ -117,9 +175,25 @@ def is_stuck(self, recent: list[bytes], window: int = 3) -> bool:
117175
hashes = [hashlib.md5(s).hexdigest() for s in recent[-window:]]
118176
return len(set(hashes)) == 1
119177

120-
def health_check(self) -> bool:
121-
"""True if WAA server responds."""
178+
def probe(self, timeout: float = 10.0) -> dict[str, Any]:
179+
"""Health-check the WAA server. Returns status dict.
180+
181+
Attempts a screenshot to verify the full pipeline (not just HTTP).
182+
"""
183+
result: dict[str, Any] = {"reachable": False, "screenshot_ok": False}
122184
try:
123-
return self._session.get(f"{self.server_url}/screenshot", timeout=10).status_code == 200
124-
except requests.RequestException:
125-
return False
185+
resp = self._session.get(
186+
f"{self.server_url}/screenshot", timeout=timeout,
187+
)
188+
result["reachable"] = True
189+
result["status_code"] = resp.status_code
190+
if resp.status_code == 200:
191+
result["screenshot_ok"] = len(resp.content) > 100
192+
result["screenshot_bytes"] = len(resp.content)
193+
except requests.RequestException as e:
194+
result["error"] = str(e)
195+
return result
196+
197+
def health_check(self) -> bool:
198+
"""True if WAA server responds with a valid screenshot."""
199+
return self.probe().get("screenshot_ok", False)

scripts/train_grpo_standalone.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,48 @@
1111
1212
Or equivalently via module:
1313
python -m openadapt_evals.training.standalone.trainer --task-dir ...
14+
15+
NOTE: We must avoid triggering openadapt_evals/__init__.py, which eagerly
16+
imports agents/adapters/demo_library/benchmarks. The demo_library import
17+
pulls in open_clip at module level, which can crash in minimal training
18+
environments (e.g., numpy ABI mismatch). We work around this by inserting
19+
a lightweight shim into sys.modules for the top-level package before any
20+
sub-imports run.
1421
"""
1522

16-
from openadapt_evals.training.standalone.trainer import main
23+
import importlib
24+
import sys
25+
import types
26+
from pathlib import Path
27+
28+
29+
def _ensure_lightweight_package(pkg_name: str, pkg_dir: Path) -> None:
30+
"""Register a package in sys.modules without executing its __init__.py.
31+
32+
This lets us ``import openadapt_evals.training.standalone.trainer``
33+
without the top-level ``openadapt_evals/__init__.py`` running its
34+
heavy re-exports (agents, adapters, demo_library, benchmarks).
35+
"""
36+
if pkg_name in sys.modules:
37+
return
38+
pkg = types.ModuleType(pkg_name)
39+
pkg.__path__ = [str(pkg_dir)]
40+
pkg.__package__ = pkg_name
41+
sys.modules[pkg_name] = pkg
42+
43+
44+
def main() -> None:
45+
root = Path(__file__).resolve().parent.parent
46+
pkg_root = root / "openadapt_evals"
47+
48+
# Shim only the top-level package; sub-packages have lightweight __init__.py
49+
_ensure_lightweight_package("openadapt_evals", pkg_root)
50+
51+
# Now the standalone trainer can be imported without pulling in the
52+
# full agents/adapters/benchmarks dependency tree.
53+
mod = importlib.import_module("openadapt_evals.training.standalone.trainer")
54+
mod.main()
55+
1756

1857
if __name__ == "__main__":
1958
main()

0 commit comments

Comments
 (0)