|
| 1 | +"""OpenEnv-compatible WAA desktop environment. |
| 2 | +
|
| 3 | +Wraps RLEnvironment + WAALiveAdapter into the OpenEnv Environment |
| 4 | +interface. Each instance maps to one Windows VM session. |
| 5 | +
|
| 6 | +Can be used standalone (direct Python) or served via OpenEnv's |
| 7 | +create_app() as an HTTP+WebSocket server. |
| 8 | +
|
| 9 | +Usage (standalone): |
| 10 | + env = WAAOpenEnvEnvironment(server_url="http://localhost:5001") |
| 11 | + obs = env.reset(task_id="custom-notepad-hello") |
| 12 | + obs = env.step(WAAAction(type="click", x=0.5, y=0.3)) |
| 13 | + print(env.state) |
| 14 | +
|
| 15 | +Usage (server): |
| 16 | + from openenv.core.env_server.http_server import create_app |
| 17 | + from openadapt_evals.openenv.models import WAAAction, WAAObservation |
| 18 | + from openadapt_evals.openenv.environment import WAAOpenEnvEnvironment |
| 19 | +
|
| 20 | + app = create_app(WAAOpenEnvEnvironment, WAAAction, WAAObservation, |
| 21 | + env_name="waa_desktop") |
| 22 | + uvicorn.run(app, host="0.0.0.0", port=8000) |
| 23 | +""" |
| 24 | + |
| 25 | +from __future__ import annotations |
| 26 | + |
| 27 | +import base64 |
| 28 | +import logging |
| 29 | +from typing import Any, Optional |
| 30 | +from uuid import uuid4 |
| 31 | + |
| 32 | +from openadapt_evals.openenv.models import WAAAction, WAAObservation, WAAState |
| 33 | + |
| 34 | +logger = logging.getLogger(__name__) |
| 35 | + |
| 36 | + |
| 37 | +class WAAOpenEnvEnvironment: |
| 38 | + """OpenEnv-compatible WAA desktop environment. |
| 39 | +
|
| 40 | + Follows the OpenEnv Environment protocol (reset/step/state) without |
| 41 | + requiring openenv-core as an import-time dependency. When served via |
| 42 | + create_app(), OpenEnv discovers the methods via duck typing. |
| 43 | +
|
| 44 | + Args: |
| 45 | + server_url: WAA Flask server URL (default: http://localhost:5001). |
| 46 | + evaluate_url: Separate evaluate server URL (default: same as server_url). |
| 47 | + default_task_id: Task ID to use when reset() is called without one. |
| 48 | + max_steps: Maximum steps per episode. |
| 49 | + task_config_dir: Directory of YAML task configs for dense rewards. |
| 50 | + """ |
| 51 | + |
| 52 | + SUPPORTS_CONCURRENT_SESSIONS = False |
| 53 | + |
| 54 | + def __init__( |
| 55 | + self, |
| 56 | + server_url: str = "http://localhost:5001", |
| 57 | + evaluate_url: str | None = None, |
| 58 | + default_task_id: str | None = None, |
| 59 | + max_steps: int = 15, |
| 60 | + task_config_dir: str | None = None, |
| 61 | + **kwargs: Any, |
| 62 | + ): |
| 63 | + self._server_url = server_url |
| 64 | + self._evaluate_url = evaluate_url |
| 65 | + self._default_task_id = default_task_id |
| 66 | + self._max_steps = max_steps |
| 67 | + self._task_config_dir = task_config_dir |
| 68 | + self._rl_env = None |
| 69 | + self._task_configs: dict[str, Any] = {} |
| 70 | + self._state = WAAState() |
| 71 | + |
| 72 | + # Load task configs if directory provided |
| 73 | + if task_config_dir: |
| 74 | + self._load_task_configs(task_config_dir) |
| 75 | + |
| 76 | + def _load_task_configs(self, dir_path: str) -> None: |
| 77 | + """Load YAML task configs from a directory.""" |
| 78 | + from openadapt_evals.task_config import TaskConfig |
| 79 | + |
| 80 | + for tc in TaskConfig.from_dir(dir_path): |
| 81 | + self._task_configs[tc.id] = tc |
| 82 | + self._task_configs[tc.name] = tc |
| 83 | + |
| 84 | + def _ensure_rl_env(self): |
| 85 | + """Lazily create the RLEnvironment + adapter.""" |
| 86 | + if self._rl_env is not None: |
| 87 | + return self._rl_env |
| 88 | + |
| 89 | + from openadapt_evals.adapters.rl_env import RLEnvironment |
| 90 | + from openadapt_evals.adapters.waa.live import WAALiveAdapter, WAALiveConfig |
| 91 | + |
| 92 | + adapter = WAALiveAdapter( |
| 93 | + WAALiveConfig( |
| 94 | + server_url=self._server_url, |
| 95 | + evaluate_url=self._evaluate_url, |
| 96 | + ) |
| 97 | + ) |
| 98 | + self._rl_env = RLEnvironment( |
| 99 | + adapter, default_task_id=self._default_task_id |
| 100 | + ) |
| 101 | + return self._rl_env |
| 102 | + |
| 103 | + def reset( |
| 104 | + self, |
| 105 | + seed: int | None = None, |
| 106 | + episode_id: str | None = None, |
| 107 | + **kwargs: Any, |
| 108 | + ) -> WAAObservation: |
| 109 | + """Reset the environment to a task's initial state. |
| 110 | +
|
| 111 | + Args: |
| 112 | + seed: Random seed (unused, for OpenEnv compatibility). |
| 113 | + episode_id: Episode identifier. |
| 114 | + **kwargs: May include task_id to override default. |
| 115 | +
|
| 116 | + Returns: |
| 117 | + Initial WAAObservation with screenshot. |
| 118 | + """ |
| 119 | + env = self._ensure_rl_env() |
| 120 | + task_id = kwargs.get("task_id", self._default_task_id) |
| 121 | + |
| 122 | + # Load TaskConfig for dense rewards if available |
| 123 | + tc = self._task_configs.get(task_id) |
| 124 | + if tc: |
| 125 | + env.load_task_config(tc) |
| 126 | + |
| 127 | + from openadapt_evals.adapters.rl_env import ResetConfig |
| 128 | + |
| 129 | + obs = env.reset(config=ResetConfig(task_id=task_id)) |
| 130 | + |
| 131 | + self._state = WAAState( |
| 132 | + episode_id=episode_id or uuid4().hex[:12], |
| 133 | + step_count=0, |
| 134 | + task_id=task_id, |
| 135 | + task_name=tc.name if tc else task_id, |
| 136 | + status="running", |
| 137 | + ) |
| 138 | + |
| 139 | + return self._to_observation(obs) |
| 140 | + |
| 141 | + def step( |
| 142 | + self, |
| 143 | + action: WAAAction, |
| 144 | + timeout_s: float | None = None, |
| 145 | + **kwargs: Any, |
| 146 | + ) -> WAAObservation: |
| 147 | + """Execute an action in the environment. |
| 148 | +
|
| 149 | + Args: |
| 150 | + action: The action to execute. |
| 151 | + timeout_s: Timeout (unused, for OpenEnv compatibility). |
| 152 | +
|
| 153 | + Returns: |
| 154 | + WAAObservation with new screenshot, reward, and done flag. |
| 155 | + """ |
| 156 | + env = self._ensure_rl_env() |
| 157 | + self._state.step_count += 1 |
| 158 | + |
| 159 | + # Handle fractional coordinates |
| 160 | + if action.type == "done": |
| 161 | + from openadapt_evals.adapters.base import BenchmarkAction |
| 162 | + |
| 163 | + step_result = env.step(BenchmarkAction(type="done")) |
| 164 | + elif ( |
| 165 | + action.x is not None |
| 166 | + and action.y is not None |
| 167 | + and 0 <= action.x <= 1 |
| 168 | + and 0 <= action.y <= 1 |
| 169 | + ): |
| 170 | + step_result = env.pixel_action( |
| 171 | + x_frac=action.x, |
| 172 | + y_frac=action.y, |
| 173 | + action_type=action.type, |
| 174 | + text=action.text, |
| 175 | + key=action.key, |
| 176 | + ) |
| 177 | + else: |
| 178 | + from openadapt_evals.adapters.base import BenchmarkAction |
| 179 | + |
| 180 | + step_result = env.step( |
| 181 | + BenchmarkAction( |
| 182 | + type=action.type, |
| 183 | + x=action.x, |
| 184 | + y=action.y, |
| 185 | + text=action.text, |
| 186 | + key=action.key, |
| 187 | + ) |
| 188 | + ) |
| 189 | + |
| 190 | + done = step_result.done or self._state.step_count >= self._max_steps |
| 191 | + |
| 192 | + # Compute reward at episode end |
| 193 | + reward = None |
| 194 | + if done: |
| 195 | + try: |
| 196 | + reward = env.evaluate_dense() |
| 197 | + self._state.score = reward |
| 198 | + self._state.status = "completed" |
| 199 | + |
| 200 | + # Update milestone info |
| 201 | + last_info = step_result.info or {} |
| 202 | + self._state.milestones_passed = last_info.get("milestones_passed", 0) |
| 203 | + self._state.milestones_total = last_info.get("milestones_total", 0) |
| 204 | + except Exception as exc: |
| 205 | + logger.warning("Evaluation failed: %s", exc) |
| 206 | + reward = 0.0 |
| 207 | + self._state.status = "failed" |
| 208 | + |
| 209 | + self._state.done = done |
| 210 | + return self._to_observation(step_result.observation, reward=reward, done=done) |
| 211 | + |
| 212 | + @property |
| 213 | + def state(self) -> WAAState: |
| 214 | + """Current environment state.""" |
| 215 | + return self._state |
| 216 | + |
| 217 | + def close(self) -> None: |
| 218 | + """Clean up resources.""" |
| 219 | + self._rl_env = None |
| 220 | + |
| 221 | + def _to_observation( |
| 222 | + self, |
| 223 | + obs: Any, |
| 224 | + reward: float | None = None, |
| 225 | + done: bool = False, |
| 226 | + ) -> WAAObservation: |
| 227 | + """Convert a BenchmarkObservation to WAAObservation.""" |
| 228 | + screenshot_b64 = None |
| 229 | + if obs.screenshot: |
| 230 | + screenshot_b64 = base64.b64encode(obs.screenshot).decode() |
| 231 | + |
| 232 | + a11y = None |
| 233 | + if obs.accessibility_tree: |
| 234 | + a11y = str(obs.accessibility_tree)[:10000] # cap for transport |
| 235 | + |
| 236 | + return WAAObservation( |
| 237 | + screenshot_b64=screenshot_b64, |
| 238 | + accessibility_tree=a11y, |
| 239 | + window_title=getattr(obs, "window_title", None), |
| 240 | + step_index=self._state.step_count, |
| 241 | + done=done, |
| 242 | + reward=reward, |
| 243 | + ) |
0 commit comments