Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 49 additions & 16 deletions examples/tools/computer_use.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import asyncio
import base64
import sys
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any, Literal

from playwright.async_api import Browser, Page, Playwright, async_playwright
Expand Down Expand Up @@ -118,46 +120,77 @@ async def screenshot(self) -> str:
png_bytes = await self.page.screenshot(full_page=False)
return base64.b64encode(png_bytes).decode("utf-8")

async def click(self, x: int, y: int, button: Button = "left") -> None:
def _normalize_keys(self, keys: list[str] | None) -> list[str]:
if not keys:
return []
return [CUA_KEY_TO_PLAYWRIGHT_KEY.get(key.lower(), key) for key in keys]

@asynccontextmanager
async def _hold_keys(self, keys: list[str] | None) -> AsyncIterator[None]:
mapped_keys = self._normalize_keys(keys)
try:
for key in mapped_keys:
await self.page.keyboard.down(key)
yield
finally:
for key in reversed(mapped_keys):
await self.page.keyboard.up(key)

async def click(
self, x: int, y: int, button: Button = "left", *, keys: list[str] | None = None
) -> None:
playwright_button: Literal["left", "middle", "right"] = "left"

# Playwright only supports left, middle, right buttons
if button in ("left", "right", "middle"):
playwright_button = button # type: ignore

await self.page.mouse.click(x, y, button=playwright_button)
async with self._hold_keys(keys):
await self.page.mouse.click(x, y, button=playwright_button)

async def double_click(self, x: int, y: int) -> None:
await self.page.mouse.dblclick(x, y)
async def double_click(self, x: int, y: int, *, keys: list[str] | None = None) -> None:
async with self._hold_keys(keys):
await self.page.mouse.dblclick(x, y)

async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
await self.page.mouse.move(x, y)
await self.page.evaluate(f"window.scrollBy({scroll_x}, {scroll_y})")
async def scroll(
self,
x: int,
y: int,
scroll_x: int,
scroll_y: int,
*,
keys: list[str] | None = None,
) -> None:
async with self._hold_keys(keys):
await self.page.mouse.move(x, y)
await self.page.evaluate(f"window.scrollBy({scroll_x}, {scroll_y})")

async def type(self, text: str) -> None:
await self.page.keyboard.type(text)

async def wait(self) -> None:
await asyncio.sleep(1)

async def move(self, x: int, y: int) -> None:
await self.page.mouse.move(x, y)
async def move(self, x: int, y: int, *, keys: list[str] | None = None) -> None:
async with self._hold_keys(keys):
await self.page.mouse.move(x, y)

async def keypress(self, keys: list[str]) -> None:
mapped_keys = [CUA_KEY_TO_PLAYWRIGHT_KEY.get(key.lower(), key) for key in keys]
mapped_keys = self._normalize_keys(keys)
for key in mapped_keys:
await self.page.keyboard.down(key)
for key in reversed(mapped_keys):
await self.page.keyboard.up(key)

async def drag(self, path: list[tuple[int, int]]) -> None:
async def drag(self, path: list[tuple[int, int]], *, keys: list[str] | None = None) -> None:
if not path:
return
await self.page.mouse.move(path[0][0], path[0][1])
await self.page.mouse.down()
for px, py in path[1:]:
await self.page.mouse.move(px, py)
await self.page.mouse.up()
async with self._hold_keys(keys):
await self.page.mouse.move(path[0][0], path[0][1])
await self.page.mouse.down()
for px, py in path[1:]:
await self.page.mouse.move(px, py)
await self.page.mouse.up()


