Skip to content

Commit 879c53c

Browse files
abrichrclaude
andauthored
fix: address review findings in verl-agent adapter (#88)
- Fix SCROLL direction not forwarded to BenchmarkAction.scroll_direction - Fix DRAG parsing to include end_x/end_y coordinates - Fix is_action_valid logic: use pattern match instead of inverted condition - Fix fractional coord conversion: trust _use_fractional flag instead of checking value ranges (0 and 1 are ambiguous between frac and pixel) - Convert drag end coordinates (end_x/end_y) from fractional to pixel - Add health_check() method returning ready/busy/needs_recovery/not_initialized - Add DRAG to system prompt DSL documentation - Fix vendored VAGEN source URL (mll-lab-nu -> RAGEN-AI) - Add 12 new tests: scroll direction, drag coords, health_check, is_action_valid Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 03b4f3d commit 879c53c

4 files changed

Lines changed: 143 additions & 23 deletions

File tree

openadapt_evals/adapters/_vendored/gym_base_env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Vendored from https://github.com/mll-lab-nu/VAGEN
1+
# Vendored from https://github.com/RAGEN-AI/VAGEN
22
# These are pure abstract base classes with no heavy dependencies.
33
# Vendored to avoid requiring the full VAGEN installation.
44
# Last synced: 2026-03-02

openadapt_evals/adapters/_vendored/gym_image_env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Vendored from https://github.com/mll-lab-nu/VAGEN
1+
# Vendored from https://github.com/RAGEN-AI/VAGEN
22
# These are pure abstract base classes with no heavy dependencies.
33
# Vendored to avoid requiring the full VAGEN installation.
44
# Last synced: 2026-03-02

openadapt_evals/adapters/verl_env.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def _parse_action_str(action_str: str) -> BenchmarkAction:
8181
TYPE(text="hello world")
8282
KEY(key="enter")
8383
SCROLL(x=0.50, y=0.50, direction="down")
84+
DRAG(x=0.20, y=0.30, end_x=0.80, end_y=0.70)
8485
WAIT()
8586
DONE()
8687
"""
@@ -113,6 +114,7 @@ def _parse_action_str(action_str: str) -> BenchmarkAction:
113114
type="scroll",
114115
x=float(kwargs.get("x", 0.5)),
115116
y=float(kwargs.get("y", 0.5)),
117+
scroll_direction=kwargs.get("direction", "down"),
116118
)
117119
elif cmd == "WAIT":
118120
return BenchmarkAction(type="wait")
@@ -123,6 +125,8 @@ def _parse_action_str(action_str: str) -> BenchmarkAction:
123125
type="drag",
124126
x=float(kwargs.get("x", 0.5)),
125127
y=float(kwargs.get("y", 0.5)),
128+
end_x=float(kwargs["end_x"]) if "end_x" in kwargs else None,
129+
end_y=float(kwargs["end_y"]) if "end_y" in kwargs else None,
126130
)
127131
else:
128132
return BenchmarkAction(type="done")
@@ -163,6 +167,7 @@ def _build_obs_dict(
163167
" TYPE(text=\"<text>\") - type text\n"
164168
" KEY(key=\"<name>\") - press a key (enter, tab, escape, ctrl+a, etc.)\n"
165169
" SCROLL(x=<frac>, y=<frac>, direction=\"up\"|\"down\") - scroll\n"
170+
" DRAG(x=<frac>, y=<frac>, end_x=<frac>, end_y=<frac>) - drag\n"
166171
" WAIT() - wait for the screen to update\n"
167172
" DONE() - task is complete\n"
168173
"\n"
@@ -216,6 +221,42 @@ def _ensure_env(self) -> RLEnvironment:
216221
self._rl_env = RLEnvironment(adapter, default_task_id=self._task_id)
217222
return self._rl_env
218223

224+
async def health_check(self) -> dict[str, Any]:
225+
"""Check environment health status.
226+
227+
Returns a dict with:
228+
status: "ready" | "busy" | "needs_recovery" | "not_initialized"
229+
server_url: The WAA server URL
230+
step_count: Current step count in episode
231+
232+
Use this from a pool controller to decide whether to send work
233+
to this environment or retire/restart it.
234+
"""
235+
if self._rl_env is None:
236+
return {"status": "not_initialized", "server_url": self._server_url}
237+
238+
adapter = self._rl_env.adapter
239+
# Check if the WAA server is reachable
240+
try:
241+
check_fn = getattr(adapter, "check_connection", None)
242+
reachable = await asyncio.to_thread(check_fn) if check_fn else True
243+
except Exception:
244+
reachable = False
245+
246+
if not reachable:
247+
return {
248+
"status": "needs_recovery",
249+
"server_url": self._server_url,
250+
"step_count": self._step_count,
251+
}
252+
253+
status = "busy" if self._step_count > 0 and not self._rl_env.done else "ready"
254+
return {
255+
"status": status,
256+
"server_url": self._server_url,
257+
"step_count": self._step_count,
258+
}
259+
219260
async def close(self) -> None:
220261
"""Release resources."""
221262
if self._rl_env is not None:
@@ -270,13 +311,20 @@ async def step(
270311
# Parse the LLM output into a BenchmarkAction
271312
action = _parse_action_str(action_str)
272313

273-
# Handle fractional → pixel coordinate conversion
314+
# Handle fractional → pixel coordinate conversion.
315+
# When _use_fractional is True, all coordinates from the parser are
316+
# fractions (0.0-1.0). We convert unconditionally rather than checking
317+
# value ranges, since pixel values 0 and 1 would be ambiguous.
274318
if self._use_fractional and action.type in ("click", "scroll", "drag"):
275319
w, h = env.screen_size
276-
if action.x is not None and 0.0 <= action.x <= 1.0:
320+
if action.x is not None:
277321
action.x = int(action.x * w)
278-
if action.y is not None and 0.0 <= action.y <= 1.0:
322+
if action.y is not None:
279323
action.y = int(action.y * h)
324+
if action.end_x is not None:
325+
action.end_x = int(action.end_x * w)
326+
if action.end_y is not None:
327+
action.end_y = int(action.end_y * h)
280328

281329
# Execute action in a thread
282330
rollout_step = await asyncio.to_thread(env.step, action)
@@ -287,7 +335,8 @@ async def step(
287335
# Compute reward
288336
reward = 0.0
289337
info: dict[str, Any] = rollout_step.info
290-
info["is_action_valid"] = action.type != "done" or action_str.strip() == ""
338+
# Action is valid if it was explicitly parsed (not the fallback for unparseable input)
339+
info["is_action_valid"] = _ACTION_PATTERN.search(action_str) is not None
291340

292341
if done and self._evaluate_at_done:
293342
try:

tests/test_verl_env.py

Lines changed: 88 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from openadapt_evals.adapters.rl_env import RLEnvironment
1414
from openadapt_evals.adapters.verl_env import (
1515
WAADesktopEnv,
16+
_ACTION_PATTERN,
1617
_build_obs_dict,
1718
_parse_action_str,
1819
)
@@ -52,14 +53,37 @@ def test_done(self):
5253
action = _parse_action_str("DONE()")
5354
assert action.type == "done"
5455

55-
def test_scroll(self):
56+
def test_scroll_with_direction(self):
5657
action = _parse_action_str('SCROLL(x=0.50, y=0.50, direction="down")')
5758
assert action.type == "scroll"
59+
assert action.scroll_direction == "down"
60+
61+
def test_scroll_up(self):
62+
action = _parse_action_str('SCROLL(x=0.50, y=0.50, direction="up")')
63+
assert action.scroll_direction == "up"
64+
65+
def test_scroll_default_direction(self):
66+
action = _parse_action_str("SCROLL(x=0.50, y=0.50)")
67+
assert action.scroll_direction == "down"
5868

5969
def test_invalid_returns_done(self):
6070
action = _parse_action_str("random garbage text")
6171
assert action.type == "done"
6272

73+
def test_drag_with_end_coords(self):
74+
action = _parse_action_str("DRAG(x=0.20, y=0.30, end_x=0.80, end_y=0.70)")
75+
assert action.type == "drag"
76+
assert action.x == pytest.approx(0.20)
77+
assert action.y == pytest.approx(0.30)
78+
assert action.end_x == pytest.approx(0.80)
79+
assert action.end_y == pytest.approx(0.70)
80+
81+
def test_drag_without_end_coords(self):
82+
action = _parse_action_str("DRAG(x=0.20, y=0.30)")
83+
assert action.type == "drag"
84+
assert action.end_x is None
85+
assert action.end_y is None
86+
6387
def test_with_thinking(self):
6488
action = _parse_action_str(
6589
"<think>I need to click the button</think>\nCLICK(x=0.25, y=0.75)"
@@ -68,14 +92,21 @@ def test_with_thinking(self):
6892
assert action.x == pytest.approx(0.25)
6993
assert action.y == pytest.approx(0.75)
7094

95+
def test_invalid_action_not_matched(self):
96+
"""Unparseable input should not match the action pattern."""
97+
assert _ACTION_PATTERN.search("random garbage") is None
98+
99+
def test_explicit_done_is_matched(self):
100+
"""Explicit DONE() should match the action pattern."""
101+
assert _ACTION_PATTERN.search("DONE()") is not None
102+
71103

72104
# --- Observation building tests ---
73105

74106

75107
class TestBuildObsDict:
76108
def test_with_screenshot(self):
77109
"""Test obs dict with PNG bytes."""
78-
# Create a minimal valid PNG (1x1 red pixel)
79110
from PIL import Image
80111
import io
81112

@@ -137,6 +168,7 @@ def test_system_prompt(self):
137168
assert "obs_str" in result
138169
assert "CLICK" in result["obs_str"]
139170
assert "TYPE" in result["obs_str"]
171+
assert "DRAG" in result["obs_str"]
140172
assert "DONE" in result["obs_str"]
141173

142174
def test_reset_returns_obs_dict(self):
@@ -165,7 +197,6 @@ def test_step_done_triggers_eval(self):
165197
asyncio.run(env.reset(seed=42))
166198
obs, reward, done, info = asyncio.run(env.step("DONE()"))
167199
assert done is True
168-
# Reward should be a float from evaluation (mock evaluator)
169200
assert isinstance(reward, float)
170201

171202
def test_max_steps_triggers_done(self):
@@ -186,15 +217,13 @@ def test_close(self):
186217
assert env._rl_env is None
187218

188219
def test_full_episode_flow(self):
189-
"""Test a complete episode: reset multiple steps done evaluate."""
220+
"""Test a complete episode: reset -> multiple steps -> done -> evaluate."""
190221
env = _make_mock_env()
191222
env._max_steps = 5
192223

193-
# Reset
194224
obs, info = asyncio.run(env.reset(seed=1))
195225
assert "obs_str" in obs
196226

197-
# Take some actions
198227
obs, r, done, _ = asyncio.run(env.step("CLICK(x=0.05, y=0.08)"))
199228
assert not done
200229
assert r == 0.0
@@ -203,28 +232,70 @@ def test_full_episode_flow(self):
203232
assert not done
204233
assert r == 0.0
205234

206-
# Finish
207235
obs, r, done, info = asyncio.run(env.step("DONE()"))
208236
assert done
209237
assert isinstance(r, float)
210238

211239
def test_protocol_has_required_methods(self):
212240
"""Verify WAADesktopEnv has all GymImageEnv protocol methods."""
213241
env = _make_mock_env()
214-
assert hasattr(env, "reset")
215-
assert hasattr(env, "step")
216-
assert hasattr(env, "close")
217-
assert hasattr(env, "system_prompt")
218-
assert callable(env.reset)
219-
assert callable(env.step)
220-
assert callable(env.close)
221-
assert callable(env.system_prompt)
242+
for method in ("reset", "step", "close", "system_prompt", "health_check"):
243+
assert hasattr(env, method)
244+
assert callable(getattr(env, method))
222245

223246
def test_obs_contains_image_placeholder(self):
224247
"""Test that observations with screenshots include <image> placeholder."""
225248
env = _make_mock_env()
226249
obs, _ = asyncio.run(env.reset(seed=42))
227-
# Mock adapter returns observations that may or may not have screenshots
228-
# At minimum, obs_str should be present
229250
assert "obs_str" in obs
230251
assert isinstance(obs["obs_str"], str)
252+
253+
# --- health_check tests ---
254+
255+
def test_health_check_not_initialized(self):
256+
"""Health check before reset returns not_initialized."""
257+
env = WAADesktopEnv.__new__(WAADesktopEnv)
258+
env._rl_env = None
259+
env._server_url = "mock"
260+
env._step_count = 0
261+
result = asyncio.run(env.health_check())
262+
assert result["status"] == "not_initialized"
263+
264+
def test_health_check_ready_after_episode(self):
265+
"""Health check after completed episode returns ready."""
266+
env = _make_mock_env()
267+
asyncio.run(env.reset(seed=42))
268+
asyncio.run(env.step("DONE()"))
269+
result = asyncio.run(env.health_check())
270+
assert result["status"] == "ready"
271+
272+
def test_health_check_busy_mid_episode(self):
273+
"""Health check mid-episode returns busy."""
274+
env = _make_mock_env()
275+
asyncio.run(env.reset(seed=42))
276+
asyncio.run(env.step("CLICK(x=0.5, y=0.5)"))
277+
result = asyncio.run(env.health_check())
278+
assert result["status"] == "busy"
279+
280+
# --- is_action_valid tests ---
281+
282+
def test_is_action_valid_for_parsed_action(self):
283+
"""Actions that parse successfully should be marked valid."""
284+
env = _make_mock_env()
285+
asyncio.run(env.reset(seed=42))
286+
_, _, _, info = asyncio.run(env.step("CLICK(x=0.5, y=0.5)"))
287+
assert info["is_action_valid"] is True
288+
289+
def test_is_action_valid_for_done(self):
290+
"""Explicit DONE() should be marked valid."""
291+
env = _make_mock_env()
292+
asyncio.run(env.reset(seed=42))
293+
_, _, _, info = asyncio.run(env.step("DONE()"))
294+
assert info["is_action_valid"] is True
295+
296+
def test_is_action_invalid_for_garbage(self):
297+
"""Unparseable input should be marked invalid."""
298+
env = _make_mock_env()
299+
asyncio.run(env.reset(seed=42))
300+
_, _, _, info = asyncio.run(env.step("random garbage"))
301+
assert info["is_action_valid"] is False

0 commit comments

Comments
 (0)