Skip to content

Commit 57a8094

Browse files
abrichrclaude
andauthored
feat: add desktop cleanup, manual demo tools, and fix trainer OOM bugs (#195)
- Add clean_desktop() to WAADirect to kill known distracting apps between episodes, preventing stale desktop state from leaking across phases - Handle close_all config entry type in WAADirect.setup_task() - Create manual notepad-hello demo (DemoLibrary-compatible, no screenshots) - Add scripts/create_manual_demo.py CLI for authoring demos from text specs - Fix vision tensor exclusion in GRPO loss computation (OOM on L40S) - Add try/except for float parsing in parse_vlm_output_to_action - Lower max_new_tokens default from 2048 to 512 (prevents OOM, sufficient for Thought+Action format) Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c6b1fdd commit 57a8094

6 files changed

Lines changed: 400 additions & 6 deletions

File tree

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
{
2+
"task_id": "custom-notepad-hello",
3+
"demo_id": "manual",
4+
"description": "Open Notepad via Run dialog and type Hello World",
5+
"created_at": "2026-03-26T00:00:00+00:00",
6+
"metadata": {
7+
"resolution": {
8+
"width": 1280,
9+
"height": 720
10+
},
11+
"source": "manual",
12+
"note": "Manually authored demo (no screenshots). Descriptions provide step-by-step guidance for planners."
13+
},
14+
"steps": [
15+
{
16+
"step_index": 0,
17+
"screenshot_path": "",
18+
"action_type": "key",
19+
"action_description": "KEY(win+r)",
20+
"target_description": "Desktop / taskbar",
21+
"action_value": "win+r",
22+
"x": null,
23+
"y": null,
24+
"description": "Press Win+R to open the Run dialog",
25+
"metadata": {}
26+
},
27+
{
28+
"step_index": 1,
29+
"screenshot_path": "",
30+
"action_type": "type",
31+
"action_description": "TYPE('notepad')",
32+
"target_description": "Run dialog Open field",
33+
"action_value": "notepad",
34+
"x": null,
35+
"y": null,
36+
"description": "Type 'notepad' in the Run dialog's Open field",
37+
"metadata": {}
38+
},
39+
{
40+
"step_index": 2,
41+
"screenshot_path": "",
42+
"action_type": "key",
43+
"action_description": "KEY(enter)",
44+
"target_description": "Run dialog",
45+
"action_value": "enter",
46+
"x": null,
47+
"y": null,
48+
"description": "Press Enter to launch Notepad",
49+
"metadata": {}
50+
},
51+
{
52+
"step_index": 3,
53+
"screenshot_path": "",
54+
"action_type": "click",
55+
"action_description": "CLICK(0.40, 0.40)",
56+
"target_description": "Notepad text editing area",
57+
"action_value": "",
58+
"x": 0.4,
59+
"y": 0.4,
60+
"description": "Click in the Notepad text editing area to ensure focus",
61+
"metadata": {}
62+
},
63+
{
64+
"step_index": 4,
65+
"screenshot_path": "",
66+
"action_type": "type",
67+
"action_description": "TYPE('Hello World')",
68+
"target_description": "Notepad text editing area",
69+
"action_value": "Hello World",
70+
"x": null,
71+
"y": null,
72+
"description": "Type 'Hello World' in Notepad",
73+
"metadata": {}
74+
}
75+
]
76+
}

openadapt_evals/training/standalone/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class TrainingConfig:
1717
num_rollouts_per_step: int = 8
1818
max_steps_per_episode: int = 15
1919
temperature: float = 0.7
20-
max_new_tokens: int = 2048 # 100 truncates reasoning -- keep high
20+
max_new_tokens: int = 512 # 2048 OOMs on L40S; 512 sufficient for Thought+Action
2121
server_url: str = "http://localhost:5001"
2222
task_ids: list[str] = field(default_factory=list)
2323
task_dir: str | None = None

openadapt_evals/training/standalone/prompt.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,12 @@ def parse_vlm_output_to_action(
109109
# CLICK(x=..., y=...)
110110
m = re.search(r"CLICK\(x=(-?[\d.]+),\s*y=(-?[\d.]+)\)", text, re.IGNORECASE)
111111
if m:
112-
xf = max(0.0, min(1.0, float(m.group(1))))
113-
yf = max(0.0, min(1.0, float(m.group(2))))
114-
return SimpleAction(type="click", x=int(xf * width), y=int(yf * height))
112+
try:
113+
xf = max(0.0, min(1.0, float(m.group(1))))
114+
yf = max(0.0, min(1.0, float(m.group(2))))
115+
return SimpleAction(type="click", x=int(xf * width), y=int(yf * height))
116+
except (ValueError, OverflowError):
117+
logger.warning("Malformed CLICK coords: x=%s y=%s", m.group(1), m.group(2))
115118

116119
# TYPE(text="...")
117120
m = re.search(r"""TYPE\(text=["']([^"'\\]*(?:\\.[^"'\\]*)*)["']\)""", text, re.IGNORECASE)

openadapt_evals/training/standalone/trainer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,14 @@ def _compute_rollout_loss(self, rollout: Rollout, advantage: float, scale: float
186186
continue
187187

188188
full_ids = torch.cat([prompt_inputs["input_ids"], action_ids.to(prompt_inputs["input_ids"].device)], dim=1)
189-
full_inputs = dict(prompt_inputs)
189+
# Exclude vision tensors from loss forward pass to avoid OOM.
190+
# The vision encoder backward pass is expensive and unnecessary
191+
# since we only compute loss on action tokens (past prompt_len).
192+
# Proven fix from 7 training runs on L40S GPUs.
193+
_VISION_KEYS = {"pixel_values", "pixel_values_videos",
194+
"image_grid_thw", "video_grid_thw"}
195+
full_inputs = {k: v for k, v in prompt_inputs.items()
196+
if k not in _VISION_KEYS}
190197
full_inputs["input_ids"] = full_ids
191198
full_inputs["attention_mask"] = torch.ones_like(full_ids)
192199
full_inputs = {k: v.to(device) for k, v in full_inputs.items()}
@@ -288,7 +295,7 @@ def main() -> None:
288295
p.add_argument("--num-steps", type=int, default=10)
289296
p.add_argument("--num-rollouts", type=int, default=8)
290297
p.add_argument("--max-steps-per-episode", type=int, default=15)
291-
p.add_argument("--max-new-tokens", type=int, default=2048)
298+
p.add_argument("--max-new-tokens", type=int, default=512)
292299
p.add_argument("--output", default="checkpoints/grpo")
293300
p.add_argument("--no-4bit", action="store_true")
294301
p.add_argument("--eval-model", default="gpt-4.1-mini")

openadapt_evals/training/standalone/waa_direct.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,9 @@ def setup_task(self, task_config: dict[str, Any]) -> bool:
147147
params = entry.get("parameters", {})
148148
if etype == "sleep":
149149
time.sleep(params.get("seconds", 5))
150+
elif etype == "close_all":
151+
# Kill common desktop apps for a clean state
152+
self.clean_desktop()
150153
elif etype in ("execute", "command", "launch"):
151154
cmd = params.get("command", "")
152155
if cmd:
@@ -168,6 +171,50 @@ def setup_task(self, task_config: dict[str, Any]) -> bool:
168171
time.sleep(2)
169172
return True
170173

174+
def clean_desktop(
175+
self,
176+
kill_apps: list[str] | None = None,
177+
) -> bool:
178+
"""Kill known distracting apps and show desktop for a clean state.
179+
180+
Call between episodes (flywheel phases, GRPO rollouts) to prevent
181+
stale desktop state from leaking into the next episode.
182+
183+
Args:
184+
kill_apps: Process image names to kill. Defaults to common desktop
185+
apps that interfere with task execution.
186+
187+
Returns:
188+
True if cleanup commands executed (does not verify success).
189+
"""
190+
if kill_apps is None:
191+
kill_apps = [
192+
"notepad.exe", "Code.exe", "msedge.exe", "chrome.exe",
193+
"WINWORD.EXE", "EXCEL.EXE", "POWERPNT.EXE", "wordpad.exe",
194+
"mspaint.exe", "calc.exe", "explorer.exe",
195+
]
196+
# Build taskkill command for all apps in one call
197+
kill_cmds = " ".join(
198+
f"taskkill /F /IM {app}" for app in kill_apps
199+
)
200+
commands = [
201+
# Kill listed apps (errors are OK -- app may not be running)
202+
f"import subprocess; subprocess.run('{kill_cmds}', shell=True, capture_output=True)",
203+
# Show desktop (Win+D) to minimize any remaining windows
204+
"import pyautogui; import time; pyautogui.hotkey('win', 'd'); time.sleep(1)",
205+
]
206+
for cmd in commands:
207+
try:
208+
self._session.post(
209+
f"{self.server_url}/execute_windows",
210+
json={"command": cmd}, timeout=30,
211+
)
212+
except requests.RequestException as e:
213+
logger.warning("clean_desktop error: %s", e)
214+
time.sleep(2)
215+
logger.info("Desktop cleanup completed (killed %d app types)", len(kill_apps))
216+
return True
217+
171218
def is_stuck(self, recent: list[bytes], window: int = 3) -> bool:
172219
"""True if last N screenshots are identical."""
173220
if len(recent) < window:

0 commit comments

Comments
 (0)