async def run_agent(
Expand Down
34 changes: 30 additions & 4 deletions src/agents/computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@


class Computer(abc.ABC):
"""A computer implemented with sync operations. The Computer interface abstracts the
operations needed to control a computer or browser."""
"""A computer implemented with sync operations.

Subclasses provide the local runtime behind `ComputerTool`. Mouse action methods may
also accept a keyword-only `keys` argument to receive held modifier keys when the
driver supports them.
"""

@property
def environment(self) -> Environment | None:
Expand All @@ -21,44 +25,57 @@ def dimensions(self) -> tuple[int, int] | None:

@abc.abstractmethod
def screenshot(self) -> str:
"""Return a base64-encoded PNG screenshot of the current display."""
pass

@abc.abstractmethod
def click(self, x: int, y: int, button: Button) -> None:
"""Click `button` at the given `(x, y)` screen coordinates."""
pass

@abc.abstractmethod
def double_click(self, x: int, y: int) -> None:
"""Double-click at the given `(x, y)` screen coordinates."""
pass

@abc.abstractmethod
def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
"""Scroll at `(x, y)` by `(scroll_x, scroll_y)` units."""
pass

@abc.abstractmethod
def type(self, text: str) -> None:
"""Type `text` into the currently focused target."""
pass

@abc.abstractmethod
def wait(self) -> None:
"""Wait until the computer is ready for the next action."""
pass

@abc.abstractmethod
def move(self, x: int, y: int) -> None:
"""Move the mouse cursor to the given `(x, y)` screen coordinates."""
pass

@abc.abstractmethod
def keypress(self, keys: list[str]) -> None:
"""Press the provided keys, such as `["ctrl", "c"]`."""
pass

@abc.abstractmethod
def drag(self, path: list[tuple[int, int]]) -> None:
"""Click-and-drag the mouse along the given sequence of `(x, y)` waypoints."""
pass


class AsyncComputer(abc.ABC):
"""A computer implemented with async operations. The Computer interface abstracts the
operations needed to control a computer or browser."""
"""A computer implemented with async operations.

Subclasses provide the local runtime behind `ComputerTool`. Mouse action methods may
also accept a keyword-only `keys` argument to receive held modifier keys when the
driver supports them.
"""

@property
def environment(self) -> Environment | None:
Expand All @@ -72,36 +89,45 @@ def dimensions(self) -> tuple[int, int] | None:

@abc.abstractmethod
async def screenshot(self) -> str:
"""Return a base64-encoded PNG screenshot of the current display."""
pass

@abc.abstractmethod
async def click(self, x: int, y: int, button: Button) -> None:
"""Click `button` at the given `(x, y)` screen coordinates."""
pass

@abc.abstractmethod
async def double_click(self, x: int, y: int) -> None:
"""Double-click at the given `(x, y)` screen coordinates."""
pass

@abc.abstractmethod
async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
"""Scroll at `(x, y)` by `(scroll_x, scroll_y)` units."""
pass

@abc.abstractmethod
async def type(self, text: str) -> None:
"""Type `text` into the currently focused target."""
pass

@abc.abstractmethod
async def wait(self) -> None:
"""Wait until the computer is ready for the next action."""
pass

@abc.abstractmethod
async def move(self, x: int, y: int) -> None:
"""Move the mouse cursor to the given `(x, y)` screen coordinates."""
pass

@abc.abstractmethod
async def keypress(self, keys: list[str]) -> None:
"""Press the provided keys, such as `["ctrl", "c"]`."""
pass

@abc.abstractmethod
async def drag(self, path: list[tuple[int, int]]) -> None:
"""Click-and-drag the mouse along the given sequence of `(x, y)` waypoints."""
pass
73 changes: 71 additions & 2 deletions src/agents/run_internal/tool_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,30 +189,38 @@ async def _execute_action_and_capture(
) -> str:
"""Execute computer actions (sync or async drivers) and return the final screenshot."""

async def maybe_call(method_name: str, *args: Any) -> Any:
async def maybe_call(method_name: str, *args: Any, **kwargs: Any) -> Any:
method = getattr(computer, method_name, None)
if method is None or not callable(method):
raise ModelBehaviorError(f"Computer driver missing method {method_name}")
result = method(*args)
filtered_kwargs = cls._filter_supported_kwargs(
method_name=method_name,
method=method,
kwargs=kwargs,
)
result = method(*args, **filtered_kwargs)
return await result if inspect.isawaitable(result) else result

last_action_was_screenshot = False
last_screenshot_result: Any = None
for action in cls._iter_actions(tool_call):
action_type = get_mapping_or_attr(action, "type")
action_keys = cls._normalize_modifier_keys(get_mapping_or_attr(action, "keys"))
last_action_was_screenshot = False
if action_type == "click":
await maybe_call(
"click",
get_mapping_or_attr(action, "x"),
get_mapping_or_attr(action, "y"),
get_mapping_or_attr(action, "button"),
keys=action_keys,
)
elif action_type == "double_click":
await maybe_call(
"double_click",
get_mapping_or_attr(action, "x"),
get_mapping_or_attr(action, "y"),
keys=action_keys,
)
elif action_type == "drag":
path = get_mapping_or_attr(action, "path") or []
Expand All @@ -225,6 +233,7 @@ async def maybe_call(method_name: str, *args: Any) -> Any:
)
for point in path
],
keys=action_keys,
)
elif action_type == "keypress":
await maybe_call("keypress", get_mapping_or_attr(action, "keys"))
Expand All @@ -233,6 +242,7 @@ async def maybe_call(method_name: str, *args: Any) -> Any:
"move",
get_mapping_or_attr(action, "x"),
get_mapping_or_attr(action, "y"),
keys=action_keys,
)
elif action_type == "screenshot":
last_screenshot_result = await maybe_call("screenshot")
Expand All @@ -244,6 +254,7 @@ async def maybe_call(method_name: str, *args: Any) -> Any:
get_mapping_or_attr(action, "y"),
get_mapping_or_attr(action, "scroll_x"),
get_mapping_or_attr(action, "scroll_y"),
keys=action_keys,
)
elif action_type == "type":
await maybe_call("type", get_mapping_or_attr(action, "text"))
Expand Down Expand Up @@ -289,6 +300,64 @@ def _serialize_action_payload(action: Any) -> Any:
return dataclasses.asdict(action)
return action

@staticmethod
def _normalize_modifier_keys(keys: Any) -> list[str] | None:
if not keys:
return None
return cast(list[str], keys)

@classmethod
def _filter_supported_kwargs(
cls,
*,
method_name: str,
method: Any,
kwargs: dict[str, Any],
) -> dict[str, Any]:
filtered_kwargs = {key: value for key, value in kwargs.items() if value is not None}
if not filtered_kwargs:
return {}

supported_kwargs = cls._supported_keyword_arguments(method)
unsupported_kwargs = [
key
for key in filtered_kwargs
if key not in supported_kwargs and None not in supported_kwargs
]
if unsupported_kwargs:
logger.warning(
"Computer driver method %r does not accept keyword argument(s) %s; "
"dropping them and continuing.",
method_name,
", ".join(sorted(unsupported_kwargs)),
)
for key in unsupported_kwargs:
filtered_kwargs.pop(key, None)

return filtered_kwargs

@staticmethod
def _supported_keyword_arguments(method: Any) -> set[str | None]:
try:
signature = inspect.signature(method)
except (TypeError, ValueError):
return set()
supported: set[str | None] = {
parameter.name
for parameter in signature.parameters.values()
if parameter.kind
in {
inspect.Parameter.KEYWORD_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
}
}
if any(
parameter.kind == inspect.Parameter.VAR_KEYWORD
for parameter in signature.parameters.values()
):
supported.add(None)
return supported


class LocalShellAction:
"""Execute local shell commands via the LocalShellTool with lifecycle hooks."""
Expand Down
Loading
Loading