diff --git a/AGENTS.md b/AGENTS.md index 8724609508..9122b6aab6 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -85,6 +85,7 @@ The OpenAI Agents Python repository provides the Python Agents SDK, examples, an - `src/agents/run_state.py` (RunState serialization/deserialization) - `src/agents/run_internal/session_persistence.py` (session save/rewind) - If the serialized RunState shape changes, update `CURRENT_SCHEMA_VERSION` in `src/agents/run_state.py` and the related serialization/deserialization logic. Keep released schema versions readable, and feel free to renumber or squash unreleased schema versions before release when those intermediate snapshots are intentionally unsupported. +- When bumping `CURRENT_SCHEMA_VERSION`, also add or update the matching entry in `SCHEMA_VERSION_SUMMARIES` in `src/agents/run_state.py` so every supported version keeps a short historical note describing what changed in that schema. ## Operation Guide diff --git a/examples/sandbox/__init__.py b/examples/sandbox/__init__.py new file mode 100644 index 0000000000..f34898d916 --- /dev/null +++ b/examples/sandbox/__init__.py @@ -0,0 +1 @@ +# Make the examples/sandbox directory a package for tooling consistency. diff --git a/examples/sandbox/basic.py b/examples/sandbox/basic.py new file mode 100644 index 0000000000..2b8fe00e3d --- /dev/null +++ b/examples/sandbox/basic.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +import argparse +import asyncio +import sys +from pathlib import Path +from typing import Any, Literal, cast + +from openai.types.responses import ResponseTextDeltaEvent + +from agents import ModelSettings, Runner +from agents.run import RunConfig +from agents.sandbox import Manifest, SandboxAgent, SandboxRunConfig +from agents.sandbox.entries import File + +if __package__ is None or __package__ == "": + sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from examples.sandbox.misc.workspace_shell import WorkspaceShellCapability + +Backend = Literal["docker", "modal"] +WorkspacePersistenceMode = Literal["tar", "snapshot_filesystem"] + +DEFAULT_QUESTION = "Summarize this sandbox project in 2 sentences." +DEFAULT_BACKEND: Backend = "docker" +DEFAULT_MODAL_APP_NAME = "openai-agents-python-sandbox-example" +DEFAULT_MODAL_WORKSPACE_PERSISTENCE: WorkspacePersistenceMode = "tar" + + +def _stream_event_banner(event_name: str) -> str | None: + if event_name == "tool_called": + return "[tool call] shell" + if event_name == "tool_output": + return "[tool output] shell" + return None + + +def _build_manifest(backend: Backend) -> Manifest: + backend_label = "Docker" if backend == "docker" else "Modal" + return Manifest( + entries={ + "README.md": File( + content=( + b"# Demo Project\n\n" + + ( + f"This sandbox contains a tiny demo project for the {backend_label} " + "sandbox runner.\n" + ).encode() + + b"The goal is to show how Runner can prepare a sandbox workspace.\n" + ) + ), + "src/app.py": File( + content=b'def greet(name: str) -> str:\n return f"Hello, {name}!"\n' + ), + "docs/notes.md": File( + content=( + b"# Notes\n\n" + b"- The example is intentionally minimal.\n" + b"- The model should inspect files through the shell tool.\n" + ) + ), + } + ) + + +def _build_agent(*, model: str, manifest: Manifest, backend: Backend) -> SandboxAgent: + backend_label = "Docker" if backend == "docker" else "Modal" + return SandboxAgent( + name=f"{backend_label} Sandbox Assistant", + model=model, + # `instructions` is the base agent instructions for this example's task. + instructions=( + "Answer questions about the sandbox workspace. Inspect the project before answering, " + "and keep the response concise." + ), + # `developer_instructions` is appended after that as additional deterministic instructions. + # Here, the tiny-workspace constraint is kept in `developer_instructions`. + developer_instructions=( + "Do not guess file names like package.json or pyproject.toml. " + "This demo intentionally contains a tiny workspace." + ), + # `default_manifest` tells the sandbox agent which workspace it should expect. + default_manifest=manifest, + # `WorkspaceShellCapability()` exposes one shell tool so the model can inspect files. + capabilities=[WorkspaceShellCapability()], + # `tool_choice="required"` makes the demo more deterministic by forcing the model + # to look at the workspace instead of answering from prior assumptions. + model_settings=ModelSettings(tool_choice="required"), + ) + + +def _require_modal_dependency() -> tuple[Any, Any]: + try: + from agents.extensions.sandbox import ModalSandboxClient, ModalSandboxClientOptions + except Exception as exc: # pragma: no cover - import path depends on optional extras + raise SystemExit( + "Modal-backed runs require the optional repo extra.\n" + "Install it with: uv sync --extra modal" + ) from exc + + return ModalSandboxClient, ModalSandboxClientOptions + + +def _require_docker_dependency() -> tuple[Any, Any, Any]: + try: + from docker import from_env as docker_from_env # type: ignore[import-untyped] + except Exception as exc: # pragma: no cover - import path depends on local Docker setup + raise SystemExit( + "Docker-backed runs require the Docker SDK.\n" + "Install the repo dependencies with: make sync" + ) from exc + + from agents.sandbox.sandboxes.docker import DockerSandboxClient, DockerSandboxClientOptions + + return docker_from_env, DockerSandboxClient, DockerSandboxClientOptions + + +async def _create_session( + *, + backend: Backend, + manifest: Manifest, + agent: SandboxAgent, +): + if backend == "docker": + docker_from_env, DockerSandboxClient, DockerSandboxClientOptions = ( + _require_docker_dependency() + ) + client = DockerSandboxClient(docker_from_env()) + session = await client.create( + manifest=manifest, + codex=agent.codex, + options=DockerSandboxClientOptions(image="python:3.14-slim"), + ) + return client, session + + ModalSandboxClient, ModalSandboxClientOptions = _require_modal_dependency() + client = ModalSandboxClient() + session = await client.create( + manifest=manifest, + codex=agent.codex, + options=ModalSandboxClientOptions( + app_name=DEFAULT_MODAL_APP_NAME, + workspace_persistence=DEFAULT_MODAL_WORKSPACE_PERSISTENCE, + ), + ) + return client, session + + +async def main( + model: str, + question: str, + backend: Backend, +) -> None: + manifest = _build_manifest(backend) + agent = _build_agent(model=model, manifest=manifest, backend=backend) + client, session = await _create_session( + backend=backend, + manifest=manifest, + agent=agent, + ) + + await session.start() + print(await session.ls(".codex_bin/codex")) + + try: + # `async with session` keeps the example on the public session lifecycle API. + # `Runner` reuses the already-running session without starting it a second time. + async with session: + # `Runner.run_streamed()` drives the model and yields text and tool events in real time. + result = Runner.run_streamed( + agent, + question, + run_config=RunConfig( + sandbox=SandboxRunConfig(session=session), + workflow_name=f"{backend.title()} sandbox example", + ), + ) + saw_text_delta = False + saw_any_text = False + + # The stream contains raw text deltas from the assistant plus structured tool events. + async for event in result.stream_events(): + if event.type == "raw_response_event" and isinstance( + event.data, ResponseTextDeltaEvent + ): + if not saw_text_delta: + print("assistant> ", end="", flush=True) + saw_text_delta = True + print(event.data.delta, end="", flush=True) + saw_any_text = True + continue + + if event.type != "run_item_stream_event": + continue + + banner = _stream_event_banner(event.name) + if banner is not None: + if saw_text_delta: + print() + saw_text_delta = False + print(banner) + + if saw_text_delta: + print() + if not saw_any_text: + print(result.final_output) + finally: + await client.delete(session) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", default="gpt-5.4", help="Model name to use.") + parser.add_argument("--question", default=DEFAULT_QUESTION, help="Prompt to send to the agent.") + parser.add_argument( + "--backend", + default=DEFAULT_BACKEND, + choices=["docker", "modal"], + help="Sandbox backend to use for this example.", + ) + args = parser.parse_args() + asyncio.run( + main( + args.model, + args.question, + cast(Backend, args.backend), + ) + ) diff --git a/examples/sandbox/data/f1040.pdf b/examples/sandbox/data/f1040.pdf new file mode 100644 index 0000000000..77556e80ec Binary files /dev/null and b/examples/sandbox/data/f1040.pdf differ diff --git a/examples/sandbox/data/sample_w2.pdf b/examples/sandbox/data/sample_w2.pdf new file mode 100644 index 0000000000..ecc05d994b Binary files /dev/null and b/examples/sandbox/data/sample_w2.pdf differ diff --git a/examples/sandbox/extensions/README.md b/examples/sandbox/extensions/README.md new file mode 100644 index 0000000000..2816fc3302 --- /dev/null +++ b/examples/sandbox/extensions/README.md @@ -0,0 +1,125 @@ +# Cloud Sandbox Extension Examples + +These examples are for manual verification of the cloud sandbox backends that +live under `agents.extensions.sandbox`. + +They intentionally keep the flow simple: + +1. Build a tiny manifest in memory. +2. Create a `SandboxAgent` that inspects that workspace through one shell tool. +3. Run the agent against either E2B or Modal. + +Both examples require `OPENAI_API_KEY`, because they call the model through the +normal `Runner` path. + +## E2B + +### Setup + +Install the repo extra: + +```bash +uv sync --extra e2b +``` + +Create an E2B account, create an API key, and export it as `E2B_API_KEY`. +The official setup docs are: + +- +- + +Export the required environment variables: + +```bash +export OPENAI_API_KEY=... +export E2B_API_KEY=... +``` + +### Run + +```bash +uv run python examples/sandbox/extensions/e2b_runner.py --stream +``` + +Useful flags: + +- `--sandbox-type e2b_code_interpreter_async` +- `--template ` +- `--timeout 300` +- `--pause-on-exit` + +The example defaults to `e2b_code_interpreter_async`, which matches the async +Code Interpreter backend supported by this repo. + +## Modal + +If you want the same explicit session lifecycle shown in +`examples/sandbox/basic.py`, that example now accepts +`--backend modal` and reuses the same streamed tool-output flow: + +```bash +uv run python examples/sandbox/basic.py \ + --backend modal +``` + +The dedicated script below stays as the smaller extension-specific example. + +### Setup + +Install the repo extra: + +```bash +uv sync --extra modal +``` + +Authenticate Modal with either CLI token setup or environment variables. The +official references are: + +- +- +- + +If you want to configure credentials directly from the CLI: + +```bash +uv run modal token set --token-id --token-secret +``` + +Or export environment variables for the current shell: + +```bash +export OPENAI_API_KEY=... +export MODAL_TOKEN_ID=... +export MODAL_TOKEN_SECRET=... +``` + +### Run + +```bash +uv run python examples/sandbox/extensions/modal_runner.py \ + --app-name openai-agents-python-sandbox-example \ + --stream +``` + +Useful flags: + +- `--workspace-persistence tar` +- `--workspace-persistence snapshot_filesystem` +- `--sandbox-create-timeout-s 60` + +`app_name` is required by `ModalSandboxClientOptions`, so the example makes it +an explicit CLI flag instead of hiding it. + +## What to expect + +Each script asks the model to inspect a small workspace and summarize it. A +successful run should: + +1. Start the chosen cloud sandbox backend. +2. Materialize the manifest into the sandbox workspace. +3. Call the shell tool at least once. +4. Print either streamed text or a final short answer about the workspace. + +These examples are not live-validated in CI because they depend on external +cloud credentials, but they are shaped so contributors can verify backend +behavior locally with one command per provider. diff --git a/examples/sandbox/extensions/__init__.py b/examples/sandbox/extensions/__init__.py new file mode 100644 index 0000000000..fb3e80a2d0 --- /dev/null +++ b/examples/sandbox/extensions/__init__.py @@ -0,0 +1 @@ +"""Manual validation examples for cloud sandbox extensions.""" diff --git a/examples/sandbox/extensions/e2b_runner.py b/examples/sandbox/extensions/e2b_runner.py new file mode 100644 index 0000000000..973511cb4f --- /dev/null +++ b/examples/sandbox/extensions/e2b_runner.py @@ -0,0 +1,209 @@ +""" +Minimal E2B-backed sandbox example for manual validation. + +This example is intentionally small: it creates a tiny workspace, lets the +agent inspect it through one shell tool, and prints a short answer. +""" + +import argparse +import asyncio +import os +import sys +from pathlib import Path +from typing import Literal + +from openai.types.responses import ResponseTextDeltaEvent + +from agents import ModelSettings, Runner +from agents.run import RunConfig +from agents.sandbox import Manifest, SandboxAgent, SandboxRunConfig + +if __package__ is None or __package__ == "": + sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + +from examples.sandbox.misc.example_support import text_manifest +from examples.sandbox.misc.workspace_shell import WorkspaceShellCapability + +try: + from agents.extensions.sandbox import ( + E2BSandboxClient, + E2BSandboxClientOptions, + E2BSandboxType, + ) +except Exception as exc: # pragma: no cover - import path depends on optional extras + raise SystemExit( + "E2B sandbox examples require the optional repo extra.\n" + "Install it with: uv sync --extra e2b" + ) from exc + + +DEFAULT_QUESTION = "Summarize this cloud sandbox workspace in 2 sentences." +DEFAULT_SANDBOX_TYPE = E2BSandboxType.E2B_ASYNC.value + + +def _build_manifest() -> Manifest: + return text_manifest( + { + "README.md": ( + "# Renewal Notes\n\n" + "This workspace contains a tiny account review packet for manual sandbox testing.\n" + ), + "customer.md": ( + "# Customer\n\n" + "- Name: Northwind Health.\n" + "- Renewal date: 2026-04-15.\n" + "- Risk: unresolved SSO setup.\n" + ), + "next_steps.md": ( + "# Next steps\n\n" + "1. Finish the SSO fix.\n" + "2. Confirm legal language before procurement review.\n" + ), + } + ) + + +def _require_env(name: str) -> None: + if os.environ.get(name): + return + raise SystemExit(f"{name} must be set before running this example.") + + +def _rewrite_template_resolution_error(exc: Exception) -> None: + message = str(exc) + marker = "error resolving template '" + if marker not in message: + return + template = message.split(marker, 1)[1].split("'", 1)[0] + raise SystemExit( + f"E2B could not resolve template `{template}`.\n" + "Pass `--template ` with a template that exists for this E2B account/team. " + "If you were relying on the example default, the SDK default template for this backend is " + "not available in your current E2B environment." + ) from exc + + +async def main( + *, + model: str, + question: str, + sandbox_type: Literal[ + "e2b_code_interpreter_async", + "e2b_code_interpreter", + "e2b_async", + "e2b", + ], + template: str | None, + timeout: int | None, + pause_on_exit: bool, + stream: bool, +) -> None: + _require_env("OPENAI_API_KEY") + _require_env("E2B_API_KEY") + + manifest = _build_manifest() + agent = SandboxAgent( + name="E2B Sandbox Assistant", + model=model, + # `instructions` is the base agent instructions for this sandbox task. + instructions=( + "Answer questions about the sandbox workspace. Inspect the files before answering " + "and keep the response concise." + ), + # `developer_instructions` is appended after that as additional deterministic instructions. + # Here, the grounding constraints are kept in `developer_instructions`. + developer_instructions=( + "Do not invent files or statuses that are not present in the workspace. Cite the " + "file names you inspected." + ), + default_manifest=manifest, + capabilities=[WorkspaceShellCapability()], + model_settings=ModelSettings(tool_choice="required"), + ) + + run_config = RunConfig( + sandbox=SandboxRunConfig( + client=E2BSandboxClient(), + options=E2BSandboxClientOptions( + sandbox_type=E2BSandboxType(sandbox_type), + template=template, + timeout=timeout, + pause_on_exit=pause_on_exit, + ), + ), + workflow_name="E2B sandbox example", + ) + + if not stream: + try: + result = await Runner.run(agent, question, run_config=run_config) + except Exception as exc: + _rewrite_template_resolution_error(exc) + raise + print(result.final_output) + return + + try: + stream_result = Runner.run_streamed(agent, question, run_config=run_config) + except Exception as exc: + _rewrite_template_resolution_error(exc) + raise + saw_text_delta = False + try: + async for event in stream_result.stream_events(): + if event.type == "raw_response_event" and isinstance( + event.data, ResponseTextDeltaEvent + ): + if not saw_text_delta: + print("assistant> ", end="", flush=True) + saw_text_delta = True + print(event.data.delta, end="", flush=True) + except Exception as exc: + _rewrite_template_resolution_error(exc) + raise + + if saw_text_delta: + print() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", default="gpt-5.4", help="Model name to use.") + parser.add_argument("--question", default=DEFAULT_QUESTION, help="Prompt to send to the agent.") + parser.add_argument( + "--sandbox-type", + default=DEFAULT_SANDBOX_TYPE, + choices=[member.value for member in E2BSandboxType], + help=( + "E2B sandbox implementation to create. Defaults to the generic async sandbox because " + "some installed code-interpreter SDK versions still request the legacy " + "`code-interpreter-v1` template." + ), + ) + parser.add_argument("--template", default=None, help="Optional E2B template name.") + parser.add_argument( + "--timeout", + type=int, + default=300, + help="Optional E2B sandbox timeout in seconds.", + ) + parser.add_argument( + "--pause-on-exit", + action="store_true", + default=False, + help="Pause the sandbox on shutdown instead of killing it.", + ) + parser.add_argument("--stream", action="store_true", default=False, help="Stream the response.") + args = parser.parse_args() + + asyncio.run( + main( + model=args.model, + question=args.question, + sandbox_type=args.sandbox_type, + template=args.template, + timeout=args.timeout, + pause_on_exit=args.pause_on_exit, + stream=args.stream, + ) + ) diff --git a/examples/sandbox/extensions/modal_runner.py b/examples/sandbox/extensions/modal_runner.py new file mode 100644 index 0000000000..b9bec2cf42 --- /dev/null +++ b/examples/sandbox/extensions/modal_runner.py @@ -0,0 +1,161 @@ +""" +Minimal Modal-backed sandbox example for manual validation. + +This example mirrors the local and Docker sandbox demos, but it sends the +workspace to a Modal sandbox. +""" + +import argparse +import asyncio +import os +import sys +from pathlib import Path +from typing import Literal + +from openai.types.responses import ResponseTextDeltaEvent + +from agents import ModelSettings, Runner +from agents.run import RunConfig +from agents.sandbox import Manifest, SandboxAgent, SandboxRunConfig + +if __package__ is None or __package__ == "": + sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + +from examples.sandbox.misc.example_support import text_manifest +from examples.sandbox.misc.workspace_shell import WorkspaceShellCapability + +try: + from agents.extensions.sandbox import ModalSandboxClient, ModalSandboxClientOptions +except Exception as exc: # pragma: no cover - import path depends on optional extras + raise SystemExit( + "Modal sandbox examples require the optional repo extra.\n" + "Install it with: uv sync --extra modal" + ) from exc + + +DEFAULT_QUESTION = "Summarize this cloud sandbox workspace in 2 sentences." + + +def _build_manifest() -> Manifest: + return text_manifest( + { + "README.md": ( + "# Modal Demo Workspace\n\n" + "This workspace exists to validate the Modal sandbox backend manually.\n" + ), + "incident.md": ( + "# Incident\n\n" + "- Customer: Fabrikam Retail.\n" + "- Issue: delayed reporting rollout.\n" + "- Primary blocker: incomplete security questionnaire.\n" + ), + "plan.md": ( + "# Plan\n\n" + "1. Close the questionnaire.\n" + "2. Reconfirm the rollout date with the customer.\n" + ), + } + ) + + +def _require_env(name: str) -> None: + if os.environ.get(name): + return + raise SystemExit(f"{name} must be set before running this example.") + + +async def main( + *, + model: str, + question: str, + app_name: str, + workspace_persistence: Literal["tar", "snapshot_filesystem"], + sandbox_create_timeout_s: float | None, + stream: bool, +) -> None: + _require_env("OPENAI_API_KEY") + + manifest = _build_manifest() + agent = SandboxAgent( + name="Modal Sandbox Assistant", + model=model, + # `instructions` is the base agent instructions for this sandbox task. + instructions=( + "Answer questions about the sandbox workspace. Inspect the files before answering " + "and keep the response concise." + ), + # `developer_instructions` is appended after that as additional deterministic instructions. + # Here, the grounding constraints are kept in `developer_instructions`. + developer_instructions=( + "Do not invent files or statuses that are not present in the workspace. Cite the " + "file names you inspected." + ), + default_manifest=manifest, + capabilities=[WorkspaceShellCapability()], + model_settings=ModelSettings(tool_choice="required"), + ) + + run_config = RunConfig( + sandbox=SandboxRunConfig( + client=ModalSandboxClient(), + options=ModalSandboxClientOptions( + app_name=app_name, + workspace_persistence=workspace_persistence, + sandbox_create_timeout_s=sandbox_create_timeout_s, + ), + ), + workflow_name="Modal sandbox example", + ) + + if not stream: + result = await Runner.run(agent, question, run_config=run_config) + print(result.final_output) + return + + stream_result = Runner.run_streamed(agent, question, run_config=run_config) + saw_text_delta = False + async for event in stream_result.stream_events(): + if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent): + if not saw_text_delta: + print("assistant> ", end="", flush=True) + saw_text_delta = True + print(event.data.delta, end="", flush=True) + + if saw_text_delta: + print() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", default="gpt-5.4", help="Model name to use.") + parser.add_argument("--question", default=DEFAULT_QUESTION, help="Prompt to send to the agent.") + parser.add_argument( + "--app-name", + default="openai-agents-python-sandbox-example", + help="Modal app name to create or reuse for the sandbox.", + ) + parser.add_argument( + "--workspace-persistence", + default="tar", + choices=["tar", "snapshot_filesystem"], + help="Workspace persistence mode for the Modal sandbox.", + ) + parser.add_argument( + "--sandbox-create-timeout-s", + type=float, + default=None, + help="Optional timeout for creating the Modal sandbox.", + ) + parser.add_argument("--stream", action="store_true", default=False, help="Stream the response.") + args = parser.parse_args() + + asyncio.run( + main( + model=args.model, + question=args.question, + app_name=args.app_name, + workspace_persistence=args.workspace_persistence, + sandbox_create_timeout_s=args.sandbox_create_timeout_s, + stream=args.stream, + ) + ) diff --git a/examples/sandbox/handoffs.py b/examples/sandbox/handoffs.py new file mode 100644 index 0000000000..b7f1dc8864 --- /dev/null +++ b/examples/sandbox/handoffs.py @@ -0,0 +1,107 @@ +""" +Show how a non-sandbox agent can hand work to a sandbox agent. + +The intake agent never sees a workspace directly. It hands document-heavy work +to a sandbox reviewer, and that reviewer then hands the synthesized result to a +plain account-facing writer. +""" + +import argparse +import asyncio +import sys +from pathlib import Path + +from agents import Agent, Runner +from agents.run import RunConfig +from agents.sandbox import SandboxAgent, SandboxRunConfig +from agents.sandbox.sandboxes.unix_local import UnixLocalSandboxClient + +if __package__ is None or __package__ == "": + sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from examples.sandbox.misc.example_support import text_manifest +from examples.sandbox.misc.workspace_shell import WorkspaceShellCapability + +DEFAULT_QUESTION = ( + "Review the attached onboarding packet and draft a short internal note for the account " + "executive about what to confirm before kickoff." +) + + +async def main(model: str, question: str) -> None: + # The manifest becomes the workspace that only the sandbox reviewer can inspect. + manifest = text_manifest( + { + "customer_background.md": ( + "# Customer background\n\n" + "- Customer: Bluebird Logistics.\n" + "- Region: North America.\n" + "- New purchase: analytics workspace plus SSO.\n" + ), + "kickoff_checklist.md": ( + "# Kickoff checklist\n\n" + "- Security questionnaire is still in review.\n" + "- Two customer admins still need to complete access training.\n" + "- Target kickoff date is next Tuesday.\n" + ), + "implementation_scope.md": ( + "# Implementation scope\n\n" + "- The customer wants historical data migration for 5 years of records.\n" + "- Data engineering support is available only starting next month.\n" + ), + } + ) + + # This final agent does not inspect files. It only rewrites reviewed facts into a note. + account_manager = Agent( + name="Account Executive Assistant", + model=model, + instructions=( + "You write concise internal updates for account teams. Convert the sandbox review " + "into a short note with a headline, the top risks, and a recommended next step." + ), + ) + + # This sandbox agent can inspect the workspace, then hand its findings to the writer above. + sandbox_reviewer = SandboxAgent( + name="Onboarding Packet Reviewer", + model=model, + # `instructions` is the base agent instructions for the review-and-handoff task. + instructions=( + "You inspect onboarding documents in the sandbox, verify the facts, then hand off " + "to the account executive assistant to draft the final note." + ), + # `developer_instructions` is appended after that as additional deterministic instructions. + # Here, "do not answer directly" is kept in `developer_instructions`. + developer_instructions="Do not answer the user directly after reviewing the packet.", + default_manifest=manifest, + handoffs=[account_manager], + capabilities=[WorkspaceShellCapability()], + ) + + # The starting agent is a normal agent. It only decides when to hand off into the sandbox. + intake_agent = Agent( + name="Deal Desk Intake", + model=model, + instructions=( + "You triage internal requests. If a request depends on attached documents, hand off " + "to the onboarding packet reviewer immediately." + ), + handoffs=[sandbox_reviewer], + ) + + result = await Runner.run( + intake_agent, + question, + run_config=RunConfig(sandbox=SandboxRunConfig(client=UnixLocalSandboxClient())), + ) + print(result.final_output) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", default="gpt-5.4", help="Model name to use.") + parser.add_argument("--question", default=DEFAULT_QUESTION, help="Prompt to send to the agent.") + args = parser.parse_args() + + asyncio.run(main(args.model, args.question)) diff --git a/examples/sandbox/misc/__init__.py b/examples/sandbox/misc/__init__.py new file mode 100644 index 0000000000..8a5a5231df --- /dev/null +++ b/examples/sandbox/misc/__init__.py @@ -0,0 +1 @@ +# Shared support code for sandbox examples. diff --git a/examples/sandbox/misc/example_support.py b/examples/sandbox/misc/example_support.py new file mode 100644 index 0000000000..0f6a1bb04a --- /dev/null +++ b/examples/sandbox/misc/example_support.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from collections.abc import Mapping + +from agents.sandbox import Manifest +from agents.sandbox.entries import File + + +def text_manifest(files: Mapping[str, str]) -> Manifest: + """Build a manifest from in-memory UTF-8 text files.""" + + return Manifest( + entries={path: File(content=contents.encode("utf-8")) for path, contents in files.items()} + ) + + +def tool_call_name(raw_item: object) -> str: + """Return a readable name for a raw tool call item.""" + + if isinstance(raw_item, dict): + name = raw_item.get("name") + item_type = raw_item.get("type") + else: + name = getattr(raw_item, "name", None) + item_type = getattr(raw_item, "type", None) + + if isinstance(name, str) and name: + return name + if item_type == "shell_call": + return "shell" + if isinstance(item_type, str): + return item_type + return "" diff --git a/examples/sandbox/misc/reference_policy_mcp_server.py b/examples/sandbox/misc/reference_policy_mcp_server.py new file mode 100644 index 0000000000..0e6486d575 --- /dev/null +++ b/examples/sandbox/misc/reference_policy_mcp_server.py @@ -0,0 +1,25 @@ +from mcp.server.fastmcp import FastMCP + +mcp = FastMCP("Reference Policy Server") + + +@mcp.tool() +def get_policy_reference(topic: str) -> str: + """Return short internal policy guidance for a supported topic.""" + normalized = topic.strip().lower() + if "discount" in normalized: + return ( + "Discount policy: discounts from 11 to 15 percent require regional sales director " + "approval. Discounts above 15 percent require both finance and the regional sales " + "director." + ) + if "security" in normalized or "review" in normalized: + return ( + "Security review policy: any new data export workflow must finish security review " + "before kickoff or production access." + ) + return "No policy reference is available for that topic in this demo." + + +if __name__ == "__main__": + mcp.run() diff --git a/examples/sandbox/misc/workspace_apply_patch.py b/examples/sandbox/misc/workspace_apply_patch.py new file mode 100644 index 0000000000..850dc526af --- /dev/null +++ b/examples/sandbox/misc/workspace_apply_patch.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import io +from pathlib import Path + +from agents import ApplyPatchTool, apply_diff +from agents.editor import ApplyPatchOperation, ApplyPatchResult +from agents.sandbox import Capability, Manifest +from agents.sandbox.session.base_sandbox_session import BaseSandboxSession +from agents.tool import Tool + + +def _read_text(handle: io.IOBase) -> str: + payload = handle.read() + if isinstance(payload, str): + return payload + if isinstance(payload, (bytes, bytearray)): + return bytes(payload).decode("utf-8", errors="replace") + return str(payload) + + +class _SandboxWorkspaceEditor: + def __init__(self, session: BaseSandboxSession) -> None: + self._session = session + + async def create_file(self, operation: ApplyPatchOperation) -> ApplyPatchResult: + target = self._resolve_path(operation.path) + content = apply_diff("", operation.diff or "", mode="create") + await self._session.mkdir(target.parent, parents=True) + await self._session.write(target, io.BytesIO(content.encode("utf-8"))) + return ApplyPatchResult(output=f"Created {self._display_path(target)}") + + async def update_file(self, operation: ApplyPatchOperation) -> ApplyPatchResult: + target = self._resolve_path(operation.path) + handle = await self._session.read(target) + try: + original = _read_text(handle) + finally: + handle.close() + updated = apply_diff(original, operation.diff or "") + await self._session.write(target, io.BytesIO(updated.encode("utf-8"))) + return ApplyPatchResult(output=f"Updated {self._display_path(target)}") + + async def delete_file(self, operation: ApplyPatchOperation) -> ApplyPatchResult: + target = self._resolve_path(operation.path) + await self._session.rm(target) + return ApplyPatchResult(output=f"Deleted {self._display_path(target)}") + + def _resolve_path(self, raw_path: str) -> Path: + return self._session.normalize_path(raw_path) + + def _display_path(self, path: Path) -> str: + root = Path(self._session.state.manifest.root) + return path.relative_to(root).as_posix() + + +class WorkspaceApplyPatchCapability(Capability): + """Expose the hosted apply_patch tool against the active sandbox workspace.""" + + def __init__(self) -> None: + super().__init__(type="workspace_apply_patch") + self._session: BaseSandboxSession | None = None + + def bind(self, session: BaseSandboxSession) -> None: + self._session = session + + def tools(self) -> list[Tool]: + if self._session is None: + return [] + return [ApplyPatchTool(editor=_SandboxWorkspaceEditor(self._session))] + + async def instructions(self, manifest: Manifest) -> str | None: + _ = manifest + return ( + "Use the `apply_patch` tool for workspace text edits when you need to create or " + "update files inside the sandbox. Prefer saving final outputs in the requested " + "workspace directories instead of describing edits without writing them." + ) diff --git a/examples/sandbox/misc/workspace_shell.py b/examples/sandbox/misc/workspace_shell.py new file mode 100644 index 0000000000..766167a535 --- /dev/null +++ b/examples/sandbox/misc/workspace_shell.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from agents.sandbox import Capability, Manifest +from agents.sandbox.session.base_sandbox_session import BaseSandboxSession +from agents.tool import ( + ShellCallOutcome, + ShellCommandOutput, + ShellCommandRequest, + ShellResult, + ShellTool, + Tool, +) + + +class WorkspaceShellCapability(Capability): + """Expose one shell tool for inspecting the active sandbox workspace.""" + + def __init__(self) -> None: + super().__init__(type="workspace_shell") + self._session: BaseSandboxSession | None = None + + def bind(self, session: BaseSandboxSession) -> None: + self._session = session + + def tools(self) -> list[Tool]: + return [ShellTool(executor=self._execute_shell)] + + async def instructions(self, manifest: Manifest) -> str | None: + _ = manifest + return ( + "Use the `shell` tool to inspect the sandbox workspace before answering. " + "The workspace root is the current working directory, so prefer relative paths " + "with commands like `pwd`, `find .`, and `cat`. Only cite files you actually read." + ) + + async def _execute_shell(self, request: ShellCommandRequest) -> ShellResult: + if self._session is None: + raise RuntimeError("Workspace shell is not bound to a sandbox session.") + + timeout_s = ( + request.data.action.timeout_ms / 1000 + if request.data.action.timeout_ms is not None + else None + ) + outputs: list[ShellCommandOutput] = [] + for command in request.data.action.commands: + result = await self._session.exec(command, timeout=timeout_s, shell=True) + outputs.append( + ShellCommandOutput( + command=command, + stdout=result.stdout.decode("utf-8", errors="replace"), + stderr=result.stderr.decode("utf-8", errors="replace"), + outcome=ShellCallOutcome(type="exit", exit_code=result.exit_code), + ) + ) + return ShellResult(output=outputs) diff --git a/examples/sandbox/sandbox_agent_with_tools.py b/examples/sandbox/sandbox_agent_with_tools.py new file mode 100644 index 0000000000..2e7987f8d4 --- /dev/null +++ b/examples/sandbox/sandbox_agent_with_tools.py @@ -0,0 +1,121 @@ +""" +Show how a sandbox agent can combine three tool sources in one run. + +This example gives the model: + +1. A sandbox workspace to inspect with the shared shell capability. +2. A normal local function tool for approval routing. +3. A local stdio MCP server for reference policy lookups. +""" + +import argparse +import asyncio +import sys +from pathlib import Path + +from agents import Runner, function_tool +from agents.mcp import MCPServerStdio +from agents.run import RunConfig +from agents.sandbox import SandboxAgent, SandboxRunConfig +from agents.sandbox.sandboxes.unix_local import UnixLocalSandboxClient + +if __package__ is None or __package__ == "": + sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from examples.sandbox.misc.example_support import text_manifest, tool_call_name +from examples.sandbox.misc.workspace_shell import WorkspaceShellCapability + +DEFAULT_QUESTION = ( + "Review this enterprise renewal request. Tell me who needs to approve the discount, " + "whether security review is still open, and the most important note for the account team. " + "Confirm the approval and security answers against the reference policy server before you respond." +) + + +@function_tool +def get_discount_approval_path(discount_percent: int) -> str: + """Return the approver required for a proposed discount percentage.""" + if discount_percent <= 10: + return "The account executive can approve discounts up to 10 percent." + if discount_percent <= 15: + return "The regional sales director must approve discounts from 11 to 15 percent." + return "Finance and the regional sales director must both approve discounts above 15 percent." + + +async def main(model: str, question: str) -> None: + # This manifest becomes the workspace that the sandbox agent can inspect. + manifest = text_manifest( + { + "renewal_request.md": ( + "# Renewal request\n\n" + "- Customer: Contoso Manufacturing.\n" + "- Requested discount: 14 percent.\n" + "- Renewal term: 12 months.\n" + "- Requested close date: March 28.\n" + ), + "account_notes.md": ( + "# Account notes\n\n" + "- The customer expanded usage in two plants this quarter.\n" + "- Security review for the new data export workflow was opened last week.\n" + "- Procurement wants a final approval map before they send the order form.\n" + ), + } + ) + + # The reference MCP server is another local process. The agent can call its tools alongside + # the sandbox shell tool and the normal Python function tool. + async with MCPServerStdio( + name="Reference Policy Server", + params={ + "command": sys.executable, + "args": [ + str(Path(__file__).resolve().parent / "misc" / "reference_policy_mcp_server.py") + ], + }, + ) as server: + agent = SandboxAgent( + name="Renewal Review Assistant", + model=model, + # `instructions` is the base agent instructions for the multi-tool review task. + instructions=( + "You review renewal requests. Inspect the packet, use " + "`get_discount_approval_path` for discount routing, and use the MCP reference " + "policy server when you need confirmation. Before you answer, you must call " + "`get_discount_approval_path` and at least one MCP policy tool." + ), + # `developer_instructions` is appended after that as additional deterministic + # instructions. Here, the concise-answer constraint is kept there. + developer_instructions=( + "Keep the answer concise and business-ready. Mention which policy topic you " + "confirmed through MCP." + ), + default_manifest=manifest, + tools=[get_discount_approval_path], + mcp_servers=[server], + capabilities=[WorkspaceShellCapability()], + ) + + result = await Runner.run( + agent, + question, + run_config=RunConfig(sandbox=SandboxRunConfig(client=UnixLocalSandboxClient())), + ) + tool_names: list[str] = [] + for item in result.new_items: + if getattr(item, "type", None) != "tool_call_item": + continue + name = tool_call_name(item.raw_item) + if name: + tool_names.append(name) + if tool_names: + print(f"[tools used] {', '.join(tool_names)}") + print(result.final_output) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", default="gpt-5.4", help="Model name to use.") + parser.add_argument("--question", default=DEFAULT_QUESTION, help="Prompt to send to the agent.") + args = parser.parse_args() + + asyncio.run(main(args.model, args.question)) diff --git a/examples/sandbox/sandbox_agents_as_tools.py b/examples/sandbox/sandbox_agents_as_tools.py new file mode 100644 index 0000000000..4946931e9a --- /dev/null +++ b/examples/sandbox/sandbox_agents_as_tools.py @@ -0,0 +1,207 @@ +""" +Show how sandbox agents can be exposed as tools to a normal orchestrator. + +Each sandbox reviewer gets its own isolated workspace. The outer orchestrator +does not inspect files directly. It calls the reviewers as tools and combines +their outputs with a normal Python function tool. +""" + +import argparse +import asyncio +import json +import sys +from pathlib import Path +from typing import Literal + +from pydantic import BaseModel, Field + +from agents import Agent, ModelSettings, Runner, function_tool +from agents.run import RunConfig +from agents.sandbox import SandboxAgent, SandboxRunConfig +from agents.sandbox.sandboxes.unix_local import UnixLocalSandboxClient + +if __package__ is None or __package__ == "": + sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from examples.sandbox.misc.example_support import text_manifest, tool_call_name +from examples.sandbox.misc.workspace_shell import WorkspaceShellCapability + +DEFAULT_QUESTION = ( + "Review the Acme renewal materials and give me a short recommendation for the deal desk. " + "Include pricing risk, rollout risk, and the most important next step." +) + + +class PricingPacketReview(BaseModel): + requested_discount_percent: int = Field( + description="Exact requested discount percentage from pricing_summary.md." + ) + requested_term_months: int = Field( + description="Exact requested renewal term in months from pricing_summary.md." + ) + pricing_risk: Literal["low", "medium", "high"] + summary: str = Field(description="Short pricing risk summary grounded in the reviewed files.") + recommended_next_step: str = Field( + description="Most important commercial next step for the deal desk." + ) + evidence_files: list[str] = Field( + description="File names that support the review.", min_length=1 + ) + + +class RolloutRiskReview(BaseModel): + rollout_risk: Literal["low", "medium", "high"] + summary: str = Field(description="Short rollout risk summary grounded in the reviewed files.") + blockers: list[str] = Field(description="Concrete rollout blockers from the reviewed files.") + recommended_next_step: str = Field( + description="Most important delivery next step for the deal desk." + ) + evidence_files: list[str] = Field( + description="File names that support the review.", min_length=1 + ) + + +async def _structured_tool_output_extractor(result) -> str: + final_output = result.final_output + if isinstance(final_output, BaseModel): + return json.dumps(final_output.model_dump(mode="json"), sort_keys=True) + return str(final_output) + + +@function_tool +def get_discount_approval_rule(discount_percent: int) -> str: + """Return the internal approver required for a proposed discount.""" + if discount_percent <= 10: + return "Discounts up to 10 percent can be approved by the account executive." + if discount_percent <= 15: + return "Discounts from 11 to 15 percent require regional sales director approval." + return "Discounts above 15 percent require finance and regional sales director approval." + + +async def main(model: str, question: str) -> None: + # This manifest is visible only to the pricing reviewer. + pricing_manifest = text_manifest( + { + "pricing_summary.md": ( + "# Pricing summary\n\n" + "- Current annual contract: $220,000.\n" + "- Requested renewal term: 24 months.\n" + "- Requested discount: 15 percent.\n" + "- Account executive target discount band: 8 to 10 percent.\n" + ), + "commercial_notes.md": ( + "# Commercial notes\n\n" + "- The customer expanded from 120 to 170 paid seats in the last 6 months.\n" + "- Procurement asked for one final concession to close before quarter end.\n" + ), + } + ) + + # This separate manifest is visible only to the rollout reviewer. + rollout_manifest = text_manifest( + { + "rollout_plan.md": ( + "# Rollout plan\n\n" + "- Customer wants a 30-day rollout for three new regional teams.\n" + "- Regional admins have not completed training yet.\n" + "- SSO migration is scheduled for the second week of the rollout.\n" + ), + "support_history.md": ( + "# Support history\n\n" + "- Two high-priority onboarding tickets were closed in the last quarter.\n" + "- No open production incidents.\n" + "- Customer success manager asked for a phased launch if the contract closes.\n" + ), + } + ) + + pricing_agent = SandboxAgent( + name="Pricing Packet Reviewer", + model=model, + instructions=( + "You inspect renewal pricing documents and return a structured commercial review. " + "Inspect the files before answering and extract the exact requested discount percent " + "and renewal term from pricing_summary.md." + ), + developer_instructions=( + "Use the shell tool before answering. requested_discount_percent must match the exact " + "integer in pricing_summary.md. requested_term_months must match the exact renewal " + "term from pricing_summary.md. Do not introduce any facts, incidents, or numbers that " + "are not present in pricing_summary.md or commercial_notes.md. evidence_files must " + "list only files you actually inspected." + ), + default_manifest=pricing_manifest, + capabilities=[WorkspaceShellCapability()], + model_settings=ModelSettings(tool_choice="required"), + output_type=PricingPacketReview, + ) + rollout_agent = SandboxAgent( + name="Rollout Risk Reviewer", + model=model, + instructions=( + "You inspect rollout plans and return a structured delivery review. Inspect the files " + "before answering and keep the output tightly grounded in the rollout documents." + ), + developer_instructions=( + "Use the shell tool before answering. blockers must only contain issues that appear in " + "rollout_plan.md or support_history.md. Do not introduce any extra numbers, incidents, " + "or stakeholders beyond those files. evidence_files must list only files you actually " + "inspected." + ), + default_manifest=rollout_manifest, + capabilities=[WorkspaceShellCapability()], + model_settings=ModelSettings(tool_choice="required"), + output_type=RolloutRiskReview, + ) + + # Each sandbox-backed tool gets its own run configuration so the workspaces stay isolated. + pricing_run_config = RunConfig(sandbox=SandboxRunConfig(client=UnixLocalSandboxClient())) + rollout_run_config = RunConfig(sandbox=SandboxRunConfig(client=UnixLocalSandboxClient())) + + orchestrator = Agent( + name="Revenue Operations Coordinator", + model=model, + instructions=( + "You coordinate renewal reviews. Before answering, you must use all three tools: " + "`review_pricing_packet`, `review_rollout_risk`, and `get_discount_approval_rule`. " + "The review tools return JSON. Use the exact `requested_discount_percent` field from " + "`review_pricing_packet` when calling `get_discount_approval_rule`. In the final " + "recommendation, use only facts and numbers that appear in the tool outputs, and do " + "not add any extra incidents, price points, or contract terms." + ), + model_settings=ModelSettings(tool_choice="required"), + tools=[ + pricing_agent.as_tool( + tool_name="review_pricing_packet", + tool_description="Inspect the pricing packet and summarize commercial risk.", + custom_output_extractor=_structured_tool_output_extractor, + run_config=pricing_run_config, + ), + rollout_agent.as_tool( + tool_name="review_rollout_risk", + tool_description="Inspect the rollout packet and summarize implementation risk.", + custom_output_extractor=_structured_tool_output_extractor, + run_config=rollout_run_config, + ), + get_discount_approval_rule, + ], + ) + + result = await Runner.run(orchestrator, question) + tool_names = [ + tool_call_name(item.raw_item) + for item in result.new_items + if getattr(item, "type", None) == "tool_call_item" + ] + if tool_names: + print(f"[tools used] {', '.join(tool_names)}") + print(result.final_output) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", default="gpt-5.4", help="Model name to use.") + parser.add_argument("--question", default=DEFAULT_QUESTION, help="Prompt to send to the agent.") + args = parser.parse_args() + + asyncio.run(main(args.model, args.question)) diff --git a/examples/sandbox/tax_prep.py b/examples/sandbox/tax_prep.py new file mode 100644 index 0000000000..38112e2e5f --- /dev/null +++ b/examples/sandbox/tax_prep.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +import argparse +import asyncio +import sys +from pathlib import Path +from typing import cast + +from openai.types.responses import ResponseTextDeltaEvent + +from agents import Runner +from agents.items import TResponseInputItem +from agents.run import RunConfig +from agents.sandbox import Manifest, SandboxAgent, SandboxRunConfig +from agents.sandbox.capabilities import Skills +from agents.sandbox.entries import Dir, GitRepo, LocalFile + +if __package__ is None or __package__ == "": + sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from examples.sandbox.misc.workspace_apply_patch import WorkspaceApplyPatchCapability +from examples.sandbox.misc.workspace_shell import WorkspaceShellCapability + +DATA_PATH = Path(__file__).resolve().parent / "data" +W2_PATH = DATA_PATH / "sample_w2.pdf" +FORM_1040_PATH = DATA_PATH / "f1040.pdf" +DEFAULT_IMAGE = "tax-prep:latest" +DEFAULT_SKILLS_REPO = "sdcoffey/tax-prep-skills" +DEFAULT_SKILLS_REF = "main" +DEFAULT_QUESTION = "Please generate a 1040 for filing year 2025." + +INSTRUCTIONS = """ +You are a federal tax filing agent. Your job is to compute year-end taxes and +produce a filled-out Form 1040 for the specified tax year using the user's +provided documents. Use only the information in the supplied files. If required +data is missing or unclear, ask follow-up questions or note explicit +assumptions. Save the finalized, filled PDF in the `output/` directory and +provide a short summary of key amounts such as income, deductions, tax, and +refund or amount due. + +This is a demo, so assume the following unless the workspace says otherwise: +1. Filing status is single. +2. SSN is 123-45-6789. +3. Date of birth is 1991-01-01. +4. There are no other income documents. +5. If a minor data point is still needed, make up a clearly synthetic test value. + +Use the `federal-tax-prep` skill to accomplish this task. +""".strip() + + +def _require_docker_dependency(): + try: + from docker import from_env as docker_from_env # type: ignore[import-untyped] + except Exception as exc: # pragma: no cover - import path depends on local Docker setup + raise SystemExit( + "Docker-backed runs require the Docker SDK.\n" + "Install the repo dependencies with: make sync" + ) from exc + + from agents.sandbox.sandboxes.docker import DockerSandboxClient, DockerSandboxClientOptions + + return docker_from_env, DockerSandboxClient, DockerSandboxClientOptions + + +def _build_manifest() -> Manifest: + return Manifest( + entries={ + "taxpayer_data": Dir( + children={"sample_w2.pdf": LocalFile(src=W2_PATH)}, + description="Taxpayer income documents such as W-2s and 1099s.", + ), + "reference_forms": Dir( + children={"f1040.pdf": LocalFile(src=FORM_1040_PATH)}, + description="Blank tax forms the agent can use as templates.", + ), + "output": Dir(description="Write finalized tax documents here."), + } + ) + + +def _build_agent(*, model: str, skills_repo: str, skills_ref: str) -> SandboxAgent: + return SandboxAgent( + name="Tax Prep Assistant", + model=model, + instructions=INSTRUCTIONS, + developer_instructions=( + "Inspect the workspace before answering. Keep final explanations concise, and make " + "sure the final filled files are actually written into `output/`." + ), + default_manifest=_build_manifest(), + capabilities=[ + WorkspaceShellCapability(), + WorkspaceApplyPatchCapability(), + Skills( + from_=GitRepo(repo=skills_repo, ref=skills_ref), + ), + ], + codex=False, + ) + + +async def _copy_output_dir( + *, + session, + destination_root: Path, +) -> list[Path]: + destination_root.mkdir(parents=True, exist_ok=True) + remote_output_root = session.normalize_path("output") + + pending_dirs = [remote_output_root] + copied_files: list[Path] = [] + while pending_dirs: + current_dir = pending_dirs.pop() + for entry in await session.ls(current_dir): + entry_path = Path(entry.path) + if entry.is_dir(): + pending_dirs.append(entry_path) + continue + + relative_path = entry_path.relative_to(remote_output_root) + local_path = destination_root / relative_path + local_path.parent.mkdir(parents=True, exist_ok=True) + + handle = await session.read(entry_path) + try: + payload = handle.read() + finally: + handle.close() + + if isinstance(payload, str): + local_path.write_text(payload, encoding="utf-8") + else: + local_path.write_bytes(bytes(payload)) + copied_files.append(local_path) + + return copied_files + + +async def _run_turn( + *, + agent: SandboxAgent, + input_items: list[TResponseInputItem], + run_config: RunConfig, +) -> list[TResponseInputItem]: + stream_result = Runner.run_streamed(agent, input_items, run_config=run_config) + saw_text_delta = False + async for event in stream_result.stream_events(): + if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent): + if not saw_text_delta: + print("assistant> ", end="", flush=True) + saw_text_delta = True + print(event.data.delta, end="", flush=True) + continue + + if event.type == "run_item_stream_event" and event.name == "tool_called": + raw_item = getattr(event.item, "raw_item", None) + tool_name = "" + if isinstance(raw_item, dict): + tool_name = cast(str, raw_item.get("name") or raw_item.get("type") or "") + else: + tool_name = cast( + str, + getattr(raw_item, "name", None) or getattr(raw_item, "type", None) or "", + ) + if tool_name: + if saw_text_delta: + print() + saw_text_delta = False + print(f"[tool call] {tool_name}") + + if saw_text_delta: + print() + + return stream_result.to_input_list() + + +async def main( + *, + model: str, + image: str, + question: str, + output_dir: Path, + skills_repo: str, + skills_ref: str, +) -> None: + docker_from_env, DockerSandboxClient, DockerSandboxClientOptions = _require_docker_dependency() + agent = _build_agent(model=model, skills_repo=skills_repo, skills_ref=skills_ref) + client = DockerSandboxClient(docker_from_env()) + session = await client.create( + manifest=agent.default_manifest, + codex=agent.codex, + options=DockerSandboxClientOptions(image=image), + ) + + run_config = RunConfig( + sandbox=SandboxRunConfig(session=session), + workflow_name="Sandbox tax prep demo", + ) + + conversation: list[TResponseInputItem] = [{"role": "user", "content": question}] + + try: + async with session: + conversation = await _run_turn( + agent=agent, + input_items=conversation, + run_config=run_config, + ) + + while True: + try: + additional_input = input("> ") + except (EOFError, KeyboardInterrupt): + break + + conversation.append({"role": "user", "content": additional_input}) + conversation = await _run_turn( + agent=agent, + input_items=conversation, + run_config=run_config, + ) + + copied_files = await _copy_output_dir(session=session, destination_root=output_dir) + finally: + await client.delete(session) + + print(f"\nCopied {len(copied_files)} file(s) to {output_dir}") + for copied_file in copied_files: + print(copied_file) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", default="gpt-5.4", help="Model name to use.") + parser.add_argument("--image", default=DEFAULT_IMAGE, help="Docker image for the sandbox.") + parser.add_argument("--question", default=DEFAULT_QUESTION, help="Prompt to send to the agent.") + parser.add_argument( + "--output-dir", + default="tax-prep-results", + help="Local directory where files from sandbox output/ will be copied.", + ) + parser.add_argument( + "--skills-repo", + default=DEFAULT_SKILLS_REPO, + help="GitHub repo in owner/name form for the skills bundle.", + ) + parser.add_argument( + "--skills-ref", + default=DEFAULT_SKILLS_REF, + help="Git ref for the skills bundle.", + ) + args = parser.parse_args() + + asyncio.run( + main( + model=args.model, + image=args.image, + question=args.question, + output_dir=Path(args.output_dir).resolve(), + skills_repo=args.skills_repo, + skills_ref=args.skills_ref, + ) + ) diff --git a/examples/sandbox/unix_local_runner.py b/examples/sandbox/unix_local_runner.py new file mode 100644 index 0000000000..06c1937a3a --- /dev/null +++ b/examples/sandbox/unix_local_runner.py @@ -0,0 +1,115 @@ +""" +Start here if you want the simplest Unix-local sandbox example. + +This file mirrors the Docker example, but the sandbox runs as a temporary local +workspace on macOS or Linux instead of inside a Docker container. +""" + +import argparse +import asyncio +import sys +from pathlib import Path + +from openai.types.responses import ResponseTextDeltaEvent + +from agents import Runner +from agents.run import RunConfig +from agents.sandbox import SandboxAgent, SandboxRunConfig +from agents.sandbox.sandboxes.unix_local import UnixLocalSandboxClient + +if __package__ is None or __package__ == "": + sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from examples.sandbox.misc.example_support import text_manifest +from examples.sandbox.misc.workspace_shell import WorkspaceShellCapability + +DEFAULT_QUESTION = ( + "Review this renewal packet. Summarize the customer's situation, the likely blockers, " + "and the next two actions an account team should take." +) + + +async def main(model: str, question: str, stream: bool) -> None: + # The manifest is the file tree that will be materialized into the sandbox workspace. + manifest = text_manifest( + { + "account_brief.md": ( + "# Northwind Health\n\n" + "- Segment: Mid-market healthcare analytics provider.\n" + "- Annual contract value: $148,000.\n" + "- Renewal date: 2026-04-15.\n" + "- Executive sponsor: Director of Data Operations.\n" + ), + "renewal_request.md": ( + "# Renewal request\n\n" + "Northwind requested a 12 percent discount in exchange for a two-year renewal. " + "They also want a 45-day implementation timeline for a new reporting workspace.\n" + ), + "usage_notes.md": ( + "# Usage notes\n\n" + "- Weekly active users increased 18 percent over the last quarter.\n" + "- API traffic is stable.\n" + "- The customer still has one unresolved SSO configuration issue from onboarding.\n" + ), + "implementation_risks.md": ( + "# Delivery risks\n\n" + "- Security questionnaire for the new reporting workspace is not complete.\n" + "- Customer procurement requires final legal language by April 1.\n" + ), + } + ) + + # The sandbox agent sees the manifest as its workspace and uses one shared shell tool + # to inspect the files before answering. + agent = SandboxAgent( + name="Renewal Packet Analyst", + model=model, + # `instructions` is the base agent instructions for the renewal review task. + instructions=( + "You review renewal packets for an account team. Inspect the packet before answering. " + "Keep the response concise, business-focused, and cite the file names that support " + "each conclusion." + ), + # `developer_instructions` is appended after that as additional deterministic instructions. + # Here, the grounding constraints live in `developer_instructions`. + developer_instructions=( + "If a conclusion depends on a file, mention that file by name. Do not invent numbers " + "or statuses that are not present in the workspace." + ), + default_manifest=manifest, + capabilities=[WorkspaceShellCapability()], + ) + + # With Unix-local sandboxes, the runner creates and cleans up the temporary workspace for us. + run_config = RunConfig( + sandbox=SandboxRunConfig(client=UnixLocalSandboxClient()), + workflow_name="Unix local sandbox review", + ) + + if not stream: + result = await Runner.run(agent, question, run_config=run_config) + print(result.final_output) + return + + # The streaming path prints text deltas as they arrive so the example behaves like a demo. + stream_result = Runner.run_streamed(agent, question, run_config=run_config) + saw_text_delta = False + async for event in stream_result.stream_events(): + if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent): + if not saw_text_delta: + print("assistant> ", end="", flush=True) + saw_text_delta = True + print(event.data.delta, end="", flush=True) + + if saw_text_delta: + print() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", default="gpt-5.4", help="Model name to use.") + parser.add_argument("--question", default=DEFAULT_QUESTION, help="Prompt to send to the agent.") + parser.add_argument("--stream", action="store_true", default=False, help="Stream the response.") + args = parser.parse_args() + + asyncio.run(main(args.model, args.question, args.stream)) diff --git a/pyproject.toml b/pyproject.toml index f34a02e473..0d37883e32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "typing-extensions>=4.12.2, <5", "requests>=2.0, <3", "types-requests>=2.0, <3", + "websockets>=15.0, <16", "mcp>=1.19.0, <2; python_version >= '3.10'", ] classifiers = [ @@ -42,6 +43,9 @@ sqlalchemy = ["SQLAlchemy>=2.0", "asyncpg>=0.29.0"] encrypt = ["cryptography>=45.0, <46"] redis = ["redis>=7"] dapr = ["dapr>=1.16.0", "grpcio>=1.60.0"] +docker = ["docker>=6.1"] +modal = ["modal>=1.3.1"] +e2b = ["e2b>=2.12.1", "e2b-code-interpreter>=1.0"] [dependency-groups] dev = [ @@ -51,7 +55,7 @@ dev = [ "pytest-asyncio", "pytest-mock>=3.14.0", "pytest-xdist", - "rich>=13.1.0, <14", + "rich>=13.1.0, <15", "mkdocs>=1.6.0", "mkdocs-material>=9.6.0", "mkdocstrings[python]>=0.28.0", @@ -112,6 +116,7 @@ convention = "google" [tool.ruff.lint.per-file-ignores] "examples/**/*.py" = ["E501"] +"src/agents/sandbox/app_server/generated/v2_all.py" = ["E501"] [tool.mypy] strict = true @@ -123,9 +128,27 @@ disallow_untyped_calls = false module = "sounddevice.*" ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "agents.sandbox.app_server.generated.*" +ignore_errors = true + [tool.coverage.run] source = ["src/agents"] -omit = ["tests/*"] +omit = [ + "tests/*", + "src/agents/sandbox/sandboxes/*.py", + "src/agents/sandbox/task_context.py", + "src/agents/sandbox/task_runtime.py", + "src/agents/sandbox/materialization.py", + "src/agents/sandbox/entries/artifacts.py", + "src/agents/sandbox/entries/mounts/*.py", + "src/agents/sandbox/util/checksums.py", + "src/agents/sandbox/util/deep_merge.py", + "src/agents/sandbox/util/github.py", + "src/agents/sandbox/util/iterator_io.py", + "src/agents/sandbox/util/parse_utils.py", + "src/agents/sandbox/util/tar_utils.py", +] [tool.coverage.report] show_missing = true diff --git a/pyrightconfig.json b/pyrightconfig.json index 5ed525163c..9c1c39db87 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -1,5 +1,6 @@ { "include": ["src", "tests"], + "exclude": ["src/agents/sandbox/app_server/generated"], "extraPaths": ["."], "pythonVersion": "3.10", "typeCheckingMode": "basic", diff --git a/src/agents/_public_agent.py b/src/agents/_public_agent.py new file mode 100644 index 0000000000..e9550a31a2 --- /dev/null +++ b/src/agents/_public_agent.py @@ -0,0 +1,21 @@ +"""Helpers for preserving the user-visible agent identity during execution rewrites.""" + +from __future__ import annotations + +from .agent import Agent + +_PUBLIC_AGENT_ATTR = "_agents_public_agent" + + +def set_public_agent(execution_agent: Agent, public_agent: Agent) -> Agent: + """Tag an execution-only clone with the agent identity exposed to hooks and results.""" + setattr(execution_agent, _PUBLIC_AGENT_ATTR, public_agent) + return execution_agent + + +def get_public_agent(agent: Agent) -> Agent: + """Return the user-visible agent identity for hooks, tool execution, and results.""" + public_agent = getattr(agent, _PUBLIC_AGENT_ATTR, None) + if isinstance(public_agent, Agent): + return public_agent + return agent diff --git a/src/agents/agent.py b/src/agents/agent.py index dd291fcb8b..28999e6a0b 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -893,4 +893,10 @@ async def get_prompt( self, run_context: RunContextWrapper[TContext] ) -> ResponsePromptParam | None: """Get the prompt for the agent.""" - return await PromptUtil.to_model_input(self.prompt, run_context, self) + from ._public_agent import get_public_agent + + return await PromptUtil.to_model_input( + self.prompt, + run_context, + cast(Agent[TContext], get_public_agent(self)), + ) diff --git a/src/agents/extensions/sandbox/__init__.py b/src/agents/extensions/sandbox/__init__.py new file mode 100644 index 0000000000..3c0cb682d4 --- /dev/null +++ b/src/agents/extensions/sandbox/__init__.py @@ -0,0 +1,49 @@ +try: + from .sandboxes import ( + E2BSandboxClient as E2BSandboxClient, + E2BSandboxClientOptions as E2BSandboxClientOptions, + E2BSandboxSession as E2BSandboxSession, + E2BSandboxSessionState as E2BSandboxSessionState, + E2BSandboxTimeouts as E2BSandboxTimeouts, + E2BSandboxType as E2BSandboxType, + ) + + _HAS_E2B = True +except Exception: # pragma: no cover + _HAS_E2B = False + +try: + from .sandboxes import ( + ModalSandboxClient as ModalSandboxClient, + ModalSandboxClientOptions as ModalSandboxClientOptions, + ModalSandboxSession as ModalSandboxSession, + ModalSandboxSessionState as ModalSandboxSessionState, + ) + + _HAS_MODAL = True +except Exception: # pragma: no cover + _HAS_MODAL = False + +__all__: list[str] = [] + +if _HAS_E2B: + __all__.extend( + [ + "E2BSandboxClient", + "E2BSandboxClientOptions", + "E2BSandboxSession", + "E2BSandboxSessionState", + "E2BSandboxTimeouts", + "E2BSandboxType", + ] + ) + +if _HAS_MODAL: + __all__.extend( + [ + "ModalSandboxClient", + "ModalSandboxClientOptions", + "ModalSandboxSession", + "ModalSandboxSessionState", + ] + ) diff --git a/src/agents/extensions/sandbox/sandboxes/__init__.py b/src/agents/extensions/sandbox/sandboxes/__init__.py new file mode 100644 index 0000000000..95c089bdaf --- /dev/null +++ b/src/agents/extensions/sandbox/sandboxes/__init__.py @@ -0,0 +1,49 @@ +try: + from .e2b import ( + E2BSandboxClient as E2BSandboxClient, + E2BSandboxClientOptions as E2BSandboxClientOptions, + E2BSandboxSession as E2BSandboxSession, + E2BSandboxSessionState as E2BSandboxSessionState, + E2BSandboxTimeouts as E2BSandboxTimeouts, + E2BSandboxType as E2BSandboxType, + ) + + _HAS_E2B = True +except Exception: # pragma: no cover + _HAS_E2B = False + +try: + from .modal import ( + ModalSandboxClient as ModalSandboxClient, + ModalSandboxClientOptions as ModalSandboxClientOptions, + ModalSandboxSession as ModalSandboxSession, + ModalSandboxSessionState as ModalSandboxSessionState, + ) + + _HAS_MODAL = True +except Exception: # pragma: no cover + _HAS_MODAL = False + +__all__: list[str] = [] + +if _HAS_E2B: + __all__.extend( + [ + "E2BSandboxClient", + "E2BSandboxClientOptions", + "E2BSandboxSession", + "E2BSandboxSessionState", + "E2BSandboxTimeouts", + "E2BSandboxType", + ] + ) + +if _HAS_MODAL: + __all__.extend( + [ + "ModalSandboxClient", + "ModalSandboxClientOptions", + "ModalSandboxSession", + "ModalSandboxSessionState", + ] + ) diff --git a/src/agents/extensions/sandbox/sandboxes/e2b.py b/src/agents/extensions/sandbox/sandboxes/e2b.py new file mode 100644 index 0000000000..aa15b413fe --- /dev/null +++ b/src/agents/extensions/sandbox/sandboxes/e2b.py @@ -0,0 +1,1000 @@ +from __future__ import annotations + +import base64 +import binascii +import inspect +import io +import shlex +import tarfile +import uuid +from collections.abc import Awaitable, Mapping +from dataclasses import dataclass +from enum import Enum +from pathlib import Path, PurePosixPath +from typing import cast + +from pydantic import BaseModel, Field + +from ....sandbox.codex_config import ( + CodexConfig, + apply_codex_to_manifest, + apply_codex_to_session_state, +) +from ....sandbox.entries import Mount, resolve_workspace_path +from ....sandbox.errors import ( + ExecNonZeroError, + ExecTimeoutError, + ExecTransportError, + WorkspaceArchiveReadError, + WorkspaceArchiveWriteError, + WorkspaceReadNotFoundError, + WorkspaceStartError, + WorkspaceWriteTypeError, +) +from ....sandbox.manifest import Manifest +from ....sandbox.session import SandboxSession, SandboxSessionState +from ....sandbox.session.base_sandbox_session import BaseSandboxSession +from ....sandbox.session.dependencies import Dependencies +from ....sandbox.session.manager import Instrumentation +from ....sandbox.session.sandbox_client import BaseSandboxClient +from ....sandbox.snapshot import SnapshotSpec, resolve_snapshot +from ....sandbox.types import ExecResult +from ....sandbox.util.retry import ( + TRANSIENT_HTTP_STATUS_CODES, + exception_chain_contains_type, + exception_chain_has_status_code, + retry_async, +) + + +class _E2BFilesAPI: + def write( + self, + path: str, + data: bytes, + request_timeout: float | None = None, + ) -> object: + raise NotImplementedError + + def remove(self, path: str, request_timeout: float | None = None) -> object: + raise NotImplementedError + + def make_dir(self, path: str, request_timeout: float | None = None) -> object: + raise NotImplementedError + + def read(self, path: str, format: str = "bytes") -> object: + raise NotImplementedError + + +class _E2BCommandsAPI: + def run( + self, + command: str, + timeout: float | None = None, + cwd: str | None = None, + envs: dict[str, str] | None = None, + user: str | None = None, + ) -> object: + raise NotImplementedError + + +class _E2BSandboxAPI: + sandbox_id: object + files: _E2BFilesAPI + commands: _E2BCommandsAPI + + def beta_pause(self) -> object: + raise NotImplementedError + + def kill(self) -> object: + raise NotImplementedError + + def is_running(self, request_timeout: float | None = None) -> object: + raise NotImplementedError + + +class _E2BSandboxFactoryAPI: + def create( + self, + *, + template: str | None = None, + timeout: int | None = None, + metadata: dict[str, str] | None = None, + envs: dict[str, str] | None = None, + secure: bool = True, + allow_internet_access: bool = True, + ) -> object: + raise NotImplementedError + + def _cls_connect( + self, + *, + sandbox_id: str, + timeout: int | None = None, + ) -> object: + raise NotImplementedError + + +# NOTE: We avoid importing `e2b_code_interpreter` or `e2b` at module import time so that users +# without the optional dependency can still import the sandbox package (they just can't use the +# E2B sandbox). + + +class E2BSandboxType(str, Enum): + """Supported E2B sandbox implementations.""" + + CODE_INTERPRETER_ASYNC = "e2b_code_interpreter_async" + CODE_INTERPRETER = "e2b_code_interpreter" + E2B_ASYNC = "e2b_async" + E2B = "e2b" + + +def _coerce_sandbox_type(value: E2BSandboxType | str | None) -> E2BSandboxType: + if value is None: + raise ValueError( + "E2BSandboxClientOptions.sandbox_type is required. " + "Use one of: e2b_code_interpreter_async, e2b_code_interpreter, e2b_async, e2b." + ) + if isinstance(value, E2BSandboxType): + return value + try: + return E2BSandboxType(value) + except ValueError as e: + raise ValueError( + "Invalid E2BSandboxClientOptions.sandbox_type. " + "Use one of: e2b_code_interpreter_async, e2b_code_interpreter, e2b_async, e2b." + ) from e + + +def _import_sandbox_class(sandbox_type: E2BSandboxType) -> _E2BSandboxFactoryAPI: + if sandbox_type in { + E2BSandboxType.CODE_INTERPRETER_ASYNC, + E2BSandboxType.CODE_INTERPRETER, + }: + module_name = "e2b_code_interpreter" + class_name = ( + "AsyncSandbox" if sandbox_type is E2BSandboxType.CODE_INTERPRETER_ASYNC else "Sandbox" + ) + missing_msg = ( + "E2BSandboxClient requires the optional `e2b-code-interpreter` dependency.\n" + "Install the E2B extra before using this sandbox backend." + ) + else: + module_name = "e2b" + class_name = "AsyncSandbox" if sandbox_type is E2BSandboxType.E2B_ASYNC else "Sandbox" + missing_msg = ( + "E2BSandboxClient requires the optional `e2b` dependency.\n" + "Install the E2B extra before using this sandbox backend." + ) + + try: + module = __import__(module_name, fromlist=[class_name]) + sandbox_cls = getattr(module, class_name) + except Exception as e: # pragma: no cover - exercised via unit tests with fakes + if module_name == "e2b": + try: + module = __import__("e2b.sandbox", fromlist=[class_name]) + sandbox_cls = getattr(module, class_name) + except Exception: + raise ImportError(missing_msg) from e + else: + raise ImportError(missing_msg) from e + + return cast(_E2BSandboxFactoryAPI, sandbox_cls) + + +def _as_sandbox_api(sandbox: object) -> _E2BSandboxAPI: + return cast(_E2BSandboxAPI, sandbox) + + +def _sandbox_id(sandbox: object) -> object: + return _as_sandbox_api(sandbox).sandbox_id + + +def _sandbox_write_file( + sandbox: object, + path: str, + data: bytes, + *, + request_timeout: float | None = None, +) -> object: + return _as_sandbox_api(sandbox).files.write( + path, + data, + request_timeout=request_timeout, + ) + + +def _sandbox_remove_file( + sandbox: object, + path: str, + *, + request_timeout: float | None = None, +) -> object: + return _as_sandbox_api(sandbox).files.remove(path, request_timeout=request_timeout) + + +def _sandbox_make_dir( + sandbox: object, + path: str, + *, + request_timeout: float | None = None, +) -> object: + return _as_sandbox_api(sandbox).files.make_dir(path, request_timeout=request_timeout) + + +def _sandbox_read_file(sandbox: object, path: str, *, format: str = "bytes") -> object: + return _as_sandbox_api(sandbox).files.read(path, format=format) + + +def _sandbox_run_command( + sandbox: object, + command: str, + *, + timeout: float | None = None, + cwd: str | None = None, + envs: dict[str, str] | None = None, + user: str | None = None, +) -> object: + return _as_sandbox_api(sandbox).commands.run( + command, + timeout=timeout, + cwd=cwd, + envs=envs, + user=user, + ) + + +def _sandbox_pause(sandbox: object) -> object: + return _as_sandbox_api(sandbox).beta_pause() + + +def _sandbox_kill(sandbox: object) -> object: + return _as_sandbox_api(sandbox).kill() + + +def _sandbox_is_running(sandbox: object, *, request_timeout: float | None = None) -> object: + return _as_sandbox_api(sandbox).is_running(request_timeout=request_timeout) + + +def _sandbox_create( + sandbox_class: _E2BSandboxFactoryAPI, + *, + template: str | None = None, + timeout: int | None = None, + metadata: dict[str, str] | None = None, + envs: dict[str, str] | None = None, + secure: bool = True, + allow_internet_access: bool = True, +) -> object: + return sandbox_class.create( + template=template, + timeout=timeout, + metadata=metadata, + envs=envs, + secure=secure, + allow_internet_access=allow_internet_access, + ) + + +def _sandbox_connect( + sandbox_class: _E2BSandboxFactoryAPI, + *, + sandbox_id: str, + timeout: int | None = None, +) -> object: + return sandbox_class._cls_connect(sandbox_id=sandbox_id, timeout=timeout) + + +async def _maybe_await(value: object) -> object: + if inspect.isawaitable(value): + return await cast(Awaitable[object], value) + return value + + +def _import_e2b_exceptions() -> Mapping[str, type[BaseException]]: + """Best-effort import of E2B exception classes for classification.""" + + try: + from e2b.exceptions import ( # type: ignore[import-untyped] + NotFoundException, + SandboxException, + TimeoutException, + ) + except Exception: # pragma: no cover - handled by fallbacks + return {} + + return { + "not_found": cast(type[BaseException], NotFoundException), + "sandbox": cast(type[BaseException], SandboxException), + "timeout": cast(type[BaseException], TimeoutException), + } + + +def _import_command_exit_exception() -> type[BaseException] | None: + try: + from e2b.sandbox.commands.command_handle import ( # type: ignore[import-untyped] + CommandExitException, + ) + except Exception: # pragma: no cover - handled by fallbacks + return None + return cast(type[BaseException], CommandExitException) + + +def _retryable_persist_workspace_error_types() -> tuple[type[BaseException], ...]: + excs = _import_e2b_exceptions() + retryable: list[type[BaseException]] = [] + timeout_exc = excs.get("timeout") + if timeout_exc is not None: + retryable.append(timeout_exc) + return tuple(retryable) + + +class E2BSandboxTimeouts(BaseModel): + """Timeout configuration for E2B operations.""" + + # E2B commands default to a 60s timeout when `timeout=None`. Sandbox semantics + # for `timeout=None` are "no timeout", so we pass a large sentinel value instead. + exec_timeout_unbounded_s: float = Field(default=24 * 60 * 60, ge=1) # 24 hours + + # Keepalive / is_running should be quick; if it does not return promptly, + # the sandbox is unhealthy. + keepalive_s: float = Field(default=5, ge=1) + + # best-effort cleanup (e.g., removing temp tar files) should not block shutdown for long. + cleanup_s: float = Field(default=30, ge=1) + + # fast, small ops like `mkdir -p` / `cat` / metadata-ish operations. + fast_op_s: float = Field(default=10, ge=1) + + # uploading tar contents can take longer than fast ops. + file_upload_s: float = Field(default=30, ge=1) + + # snapshot tar ops can be heavier on large workspaces. + snapshot_tar_s: float = Field(default=60, ge=1) + + +@dataclass(frozen=True) +class E2BSandboxClientOptions: + """Client options for the E2B sandbox.""" + + sandbox_type: E2BSandboxType | str + template: str | None = None + timeout: int | None = None + metadata: dict[str, str] | None = None + envs: dict[str, str] | None = None + secure: bool = True + allow_internet_access: bool = True + timeouts: E2BSandboxTimeouts | dict[str, object] | None = None + pause_on_exit: bool = False + + +class E2BSandboxSessionState(SandboxSessionState): + sandbox_id: str + sandbox_type: E2BSandboxType = Field(default=E2BSandboxType.CODE_INTERPRETER_ASYNC) + template: str | None = None + sandbox_timeout: int | None = None + metadata: dict[str, str] | None = None + base_envs: dict[str, str] = Field(default_factory=dict) + secure: bool = True + allow_internet_access: bool = True + timeouts: E2BSandboxTimeouts = Field(default_factory=E2BSandboxTimeouts) + pause_on_exit: bool = False + workspace_root_ready: bool = False + + +class E2BSandboxSession(BaseSandboxSession): + """E2B-backed sandbox session implementation.""" + + state: E2BSandboxSessionState + _sandbox: _E2BSandboxAPI + _skip_start: bool + _workspace_root_ready: bool + _resume_preserves_system_state: bool + + def __init__( + self, + *, + state: E2BSandboxSessionState, + sandbox: object, + ) -> None: + self.state = state + self._sandbox = _as_sandbox_api(sandbox) + self._skip_start = False + self._workspace_root_ready = state.workspace_root_ready + self._resume_preserves_system_state = False + + @classmethod + def from_state( + cls, + state: E2BSandboxSessionState, + *, + sandbox: object, + ) -> E2BSandboxSession: + return cls(state=state, sandbox=sandbox) + + @property + def sandbox_id(self) -> str: + return self.state.sandbox_id + + async def _resolved_envs(self) -> dict[str, str]: + manifest_envs = await self.state.manifest.environment.resolve() + # Manifest envs take precedence over base envs supplied via client options. + return {**self.state.base_envs, **manifest_envs} + + def _coerce_exec_timeout(self, timeout_s: float | None) -> float: + if timeout_s is None: + return float(self.state.timeouts.exec_timeout_unbounded_s) + if timeout_s <= 0: + # Sandbox timeout cannot be <= 0; use 1s and rely on caller semantics. + return 1.0 + return float(timeout_s) + + async def _ensure_dir(self, path: Path, *, reason: str) -> None: + """Create a directory using the E2B Files API.""" + if path == Path("/"): + return + try: + await _maybe_await( + _sandbox_make_dir( + self._sandbox, + str(path), + request_timeout=self.state.timeouts.fast_op_s, + ) + ) + except Exception as e: # pragma: no cover - exercised via unit tests with fakes + raise WorkspaceArchiveWriteError(path=path, context={"reason": reason}, cause=e) from e + + async def _ensure_workspace_root(self) -> None: + """Ensure the workspace root exists before materialization starts.""" + await self._ensure_dir(Path(self.state.manifest.root), reason="root_make_failed") + + async def _prepare_workspace_root_for_exec(self) -> None: + """Create the workspace root through the command API before using it as `cwd`.""" + root = str(Path(self.state.manifest.root)) + envs = await self._resolved_envs() + result = await _maybe_await( + _sandbox_run_command( + self._sandbox, + f"mkdir -p -- {shlex.quote(root)}", + timeout=self.state.timeouts.fast_op_s, + cwd="/", + envs=envs, + ) + ) + exit_code = int(getattr(result, "exit_code", 0) or 0) + if exit_code != 0: + raise WorkspaceStartError( + path=Path(self.state.manifest.root), + context={ + "reason": "workspace_root_nonzero_exit", + "exit_code": exit_code, + "stderr": str(getattr(result, "stderr", "") or ""), + }, + ) + self._workspace_root_ready = True + self.state.workspace_root_ready = True + + def should_provision_manifest_accounts_on_resume(self) -> bool: + return not self._resume_preserves_system_state + + async def start(self) -> None: + if self._skip_start: + if not self._workspace_root_ready: + try: + await self._prepare_workspace_root_for_exec() + except WorkspaceStartError: + raise + except Exception as e: + raise WorkspaceStartError(path=Path(self.state.manifest.root), cause=e) from e + return + try: + # Ensure the workspace root exists before manifest materialization/hydration occurs. + await self._ensure_workspace_root() + await self._prepare_workspace_root_for_exec() + except WorkspaceStartError: + raise + except Exception as e: + raise WorkspaceStartError(path=Path(self.state.manifest.root), cause=e) from e + + await super().start() + + async def stop(self) -> None: + await super().stop() + + async def shutdown(self) -> None: + # Best-effort kill of the remote sandbox. + try: + if self.state.pause_on_exit: + await _maybe_await(_sandbox_pause(self._sandbox)) + else: + await _maybe_await(_sandbox_kill(self._sandbox)) + except Exception: + if self.state.pause_on_exit: + try: + await _maybe_await(_sandbox_kill(self._sandbox)) + except Exception: + pass + else: + pass + + async def _exec_internal( + self, + *command: str | Path, + timeout: float | None = None, + ) -> ExecResult: + command_list = [str(c) for c in command] + envs = await self._resolved_envs() + cwd = self.state.manifest.root if self._workspace_root_ready else None + user: str | None = None + if command_list and command_list[0] == "sudo" and len(command_list) >= 4: + # Handle the `sudo -u -- ...` prefix introduced by SandboxSession.exec. + if command_list[1] == "-u" and command_list[3] == "--": + user = command_list[2] + command_list = command_list[4:] + + cmd_str = shlex.join(command_list) + exec_timeout = self._coerce_exec_timeout(timeout) + + e2b_exc = _import_e2b_exceptions() + timeout_exc = e2b_exc.get("timeout") + command_exit_exc = _import_command_exit_exception() + + try: + result = await _maybe_await( + _sandbox_run_command( + self._sandbox, + cmd_str, + timeout=exec_timeout, + cwd=cwd, + envs=envs, + user=user, + ) + ) + return ExecResult( + stdout=str(getattr(result, "stdout", "") or "").encode("utf-8", errors="replace"), + stderr=str(getattr(result, "stderr", "") or "").encode("utf-8", errors="replace"), + exit_code=int(getattr(result, "exit_code", 0) or 0), + ) + except Exception as e: # pragma: no cover - exercised via unit tests with fakes + if timeout_exc is not None and isinstance(e, timeout_exc): + raise ExecTimeoutError(command=command, timeout_s=timeout, cause=e) from e + + if command_exit_exc is not None and isinstance(e, command_exit_exc): + exit_code = int(getattr(e, "exit_code", 1) or 1) + stdout = str(getattr(e, "stdout", "") or "") + stderr = str(getattr(e, "stderr", "") or "") + return ExecResult( + stdout=stdout.encode("utf-8", errors="replace"), + stderr=stderr.encode("utf-8", errors="replace"), + exit_code=exit_code, + ) + + raise ExecTransportError(command=command, cause=e) from e + + async def read(self, path: Path) -> io.IOBase: + workspace_path = resolve_workspace_path( + Path(self.state.manifest.root), + path, + allow_absolute_within_root=True, + ) + + e2b_exc = _import_e2b_exceptions() + not_found_exc = e2b_exc.get("not_found") + + try: + content = await _maybe_await( + _sandbox_read_file(self._sandbox, str(workspace_path), format="bytes") + ) + if isinstance(content, (bytes, bytearray)): + data = bytes(content) + elif isinstance(content, str): + data = content.encode("utf-8", errors="replace") + else: + data = str(content).encode("utf-8", errors="replace") + return io.BytesIO(data) + except Exception as e: # pragma: no cover - exercised via unit tests with fakes + if not_found_exc is not None and isinstance(e, not_found_exc): + raise WorkspaceReadNotFoundError(path=path, cause=e) from e + raise WorkspaceArchiveReadError(path=path, cause=e) from e + + async def write(self, path: Path, data: io.IOBase) -> None: + payload = data.read() + if isinstance(payload, str): + payload = payload.encode("utf-8") + if not isinstance(payload, (bytes, bytearray)): + raise WorkspaceWriteTypeError(path=path, actual_type=type(payload).__name__) + + workspace_path = resolve_workspace_path( + Path(self.state.manifest.root), + path, + allow_absolute_within_root=True, + ) + + try: + await _maybe_await( + _sandbox_write_file( + self._sandbox, + str(workspace_path), + bytes(payload), + request_timeout=self.state.timeouts.file_upload_s, + ) + ) + except Exception as e: # pragma: no cover - exercised via unit tests with fakes + raise WorkspaceArchiveWriteError(path=workspace_path, cause=e) from e + + async def running(self) -> bool: + if not self._workspace_root_ready: + return False + try: + return bool( + await _maybe_await( + _sandbox_is_running( + self._sandbox, + request_timeout=self.state.timeouts.keepalive_s, + ) + ) + ) + except Exception: + return False + + async def mkdir(self, path: Path | str, *, parents: bool = False) -> None: + path = self.normalize_path(path) + if not parents: + parent = path.parent + test = await self.exec("test", "-d", str(parent), shell=False) + if not test.ok(): + raise ExecNonZeroError(test, command=("test", "-d", str(parent))) + await self._ensure_dir(path, reason="mkdir_failed") + + def _tar_exclude_args(self) -> list[str]: + excludes: list[str] = [] + for rel in sorted(self._persist_workspace_skip_relpaths(), key=lambda p: p.as_posix()): + rel_posix = rel.as_posix().lstrip("/") + if not rel_posix or rel_posix in {".", "/"}: + continue + excludes.append(f"--exclude={shlex.quote(rel_posix)}") + excludes.append(f"--exclude={shlex.quote(f'./{rel_posix}')}") + return excludes + + @retry_async( + retry_if=lambda exc, self, tar_cmd: exception_chain_contains_type( + exc, _retryable_persist_workspace_error_types() + ) + or exception_chain_has_status_code(exc, TRANSIENT_HTTP_STATUS_CODES) + ) + async def _run_persist_workspace_command(self, tar_cmd: str) -> str: + try: + envs = await self._resolved_envs() + result = await _maybe_await( + _sandbox_run_command( + self._sandbox, + tar_cmd, + timeout=self.state.timeouts.snapshot_tar_s, + cwd="/", + envs=envs, + ) + ) + exit_code = int(getattr(result, "exit_code", 0) or 0) + if exit_code != 0: + raise WorkspaceArchiveReadError( + path=Path(self.state.manifest.root), + context={ + "reason": "snapshot_nonzero_exit", + "exit_code": exit_code, + "stderr": str(getattr(result, "stderr", "") or ""), + }, + ) + return str(getattr(result, "stdout", "") or "") + except WorkspaceArchiveReadError: + raise + except Exception as e: # pragma: no cover - exercised via unit tests with fakes + raise WorkspaceArchiveReadError(path=Path(self.state.manifest.root), cause=e) from e + + async def persist_workspace(self) -> io.IOBase: + def _error_context_summary(error: WorkspaceArchiveReadError) -> dict[str, str]: + summary = {"message": error.message} + if error.cause is not None: + summary["cause_type"] = type(error.cause).__name__ + summary["cause"] = str(error.cause) + return summary + + root = Path(self.state.manifest.root) + excludes = " ".join(self._tar_exclude_args()) + tar_cmd = f"tar {excludes} -C {shlex.quote(str(root))} -cf - . | base64 -w0" + unmounted_mounts: list[tuple[Mount, Path]] = [] + unmount_error: WorkspaceArchiveReadError | None = None + for mount_entry, mount_path in self.state.manifest.ephemeral_mount_targets(): + try: + await mount_entry.unmount_path(self, mount_path) + except Exception as e: + unmount_error = WorkspaceArchiveReadError(path=root, cause=e) + break + unmounted_mounts.append((mount_entry, mount_path)) + + snapshot_error: WorkspaceArchiveReadError | None = None + raw: bytes | None = None + if unmount_error is None: + try: + encoded = await self._run_persist_workspace_command(tar_cmd) + try: + raw = base64.b64decode(encoded.encode("utf-8"), validate=True) + except (binascii.Error, ValueError) as e: + raise WorkspaceArchiveReadError( + path=root, + context={"reason": "snapshot_invalid_base64"}, + cause=e, + ) from e + except WorkspaceArchiveReadError as e: + snapshot_error = e + + remount_error: WorkspaceArchiveReadError | None = None + for mount_entry, mount_path in reversed(unmounted_mounts): + try: + await mount_entry.mount(self, mount_path) + except Exception as e: + current_error = WorkspaceArchiveReadError(path=root, cause=e) + if remount_error is None: + remount_error = current_error + if unmount_error is not None: + remount_error.context["earlier_unmount_error"] = _error_context_summary( + unmount_error + ) + else: + additional_remount_errors = remount_error.context.setdefault( + "additional_remount_errors", [] + ) + assert isinstance(additional_remount_errors, list) + additional_remount_errors.append(_error_context_summary(current_error)) + + if remount_error is not None: + if snapshot_error is not None: + remount_error.context["snapshot_error_before_remount_corruption"] = ( + _error_context_summary(snapshot_error) + ) + raise remount_error + if unmount_error is not None: + raise unmount_error + if snapshot_error is not None: + raise snapshot_error + + assert raw is not None + return io.BytesIO(raw) + + def _validate_tar_bytes(self, raw: bytes) -> None: + try: + with tarfile.open(fileobj=io.BytesIO(raw), mode="r:*") as tar: + for member in tar.getmembers(): + name = member.name + if name in ("", ".", "./"): + continue + rel = PurePosixPath(name) + if rel.is_absolute(): + raise ValueError(f"absolute path member: {name}") + if ".." in rel.parts: + raise ValueError(f"parent traversal member: {name}") + if member.issym() or member.islnk(): + raise ValueError(f"link member not allowed: {name}") + if not (member.isdir() or member.isreg()): + raise ValueError(f"unsupported member type: {name}") + except (tarfile.TarError, OSError) as e: + raise ValueError("invalid tar stream") from e + + async def hydrate_workspace(self, data: io.IOBase) -> None: + root = Path(self.state.manifest.root) + tar_path = f"/tmp/uc-hydrate-{self.state.session_id.hex}.tar" + + raw = data.read() + if isinstance(raw, str): + raw = raw.encode("utf-8") + if not isinstance(raw, (bytes, bytearray)): + raise WorkspaceWriteTypeError(path=Path(tar_path), actual_type=type(raw).__name__) + + try: + self._validate_tar_bytes(bytes(raw)) + except ValueError as e: + raise WorkspaceArchiveWriteError( + path=root, + context={"reason": "unsafe_or_invalid_tar", "detail": str(e)}, + cause=e, + ) from e + + try: + await self._ensure_workspace_root() + envs = await self._resolved_envs() + await _maybe_await( + _sandbox_write_file( + self._sandbox, + tar_path, + bytes(raw), + request_timeout=self.state.timeouts.file_upload_s, + ) + ) + result = await _maybe_await( + _sandbox_run_command( + self._sandbox, + f"tar -C {shlex.quote(str(root))} -xf {shlex.quote(tar_path)}", + timeout=self.state.timeouts.snapshot_tar_s, + cwd="/", + envs=envs, + ) + ) + exit_code = int(getattr(result, "exit_code", 0) or 0) + if exit_code != 0: + raise WorkspaceArchiveWriteError( + path=root, + context={ + "reason": "hydrate_nonzero_exit", + "exit_code": exit_code, + "stderr": str(getattr(result, "stderr", "") or ""), + }, + ) + self._workspace_root_ready = True + self.state.workspace_root_ready = True + except WorkspaceArchiveWriteError: + raise + except Exception as e: # pragma: no cover - exercised via unit tests with fakes + raise WorkspaceArchiveWriteError(path=root, cause=e) from e + finally: + try: + envs = await self._resolved_envs() + await _maybe_await( + _sandbox_run_command( + self._sandbox, + f"rm -f -- {shlex.quote(tar_path)}", + timeout=self.state.timeouts.cleanup_s, + cwd="/", + envs=envs, + ) + ) + except Exception: + pass + + +class E2BSandboxClient(BaseSandboxClient[E2BSandboxClientOptions]): + backend_id = "e2b" + _instrumentation: Instrumentation + + def __init__( + self, + *, + instrumentation: Instrumentation | None = None, + dependencies: Dependencies | None = None, + ) -> None: + self._instrumentation = instrumentation or Instrumentation() + self._dependencies = dependencies + + async def create( + self, + *, + snapshot: SnapshotSpec | None = None, + manifest: Manifest | None = None, + codex: bool | CodexConfig = False, + options: E2BSandboxClientOptions, + ) -> SandboxSession: + if options is None: + raise ValueError("E2BSandboxClient.create requires options") + + manifest = apply_codex_to_manifest(manifest, codex) + sandbox_type = _coerce_sandbox_type(options.sandbox_type) + + timeouts_in = options.timeouts + if isinstance(timeouts_in, E2BSandboxTimeouts): + timeouts = timeouts_in + elif timeouts_in is None: + timeouts = E2BSandboxTimeouts() + else: + timeouts = E2BSandboxTimeouts.model_validate(timeouts_in) + + base_envs = dict(options.envs or {}) + manifest_envs = await manifest.environment.resolve() + envs = {**base_envs, **manifest_envs} or None + + SandboxClass = _import_sandbox_class(sandbox_type) + sandbox = await _maybe_await( + _sandbox_create( + SandboxClass, + template=options.template, + timeout=options.timeout, + metadata=options.metadata, + envs=envs, + secure=options.secure, + allow_internet_access=options.allow_internet_access, + ) + ) + + session_id = uuid.uuid4() + snapshot_instance = resolve_snapshot(snapshot, str(session_id)) + state = E2BSandboxSessionState( + session_id=session_id, + manifest=manifest, + snapshot=snapshot_instance, + sandbox_id=str(_sandbox_id(sandbox)), + sandbox_type=sandbox_type, + template=options.template, + sandbox_timeout=options.timeout, + metadata=options.metadata, + base_envs=base_envs, + secure=options.secure, + allow_internet_access=options.allow_internet_access, + timeouts=timeouts, + pause_on_exit=options.pause_on_exit, + ) + inner = E2BSandboxSession.from_state(state, sandbox=sandbox) + return self._wrap_session(inner, instrumentation=self._instrumentation) + + async def delete(self, session: SandboxSession) -> SandboxSession: + inner = session._inner + if not isinstance(inner, E2BSandboxSession): + raise TypeError("E2BSandboxClient.delete expects an E2BSandboxSession") + return session + + async def resume( + self, + state: SandboxSessionState, + *, + codex: bool | CodexConfig = False, + ) -> SandboxSession: + if not isinstance(state, E2BSandboxSessionState): + raise TypeError("E2BSandboxClient.resume expects an E2BSandboxSessionState") + state = apply_codex_to_session_state(state, codex) + + sandbox_type = _coerce_sandbox_type(state.sandbox_type) + SandboxClass = _import_sandbox_class(sandbox_type) + + base_envs = dict(state.base_envs) + manifest_envs = await state.manifest.environment.resolve() + envs = {**base_envs, **manifest_envs} or None + + sandbox: object + reconnected = False + try: + # `_cls_connect` is the current async entrypoint for re-attaching to a sandbox id. + sandbox = await _maybe_await( + _sandbox_connect( + SandboxClass, + sandbox_id=state.sandbox_id, + timeout=state.sandbox_timeout, + ) + ) + if not state.pause_on_exit: + is_running = await _maybe_await( + _sandbox_is_running(sandbox, request_timeout=state.timeouts.keepalive_s) + ) + if not is_running: + raise RuntimeError("sandbox_not_running") + reconnected = True + except Exception: + sandbox = await _maybe_await( + _sandbox_create( + SandboxClass, + template=state.template, + timeout=state.sandbox_timeout, + metadata=state.metadata, + envs=envs, + secure=state.secure, + allow_internet_access=state.allow_internet_access, + ) + ) + state.sandbox_id = str(_sandbox_id(sandbox)) + + inner = E2BSandboxSession.from_state(state, sandbox=sandbox) + inner._resume_preserves_system_state = reconnected + if state.pause_on_exit and reconnected: + inner._skip_start = True + else: + inner._skip_start = False + return self._wrap_session(inner, instrumentation=self._instrumentation) + + def deserialize_session_state(self, payload: dict[str, object]) -> SandboxSessionState: + return E2BSandboxSessionState.model_validate(payload) + + +__all__ = [ + "E2BSandboxClient", + "E2BSandboxClientOptions", + "E2BSandboxSession", + "E2BSandboxSessionState", + "E2BSandboxTimeouts", + "E2BSandboxType", +] diff --git a/src/agents/extensions/sandbox/sandboxes/modal.py b/src/agents/extensions/sandbox/sandboxes/modal.py new file mode 100644 index 0000000000..c09d7c2b9a --- /dev/null +++ b/src/agents/extensions/sandbox/sandboxes/modal.py @@ -0,0 +1,1045 @@ +""" +Modal sandbox (https://modal.com) implementation. + +Run `python -m modal setup` to configure Modal locally. + +This module provides a Modal-backed sandbox client/session implementation backed by +`modal.Sandbox`. + +Note: The `modal` dependency is intended to be optional (installed via an extra), +so package-level exports should guard imports of this module. Within this module, +we import Modal normally so IDEs can resolve and navigate Modal types. +""" + +from __future__ import annotations + +import asyncio +import functools +import io +import json +import logging +import math +import shlex +import tarfile +import uuid +from collections.abc import Awaitable +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, Literal, TypeVar, cast + +import modal +from modal.container_process import ContainerProcess + +from ....sandbox.codex_config import ( + CodexConfig, + apply_codex_to_manifest, + apply_codex_to_session_state, +) +from ....sandbox.entries import resolve_workspace_path +from ....sandbox.errors import ( + ExecTimeoutError, + ExecTransportError, + WorkspaceArchiveReadError, + WorkspaceArchiveWriteError, + WorkspaceReadNotFoundError, + WorkspaceStartError, + WorkspaceStopError, + WorkspaceWriteTypeError, +) +from ....sandbox.manifest import Manifest +from ....sandbox.session import SandboxSession, SandboxSessionState +from ....sandbox.session.base_sandbox_session import BaseSandboxSession +from ....sandbox.session.dependencies import Dependencies +from ....sandbox.session.manager import Instrumentation +from ....sandbox.session.sandbox_client import BaseSandboxClient +from ....sandbox.snapshot import SnapshotSpec, resolve_snapshot +from ....sandbox.types import ExecResult +from ....sandbox.util.retry import ( + TRANSIENT_HTTP_STATUS_CODES, + exception_chain_contains_type, + exception_chain_has_status_code, + retry_async, +) +from ....sandbox.util.tar_utils import UnsafeTarMemberError, should_skip_tar_member + +_DEFAULT_TIMEOUT_S = 30.0 +_DEFAULT_IMAGE_TAG = "python:3.11-slim" +_DEFAULT_SNAPSHOT_FILESYSTEM_TIMEOUT_S = 60.0 +_MODAL_STDIN_CHUNK_SIZE = 8 * 1024 * 1024 + +WorkspacePersistenceMode = Literal["tar", "snapshot_filesystem"] + +_WORKSPACE_PERSISTENCE_TAR: WorkspacePersistenceMode = "tar" +_WORKSPACE_PERSISTENCE_SNAPSHOT_FILESYSTEM: WorkspacePersistenceMode = "snapshot_filesystem" + +# Magic prefix for snapshot_filesystem payloads that cannot be represented as tar bytes. +_UC_MODAL_SNAPSHOT_FS_MAGIC = b"UC_MODAL_SNAPSHOT_FS_V1\n" + +logger = logging.getLogger(__name__) +R = TypeVar("R") + + +def _write_process_stdin(proc: ContainerProcess[bytes], data: bytes | bytearray) -> None: + """ + Stream stdin to Modal in bounded chunks so command-router backed writers do not overflow. + """ + + view = memoryview(data) + for start in range(0, len(view), _MODAL_STDIN_CHUNK_SIZE): + proc.stdin.write(view[start : start + _MODAL_STDIN_CHUNK_SIZE]) + proc.stdin.drain() + proc.stdin.write_eof() + proc.stdin.drain() + + +@dataclass(frozen=True) +class ModalSandboxClientOptions: + app_name: str + sandbox_create_timeout_s: float | None = None + workspace_persistence: WorkspacePersistenceMode = _WORKSPACE_PERSISTENCE_TAR + snapshot_filesystem_timeout_s: float | None = None + snapshot_filesystem_restore_timeout_s: float | None = None + + +def _encode_snapshot_filesystem_ref(*, snapshot_id: str) -> bytes: + # Small JSON envelope so we can round-trip a non-tar snapshot reference + # through Snapshot.persist(). + body = json.dumps({"snapshot_id": snapshot_id}, separators=(",", ":"), sort_keys=True).encode( + "utf-8" + ) + return _UC_MODAL_SNAPSHOT_FS_MAGIC + body + + +def _decode_snapshot_filesystem_ref(raw: bytes) -> str | None: + if not raw.startswith(_UC_MODAL_SNAPSHOT_FS_MAGIC): + return None + body = raw[len(_UC_MODAL_SNAPSHOT_FS_MAGIC) :] + try: + obj = json.loads(body.decode("utf-8")) + except Exception: + return None + snapshot_id = obj.get("snapshot_id") + return snapshot_id if isinstance(snapshot_id, str) and snapshot_id else None + + +@dataclass(frozen=True) +class ModalImageSelector: + """ + A single "image selector" type to avoid juggling image/image_id/image_tag separately. + """ + + kind: Literal["image", "id", "tag"] + value: modal.Image | str + + @classmethod + def from_image(cls, image: modal.Image) -> ModalImageSelector: + return cls(kind="image", value=image) + + @classmethod + def from_id(cls, image_id: str) -> ModalImageSelector: + return cls(kind="id", value=image_id) + + @classmethod + def from_tag(cls, image_tag: str) -> ModalImageSelector: + return cls(kind="tag", value=image_tag) + + +@dataclass(frozen=True) +class ModalSandboxSelector: + """ + A single "sandbox selector" type to avoid juggling sandbox/sandbox_id separately. + """ + + kind: Literal["sandbox", "id"] + value: modal.Sandbox | str + + @classmethod + def from_sandbox(cls, sandbox: modal.Sandbox) -> ModalSandboxSelector: + return cls(kind="sandbox", value=sandbox) + + @classmethod + def from_id(cls, sandbox_id: str) -> ModalSandboxSelector: + return cls(kind="id", value=sandbox_id) + + +class ModalSandboxSessionState(SandboxSessionState): + """ + Serializable state for a Modal-backed session. + + We store only values that can be safely persisted and later used by `resume()`. + """ + + app_name: str + # Optional Modal image object id (enables reconstructing a custom image via Image.from_id()). + image_id: str | None = None + # Registry image tag (e.g. "debian:bookworm" or "ghcr.io/org/img:tag"). + # Used when `image_id` isn't available and no in-memory image override was provided. + image_tag: str | None = None + # Timeout for creating a sandbox (Modal calls are synchronous from the user's perspective + # and can block; we wrap them in a thread with asyncio timeout). + sandbox_create_timeout_s: float = _DEFAULT_TIMEOUT_S + sandbox_id: str | None = None + # Workspace persistence mode: + # - "tar": create a tar stream in the sandbox via `tar cf - ...` and pull bytes back via stdout. + # - "snapshot_filesystem": use Modal's `Sandbox.snapshot_filesystem()` + # (if available) and persist a snapshot reference. + workspace_persistence: WorkspacePersistenceMode = _WORKSPACE_PERSISTENCE_TAR + # Async timeouts for snapshot_filesystem-based persistence and restore. + snapshot_filesystem_timeout_s: float = _DEFAULT_SNAPSHOT_FILESYSTEM_TIMEOUT_S + snapshot_filesystem_restore_timeout_s: float = _DEFAULT_SNAPSHOT_FILESYSTEM_TIMEOUT_S + + +class ModalSandboxSession(BaseSandboxSession): + """ + SandboxSession implementation backed by a Modal Sandbox. + """ + + state: ModalSandboxSessionState + + _sandbox: modal.Sandbox | None + _image: modal.Image | None + _running: bool + + def __init__( + self, + *, + state: ModalSandboxSessionState, + # Optional in-memory handles. These are not guaranteed to be resumable; state holds ids. + image: modal.Image | None = None, + sandbox: modal.Sandbox | None = None, + ) -> None: + self.state = state + self._image = image + self._sandbox = sandbox + if image is not None: + self.state.image_id = getattr(image, "object_id", self.state.image_id) + if sandbox is not None: + self.state.sandbox_id = getattr(sandbox, "object_id", self.state.sandbox_id) + self._running = False + + @classmethod + def from_state( + cls, + state: ModalSandboxSessionState, + *, + image: modal.Image | None = None, + sandbox: modal.Sandbox | None = None, + ) -> ModalSandboxSession: + return cls(state=state, image=image, sandbox=sandbox) + + async def _call_modal( + self, + fn: Callable[..., R], + *args: object, + call_timeout: float | None = None, + **kwargs: object, + ) -> R: + """ + Prefer Modal's async interface (`fn.aio(...)`) when available. + + Falls back to running the blocking call in a thread to preserve compatibility + with SDK surfaces that do not expose `.aio`. + """ + + aio_fn = getattr(fn, "aio", None) + if callable(aio_fn): + coro = cast(Awaitable[R], aio_fn(*args, **kwargs)) + else: + loop = asyncio.get_running_loop() + bound = functools.partial(fn, *args, **kwargs) + coro = loop.run_in_executor(None, bound) + if call_timeout is None: + return await coro + return await asyncio.wait_for(coro, timeout=call_timeout) + + async def start(self) -> None: + try: + # Ensure workspace root exists before SandboxSession.start() needs it. + await self.exec("mkdir", "-p", "--", str(Path(self.state.manifest.root)), shell=False) + except Exception as e: + raise WorkspaceStartError(path=Path(self.state.manifest.root), cause=e) from e + + self._running = True + await super().start() + + async def stop(self) -> None: + try: + await super().stop() + except Exception as e: + raise WorkspaceStopError(path=Path(self.state.manifest.root), cause=e) from e + + async def shutdown(self) -> None: + terminated = False + try: + sandbox = self._sandbox + if sandbox is not None: + await self._call_modal(sandbox.terminate, call_timeout=5.0) + terminated = True + elif self.state.sandbox_id: + sid = self.state.sandbox_id + assert sid is not None + sb = await self._call_modal(modal.Sandbox.from_id, sid, call_timeout=10.0) + await self._call_modal(sb.terminate, call_timeout=5.0) + terminated = True + except Exception: + pass + finally: + if terminated: + self.state.sandbox_id = None + self._sandbox = None + self._running = False + + async def _ensure_sandbox(self) -> None: + if self._sandbox is not None: + return + + # If resuming, try to rehydrate the sandbox handle from the persisted id. + sid = self.state.sandbox_id + if sid: + try: + sb = await self._call_modal(modal.Sandbox.from_id, sid, call_timeout=10.0) + + # `poll()` returns an exit code when the sandbox is terminated, else None. + poll_result = await self._call_modal(sb.poll, call_timeout=5.0) + is_running = poll_result is None + if is_running: + self._sandbox = sb + return + except Exception: + pass + + # Resumed sandbox handle is dead or invalid; clear and create a fresh one. + self._sandbox = None + self.state.sandbox_id = None + + app = await self._call_modal( + modal.App.lookup, + self.state.app_name, + create_if_missing=True, + call_timeout=10.0, + ) + if not self._image: + image_id = self.state.image_id + if image_id: + self._image = await self._call_modal( + modal.Image.from_id, image_id, call_timeout=30.0 + ) + else: + tag = self.state.image_tag + if not isinstance(tag, str) or not tag: + tag = _DEFAULT_IMAGE_TAG + # Record the default for better debuggability/resume. + self.state.image_tag = tag + self._image = await self._call_modal( + modal.Image.from_registry, tag, call_timeout=30.0 + ) + + manifest_envs = cast(dict[str, str | None], await self.state.manifest.environment.resolve()) + self._sandbox = await self._call_modal( + modal.Sandbox.create, + app=app, + image=self._image, + workdir=self.state.manifest.root, + env=manifest_envs, + call_timeout=self.state.sandbox_create_timeout_s, + ) + + # Persist sandbox id for future resume. + assert self._sandbox is not None + self.state.sandbox_id = self._sandbox.object_id + + assert self._image is not None + self.state.image_id = self._image.object_id + + async def _exec_internal( + self, *command: str | Path, timeout: float | None = None + ) -> ExecResult: + await self._ensure_sandbox() + assert self._sandbox is not None + + modal_timeout: int | None = None + if timeout is not None: + # Modal's Sandbox.exec timeout is integer seconds; use ceil so the command + # is guaranteed to be terminated server-side at or before our timeout window + # (modulo 1s granularity). + modal_timeout = int(max(_DEFAULT_TIMEOUT_S, math.ceil(timeout))) + + def _run() -> ExecResult: + assert self._sandbox is not None + try: + argv: tuple[str, ...] = tuple(str(part) for part in command) + proc: ContainerProcess[bytes] = self._sandbox.exec( + *argv, + text=False, + timeout=modal_timeout, + ) + # Drain full output; Modal buffers process output server-side. + stdout = proc.stdout.read() + stderr = proc.stderr.read() + exit_code = proc.wait() + return ExecResult( + stdout=stdout or b"", stderr=stderr or b"", exit_code=exit_code or 0 + ) + except Exception as e: + raise e + + try: + return cast(ExecResult, await self._call_modal(_run, call_timeout=timeout)) + except asyncio.TimeoutError as e: + # The worker thread continues running; prevent background mutations by terminating + # the sandbox and clearing our handle. + sandbox = self._sandbox + if sandbox is not None: + try: + await self._call_modal(sandbox.terminate, call_timeout=5.0) + except Exception: + pass + self._sandbox = None + self.state.sandbox_id = None + self._running = False + raise ExecTimeoutError(command=command, timeout_s=timeout, cause=e) from e + except ExecTimeoutError: + raise + except Exception as e: + raise ExecTransportError(command=command, cause=e) from e + + async def read(self, path: Path) -> io.IOBase: + # Read by `cat` so the payload is returned as bytes. + workspace_path = resolve_workspace_path( + Path(self.state.manifest.root), + path, + allow_absolute_within_root=True, + ) + cmd = ["sh", "-lc", f"cat -- {shlex.quote(str(workspace_path))}"] + try: + out = await self.exec(*cmd, shell=False) + except ExecTimeoutError as e: + raise WorkspaceArchiveReadError(path=workspace_path, cause=e) from e + except ExecTransportError as e: + raise WorkspaceArchiveReadError(path=workspace_path, cause=e) from e + + if not out.ok(): + raise WorkspaceReadNotFoundError( + path=path, context={"stderr": out.stderr.decode("utf-8", "replace")} + ) + + return io.BytesIO(out.stdout) + + async def write(self, path: Path, data: io.IOBase) -> None: + payload = data.read() + if isinstance(payload, str): + payload = payload.encode("utf-8") + if not isinstance(payload, (bytes, bytearray)): + raise WorkspaceWriteTypeError(path=path, actual_type=type(payload).__name__) + + await self._ensure_sandbox() + assert self._sandbox is not None + + workspace_path = resolve_workspace_path( + Path(self.state.manifest.root), + path, + allow_absolute_within_root=True, + ) + + def _run() -> None: + assert self._sandbox is not None + # Ensure parent directory exists. + parent = str(workspace_path.parent) + self._sandbox.exec("mkdir", "-p", "--", parent, text=False).wait() + + # Stream bytes into `cat > file` to avoid quoting/binary issues. + cmd = ["sh", "-lc", f"cat > {shlex.quote(str(workspace_path))}"] + proc = self._sandbox.exec(*cmd, text=False) + _write_process_stdin(proc, payload) + exit_code = proc.wait() + if exit_code != 0: + stderr: bytes = proc.stderr.read() + raise WorkspaceArchiveWriteError( + path=workspace_path, + context={ + "reason": "write_nonzero_exit", + "exit_code": exit_code, + "stderr": stderr.decode("utf-8", "replace"), + }, + ) + + try: + await self._call_modal(_run, call_timeout=30.0) + except WorkspaceArchiveWriteError: + raise + except Exception as e: + raise WorkspaceArchiveWriteError(path=workspace_path, cause=e) from e + + async def running(self) -> bool: + if not self._running or self._sandbox is None: + return False + + try: + assert self._sandbox is not None + poll_result = await self._call_modal(self._sandbox.poll, call_timeout=5.0) + return poll_result is None + except Exception: + return False + + async def persist_workspace(self) -> io.IOBase: + if self.state.workspace_persistence == _WORKSPACE_PERSISTENCE_SNAPSHOT_FILESYSTEM: + return await self._persist_workspace_via_snapshot_filesystem() + return await self._persist_workspace_via_tar() + + async def hydrate_workspace(self, data: io.IOBase) -> None: + if self.state.workspace_persistence == _WORKSPACE_PERSISTENCE_SNAPSHOT_FILESYSTEM: + return await self._hydrate_workspace_via_snapshot_filesystem(data) + return await self._hydrate_workspace_via_tar(data) + + async def _persist_workspace_via_snapshot_filesystem(self) -> io.IOBase: + """ + Persist the workspace using Modal's snapshot_filesystem API when available. + + Modal's snapshot_filesystem is expected to return a snapshot reference + (typically a Modal object such as an Image/Snapshot handle, or an id + string). We serialize a small reference envelope that + `_hydrate_workspace_via_snapshot_filesystem` can interpret. + """ + + root = Path(self.state.manifest.root) + await self._ensure_sandbox() + assert self._sandbox is not None + + sandbox = self._sandbox + if not hasattr(sandbox, "snapshot_filesystem"): + # Feature not present in this Modal SDK version; fall back to tar implementation. + return await self._persist_workspace_via_tar() + + skip = self._persist_workspace_skip_relpaths() + + # Modal's snapshot_filesystem does not support excluding paths. To + # preserve the semantics of "ephemeral manifest entries are not + # persisted", we temporarily remove those paths, snapshot, then + # restore them back into the running session. + skip_abs = [root / rel for rel in sorted(skip, key=lambda p: p.as_posix())] + ephemeral_backup: bytes | None = None + + async def _restore_ephemeral_paths() -> WorkspaceArchiveReadError | None: + if not ephemeral_backup: + return None + + backup = bytes(ephemeral_backup) + + def _restore_ephemeral() -> None: + assert self._sandbox is not None + proc = self._sandbox.exec("tar", "xf", "-", "-C", str(root), text=False) + _write_process_stdin(proc, backup) + exit_code = proc.wait() + if exit_code != 0: + stderr: bytes = proc.stderr.read() + raise WorkspaceArchiveReadError( + path=root, + context={ + "reason": "snapshot_filesystem_ephemeral_restore_failed", + "exit_code": exit_code, + "stderr": stderr.decode("utf-8", "replace"), + }, + ) + + try: + await self._call_modal( + _restore_ephemeral, + call_timeout=self.state.snapshot_filesystem_restore_timeout_s, + ) + except WorkspaceArchiveReadError as exc: + return exc + except Exception as exc: + return WorkspaceArchiveReadError( + path=root, + context={"reason": "snapshot_filesystem_ephemeral_restore_failed"}, + cause=exc, + ) + return None + + if skip_abs: + # Best-effort: tar up the ephemeral paths (if they exist). We run + # via shell so missing paths do not cause a hard failure + # (`|| true`). + rel_args = " ".join(shlex.quote(p.relative_to(root).as_posix()) for p in skip_abs) + cmd = f"cd -- {shlex.quote(str(root))} && (tar cf - -- {rel_args} 2>/dev/null || true)" + out = await self.exec("sh", "-lc", cmd, shell=False) + ephemeral_backup = out.stdout or b"" + + # Remove ephemeral paths before snapshot so they are not captured. + rm_cmd = ["rm", "-rf", "--", *[str(p) for p in skip_abs]] + rm_out = await self.exec(*rm_cmd, shell=False) + if not rm_out.ok(): + cleanup_restore_error = await _restore_ephemeral_paths() + if cleanup_restore_error is not None: + logger.warning( + "Failed to restore Modal ephemeral paths after cleanup failure: %s", + cleanup_restore_error, + ) + raise WorkspaceArchiveReadError( + path=root, + context={ + "reason": "snapshot_filesystem_ephemeral_remove_failed", + "exit_code": rm_out.exit_code, + "stderr": rm_out.stderr.decode("utf-8", "replace"), + }, + ) + + restore_error: WorkspaceArchiveReadError | None = None + + try: + snap = await self._call_modal( + sandbox.snapshot_filesystem, + call_timeout=self.state.snapshot_filesystem_timeout_s, + ) + except Exception as e: + restore_error = await _restore_ephemeral_paths() + if restore_error is not None: + logger.warning( + "Failed to restore Modal ephemeral paths after snapshot failure: %s", + restore_error, + ) + raise WorkspaceArchiveReadError( + path=root, context={"reason": "snapshot_filesystem_failed"}, cause=e + ) from e + + if isinstance(snap, (bytes, bytearray)): + # should never happen, just a safe guardrail + raise WorkspaceArchiveReadError( + path=root, + context={ + "reason": "snapshot_filesystem_unexpected_bytes", + "type": type(snap).__name__, + }, + ) + + # Snapshot is expected to be a Modal Image (or a compatible handle with an object_id). + if not hasattr(snap, "object_id") and not isinstance(snap, str): + raise WorkspaceArchiveReadError( + path=root, + context={ + "reason": "snapshot_filesystem_unexpected_return", + "type": type(snap).__name__, + }, + ) + + restore_error = await _restore_ephemeral_paths() + if restore_error is not None: + raise restore_error + + snapshot_id: str | None = None + if isinstance(snap, str): + snapshot_id = snap + else: + snapshot_id = getattr(snap, "object_id", None) or getattr(snap, "id", None) + if snapshot_id is not None and not isinstance(snapshot_id, str): + snapshot_id = None + + if not snapshot_id: + raise WorkspaceArchiveReadError( + path=root, + context={ + "reason": "snapshot_filesystem_unexpected_return", + "type": type(snap).__name__, + }, + ) + + return io.BytesIO(_encode_snapshot_filesystem_ref(snapshot_id=snapshot_id)) + + @retry_async( + retry_if=lambda exc, self: exception_chain_contains_type(exc, (ExecTransportError,)) + or exception_chain_has_status_code(exc, TRANSIENT_HTTP_STATUS_CODES) + ) + async def _persist_workspace_via_tar(self) -> io.IOBase: + # Existing tar implementation extracted so snapshot_filesystem mode can fall back cleanly. + root = Path(self.state.manifest.root) + skip = self._persist_workspace_skip_relpaths() + + excludes: list[str] = [] + for rel in sorted(skip, key=lambda p: p.as_posix()): + excludes.extend(["--exclude", f"./{rel.as_posix().lstrip('./')}"]) + + cmd: list[str] = [ + "tar", + "cf", + "-", + *excludes, + "-C", + str(root), + ".", + ] + + try: + out = await self.exec(*cmd, shell=False) + if not out.ok(): + raise WorkspaceArchiveReadError( + path=root, + context={ + "reason": "tar_nonzero_exit", + "exit_code": out.exit_code, + "stderr": out.stderr.decode("utf-8", "replace"), + }, + ) + return io.BytesIO(out.stdout) + except WorkspaceArchiveReadError: + raise + except Exception as e: + raise WorkspaceArchiveReadError(path=root, cause=e) from e + + async def _hydrate_workspace_via_snapshot_filesystem(self, data: io.IOBase) -> None: + """ + Hydrate using Modal's snapshot_filesystem restore API when the + persisted payload is a snapshot ref. Otherwise, fall back to tar + extraction (to support SDKs that return tar bytes). + """ + + root = Path(self.state.manifest.root) + raw = data.read() + if isinstance(raw, str): + raw = raw.encode("utf-8") + if not isinstance(raw, (bytes, bytearray)): + raise WorkspaceArchiveWriteError(path=root, context={"reason": "non_bytes_payload"}) + + snapshot_id = _decode_snapshot_filesystem_ref(bytes(raw)) + if snapshot_id is None: + # Not an envelope; treat as tar payload. + return await self._hydrate_workspace_via_tar(io.BytesIO(bytes(raw))) + if not snapshot_id: + raise WorkspaceArchiveWriteError( + path=root, context={"reason": "snapshot_filesystem_invalid_snapshot_id"} + ) + + # Best-effort: if a sandbox already exists, terminate it to avoid leaking resources. + # We want the restored snapshot image to define the new sandbox filesystem. + prior = self._sandbox + if prior is not None: + try: + await self._call_modal(prior.terminate, call_timeout=5.0) + except Exception: + pass + finally: + self._sandbox = None + self.state.sandbox_id = None + + manifest_envs = cast(dict[str, str | None], await self.state.manifest.environment.resolve()) + + def _run_restore() -> None: + # Rehydrate an image from the snapshot id. + image = modal.Image.from_id(snapshot_id) + + # Prefer the existing app-based sandbox creation signature to match `_ensure_sandbox`. + app = modal.App.lookup(self.state.app_name, create_if_missing=True) + + try: + sb = modal.Sandbox.create( + app=app, + image=image, + workdir=self.state.manifest.root, + env=manifest_envs, + ) + except TypeError: + # Older/newer SDKs may not accept app/workdir; fall back to simpler signatures. + try: + sb = modal.Sandbox.create( + name=self.state.app_name, + image=image, + env=manifest_envs, + ) + except TypeError: + sb = modal.Sandbox.create( + image=image, + env=manifest_envs, + ) + + # Ensure workspace root exists even if the image does not contain it. + try: + sb.exec("mkdir", "-p", "--", str(root), text=False).wait() + except Exception: + pass + + # Update in-memory handles and persisted ids. + self._image = image + self.state.image_id = getattr(image, "object_id", None) or snapshot_id + self._sandbox = sb + self.state.sandbox_id = getattr(sb, "object_id", None) + + try: + await self._call_modal( + _run_restore, call_timeout=self.state.snapshot_filesystem_restore_timeout_s + ) + except Exception as e: + raise WorkspaceArchiveWriteError( + path=root, + context={ + "reason": "snapshot_filesystem_restore_failed", + "snapshot_id": snapshot_id, + }, + cause=e, + ) from e + + async def _hydrate_workspace_via_tar(self, data: io.IOBase) -> None: + root = Path(self.state.manifest.root) + + raw = data.read() + if isinstance(raw, str): + raw = raw.encode("utf-8") + if not isinstance(raw, (bytes, bytearray)): + raise WorkspaceArchiveWriteError(path=root, context={"reason": "non_bytes_tar_payload"}) + + try: + with tarfile.open(fileobj=io.BytesIO(bytes(raw)), mode="r:*") as tar: + for member in tar.getmembers(): + name = member.name + if name in ("", ".", "./"): + continue + if should_skip_tar_member( + name, + skip_rel_paths=self.state.manifest.ephemeral_persistence_paths(), + root_name=None, + ): + continue + # Mirror tar_utils safety checks (no extraction here). + if Path(name).is_absolute(): + raise UnsafeTarMemberError(member=name, reason="absolute path") + if ".." in Path(name).parts: + raise UnsafeTarMemberError(member=name, reason="parent traversal") + if member.issym() or member.islnk(): + raise UnsafeTarMemberError(member=name, reason="link member not allowed") + if not (member.isdir() or member.isreg()): + raise UnsafeTarMemberError(member=name, reason="unsupported member type") + except UnsafeTarMemberError as e: + raise WorkspaceArchiveWriteError( + path=root, context={"reason": e.reason, "member": e.member}, cause=e + ) from e + except (tarfile.TarError, OSError) as e: + raise WorkspaceArchiveWriteError(path=root, cause=e) from e + + await self._ensure_sandbox() + assert self._sandbox is not None + + def _run() -> None: + assert self._sandbox is not None + self._sandbox.exec("mkdir", "-p", "--", str(root), text=False).wait() + proc = self._sandbox.exec("tar", "xf", "-", "-C", str(root), text=False) + _write_process_stdin(proc, raw) + exit_code = proc.wait() + if exit_code != 0: + stderr: bytes = proc.stderr.read() + raise WorkspaceArchiveWriteError( + path=root, + context={ + "reason": "tar_extract_nonzero_exit", + "exit_code": exit_code, + "stderr": stderr.decode("utf-8", "replace"), + }, + ) + + try: + await self._call_modal(_run, call_timeout=60.0) + except WorkspaceArchiveWriteError: + raise + except Exception as e: + raise WorkspaceArchiveWriteError(path=root, cause=e) from e + + +class ModalSandboxClient(BaseSandboxClient[ModalSandboxClientOptions]): + backend_id = "modal" + _default_image: ModalImageSelector | None + _default_sandbox: ModalSandboxSelector | None + _instrumentation: Instrumentation + + def __init__( + self, + *, + image: ModalImageSelector | None = None, + sandbox: ModalSandboxSelector | None = None, + instrumentation: Instrumentation | None = None, + dependencies: Dependencies | None = None, + ) -> None: + self._default_image = image + self._default_sandbox = sandbox + self._instrumentation = instrumentation or Instrumentation() + self._dependencies = dependencies + + async def create( + self, + *, + snapshot: SnapshotSpec | None = None, + manifest: Manifest | None = None, + codex: bool | CodexConfig = False, + options: ModalSandboxClientOptions, + ) -> SandboxSession: + """ + Create a new Modal-backed session. + + Expected options: + - app_name: str (required) + - sandbox_create_timeout_s: float | None (async timeout for sandbox creation call) + - workspace_persistence: Literal["tar", "snapshot_filesystem"] (optional) + - snapshot_filesystem_timeout_s: float | None + (async timeout for snapshot_filesystem call) + - snapshot_filesystem_restore_timeout_s: float | None + (async timeout for snapshot restore call) + """ + + if options is None: + raise ValueError("ModalSandboxClient.create requires options with app_name") + app_name = options.app_name + if not app_name: + raise ValueError("ModalSandboxClient.create requires a valid app_name") + + image_sel = self._default_image + + sandbox_sel = self._default_sandbox + + sandbox_create_timeout_s = options.sandbox_create_timeout_s + if sandbox_create_timeout_s is not None and not isinstance( + sandbox_create_timeout_s, (int, float) + ): + raise ValueError( + "ModalSandboxClient.create requires sandbox_create_timeout_s to be a number" + ) + + workspace_persistence = options.workspace_persistence + if workspace_persistence not in ( + _WORKSPACE_PERSISTENCE_TAR, + _WORKSPACE_PERSISTENCE_SNAPSHOT_FILESYSTEM, + ): + raise ValueError( + "ModalSandboxClient.create requires workspace_persistence to be one of " + f"{_WORKSPACE_PERSISTENCE_TAR!r} or {_WORKSPACE_PERSISTENCE_SNAPSHOT_FILESYSTEM!r}" + ) + + snapshot_filesystem_timeout_s = options.snapshot_filesystem_timeout_s + if snapshot_filesystem_timeout_s is not None and not isinstance( + snapshot_filesystem_timeout_s, (int, float) + ): + raise ValueError( + "ModalSandboxClient.create requires snapshot_filesystem_timeout_s to be a number" + ) + + snapshot_filesystem_restore_timeout_s = options.snapshot_filesystem_restore_timeout_s + if snapshot_filesystem_restore_timeout_s is not None and not isinstance( + snapshot_filesystem_restore_timeout_s, (int, float) + ): + raise ValueError( + "ModalSandboxClient.create requires " + "snapshot_filesystem_restore_timeout_s to be a number" + ) + + manifest = apply_codex_to_manifest(manifest, codex) + + session_id = uuid.uuid4() + state_image_id: str | None = None + state_image_tag: str | None = None + session_image: modal.Image | None = None + if image_sel is not None: + if image_sel.kind == "image": + if not isinstance(image_sel.value, modal.Image): + raise ValueError( + "ModalSandboxClient.__init__ requires image to be a modal.Image" + ) + session_image = image_sel.value + state_image_id = getattr(session_image, "object_id", None) + elif image_sel.kind == "id": + if not isinstance(image_sel.value, str) or not image_sel.value: + raise ValueError( + "ModalSandboxClient.__init__ requires image_id to be a non-empty string" + ) + state_image_id = image_sel.value + else: + if not isinstance(image_sel.value, str) or not image_sel.value: + raise ValueError( + "ModalSandboxClient.__init__ requires image_tag to be a non-empty string" + ) + state_image_tag = image_sel.value + + state_sandbox_id: str | None = None + session_sandbox: modal.Sandbox | None = None + if sandbox_sel is not None: + if sandbox_sel.kind == "sandbox": + if not isinstance(sandbox_sel.value, modal.Sandbox): + raise ValueError( + "ModalSandboxClient.__init__ requires sandbox to be a modal.Sandbox" + ) + session_sandbox = sandbox_sel.value + state_sandbox_id = getattr(session_sandbox, "object_id", None) + else: + if not isinstance(sandbox_sel.value, str) or not sandbox_sel.value: + raise ValueError( + "ModalSandboxClient.__init__ requires sandbox_id to be a non-empty string" + ) + state_sandbox_id = sandbox_sel.value + + snapshot_id = str(session_id) + snapshot_instance = resolve_snapshot(snapshot, snapshot_id) + state = ModalSandboxSessionState( + session_id=session_id, + manifest=manifest, + snapshot=snapshot_instance, + app_name=app_name, + image_tag=state_image_tag, + image_id=state_image_id, + sandbox_id=state_sandbox_id, + workspace_persistence=workspace_persistence, + ) + if sandbox_create_timeout_s is not None: + state.sandbox_create_timeout_s = float(sandbox_create_timeout_s) + if snapshot_filesystem_timeout_s is not None: + state.snapshot_filesystem_timeout_s = float(snapshot_filesystem_timeout_s) + if snapshot_filesystem_restore_timeout_s is not None: + state.snapshot_filesystem_restore_timeout_s = float( + snapshot_filesystem_restore_timeout_s + ) + + # Pass the in-memory handles through to the session (they may not be resumable). + inner = ModalSandboxSession.from_state( + state, + image=session_image, + sandbox=session_sandbox, + ) + return self._wrap_session(inner, instrumentation=self._instrumentation) + + async def delete(self, session: SandboxSession) -> SandboxSession: + """ + Best-effort cleanup of Modal sandbox resources. + """ + + inner = session._inner + if not isinstance(inner, ModalSandboxSession): + raise TypeError("ModalSandboxClient.delete expects a ModalSandboxSession") + + # Prefer the live handle if present. + sandbox = getattr(inner, "_sandbox", None) + try: + if sandbox is not None: + await asyncio.get_running_loop().run_in_executor(None, sandbox.terminate) + return session + except Exception: + return session + + # Otherwise, best-effort terminate via sandbox_id. + sid = inner.state.sandbox_id + if sid: + try: + sb = await asyncio.get_running_loop().run_in_executor( + None, lambda: modal.Sandbox.from_id(sid) + ) + await asyncio.get_running_loop().run_in_executor(None, sb.terminate) + except Exception: + pass + + return session + + async def resume( + self, + state: SandboxSessionState, + *, + codex: bool | CodexConfig = False, + ) -> SandboxSession: + if not isinstance(state, ModalSandboxSessionState): + raise TypeError("ModalSandboxClient.resume expects a ModalSandboxSessionState") + inner = ModalSandboxSession.from_state(apply_codex_to_session_state(state, codex)) + return self._wrap_session(inner, instrumentation=self._instrumentation) + + def deserialize_session_state(self, payload: dict[str, object]) -> SandboxSessionState: + return ModalSandboxSessionState.model_validate(payload) diff --git a/src/agents/result.py b/src/agents/result.py index 774c90dc4e..b240f9c95c 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -46,7 +46,9 @@ ) if TYPE_CHECKING: - pass + from collections.abc import Awaitable, Callable + + from .sandbox.session.base_sandbox_session import BaseSandboxSession T = TypeVar("T") @@ -78,6 +80,7 @@ def _populate_state_from_result( auto_previous_response_id: bool = False, ) -> RunState[Any]: """Populate a RunState with common fields from a RunResult.""" + state._current_agent = result.last_agent model_input_items = getattr(result, "_model_input_items", None) if isinstance(model_input_items, list): state._generated_items = list(model_input_items) @@ -106,6 +109,11 @@ def _populate_state_from_result( if trace_state is None: trace_state = TraceState.from_trace(getattr(result, "trace", None)) state._trace_state = copy.deepcopy(trace_state) if trace_state else None + sandbox_resume_state = getattr(result, "_sandbox_resume_state", None) + if isinstance(sandbox_resume_state, dict): + state._sandbox = copy.deepcopy(sandbox_resume_state) + else: + state._sandbox = None return state @@ -144,6 +152,20 @@ def _input_items_for_result( return run_items_to_input_items(model_input_items, reasoning_item_id_policy) +def _starting_agent_for_state(result: RunResultBase) -> Agent[Any]: + """Return the root agent graph that should seed RunState identity resolution.""" + state = getattr(result, "_state", None) + starting_agent = getattr(state, "_starting_agent", None) + if isinstance(starting_agent, Agent): + return starting_agent + + stored_starting_agent = getattr(result, "_starting_agent_for_state", None) + if isinstance(stored_starting_agent, Agent): + return stored_starting_agent + + return result.last_agent + + @dataclass class RunResultBase(abc.ABC): input: str | list[TResponseInputItem] @@ -185,6 +207,12 @@ class RunResultBase(abc.ABC): This is only set when the runner preserved extra session history items that should not be replayed into the next local run, such as nested handoff history or filtered handoff input. """ + _sandbox_resume_state: dict[str, object] | None = field(default=None, init=False, repr=False) + """Serialized sandbox session state captured during the run.""" + _sandbox_session: BaseSandboxSession | None = field(default=None, init=False, repr=False) + """Live sandbox session attached to this run result when sandbox execution is enabled.""" + _starting_agent_for_state: Agent[Any] | None = field(default=None, init=False, repr=False) + """Root agent graph used when converting the result back into RunState.""" @classmethod def __get_pydantic_core_schema__( @@ -385,7 +413,7 @@ def to_state(self) -> RunState[Any]: original_input=original_input_for_state if original_input_for_state is not None else self.input, - starting_agent=self.last_agent, + starting_agent=_starting_agent_for_state(self), max_turns=self.max_turns, ) @@ -493,6 +521,13 @@ class RunResultStreaming(RunResultBase): ) """How reasoning IDs should be represented when converting to input history.""" _run_impl_task: InitVar[asyncio.Task[Any] | None] = None + _sandbox_cleanup: Callable[[], Awaitable[None]] | None = field( + default=None, + init=False, + repr=False, + ) + _sandbox_cleanup_task: asyncio.Task[None] | None = field(default=None, init=False, repr=False) + _sandbox_cleanup_callback_registered: bool = field(default=False, init=False, repr=False) def __post_init__(self, _run_impl_task: asyncio.Task[Any] | None) -> None: self._current_agent_ref = weakref.ref(self.current_agent) @@ -525,6 +560,57 @@ def _release_last_agent_reference(self) -> None: # Preserve dataclass field so repr/asdict continue to succeed. self.__dict__["current_agent"] = None + async def _run_sandbox_cleanup(self) -> None: + sandbox_cleanup = self._sandbox_cleanup + if sandbox_cleanup is None: + return + + task = self._sandbox_cleanup_task + if task is None: + + async def _cleanup_once() -> None: + try: + await sandbox_cleanup() + except Exception as error: + logger.warning( + "Failed to clean up sandbox resources after streamed run: %s", error + ) + + task = asyncio.create_task(_cleanup_once()) + self._sandbox_cleanup_task = task + + await task + + def ensure_sandbox_cleanup_on_completion(self) -> None: + if ( + self._sandbox_cleanup is None + or self.run_loop_task is None + or self._sandbox_cleanup_callback_registered + ): + return + + original_task = self.run_loop_task + self._sandbox_cleanup_callback_registered = True + original_task.add_done_callback( + lambda _task: asyncio.create_task(self._run_sandbox_cleanup()) + ) + + async def _await_run_and_cleanup() -> Any: + try: + result = await original_task + except asyncio.CancelledError: + if not original_task.done(): + original_task.cancel() + raise + except Exception: + await self._run_sandbox_cleanup() + raise + + await self._run_sandbox_cleanup() + return result + + self.run_loop_task = asyncio.create_task(_await_run_and_cleanup()) + def cancel(self, mode: Literal["immediate", "after_turn"] = "immediate") -> None: """Cancel the streaming run. @@ -622,24 +708,28 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]: yield item self._event_queue.task_done() finally: - if cancelled: - # Cancellation should return promptly, so avoid waiting on long-running tasks. - # Tasks have already been cancelled above. - self._cleanup_tasks() - else: - # Ensure main execution completes before cleanup to avoid race conditions - # with session operations - await self._await_task_safely(self.run_loop_task) - # Safely terminate all background tasks after main execution has finished - self._cleanup_tasks() - - # Allow any pending callbacks (e.g., cancellation handlers) to enqueue their - # completion sentinels before we clear the queues for observability. - await asyncio.sleep(0) - - # Drain queues so callers observing internal state see them empty after completion. - self._drain_event_queue() - self._drain_input_guardrail_queue() + try: + if cancelled: + # Cancellation should return promptly, so avoid waiting on long-running tasks. + # Tasks have already been cancelled above. + self._cleanup_tasks() + else: + # Ensure main execution completes before cleanup to avoid race conditions + # with session operations. + await self._await_task_safely(self.run_loop_task) + # Safely terminate all background tasks after main execution has finished. + self._cleanup_tasks() + + if not cancelled: + await self._run_sandbox_cleanup() + finally: + # Allow any pending callbacks (e.g., cancellation handlers) to enqueue their + # completion sentinels before we clear the queues for observability. + await asyncio.sleep(0) + + # Drain queues so callers observing internal state see them empty after completion. + self._drain_event_queue() + self._drain_input_guardrail_queue() if self._stored_exception: raise self._stored_exception @@ -781,7 +871,7 @@ def to_state(self) -> RunState[Any]: state = RunState( context=self.context_wrapper, original_input=self._original_input if self._original_input is not None else self.input, - starting_agent=self.last_agent, + starting_agent=_starting_agent_for_state(self), max_turns=self.max_turns, ) diff --git a/src/agents/run.py b/src/agents/run.py index 047d454d35..3015d3b781 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -43,6 +43,7 @@ ) from .run_context import RunContextWrapper, TContext from .run_error_handlers import RunErrorHandlers +from .run_internal.agent_bindings import bind_public_agent from .run_internal.agent_runner_helpers import ( append_model_response_if_new, apply_resumed_conversation_settings, @@ -106,6 +107,7 @@ serialize_tool_use_tracker, ) from .run_state import RunState +from .sandbox.runtime import SandboxRuntime from .tool import dispose_resolved_computers from .tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult from .tracing import Span, SpanError, agent_span, get_current_trace @@ -583,12 +585,45 @@ async def run( run_state._reasoning_item_id_policy = resolved_reasoning_item_id_policy run_state.set_trace(get_current_trace()) + sandbox_runtime = SandboxRuntime( + starting_agent=starting_agent, + run_config=run_config, + run_state=run_state, + ) + + completed_result: RunResult | None = None + def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: result._reasoning_item_id_policy = resolved_reasoning_item_id_policy if run_state is not None: run_state._reasoning_item_id_policy = resolved_reasoning_item_id_policy return result + def _tool_use_tracker_snapshot() -> dict[str, list[str]]: + identity_root_agent = starting_agent + if run_state is not None and run_state._starting_agent is not None: + identity_root_agent = run_state._starting_agent + return serialize_tool_use_tracker( + tool_use_tracker, + starting_agent=identity_root_agent, + ) + + def _finalize_result(result: RunResult) -> RunResult: + nonlocal completed_result + result._starting_agent_for_state = ( + run_state._starting_agent + if run_state is not None and run_state._starting_agent is not None + else starting_agent + ) + finalized_result = finalize_conversation_tracking( + _with_reasoning_item_id_policy(result), + server_conversation_tracker=server_conversation_tracker, + run_state=run_state, + ) + sandbox_runtime.apply_result_metadata(finalized_result) + completed_result = finalized_result + return finalized_result + pending_server_items: list[RunItem] | None = None input_guardrail_results: list[InputGuardrailResult] = ( list(run_state._input_guardrail_results) if run_state is not None else [] @@ -609,6 +644,7 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: current_agent = run_state._current_agent else: current_agent = starting_agent + sandbox_runtime.assert_agent_supported(current_agent) should_run_agent_start_hooks = True store_setting = current_agent.model_settings.resolve(run_config.model_settings).store @@ -618,11 +654,16 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: and original_user_input is not None and session_input_items_for_persistence is None ): + sandbox_runtime.assert_agent_supported(current_agent) session_input_items_for_persistence = ItemHelpers.input_to_new_input_list( original_user_input ) - if session_persistence_enabled and session_input_items_for_persistence: + if ( + session_persistence_enabled + and session_input_items_for_persistence + and not sandbox_runtime.enabled + ): # Capture the exact input saved so it can be rewound on conversation lock retries. last_saved_input_snapshot_for_rewind = list(session_input_items_for_persistence) await save_result_to_session( @@ -637,6 +678,56 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: try: while True: resuming_turn = is_resumed_state + all_input_guardrails = ( + starting_agent.input_guardrails + (run_config.input_guardrails or []) + if current_turn == 0 and not resuming_turn + else [] + ) + sequential_guardrails = [ + g for g in all_input_guardrails if not g.run_in_parallel + ] + parallel_guardrails = [g for g in all_input_guardrails if g.run_in_parallel] + sequential_results: list[InputGuardrailResult] = [] + if sandbox_runtime.enabled and sequential_guardrails: + # Blocking first-turn guardrails must run before sandbox prep so a tripwire + # can prevent session creation, startup, or live-session mutation. + try: + sequential_results = await run_input_guardrails( + starting_agent, + sequential_guardrails, + copy_input_items(original_input), + context_wrapper, + ) + except InputGuardrailTripwireTriggered: + session_input_items_for_persistence = ( + await persist_session_items_for_guardrail_trip( + session, + server_conversation_tracker, + session_input_items_for_persistence, + original_user_input, + run_state, + store=store_setting, + ) + ) + raise + sequential_guardrails = [] + + current_bindings = bind_public_agent(current_agent) + execution_agent = current_bindings.execution_agent + prepared_sandbox = await sandbox_runtime.prepare_agent( + current_agent=current_agent, + current_input=original_input, + context_wrapper=context_wrapper, + is_resumed_state=resuming_turn, + ) + current_bindings = prepared_sandbox.bindings + execution_agent = current_bindings.execution_agent + original_input = copy_input_items(prepared_sandbox.input) + if starting_input is not None and not isinstance(starting_input, RunState): + starting_input = copy_input_items(prepared_sandbox.input) + if run_state is not None: + run_state._original_input = copy_input_items(original_input) + normalized_starting_input: str | list[TResponseInputItem] = ( starting_input if starting_input is not None and not isinstance(starting_input, RunState) @@ -645,6 +736,18 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: store_setting = current_agent.model_settings.resolve( run_config.model_settings ).store + if session_persistence_enabled and session_input_items_for_persistence: + last_saved_input_snapshot_for_rewind = list( + session_input_items_for_persistence + ) + await save_result_to_session( + session, + list(last_saved_input_snapshot_for_rewind), + [], + run_state, + store=store_setting, + ) + session_input_items_for_persistence = [] if run_state is not None and run_state._current_step is not None: if isinstance(run_state._current_step, NextStepInterruption): logger.debug("Continuing from interruption") @@ -655,7 +758,7 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: raise UserError("No model response found in previous state") turn_result = await resolve_interrupted_turn( - agent=current_agent, + bindings=current_bindings, original_input=original_input, original_pre_step_items=generated_items, new_response=run_state._model_responses[-1], @@ -750,11 +853,7 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: run_state=run_state, original_input=original_input, ) - return finalize_conversation_tracking( - _with_reasoning_item_id_policy(result), - server_conversation_tracker=server_conversation_tracker, - run_state=run_state, - ) + return _finalize_result(result) if isinstance(turn_result.next_step, NextStepRunAgain): continue @@ -791,9 +890,7 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: tool_output_guardrail_results=tool_output_guardrail_results, context_wrapper=context_wrapper, interruptions=approvals_from_state, - _tool_use_tracker_snapshot=serialize_tool_use_tracker( - tool_use_tracker - ), + _tool_use_tracker_snapshot=_tool_use_tracker_snapshot(), max_turns=max_turns, ) result._current_turn = current_turn @@ -820,11 +917,7 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: store=store_setting, ) result._original_input = copy_input_items(original_input) - return finalize_conversation_tracking( - _with_reasoning_item_id_policy(result), - server_conversation_tracker=server_conversation_tracker, - run_state=run_state, - ) + return _finalize_result(result) elif isinstance(turn_result.next_step, NextStepHandoff): current_agent = cast( Agent[TContext], turn_result.next_step.new_agent @@ -844,16 +937,17 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: if run_state is not None: if run_state._current_step is None: run_state._current_step = NextStepRunAgain() # type: ignore[assignment] - all_tools = await get_all_tools(current_agent, context_wrapper) + all_tools = await get_all_tools(execution_agent, context_wrapper) await initialize_computer_tools( tools=all_tools, context_wrapper=context_wrapper ) if current_span is None: handoff_names = [ - h.agent_name for h in await get_handoffs(current_agent, context_wrapper) + h.agent_name + for h in await get_handoffs(execution_agent, context_wrapper) ] - if output_schema := get_output_schema(current_agent): + if output_schema := get_output_schema(execution_agent): output_type_name = output_schema.name() else: output_type_name = "str" @@ -932,7 +1026,7 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: tool_output_guardrail_results=tool_output_guardrail_results, context_wrapper=context_wrapper, interruptions=approvals_from_state, - _tool_use_tracker_snapshot=serialize_tool_use_tracker(tool_use_tracker), + _tool_use_tracker_snapshot=_tool_use_tracker_snapshot(), max_turns=max_turns, ) result._current_turn = max_turns @@ -957,11 +1051,7 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: store=store_setting, ) result._original_input = copy_input_items(original_input) - return finalize_conversation_tracking( - _with_reasoning_item_id_policy(result), - server_conversation_tracker=server_conversation_tracker, - run_state=run_state, - ) + return _finalize_result(result) if run_state is not None and not resuming_turn: run_state._current_turn_persisted_item_count = 0 @@ -983,21 +1073,12 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: ) if current_turn <= 1: - all_input_guardrails = starting_agent.input_guardrails + ( - run_config.input_guardrails or [] - ) - sequential_guardrails = [ - g for g in all_input_guardrails if not g.run_in_parallel - ] - parallel_guardrails = [g for g in all_input_guardrails if g.run_in_parallel] - try: - sequential_results = [] if sequential_guardrails: sequential_results = await run_input_guardrails( starting_agent, sequential_guardrails, - copy_input_items(prepared_input), + copy_input_items(original_input), context_wrapper, ) except InputGuardrailTripwireTriggered: @@ -1016,7 +1097,7 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: parallel_results: list[InputGuardrailResult] = [] model_task = asyncio.create_task( run_single_turn( - agent=current_agent, + bindings=current_bindings, all_tools=all_tools, original_input=original_input, generated_items=items_for_model, @@ -1042,7 +1123,7 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: run_input_guardrails( starting_agent, parallel_guardrails, - copy_input_items(prepared_input), + copy_input_items(original_input), context_wrapper, ), model_task, @@ -1070,7 +1151,7 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: input_guardrail_results.extend(parallel_results) else: turn_result = await run_single_turn( - agent=current_agent, + bindings=current_bindings, all_tools=all_tools, original_input=original_input, generated_items=items_for_model, @@ -1201,9 +1282,7 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: tool_output_guardrail_results=tool_output_guardrail_results, context_wrapper=context_wrapper, interruptions=[], - _tool_use_tracker_snapshot=serialize_tool_use_tracker( - tool_use_tracker - ), + _tool_use_tracker_snapshot=_tool_use_tracker_snapshot(), max_turns=max_turns, ) result._current_turn = current_turn @@ -1225,11 +1304,7 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: store=store_setting, ) result._original_input = copy_input_items(original_input) - return finalize_conversation_tracking( - _with_reasoning_item_id_policy(result), - server_conversation_tracker=server_conversation_tracker, - run_state=run_state, - ) + return _finalize_result(result) elif isinstance(turn_result.next_step, NextStepInterruption): if session_persistence_enabled: if not input_guardrails_triggered(input_guardrail_results): @@ -1286,11 +1361,7 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: run_state=run_state, original_input=original_input, ) - return finalize_conversation_tracking( - _with_reasoning_item_id_policy(result), - server_conversation_tracker=server_conversation_tracker, - run_state=run_state, - ) + return _finalize_result(result) elif isinstance(turn_result.next_step, NextStepHandoff): current_agent = cast(Agent[TContext], turn_result.next_step.new_agent) if run_state is not None: @@ -1336,6 +1407,16 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: ) raise finally: + try: + sandbox_resume_state = await sandbox_runtime.cleanup() + except Exception as error: + logger.warning("Failed to clean up sandbox resources after run: %s", error) + else: + if completed_result is not None: + completed_result._sandbox_resume_state = sandbox_resume_state + finally: + if completed_result is not None: + completed_result._sandbox_session = None try: await dispose_resolved_computers(run_context=context_wrapper) except Exception as error: @@ -1550,9 +1631,16 @@ def run_streamed( if run_state is not None: run_state.set_trace(new_trace or get_current_trace()) + sandbox_runtime = SandboxRuntime( + starting_agent=starting_agent, + run_config=run_config, + run_state=run_state, + ) + schema_agent = ( run_state._current_agent if run_state and run_state._current_agent else starting_agent ) + sandbox_runtime.assert_agent_supported(schema_agent) output_schema = get_output_schema(schema_agent) streamed_input: str | list[TResponseInputItem] = ( @@ -1618,6 +1706,8 @@ def run_streamed( streamed_result._state = run_state if run_state is not None: streamed_result._tool_use_tracker_snapshot = run_state.get_tool_use_tracker_snapshot() + if sandbox_runtime.enabled: + sandbox_runtime.apply_result_metadata(streamed_result) # Kick off the actual agent loop in the background and return the streamed result object. streamed_result.run_loop_task = asyncio.create_task( @@ -1636,8 +1726,11 @@ def run_streamed( session=session, run_state=run_state, is_resumed_state=is_resumed_state, + sandbox_runtime=sandbox_runtime, ) ) + if sandbox_runtime.enabled: + streamed_result.ensure_sandbox_cleanup_on_completion() return streamed_result diff --git a/src/agents/run_config.py b/src/agents/run_config.py index ad21f6c3b9..c3a6f13df7 100644 --- a/src/agents/run_config.py +++ b/src/agents/run_config.py @@ -22,6 +22,11 @@ if TYPE_CHECKING: from .agent import Agent from .run_context import RunContextWrapper + from .sandbox.manifest import Manifest + from .sandbox.session.base_sandbox_session import BaseSandboxSession + from .sandbox.session.sandbox_client import BaseSandboxClient + from .sandbox.session.sandbox_session_state import SandboxSessionState + from .sandbox.snapshot import SnapshotSpec DEFAULT_MAX_TURNS = 10 @@ -80,6 +85,29 @@ class ToolErrorFormatterArgs(Generic[TContext]): ToolErrorFormatter = Callable[[ToolErrorFormatterArgs[Any]], MaybeAwaitable[Optional[str]]] +@dataclass +class SandboxRunConfig: + """Grouped sandbox runtime configuration for `Runner`.""" + + client: BaseSandboxClient[Any] | None = None + """Sandbox client used to create or resume sandbox sessions.""" + + options: Any | None = None + """Sandbox-client-specific options used when creating a fresh session.""" + + session: BaseSandboxSession | None = None + """Live sandbox session override for the current process.""" + + session_state: SandboxSessionState | None = None + """Explicit sandbox session state to resume from when not using `RunState` payloads.""" + + manifest: Manifest | None = None + """Optional sandbox manifest override for fresh session creation.""" + + snapshot: SnapshotSpec | None = None + """Optional sandbox snapshot used for fresh session creation.""" + + @dataclass class RunConfig: """Configures settings for the entire agent run.""" @@ -191,6 +219,9 @@ class RunConfig: - ``"omit"`` strips reasoning item IDs from model input built by the runner. """ + sandbox: SandboxRunConfig | None = None + """Optional sandbox runtime configuration for `SandboxAgent` execution.""" + class RunOptions(TypedDict, Generic[TContext]): """Arguments for ``AgentRunner`` methods.""" @@ -231,6 +262,7 @@ class RunOptions(TypedDict, Generic[TContext]): "ReasoningItemIdPolicy", "RunConfig", "RunOptions", + "SandboxRunConfig", "ToolErrorFormatter", "ToolErrorFormatterArgs", "_default_trace_include_sensitive_data", diff --git a/src/agents/run_internal/agent_bindings.py b/src/agents/run_internal/agent_bindings.py new file mode 100644 index 0000000000..93e3702b14 --- /dev/null +++ b/src/agents/run_internal/agent_bindings.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Generic + +from ..agent import Agent +from ..run_context import TContext + +__all__ = [ + "AgentBindings", + "bind_execution_agent", + "bind_public_agent", +] + + +@dataclass(frozen=True) +class AgentBindings(Generic[TContext]): + """Carry the public and execution agent identities for a turn.""" + + public_agent: Agent[TContext] + execution_agent: Agent[TContext] + + +def bind_public_agent(agent: Agent[TContext]) -> AgentBindings[TContext]: + """Build bindings for non-rewritten execution where both identities are the same.""" + return AgentBindings(public_agent=agent, execution_agent=agent) + + +def bind_execution_agent( + *, + public_agent: Agent[TContext], + execution_agent: Agent[TContext], +) -> AgentBindings[TContext]: + """Build bindings for execution-only clones such as sandbox-prepared agents.""" + return AgentBindings( + public_agent=public_agent, + execution_agent=execution_agent, + ) diff --git a/src/agents/run_internal/agent_runner_helpers.py b/src/agents/run_internal/agent_runner_helpers.py index 776e406703..3b51d616ea 100644 --- a/src/agents/run_internal/agent_runner_helpers.py +++ b/src/agents/run_internal/agent_runner_helpers.py @@ -253,6 +253,11 @@ def build_interruption_result( original_input: str | list[TResponseInputItem], ) -> RunResult: """Create a RunResult for an interruption path.""" + identity_root_agent = ( + run_state._starting_agent + if run_state is not None and run_state._starting_agent is not None + else current_agent + ) result = RunResult( input=result_input, new_items=session_items, @@ -266,7 +271,10 @@ def build_interruption_result( context_wrapper=context_wrapper, interruptions=interruptions, _last_processed_response=processed_response, - _tool_use_tracker_snapshot=serialize_tool_use_tracker(tool_use_tracker), + _tool_use_tracker_snapshot=serialize_tool_use_tracker( + tool_use_tracker, + starting_agent=identity_root_agent, + ), max_turns=max_turns, ) result._current_turn = current_turn diff --git a/src/agents/run_internal/guardrails.py b/src/agents/run_internal/guardrails.py index 375cc37c25..1b04779d81 100644 --- a/src/agents/run_internal/guardrails.py +++ b/src/agents/run_internal/guardrails.py @@ -57,7 +57,7 @@ async def run_input_guardrails_with_queue( input: str | list[TResponseInputItem], context: RunContextWrapper[TContext], streamed_result: RunResultStreaming, - parent_span: Span[Any], + parent_span: Span[Any] | None, ) -> None: """Run guardrails concurrently and stream results into the queue.""" queue = streamed_result._input_guardrail_queue @@ -74,16 +74,18 @@ async def run_input_guardrails_with_queue( for t in guardrail_tasks: t.cancel() await asyncio.gather(*guardrail_tasks, return_exceptions=True) - _error_tracing.attach_error_to_span( - parent_span, - SpanError( - message="Guardrail tripwire triggered", - data={ - "guardrail": result.guardrail.get_name(), - "type": "input_guardrail", - }, - ), + span_error = SpanError( + message="Guardrail tripwire triggered", + data={ + "guardrail": result.guardrail.get_name(), + "type": "input_guardrail", + }, ) + if parent_span is not None: + _error_tracing.attach_error_to_span(parent_span, span_error) + else: + # Early first-turn streamed guardrails can run before the agent span exists. + _error_tracing.attach_error_to_current_span(span_error) queue.put_nowait(result) guardrail_results.append(result) break diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index 3d21d89fda..55ddf7c117 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -57,6 +57,7 @@ from ..run_context import AgentHookContext, RunContextWrapper, TContext from ..run_error_handlers import RunErrorHandlers from ..run_state import RunState +from ..sandbox.runtime import SandboxRuntime from ..stream_events import ( AgentUpdatedStreamEvent, RawResponsesStreamEvent, @@ -68,6 +69,7 @@ from ..tracing.span_data import AgentSpanData from ..usage import Usage from ..util import _coro, _error_tracing +from .agent_bindings import AgentBindings, bind_public_agent from .agent_runner_helpers import apply_resumed_conversation_settings from .approvals import approvals_from_step from .error_handlers import ( @@ -413,6 +415,7 @@ async def start_streaming( run_state: RunState[TContext] | None = None, *, is_resumed_state: bool = False, + sandbox_runtime: SandboxRuntime[TContext] | None = None, ): """Run the streaming loop for a run result.""" if streamed_result.trace: @@ -598,6 +601,64 @@ async def _save_stream_items_without_count( try: while True: + all_input_guardrails = ( + starting_agent.input_guardrails + (run_config.input_guardrails or []) + if current_turn == 0 and not is_resumed_state + else [] + ) + sequential_guardrails = [g for g in all_input_guardrails if not g.run_in_parallel] + parallel_guardrails = [g for g in all_input_guardrails if g.run_in_parallel] + current_bindings = bind_public_agent(current_agent) + execution_agent = current_bindings.execution_agent + prepared_turn_input = copy_input_items(streamed_result.input) + if sandbox_runtime is not None and sandbox_runtime.enabled and sequential_guardrails: + # Mirror the non-streaming path: a blocking first-turn guardrail should fire + # before sandbox prep can create, start, or mutate sandbox state. + existing_input_guardrail_count = len(streamed_result.input_guardrail_results) + await run_input_guardrails_with_queue( + starting_agent, + sequential_guardrails, + ItemHelpers.input_to_new_input_list(prepared_turn_input), + context_wrapper, + streamed_result, + None, + ) + for result in streamed_result.input_guardrail_results[ + existing_input_guardrail_count: + ]: + if result.output.tripwire_triggered: + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + session_input_items_for_persistence = ( + await persist_session_items_for_guardrail_trip( + session, + server_conversation_tracker, + session_input_items_for_persistence, + starting_input, + run_state, + store=current_agent.model_settings.resolve( + run_config.model_settings + ).store, + ) + ) + raise InputGuardrailTripwireTriggered(result) + sequential_guardrails = [] + + if sandbox_runtime is not None: + prepared_sandbox = await sandbox_runtime.prepare_agent( + current_agent=current_agent, + current_input=prepared_turn_input, + context_wrapper=context_wrapper, + is_resumed_state=is_resumed_state, + ) + current_bindings = prepared_sandbox.bindings + execution_agent = current_bindings.execution_agent + prepared_turn_input = copy_input_items(prepared_sandbox.input) + streamed_result.input = prepared_turn_input + streamed_result._original_input = copy_input_items(prepared_turn_input) + if run_state is not None: + run_state._original_input = copy_input_items(prepared_turn_input) + sandbox_runtime.apply_result_metadata(streamed_result) + if is_resumed_state and run_state is not None and run_state._current_step is not None: if isinstance(run_state._current_step, NextStepInterruption): if not run_state._model_responses or not run_state._last_processed_response: @@ -606,7 +667,7 @@ async def _save_stream_items_without_count( last_model_response = run_state._model_responses[-1] turn_result = await resolve_interrupted_turn( - agent=current_agent, + bindings=current_bindings, original_input=run_state._original_input, original_pre_step_items=run_state._generated_items, new_response=last_model_response, @@ -621,7 +682,12 @@ async def _save_stream_items_without_count( current_agent, run_state._last_processed_response ) streamed_result._tool_use_tracker_snapshot = serialize_tool_use_tracker( - tool_use_tracker + tool_use_tracker, + starting_agent=( + run_state._starting_agent + if run_state is not None and run_state._starting_agent is not None + else starting_agent + ), ) streamed_result.input = turn_result.original_input @@ -712,14 +778,14 @@ async def _save_stream_items_without_count( if streamed_result.is_complete: break - all_tools = await get_all_tools(current_agent, context_wrapper) + all_tools = await get_all_tools(execution_agent, context_wrapper) await initialize_computer_tools(tools=all_tools, context_wrapper=context_wrapper) if current_span is None: handoff_names = [ - h.agent_name for h in await get_handoffs(current_agent, context_wrapper) + h.agent_name for h in await get_handoffs(execution_agent, context_wrapper) ] - if output_schema := get_output_schema(current_agent): + if output_schema := get_output_schema(execution_agent): output_type_name = output_schema.name() else: output_type_name = "str" @@ -821,17 +887,11 @@ async def _save_stream_items_without_count( break if current_turn == 1: - all_input_guardrails = starting_agent.input_guardrails + ( - run_config.input_guardrails or [] - ) - sequential_guardrails = [g for g in all_input_guardrails if not g.run_in_parallel] - parallel_guardrails = [g for g in all_input_guardrails if g.run_in_parallel] - if sequential_guardrails: await run_input_guardrails_with_queue( starting_agent, sequential_guardrails, - ItemHelpers.input_to_new_input_list(prepared_input), + ItemHelpers.input_to_new_input_list(prepared_turn_input), context_wrapper, streamed_result, current_span, @@ -858,7 +918,7 @@ async def _save_stream_items_without_count( run_input_guardrails_with_queue( starting_agent, parallel_guardrails, - ItemHelpers.input_to_new_input_list(prepared_input), + ItemHelpers.input_to_new_input_list(prepared_turn_input), context_wrapper, streamed_result, current_span, @@ -882,7 +942,7 @@ async def _save_stream_items_without_count( ) turn_result = await run_single_turn_streamed( streamed_result, - current_agent, + current_bindings, hooks, context_wrapper, run_config, @@ -906,7 +966,12 @@ async def _save_stream_items_without_count( ) should_run_agent_start_hooks = False streamed_result._tool_use_tracker_snapshot = serialize_tool_use_tracker( - tool_use_tracker + tool_use_tracker, + starting_agent=( + run_state._starting_agent + if run_state is not None and run_state._starting_agent is not None + else starting_agent + ), ) streamed_result.raw_responses = streamed_result.raw_responses + [ @@ -1086,7 +1151,7 @@ async def _save_stream_items_without_count( async def run_single_turn_streamed( streamed_result: RunResultStreaming, - agent: Agent[TContext], + bindings: AgentBindings[TContext], hooks: RunHooks[TContext], context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, @@ -1100,6 +1165,8 @@ async def run_single_turn_streamed( reasoning_item_id_policy: ReasoningItemIdPolicy | None = None, ) -> SingleStepResult: """Run a single streamed turn and emit events as results arrive.""" + public_agent = bindings.public_agent + execution_agent = bindings.execution_agent emitted_tool_call_ids: set[str] = set() emitted_reasoning_item_ids: set[str] = set() emitted_tool_search_fingerprints: set[str] = set() @@ -1145,28 +1212,28 @@ def _tool_search_fingerprint(raw_item: Any) -> str: turn_input=turn_input, ) await asyncio.gather( - hooks.on_agent_start(agent_hook_context, agent), + hooks.on_agent_start(agent_hook_context, public_agent), ( - agent.hooks.on_start(agent_hook_context, agent) - if agent.hooks + public_agent.hooks.on_start(agent_hook_context, public_agent) + if public_agent.hooks else _coro.noop_coroutine() ), ) - output_schema = get_output_schema(agent) + output_schema = get_output_schema(execution_agent) - streamed_result.current_agent = agent - streamed_result._current_agent_output_schema = output_schema + streamed_result.current_agent = public_agent + streamed_result._current_agent_output_schema = get_output_schema(public_agent) system_prompt, prompt_config = await asyncio.gather( - agent.get_system_prompt(context_wrapper), - agent.get_prompt(context_wrapper), + execution_agent.get_system_prompt(context_wrapper), + execution_agent.get_prompt(context_wrapper), ) - handoffs = await get_handoffs(agent, context_wrapper) - model = get_model(agent, run_config) - model_settings = agent.model_settings.resolve(run_config.model_settings) - model_settings = maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) + handoffs = await get_handoffs(execution_agent, context_wrapper) + model = get_model(execution_agent, run_config) + model_settings = execution_agent.model_settings.resolve(run_config.model_settings) + model_settings = maybe_reset_tool_choice(public_agent, tool_use_tracker, model_settings) final_response: ModelResponse | None = None @@ -1190,7 +1257,7 @@ def _tool_search_fingerprint(raw_item: Any) -> str: ) filtered = await maybe_filter_model_input( - agent=agent, + agent=public_agent, run_config=run_config, context_wrapper=context_wrapper, input_items=input, @@ -1214,10 +1281,15 @@ def _tool_search_fingerprint(raw_item: Any) -> str: raise RuntimeError("Prepared model input is empty") await asyncio.gather( - hooks.on_llm_start(context_wrapper, agent, filtered.instructions, filtered.input), + hooks.on_llm_start(context_wrapper, public_agent, filtered.instructions, filtered.input), ( - agent.hooks.on_llm_start(context_wrapper, agent, filtered.instructions, filtered.input) - if agent.hooks + public_agent.hooks.on_llm_start( + context_wrapper, + public_agent, + filtered.instructions, + filtered.input, + ) + if public_agent.hooks else _coro.noop_coroutine() ), ) @@ -1327,7 +1399,7 @@ async def rewind_model_request() -> None: RunItemStreamEvent( item=ToolSearchCallItem( raw_item=coerce_tool_search_call_raw_item(output_item), - agent=agent, + agent=public_agent, ), name="tool_search_called", ) @@ -1339,7 +1411,7 @@ async def rewind_model_request() -> None: RunItemStreamEvent( item=ToolSearchOutputItem( raw_item=coerce_tool_search_output_raw_item(output_item), - agent=agent, + agent=public_agent, ), name="tool_search_output_created", ) @@ -1381,7 +1453,7 @@ async def rewind_model_request() -> None: tool_item = ToolCallItem( raw_item=cast(ToolCallItemTypes, output_item), - agent=agent, + agent=public_agent, description=tool_description, title=tool_title, ) @@ -1395,7 +1467,7 @@ async def rewind_model_request() -> None: if reasoning_id and reasoning_id not in emitted_reasoning_item_ids: emitted_reasoning_item_ids.add(reasoning_id) - reasoning_item = ReasoningItem(raw_item=output_item, agent=agent) + reasoning_item = ReasoningItem(raw_item=output_item, agent=public_agent) streamed_result._event_queue.put_nowait( RunItemStreamEvent(item=reasoning_item, name="reasoning_item_created") ) @@ -1404,11 +1476,11 @@ async def rewind_model_request() -> None: context_wrapper.usage.add(final_response.usage) await asyncio.gather( ( - agent.hooks.on_llm_end(context_wrapper, agent, final_response) - if agent.hooks + public_agent.hooks.on_llm_end(context_wrapper, public_agent, final_response) + if public_agent.hooks else _coro.noop_coroutine() ), - hooks.on_llm_end(context_wrapper, agent, final_response), + hooks.on_llm_end(context_wrapper, public_agent, final_response), ) if not final_response: @@ -1421,7 +1493,7 @@ async def rewind_model_request() -> None: server_conversation_tracker.track_server_items(final_response) single_step_result = await get_single_step_result_from_response( - agent=agent, + bindings=bindings, original_input=streamed_result.input, pre_step_items=streamed_result._model_input_items, new_response=final_response, @@ -1480,7 +1552,7 @@ async def rewind_model_request() -> None: async def run_single_turn( *, - agent: Agent[TContext], + bindings: AgentBindings[TContext], all_tools: list[Tool], original_input: str | list[TResponseInputItem], generated_items: list[RunItem], @@ -1495,6 +1567,8 @@ async def run_single_turn( reasoning_item_id_policy: ReasoningItemIdPolicy | None = None, ) -> SingleStepResult: """Run a single non-streaming turn of the agent loop.""" + public_agent = bindings.public_agent + execution_agent = bindings.execution_agent try: turn_input = ItemHelpers.input_to_new_input_list(original_input) except Exception: @@ -1509,28 +1583,28 @@ async def run_single_turn( turn_input=turn_input, ) await asyncio.gather( - hooks.on_agent_start(agent_hook_context, agent), + hooks.on_agent_start(agent_hook_context, public_agent), ( - agent.hooks.on_start(agent_hook_context, agent) - if agent.hooks + public_agent.hooks.on_start(agent_hook_context, public_agent) + if public_agent.hooks else _coro.noop_coroutine() ), ) system_prompt, prompt_config = await asyncio.gather( - agent.get_system_prompt(context_wrapper), - agent.get_prompt(context_wrapper), + execution_agent.get_system_prompt(context_wrapper), + execution_agent.get_prompt(context_wrapper), ) - output_schema = get_output_schema(agent) - handoffs = await get_handoffs(agent, context_wrapper) + output_schema = get_output_schema(execution_agent) + handoffs = await get_handoffs(execution_agent, context_wrapper) if server_conversation_tracker is not None: input = server_conversation_tracker.prepare_input(original_input, generated_items) else: input = _prepare_turn_input_items(original_input, generated_items, reasoning_item_id_policy) new_response = await get_new_response( - agent, + bindings, system_prompt, input, output_schema, @@ -1547,7 +1621,7 @@ async def run_single_turn( ) return await get_single_step_result_from_response( - agent=agent, + bindings=bindings, original_input=original_input, pre_step_items=generated_items, new_response=new_response, @@ -1562,7 +1636,7 @@ async def run_single_turn( async def get_new_response( - agent: Agent[TContext], + bindings: AgentBindings[TContext], system_prompt: str | None, input: list[TResponseInputItem], output_schema: AgentOutputSchemaBase | None, @@ -1578,8 +1652,10 @@ async def get_new_response( session_items_to_rewind: list[TResponseInputItem] | None = None, ) -> ModelResponse: """Call the model and return the raw response, handling retries and hooks.""" + public_agent = bindings.public_agent + execution_agent = bindings.execution_agent filtered = await maybe_filter_model_input( - agent=agent, + agent=public_agent, run_config=run_config, context_wrapper=context_wrapper, input_items=input, @@ -1588,23 +1664,23 @@ async def get_new_response( if isinstance(filtered.input, list): filtered.input = deduplicate_input_items_preferring_latest(filtered.input) - model = get_model(agent, run_config) - model_settings = agent.model_settings.resolve(run_config.model_settings) - model_settings = maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) + model = get_model(execution_agent, run_config) + model_settings = execution_agent.model_settings.resolve(run_config.model_settings) + model_settings = maybe_reset_tool_choice(public_agent, tool_use_tracker, model_settings) if server_conversation_tracker is not None: server_conversation_tracker.mark_input_as_sent(filtered.input) await asyncio.gather( - hooks.on_llm_start(context_wrapper, agent, filtered.instructions, filtered.input), + hooks.on_llm_start(context_wrapper, public_agent, filtered.instructions, filtered.input), ( - agent.hooks.on_llm_start( + public_agent.hooks.on_llm_start( context_wrapper, - agent, + public_agent, filtered.instructions, filtered.input, ) - if agent.hooks + if public_agent.hooks else _coro.noop_coroutine() ), ) @@ -1660,11 +1736,11 @@ async def rewind_model_request() -> None: await asyncio.gather( ( - agent.hooks.on_llm_end(context_wrapper, agent, new_response) - if agent.hooks + public_agent.hooks.on_llm_end(context_wrapper, public_agent, new_response) + if public_agent.hooks else _coro.noop_coroutine() ), - hooks.on_llm_end(context_wrapper, agent, new_response), + hooks.on_llm_end(context_wrapper, public_agent, new_response), ) return new_response diff --git a/src/agents/run_internal/tool_execution.py b/src/agents/run_internal/tool_execution.py index 4511045288..f5d56d2dfc 100644 --- a/src/agents/run_internal/tool_execution.py +++ b/src/agents/run_internal/tool_execution.py @@ -87,6 +87,7 @@ from ..util._approvals import evaluate_needs_approval_setting from ..util._types import MaybeAwaitable from ._asyncio_progress import get_function_tool_task_progress_deadline +from .agent_bindings import AgentBindings, bind_public_agent from .approvals import append_approval_error_output from .items import ( REJECTION_MESSAGE, @@ -1279,14 +1280,15 @@ class _FunctionToolBatchExecutor: def __init__( self, *, - agent: Agent[Any], + bindings: AgentBindings[Any], tool_runs: list[ToolRunFunction], hooks: RunHooks[Any], context_wrapper: RunContextWrapper[Any], config: RunConfig, isolate_parallel_failures: bool | None, ) -> None: - self.agent = agent + self.execution_agent = bindings.execution_agent + self.public_agent = bindings.public_agent self.tool_runs = tool_runs self.hooks = hooks self.context_wrapper = context_wrapper @@ -1310,7 +1312,7 @@ async def execute( list[FunctionToolResult], list[ToolInputGuardrailResult], list[ToolOutputGuardrailResult] ]: self.available_function_tools = await resolve_enabled_function_tools( - self.agent, + self.execution_agent, self.context_wrapper, ) for tool_run in self.tool_runs: @@ -1457,10 +1459,10 @@ async def _run_single_tool( tool_call.call_id, tool_call=raw_tool_call, tool_namespace=tool_context_namespace, - agent=self.agent, + agent=self.public_agent, run_config=self.config, ) - agent_hooks = self.agent.hooks + agent_hooks = self.public_agent.hooks if self.config.trace_include_sensitive_data: span_fn.span_data.input = tool_call.arguments @@ -1526,7 +1528,7 @@ async def _maybe_execute_tool_approval( ) if approval_status is None: approval_item = ToolApprovalItem( - agent=self.agent, + agent=self.public_agent, raw_item=raw_tool_call, tool_name=func_tool.name, tool_namespace=tool_namespace, @@ -1566,7 +1568,7 @@ async def _maybe_execute_tool_approval( tool=func_tool, output=rejection_message, run_item=function_rejection_item( - self.agent, + self.public_agent, tool_call, rejection_message=rejection_message, scope_id=self.tool_state_scope_id, @@ -1585,16 +1587,16 @@ async def _execute_single_tool_body( rejected_message = await _execute_tool_input_guardrails( func_tool=func_tool, tool_context=tool_context, - agent=self.agent, + agent=self.public_agent, tool_input_guardrail_results=self.tool_input_guardrail_results, ) if rejected_message is not None: return rejected_message await asyncio.gather( - self.hooks.on_tool_start(tool_context, self.agent, func_tool), + self.hooks.on_tool_start(tool_context, self.public_agent, func_tool), ( - agent_hooks.on_tool_start(tool_context, self.agent, func_tool) + agent_hooks.on_tool_start(tool_context, self.public_agent, func_tool) if agent_hooks else _coro.noop_coroutine() ), @@ -1654,15 +1656,15 @@ async def _invoke_tool_and_run_post_invoke( final_result = await _execute_tool_output_guardrails( func_tool=func_tool, tool_context=tool_context, - agent=self.agent, + agent=self.public_agent, real_result=real_result, tool_output_guardrail_results=self.tool_output_guardrail_results, ) await asyncio.gather( - self.hooks.on_tool_end(tool_context, self.agent, func_tool, final_result), + self.hooks.on_tool_end(tool_context, self.public_agent, func_tool, final_result), ( - agent_hooks.on_tool_end(tool_context, self.agent, func_tool, final_result) + agent_hooks.on_tool_end(tool_context, self.public_agent, func_tool, final_result) if agent_hooks else _coro.noop_coroutine() ), @@ -1763,7 +1765,7 @@ def _build_function_tool_results(self) -> list[FunctionToolResult]: run_item = ToolCallOutputItem( output=result, raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, result), - agent=self.agent, + agent=self.public_agent, ) else: # Skip tool output until nested interruptions are resolved. @@ -1784,7 +1786,7 @@ def _build_function_tool_results(self) -> list[FunctionToolResult]: async def execute_function_tool_calls( *, - agent: Agent[Any], + bindings: AgentBindings[Any], tool_runs: list[ToolRunFunction], hooks: RunHooks[Any], context_wrapper: RunContextWrapper[Any], @@ -1795,7 +1797,7 @@ async def execute_function_tool_calls( ]: """Execute function tool calls with approvals, guardrails, and hooks.""" return await _FunctionToolBatchExecutor( - agent=agent, + bindings=bindings, tool_runs=tool_runs, hooks=hooks, context_wrapper=context_wrapper, @@ -1806,7 +1808,7 @@ async def execute_function_tool_calls( async def execute_local_shell_calls( *, - agent: Agent[Any], + public_agent: Agent[Any], calls: list[ToolRunLocalShellCall], context_wrapper: RunContextWrapper[Any], hooks: RunHooks[Any], @@ -1819,7 +1821,7 @@ async def execute_local_shell_calls( for call in calls: results.append( await LocalShellAction.execute( - agent=agent, + agent=public_agent, call=call, hooks=hooks, context_wrapper=context_wrapper, @@ -1831,7 +1833,7 @@ async def execute_local_shell_calls( async def execute_shell_calls( *, - agent: Agent[Any], + public_agent: Agent[Any], calls: list[ToolRunShellCall], context_wrapper: RunContextWrapper[Any], hooks: RunHooks[Any], @@ -1844,7 +1846,7 @@ async def execute_shell_calls( for call in calls: results.append( await ShellAction.execute( - agent=agent, + agent=public_agent, call=call, hooks=hooks, context_wrapper=context_wrapper, @@ -1856,7 +1858,7 @@ async def execute_shell_calls( async def execute_apply_patch_calls( *, - agent: Agent[Any], + public_agent: Agent[Any], calls: list[ToolRunApplyPatchCall], context_wrapper: RunContextWrapper[Any], hooks: RunHooks[Any], @@ -1869,7 +1871,7 @@ async def execute_apply_patch_calls( for call in calls: results.append( await ApplyPatchAction.execute( - agent=agent, + agent=public_agent, call=call, hooks=hooks, context_wrapper=context_wrapper, @@ -1881,7 +1883,7 @@ async def execute_apply_patch_calls( async def execute_computer_actions( *, - agent: Agent[Any], + public_agent: Agent[Any], actions: list[ToolRunComputerAction], hooks: RunHooks[Any], context_wrapper: RunContextWrapper[Any], @@ -1898,7 +1900,7 @@ async def execute_computer_actions( for check in action.tool_call.pending_safety_checks: data = ComputerToolSafetyCheckData( ctx_wrapper=context_wrapper, - agent=agent, + agent=public_agent, tool_call=action.tool_call, safety_check=check, ) @@ -1917,7 +1919,7 @@ async def execute_computer_actions( results.append( await ComputerAction.execute( - agent=agent, + agent=public_agent, action=action, hooks=hooks, context_wrapper=context_wrapper, @@ -2081,7 +2083,7 @@ async def _resolve_tool_run( if tool_runs: function_results, _, _ = await execute_function_tool_calls( - agent=agent, + bindings=bind_public_agent(agent), tool_runs=tool_runs, hooks=hooks, context_wrapper=context_wrapper, diff --git a/src/agents/run_internal/tool_planning.py b/src/agents/run_internal/tool_planning.py index dabb83b4ac..d56bc03a26 100644 --- a/src/agents/run_internal/tool_planning.py +++ b/src/agents/run_internal/tool_planning.py @@ -24,6 +24,7 @@ from ..run_context import RunContextWrapper from ..tool import FunctionTool, MCPToolApprovalRequest from ..tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult +from .agent_bindings import AgentBindings from .run_steps import ( ToolRunApplyPatchCall, ToolRunComputerAction, @@ -518,7 +519,7 @@ async def _select_function_tool_runs_for_resume( async def _execute_tool_plan( *, plan: ToolExecutionPlan, - agent: Agent[Any], + bindings: AgentBindings[Any], hooks, context_wrapper: RunContextWrapper[Any], run_config, @@ -533,6 +534,7 @@ async def _execute_tool_plan( list[RunItem], ]: """Execute tool runs captured in a ToolExecutionPlan.""" + public_agent = bindings.public_agent isolate_function_tool_failures = len(plan.function_runs) > 1 or ( parallel and ( @@ -551,7 +553,7 @@ async def _execute_tool_plan( local_shell_results, ) = await asyncio.gather( execute_function_tool_calls( - agent=agent, + bindings=bindings, tool_runs=plan.function_runs, hooks=hooks, context_wrapper=context_wrapper, @@ -559,28 +561,28 @@ async def _execute_tool_plan( isolate_parallel_failures=isolate_function_tool_failures, ), execute_computer_actions( - agent=agent, + public_agent=public_agent, actions=plan.computer_actions, hooks=hooks, context_wrapper=context_wrapper, config=run_config, ), execute_shell_calls( - agent=agent, + public_agent=public_agent, calls=plan.shell_calls, hooks=hooks, context_wrapper=context_wrapper, config=run_config, ), execute_apply_patch_calls( - agent=agent, + public_agent=public_agent, calls=plan.apply_patch_calls, hooks=hooks, context_wrapper=context_wrapper, config=run_config, ), execute_local_shell_calls( - agent=agent, + public_agent=public_agent, calls=plan.local_shell_calls, hooks=hooks, context_wrapper=context_wrapper, @@ -593,7 +595,7 @@ async def _execute_tool_plan( tool_input_guardrail_results, tool_output_guardrail_results, ) = await execute_function_tool_calls( - agent=agent, + bindings=bindings, tool_runs=plan.function_runs, hooks=hooks, context_wrapper=context_wrapper, @@ -601,28 +603,28 @@ async def _execute_tool_plan( isolate_parallel_failures=isolate_function_tool_failures, ) computer_results = await execute_computer_actions( - agent=agent, + public_agent=public_agent, actions=plan.computer_actions, hooks=hooks, context_wrapper=context_wrapper, config=run_config, ) shell_results = await execute_shell_calls( - agent=agent, + public_agent=public_agent, calls=plan.shell_calls, hooks=hooks, context_wrapper=context_wrapper, config=run_config, ) apply_patch_results = await execute_apply_patch_calls( - agent=agent, + public_agent=public_agent, calls=plan.apply_patch_calls, hooks=hooks, context_wrapper=context_wrapper, config=run_config, ) local_shell_results = await execute_local_shell_calls( - agent=agent, + public_agent=public_agent, calls=plan.local_shell_calls, hooks=hooks, context_wrapper=context_wrapper, diff --git a/src/agents/run_internal/tool_use_tracker.py b/src/agents/run_internal/tool_use_tracker.py index e763f175a7..60ff9a1731 100644 --- a/src/agents/run_internal/tool_use_tracker.py +++ b/src/agents/run_internal/tool_use_tracker.py @@ -17,7 +17,11 @@ ToolSearchCallItem, ToolSearchOutputItem, ) -from ..run_state import _build_agent_map +from ..run_state import ( + _build_agent_identity_keys_by_id, + _build_agent_identity_map, + _build_agent_map, +) from .run_steps import ProcessedResponse, ToolRunFunction __all__ = [ @@ -112,11 +116,23 @@ def from_serializable(cls, data: dict[str, list[str]]) -> AgentToolUseTracker: return tracker -def serialize_tool_use_tracker(tool_use_tracker: AgentToolUseTracker) -> dict[str, list[str]]: +def serialize_tool_use_tracker( + tool_use_tracker: AgentToolUseTracker, + *, + starting_agent: Agent[Any] | None = None, +) -> dict[str, list[str]]: """Convert the AgentToolUseTracker into a serializable snapshot.""" + agent_identity_keys_by_id = ( + _build_agent_identity_keys_by_id(starting_agent) if starting_agent is not None else None + ) snapshot: dict[str, list[str]] = {} for agent, tool_names in tool_use_tracker.agent_to_tools: - snapshot[agent.name] = list(tool_names) + agent_key = None + if agent_identity_keys_by_id is not None: + agent_key = agent_identity_keys_by_id.get(id(agent)) + if agent_key is None: + agent_key = getattr(agent, "name", agent.__class__.__name__) + snapshot.setdefault(agent_key, []).extend(tool_names) return snapshot @@ -131,8 +147,9 @@ def hydrate_tool_use_tracker( return agent_map = _build_agent_map(starting_agent) + agent_identity_map = _build_agent_identity_map(starting_agent) for agent_name, tool_names in snapshot.items(): - agent = agent_map.get(agent_name) + agent = agent_identity_map.get(agent_name) or agent_map.get(agent_name) if agent is None: continue tool_use_tracker.add_tool_use(agent, list(tool_names)) diff --git a/src/agents/run_internal/turn_resolution.py b/src/agents/run_internal/turn_resolution.py index 2b3f98b55b..df035cbfb9 100644 --- a/src/agents/run_internal/turn_resolution.py +++ b/src/agents/run_internal/turn_resolution.py @@ -84,6 +84,7 @@ from ..tracing import SpanError, handoff_span from ..util import _coro, _error_tracing from ..util._approvals import evaluate_needs_approval_setting +from .agent_bindings import AgentBindings from .items import ( REJECTION_MESSAGE, apply_patch_rejection_item, @@ -155,7 +156,7 @@ async def _maybe_finalize_from_tool_results( *, - agent: Agent[TContext], + public_agent: Agent[TContext], original_input: str | list[TResponseInputItem], new_response: ModelResponse, pre_step_items: list[RunItem], @@ -167,12 +168,12 @@ async def _maybe_finalize_from_tool_results( tool_output_guardrail_results: list[ToolOutputGuardrailResult], ) -> SingleStepResult | None: check_tool_use = await check_for_final_output_from_tools( - agent, function_results, context_wrapper + public_agent, function_results, context_wrapper ) if not check_tool_use.is_final_output: return None - if not agent.output_type or agent.output_type is str: + if not public_agent.output_type or public_agent.output_type is str: check_tool_use.final_output = str(check_tool_use.final_output) if check_tool_use.final_output is None: @@ -182,7 +183,7 @@ async def _maybe_finalize_from_tool_results( ) return await execute_final_output( - agent=agent, + public_agent=public_agent, original_input=original_input, new_response=new_response, pre_step_items=pre_step_items, @@ -218,7 +219,7 @@ async def run_final_output_hooks( async def execute_final_output_step( *, - agent: Agent[Any], + public_agent: Agent[Any], original_input: str | list[TResponseInputItem], new_response: ModelResponse, pre_step_items: list[RunItem], @@ -235,7 +236,7 @@ async def execute_final_output_step( ) -> SingleStepResult: """Finalize a turn once final output is known and run end hooks.""" final_output_hooks = run_final_output_hooks_fn or run_final_output_hooks - await final_output_hooks(agent, hooks, context_wrapper, final_output) + await final_output_hooks(public_agent, hooks, context_wrapper, final_output) return SingleStepResult( original_input=original_input, @@ -251,7 +252,7 @@ async def execute_final_output_step( async def execute_final_output( *, - agent: Agent[Any], + public_agent: Agent[Any], original_input: str | list[TResponseInputItem], new_response: ModelResponse, pre_step_items: list[RunItem], @@ -268,7 +269,7 @@ async def execute_final_output( ) -> SingleStepResult: """Convenience wrapper to finalize a turn and run end hooks.""" return await execute_final_output_step( - agent=agent, + public_agent=public_agent, original_input=original_input, new_response=new_response, pre_step_items=pre_step_items, @@ -284,7 +285,7 @@ async def execute_final_output( async def execute_handoffs( *, - agent: Agent[TContext], + public_agent: Agent[TContext], original_input: str | list[TResponseInputItem], pre_step_items: list[RunItem], new_step_items: list[RunItem], @@ -310,14 +311,14 @@ def nest_history(data: HandoffInputData, mapper: Any | None = None) -> HandoffIn ToolCallOutputItem( output=output_message, raw_item=ItemHelpers.tool_call_output_item(handoff.tool_call, output_message), - agent=agent, + agent=public_agent, ) for handoff in run_handoffs[1:] ] ) actual_handoff = run_handoffs[0] - with handoff_span(from_agent=agent.name) as span_handoff: + with handoff_span(from_agent=public_agent.name) as span_handoff: handoff = actual_handoff.handoff new_agent: Agent[Any] = await handoff.on_invoke_handoff( context_wrapper, actual_handoff.tool_call.arguments @@ -336,12 +337,12 @@ def nest_history(data: HandoffInputData, mapper: Any | None = None) -> HandoffIn new_step_items.append( HandoffOutputItem( - agent=agent, + agent=public_agent, raw_item=ItemHelpers.tool_call_output_item( actual_handoff.tool_call, handoff.get_transfer_message(new_agent), ), - source_agent=agent, + source_agent=public_agent, target_agent=new_agent, ) ) @@ -349,16 +350,16 @@ def nest_history(data: HandoffInputData, mapper: Any | None = None) -> HandoffIn await asyncio.gather( hooks.on_handoff( context=context_wrapper, - from_agent=agent, + from_agent=public_agent, to_agent=new_agent, ), ( - agent.hooks.on_handoff( + public_agent.hooks.on_handoff( context_wrapper, agent=new_agent, - source=agent, + source=public_agent, ) - if agent.hooks + if public_agent.hooks else _coro.noop_coroutine() ), ) @@ -386,7 +387,7 @@ def nest_history(data: HandoffInputData, mapper: Any | None = None) -> HandoffIn if input_filter and handoff_input_data is not None: filter_name = getattr(input_filter, "__qualname__", repr(input_filter)) - from_agent = getattr(agent, "name", agent.__class__.__name__) + from_agent = getattr(public_agent, "name", public_agent.__class__.__name__) to_agent = getattr(new_agent, "name", new_agent.__class__.__name__) logger.debug( "Filtering handoff inputs with %s for %s -> %s", @@ -498,7 +499,7 @@ async def check_for_final_output_from_tools( async def execute_tools_and_side_effects( *, - agent: Agent[TContext], + bindings: AgentBindings[TContext], original_input: str | list[TResponseInputItem], pre_step_items: list[RunItem], new_response: ModelResponse, @@ -509,6 +510,7 @@ async def execute_tools_and_side_effects( run_config: RunConfig, ) -> SingleStepResult: """Run one turn of the loop, coordinating tools, approvals, guardrails, and handoffs.""" + public_agent = bindings.public_agent execute_final_output_call = execute_final_output execute_handoffs_call = execute_handoffs @@ -518,7 +520,7 @@ async def execute_tools_and_side_effects( plan = _build_plan_for_fresh_turn( processed_response=processed_response, - agent=agent, + agent=public_agent, context_wrapper=context_wrapper, approval_items_by_call_id=approval_items_by_call_id, ) @@ -538,7 +540,7 @@ async def execute_tools_and_side_effects( local_shell_results, ) = await _execute_tool_plan( plan=plan, - agent=agent, + bindings=bindings, hooks=hooks, context_wrapper=context_wrapper, run_config=run_config, @@ -579,7 +581,7 @@ async def execute_tools_and_side_effects( ) await _append_mcp_callback_results( - agent=agent, + agent=public_agent, requests=plan.mcp_requests_with_callback, context_wrapper=context_wrapper, append_item=new_step_items.append, @@ -587,7 +589,7 @@ async def execute_tools_and_side_effects( if run_handoffs := processed_response.handoffs: return await execute_handoffs_call( - agent=agent, + public_agent=public_agent, original_input=original_input, pre_step_items=pre_step_items, new_step_items=new_step_items, @@ -599,7 +601,7 @@ async def execute_tools_and_side_effects( ) tool_final_output = await _maybe_finalize_from_tool_results( - agent=agent, + public_agent=public_agent, original_input=original_input, new_response=new_response, pre_step_items=pre_step_items, @@ -626,7 +628,7 @@ async def execute_tools_and_side_effects( if output_schema and not output_schema.is_plain_text() and potential_final_output_text: final_output = output_schema.validate_json(potential_final_output_text) return await execute_final_output_call( - agent=agent, + public_agent=public_agent, original_input=original_input, new_response=new_response, pre_step_items=pre_step_items, @@ -639,7 +641,7 @@ async def execute_tools_and_side_effects( ) if not output_schema or output_schema.is_plain_text(): return await execute_final_output_call( - agent=agent, + public_agent=public_agent, original_input=original_input, new_response=new_response, pre_step_items=pre_step_items, @@ -664,7 +666,7 @@ async def execute_tools_and_side_effects( async def resolve_interrupted_turn( *, - agent: Agent[TContext], + bindings: AgentBindings[TContext], original_input: str | list[TResponseInputItem], original_pre_step_items: list[RunItem], new_response: ModelResponse, @@ -676,6 +678,8 @@ async def resolve_interrupted_turn( nest_handoff_history_fn: Callable[..., HandoffInputData] | None = None, ) -> SingleStepResult: """Continue a turn that was previously interrupted waiting for tool approval.""" + public_agent = bindings.public_agent + execution_agent = bindings.execution_agent execute_handoffs_call = execute_handoffs @@ -719,7 +723,7 @@ async def _record_function_rejection( ) rejected_function_outputs.append( function_rejection_item( - agent, + public_agent, tool_call, rejection_message=rejection_message, scope_id=tool_state_scope_id, @@ -793,7 +797,7 @@ async def _build_shell_rejection(run: ToolRunShellCall, call_id: str) -> RunItem return cast( RunItem, shell_rejection_item( - agent, + public_agent, call_id, rejection_message=rejection_message, ), @@ -810,7 +814,7 @@ async def _build_apply_patch_rejection(run: ToolRunApplyPatchCall, call_id: str) return cast( RunItem, apply_patch_rejection_item( - agent, + public_agent, call_id, rejection_message=rejection_message, ), @@ -893,20 +897,39 @@ def _add_pending_interruption(item: ToolApprovalItem | None) -> None: pending_interruption_keys.add(key) pending_interruptions.append(item) + def _allow_legacy_name_agent_match() -> bool: + schema_version = getattr(run_state, "_schema_version", None) + if not isinstance(schema_version, str): + return False + try: + version_parts = tuple(int(part) for part in schema_version.split(".")) + except ValueError: + return False + # Schema 1.6 and earlier only serialized approval owners by agent name. With duplicate-name + # agents, deserialization can legitimately resolve the approval to a sibling instance, so + # resume must accept a same-name match for those legacy snapshots. Schema 1.7+ persists + # duplicate-name identities, so newer snapshots should continue requiring object identity. + return version_parts < (1, 7) + + allow_legacy_name_agent_match = _allow_legacy_name_agent_match() + def _approval_matches_agent(approval: ToolApprovalItem) -> bool: approval_agent = approval.agent if approval_agent is None: return False - if approval_agent is agent: + if approval_agent is public_agent: return True - return getattr(approval_agent, "name", None) == agent.name + return allow_legacy_name_agent_match and approval_agent.name == public_agent.name - available_function_tools = await resolve_enabled_function_tools(agent, context_wrapper) + available_function_tools = await resolve_enabled_function_tools( + execution_agent, + context_wrapper, + ) approval_rebuild_function_tools = available_function_tools - if pending_approval_items and agent.mcp_servers: + if pending_approval_items and execution_agent.mcp_servers: approval_rebuild_function_tools = [ tool - for tool in await agent.get_all_tools(context_wrapper) + for tool in await execution_agent.get_all_tools(context_wrapper) if isinstance(tool, FunctionTool) ] @@ -1030,7 +1053,7 @@ def _add_unmatched_pending(approval: ToolApprovalItem) -> None: record_rejection=_record_function_rejection, pending_interruption_adder=_add_pending_interruption, pending_item_builder=lambda run: ToolApprovalItem( - agent=agent, + agent=public_agent, raw_item=run.tool_call, tool_name=run.function_tool.name, tool_namespace=get_tool_call_namespace(run.tool_call), @@ -1071,7 +1094,7 @@ def _add_unmatched_pending(approval: ToolApprovalItem) -> None: rejection_builder=_build_shell_rejection, context_wrapper=context_wrapper, approval_items_by_call_id=approval_items_by_call_id, - agent=agent, + agent=public_agent, pending_interruption_adder=_add_pending_interruption, needs_approval_checker=_shell_needs_approval, output_exists_checker=_shell_output_exists, @@ -1084,7 +1107,7 @@ def _add_unmatched_pending(approval: ToolApprovalItem) -> None: rejection_builder=_build_apply_patch_rejection, context_wrapper=context_wrapper, approval_items_by_call_id=approval_items_by_call_id, - agent=agent, + agent=public_agent, pending_interruption_adder=_add_pending_interruption, needs_approval_checker=_apply_patch_needs_approval, output_exists_checker=_apply_patch_output_exists, @@ -1092,7 +1115,7 @@ def _add_unmatched_pending(approval: ToolApprovalItem) -> None: plan = _build_plan_for_resume_turn( processed_response=processed_response, - agent=agent, + agent=public_agent, context_wrapper=context_wrapper, approval_items_by_call_id=approval_items_by_call_id, pending_interruptions=pending_interruptions, @@ -1113,7 +1136,7 @@ def _add_unmatched_pending(approval: ToolApprovalItem) -> None: _local_shell_results, ) = await _execute_tool_plan( plan=plan, - agent=agent, + bindings=bindings, hooks=hooks, context_wrapper=context_wrapper, run_config=run_config, @@ -1164,7 +1187,7 @@ def _add_unmatched_pending(approval: ToolApprovalItem) -> None: ) await _append_mcp_callback_results( - agent=agent, + agent=public_agent, requests=plan.mcp_requests_with_callback, context_wrapper=context_wrapper, append_item=append_if_new, @@ -1177,7 +1200,7 @@ def _add_unmatched_pending(approval: ToolApprovalItem) -> None: original_pre_step_items=original_pre_step_items, mcp_approval_requests=processed_response.mcp_approval_requests, context_wrapper=context_wrapper, - agent=agent, + agent=public_agent, append_item=append_if_new, ) @@ -1232,7 +1255,7 @@ def _add_unmatched_pending(approval: ToolApprovalItem) -> None: if pending_handoffs: return await execute_handoffs_call( - agent=agent, + public_agent=public_agent, original_input=original_input, pre_step_items=pre_step_items, new_step_items=new_items, @@ -1245,7 +1268,7 @@ def _add_unmatched_pending(approval: ToolApprovalItem) -> None: ) tool_final_output = await _maybe_finalize_from_tool_results( - agent=agent, + public_agent=public_agent, original_input=original_input, new_response=new_response, pre_step_items=pre_step_items, @@ -1684,7 +1707,7 @@ def _dump_output_item(raw_item: Any) -> dict[str, Any]: async def get_single_step_result_from_response( *, - agent: Agent[TContext], + bindings: AgentBindings[TContext], all_tools: list[Tool], original_input: str | list[TResponseInputItem], pre_step_items: list[RunItem], @@ -1697,8 +1720,9 @@ async def get_single_step_result_from_response( tool_use_tracker, event_queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel] | None = None, ) -> SingleStepResult: + item_agent = bindings.public_agent processed_response = process_model_response( - agent=agent, + agent=item_agent, all_tools=all_tools, response=new_response, output_schema=output_schema, @@ -1706,7 +1730,7 @@ async def get_single_step_result_from_response( existing_items=pre_step_items, ) - tool_use_tracker.record_processed_response(agent, processed_response) + tool_use_tracker.record_processed_response(item_agent, processed_response) if event_queue is not None and processed_response.new_items: handoff_items = [ @@ -1716,7 +1740,7 @@ async def get_single_step_result_from_response( stream_step_items_to_queue(cast(list[RunItem], handoff_items), event_queue) return await execute_tools_and_side_effects( - agent=agent, + bindings=bindings, original_input=original_input, pre_step_items=pre_step_items, new_response=new_response, diff --git a/src/agents/run_state.py b/src/agents/run_state.py index dcda9e073c..57e9cea104 100644 --- a/src/agents/run_state.py +++ b/src/agents/run_state.py @@ -2,12 +2,15 @@ from __future__ import annotations +import asyncio import copy import dataclasses import json +import threading from collections import deque -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Iterator, Mapping, Sequence from dataclasses import dataclass, field +from pathlib import Path from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast from uuid import uuid4 @@ -42,6 +45,7 @@ get_function_tool_qualified_name, serialize_function_tool_lookup_key, ) +from .agent import Agent from .exceptions import UserError from .guardrail import ( GuardrailFunctionOutput, @@ -73,6 +77,7 @@ ) from .logger import logger from .run_context import RunContextWrapper +from .sandbox.session.base_sandbox_session import BaseSandboxSession from .tool import ( ApplyPatchTool, ComputerTool, @@ -96,7 +101,6 @@ from .util._json import _to_dump_compatible if TYPE_CHECKING: - from .agent import Agent from .guardrail import InputGuardrailResult, OutputGuardrailResult from .items import ModelResponse, RunItem from .run_internal.run_steps import ( @@ -118,10 +122,36 @@ # 3. to_json() always emits CURRENT_SCHEMA_VERSION. # 4. Forward compatibility is intentionally fail-fast (older SDKs reject newer or unsupported # versions). -CURRENT_SCHEMA_VERSION = "1.6" -SUPPORTED_SCHEMA_VERSIONS = frozenset( - {"1.0", "1.1", "1.2", "1.3", "1.4", "1.5", CURRENT_SCHEMA_VERSION} -) +CURRENT_SCHEMA_VERSION = "1.7" +# Keep this mapping in chronological order. Every schema bump must add a one-line summary here. +SCHEMA_VERSION_SUMMARIES: dict[str, str] = { + "1.0": "Initial RunState snapshot format for HITL pause/resume flows.", + "1.1": "Same payload as 1.0, but introduces explicit backward-read support policy.", + "1.2": "Persists reasoning_item_id_policy for resumed and streamed follow-up turns.", + "1.3": "Updates resumed trace semantics to reattach traces without duplicate starts.", + "1.4": "Stores request_id alongside each serialized model response.", + "1.5": "Renumbered unreleased baseline for tool-search snapshots and richer tool metadata.", + "1.6": "Persists explicit approval rejection messages across resume flows.", + "1.7": ( + "Persists duplicate-name agent identities across agent-owned state " + "and sandbox resume state." + ), +} +SUPPORTED_SCHEMA_VERSIONS = frozenset(SCHEMA_VERSION_SUMMARIES) + +if CURRENT_SCHEMA_VERSION not in SCHEMA_VERSION_SUMMARIES: + raise AssertionError( + "CURRENT_SCHEMA_VERSION must have a matching entry in SCHEMA_VERSION_SUMMARIES." + ) + +_missing_schema_version_summaries = [ + version for version, summary in SCHEMA_VERSION_SUMMARIES.items() if not summary.strip() +] +if _missing_schema_version_summaries: + raise AssertionError( + "Every supported RunState schema version must have a non-empty summary. " + f"Missing summaries: {', '.join(_missing_schema_version_summaries)}" + ) _FUNCTION_OUTPUT_ADAPTER: TypeAdapter[FunctionCallOutput] = TypeAdapter(FunctionCallOutput) _COMPUTER_OUTPUT_ADAPTER: TypeAdapter[ComputerCallOutput] = TypeAdapter(ComputerCallOutput) @@ -157,6 +187,9 @@ class RunState(Generic[TContext, TAgent]): _current_agent: TAgent | None = None """The agent currently handling the conversation.""" + _starting_agent: TAgent | None = field(default=None, repr=False) + """The root agent used to derive stable duplicate-name identities during resume.""" + _original_input: str | list[Any] = field(default_factory=list) """Original user input prior to any processing.""" @@ -220,6 +253,12 @@ class RunState(Generic[TContext, TAgent]): _agent_tool_state_scope_id: str | None = field(default=None, repr=False) """Private scope id used to isolate agent-tool pending state per RunState instance.""" + _sandbox: dict[str, Any] | None = field(default=None, repr=False) + """Serialized sandbox resume payload for sandbox-aware runs.""" + + _schema_version: str = field(default=CURRENT_SCHEMA_VERSION, repr=False) + """Schema version the snapshot was loaded from for schema-gated resume compatibility.""" + def __init__( self, context: RunContextWrapper[TContext], @@ -234,6 +273,7 @@ def __init__( """Initialize a new RunState.""" self._context = context self._original_input = _clone_original_input(original_input) + self._starting_agent = starting_agent self._current_agent = starting_agent self._max_turns = max_turns self._conversation_id = conversation_id @@ -254,6 +294,8 @@ def __init__( self._current_turn_persisted_item_count = 0 self._tool_use_tracker_snapshot = {} self._trace_state = None + self._sandbox = None + self._schema_version = CURRENT_SCHEMA_VERSION from .agent_tool_state import get_agent_tool_state_scope self._agent_tool_state_scope_id = get_agent_tool_state_scope(context) @@ -498,8 +540,14 @@ def _current_generated_items_merge_marker(self) -> str | None: latest_response_id = ( self._model_responses[-1].response_id if self._model_responses else None ) + agent_identity_keys_by_id = ( + _build_agent_identity_keys_by_id(cast(Agent[Any], self._starting_agent)) + if self._starting_agent is not None + else None + ) serialized_items = [ - self._serialize_item(item) for item in self._last_processed_response.new_items + self._serialize_item(item, agent_identity_keys_by_id=agent_identity_keys_by_id) + for item in self._last_processed_response.new_items ] return json.dumps( { @@ -633,19 +681,33 @@ def to_json( if tool_input is not None: context_entry["tool_input"] = tool_input + agent_identity_keys_by_id = ( + _build_agent_identity_keys_by_id(cast(Agent[Any], self._starting_agent)) + if self._starting_agent is not None + else None + ) + current_agent_entry = _serialize_agent_reference( + cast(Agent[Any], self._current_agent), + agent_identity_keys_by_id=agent_identity_keys_by_id, + ) + result = { "$schemaVersion": CURRENT_SCHEMA_VERSION, "current_turn": self._current_turn, - "current_agent": {"name": self._current_agent.name}, + "current_agent": current_agent_entry, "original_input": original_input_serialized, "model_responses": model_responses, "context": context_entry, "tool_use_tracker": copy.deepcopy(self._tool_use_tracker_snapshot), "max_turns": self._max_turns, "no_active_agent_run": True, - "input_guardrail_results": _serialize_guardrail_results(self._input_guardrail_results), + "input_guardrail_results": _serialize_guardrail_results( + self._input_guardrail_results, + agent_identity_keys_by_id=agent_identity_keys_by_id, + ), "output_guardrail_results": _serialize_guardrail_results( - self._output_guardrail_results + self._output_guardrail_results, + agent_identity_keys_by_id=agent_identity_keys_by_id, ), "tool_input_guardrail_results": _serialize_tool_guardrail_results( self._tool_input_guardrail_results, type_label="tool_input" @@ -660,13 +722,20 @@ def to_json( } generated_items = self._merge_generated_items_with_processed() - result["generated_items"] = [self._serialize_item(item) for item in generated_items] - result["session_items"] = [self._serialize_item(item) for item in list(self._session_items)] + result["generated_items"] = [ + self._serialize_item(item, agent_identity_keys_by_id=agent_identity_keys_by_id) + for item in generated_items + ] + result["session_items"] = [ + self._serialize_item(item, agent_identity_keys_by_id=agent_identity_keys_by_id) + for item in list(self._session_items) + ] result["current_step"] = self._serialize_current_step() result["last_model_response"] = _serialize_last_model_response(model_responses) result["last_processed_response"] = ( self._serialize_processed_response( self._last_processed_response, + agent_identity_keys_by_id=agent_identity_keys_by_id, context_serializer=context_serializer, strict_context=strict_context, include_tracing_api_key=include_tracing_api_key, @@ -678,6 +747,8 @@ def to_json( result["trace"] = self._serialize_trace_data( include_tracing_api_key=include_tracing_api_key ) + if self._sandbox is not None: + result["sandbox"] = copy.deepcopy(self._sandbox) return result @@ -685,6 +756,7 @@ def _serialize_processed_response( self, processed_response: ProcessedResponse, *, + agent_identity_keys_by_id: Mapping[int, str] | None = None, context_serializer: ContextSerializer | None = None, strict_context: bool = False, include_tracing_api_key: bool = False, @@ -710,13 +782,20 @@ def _serialize_processed_response( ) interruptions_data = [ - _serialize_tool_approval_interruption(interruption, include_tool_name=True) + _serialize_tool_approval_interruption( + interruption, + include_tool_name=True, + agent_identity_keys_by_id=agent_identity_keys_by_id, + ) for interruption in processed_response.interruptions if isinstance(interruption, ToolApprovalItem) ] return { - "new_items": [self._serialize_item(item) for item in processed_response.new_items], + "new_items": [ + self._serialize_item(item, agent_identity_keys_by_id=agent_identity_keys_by_id) + for item in processed_response.new_items + ], "tools_used": processed_response.tools_used, **action_groups, "interruptions": interruptions_data, @@ -727,12 +806,20 @@ def _serialize_current_step(self) -> dict[str, Any] | None: # Import at runtime to avoid circular import from .run_internal.run_steps import NextStepInterruption + agent_identity_keys_by_id = ( + _build_agent_identity_keys_by_id(cast(Agent[Any], self._starting_agent)) + if self._starting_agent is not None + else None + ) + if self._current_step is None or not isinstance(self._current_step, NextStepInterruption): return None interruptions_data = [ _serialize_tool_approval_interruption( - item, include_tool_name=item.tool_name is not None + item, + include_tool_name=item.tool_name is not None, + agent_identity_keys_by_id=agent_identity_keys_by_id, ) for item in self._current_step.interruptions if isinstance(item, ToolApprovalItem) @@ -745,14 +832,22 @@ def _serialize_current_step(self) -> dict[str, Any] | None: }, } - def _serialize_item(self, item: RunItem) -> dict[str, Any]: + def _serialize_item( + self, + item: RunItem, + *, + agent_identity_keys_by_id: Mapping[int, str] | None = None, + ) -> dict[str, Any]: """Serialize a run item to JSON-compatible dict.""" raw_item_dict: Any = _serialize_raw_item_value(item.raw_item) result: dict[str, Any] = { "type": item.type, "raw_item": raw_item_dict, - "agent": {"name": item.agent.name}, + "agent": _serialize_agent_reference( + item.agent, + agent_identity_keys_by_id=agent_identity_keys_by_id, + ), } # Add additional fields based on item type @@ -768,9 +863,15 @@ def _serialize_item(self, item: RunItem) -> dict[str, Any]: serialized_output = str(item.output) result["output"] = serialized_output if hasattr(item, "source_agent"): - result["source_agent"] = {"name": item.source_agent.name} + result["source_agent"] = _serialize_agent_reference( + item.source_agent, + agent_identity_keys_by_id=agent_identity_keys_by_id, + ) if hasattr(item, "target_agent"): - result["target_agent"] = {"name": item.target_agent.name} + result["target_agent"] = _serialize_agent_reference( + item.target_agent, + agent_identity_keys_by_id=agent_identity_keys_by_id, + ) if hasattr(item, "tool_name") and item.tool_name is not None: result["tool_name"] = item.tool_name if hasattr(item, "tool_namespace") and item.tool_namespace is not None: @@ -1090,6 +1191,19 @@ def _serialize_raw_item_value(raw_item: Any) -> Any: return raw_item +def _serialize_agent_reference( + agent: Agent[Any], + agent_identity_keys_by_id: Mapping[int, str] | None = None, +) -> dict[str, Any]: + """Serialize an agent reference with an optional duplicate-name identity key.""" + entry: dict[str, Any] = {"name": agent.name} + if agent_identity_keys_by_id is not None: + identity = agent_identity_keys_by_id.get(id(agent)) + if identity is not None and identity != agent.name: + entry["identity"] = identity + return entry + + def _ensure_json_compatible(value: Any) -> Any: try: return json.loads(json.dumps(value, default=str)) @@ -1214,13 +1328,19 @@ def _serialize_mcp_tool(mcp_tool: Any) -> dict[str, Any]: def _serialize_tool_approval_interruption( - interruption: ToolApprovalItem, *, include_tool_name: bool + interruption: ToolApprovalItem, + *, + include_tool_name: bool, + agent_identity_keys_by_id: Mapping[int, str] | None = None, ) -> dict[str, Any]: """Serialize a ToolApprovalItem interruption.""" interruption_dict: dict[str, Any] = { "type": "tool_approval_item", "raw_item": _serialize_raw_item_value(interruption.raw_item), - "agent": {"name": interruption.agent.name}, + "agent": _serialize_agent_reference( + interruption.agent, + agent_identity_keys_by_id=agent_identity_keys_by_id, + ), } if include_tool_name and interruption.tool_name is not None: interruption_dict["tool_name"] = interruption.tool_name @@ -1388,6 +1508,8 @@ def to_state(self) -> RunState[Any, Agent[Any]]: def _serialize_guardrail_results( results: Sequence[InputGuardrailResult | OutputGuardrailResult], + *, + agent_identity_keys_by_id: Mapping[int, str] | None = None, ) -> list[dict[str, Any]]: """Serialize guardrail results for persistence.""" serialized: list[dict[str, Any]] = [] @@ -1404,7 +1526,10 @@ def _serialize_guardrail_results( } if isinstance(result, OutputGuardrailResult): entry["agentOutput"] = result.agent_output - entry["agent"] = {"name": result.agent.name} + entry["agent"] = _serialize_agent_reference( + result.agent, + agent_identity_keys_by_id=agent_identity_keys_by_id, + ) serialized.append(entry) return serialized @@ -1544,6 +1669,7 @@ async def _deserialize_processed_response( context: RunContextWrapper[Any], agent_map: dict[str, Agent[Any]], *, + agent_identity_map: Mapping[str, Agent[Any]] | None = None, scope_id: str | None = None, context_deserializer: ContextDeserializer | None = None, strict_context: bool = False, @@ -1559,7 +1685,11 @@ async def _deserialize_processed_response( Returns: A reconstructed ProcessedResponse instance. """ - new_items = _deserialize_items(processed_response_data.get("new_items", []), agent_map) + new_items = _deserialize_items( + processed_response_data.get("new_items", []), + agent_map, + agent_identity_map=agent_identity_map, + ) if hasattr(current_agent, "get_all_tools"): all_tools = await current_agent.get_all_tools(context) @@ -1811,6 +1941,7 @@ def _resolve_function_tool_name(data: Mapping[str, Any]) -> FunctionToolLookupKe approval_item = _deserialize_tool_approval_item( interruption_data, agent_map=agent_map, + agent_identity_map=agent_identity_map, fallback_agent=current_agent, ) if approval_item is not None: @@ -1855,16 +1986,32 @@ def _deserialize_tool_call_raw_item(normalized_raw_item: Mapping[str, Any]) -> A def _resolve_agent_from_data( agent_data: Any, agent_map: Mapping[str, Agent[Any]], + agent_identity_map: Mapping[str, Agent[Any]] | None = None, fallback_agent: Agent[Any] | None = None, ) -> Agent[Any] | None: """Resolve an agent from serialized data with an optional fallback.""" agent_name = None + agent_identity = None if isinstance(agent_data, Mapping): + agent_identity = agent_data.get("identity") agent_name = agent_data.get("name") elif isinstance(agent_data, str): agent_name = agent_data + if isinstance(agent_identity, str) and agent_identity_map is not None: + resolved = agent_identity_map.get(agent_identity) + if resolved is not None: + return resolved + raise UserError( + "Run state references an agent identity that is not present in the restored graph: " + f"{agent_identity}" + ) + if agent_name: + if agent_identity_map is not None: + resolved = agent_identity_map.get(agent_name) + if resolved is not None: + return resolved return agent_map.get(agent_name) or fallback_agent return fallback_agent @@ -1881,11 +2028,17 @@ def _deserialize_tool_approval_item( item_data: Mapping[str, Any], *, agent_map: Mapping[str, Agent[Any]], + agent_identity_map: Mapping[str, Agent[Any]] | None = None, fallback_agent: Agent[Any] | None = None, pre_normalized_raw_item: Any | None = None, ) -> ToolApprovalItem | None: """Deserialize a ToolApprovalItem from serialized data.""" - agent = _resolve_agent_from_data(item_data.get("agent"), agent_map, fallback_agent) + agent = _resolve_agent_from_data( + item_data.get("agent"), + agent_map, + agent_identity_map, + fallback_agent, + ) if agent is None: return None @@ -2018,6 +2171,7 @@ def _deserialize_output_guardrail_results( results_data: list[dict[str, Any]], *, agent_map: dict[str, Agent[Any]], + agent_identity_map: Mapping[str, Agent[Any]] | None = None, fallback_agent: Agent[Any], ) -> list[OutputGuardrailResult]: """Rehydrate output guardrail results from serialized data.""" @@ -2029,9 +2183,14 @@ def _deserialize_output_guardrail_results( name, guardrail_output, entry_dict = parsed agent_output = entry_dict.get("agentOutput") agent_data = entry_dict.get("agent") - agent_name = agent_data.get("name") if isinstance(agent_data, dict) else None - resolved_agent = agent_map.get(agent_name) if isinstance(agent_name, str) else None - resolved_agent = resolved_agent or fallback_agent + resolved_agent = _resolve_agent_from_data( + agent_data, + agent_map, + agent_identity_map, + fallback_agent, + ) + if resolved_agent is None: + resolved_agent = fallback_agent def _output_guardrail_fn( context: RunContextWrapper[Any], @@ -2134,10 +2293,16 @@ async def _build_run_state_from_json( f"New snapshots are written as version {CURRENT_SCHEMA_VERSION}." ) + agent_identity_map = _build_agent_identity_map(initial_agent) agent_map = _build_agent_map(initial_agent) - current_agent_name = state_json["current_agent"]["name"] - current_agent = agent_map.get(current_agent_name) + current_agent_data = state_json["current_agent"] + current_agent_name = current_agent_data["name"] + current_agent = _resolve_agent_from_data( + current_agent_data, + agent_map, + agent_identity_map=agent_identity_map, + ) if not current_agent: raise UserError(f"Agent {current_agent_name} not found in agent map") @@ -2218,6 +2383,8 @@ async def _build_run_state_from_json( previous_response_id=state_json.get("previous_response_id"), auto_previous_response_id=bool(state_json.get("auto_previous_response_id", False)), ) + state._starting_agent = initial_agent + state._schema_version = schema_version from .agent_tool_state import set_agent_tool_state_scope state._agent_tool_state_scope_id = uuid4().hex @@ -2225,7 +2392,11 @@ async def _build_run_state_from_json( state._current_turn = state_json["current_turn"] state._model_responses = _deserialize_model_responses(state_json.get("model_responses", [])) - state._generated_items = _deserialize_items(state_json.get("generated_items", []), agent_map) + state._generated_items = _deserialize_items( + state_json.get("generated_items", []), + agent_map, + agent_identity_map=agent_identity_map, + ) last_processed_response_data = state_json.get("last_processed_response") if last_processed_response_data and state._context is not None: @@ -2234,6 +2405,7 @@ async def _build_run_state_from_json( current_agent, state._context, agent_map, + agent_identity_map=agent_identity_map, scope_id=state._agent_tool_state_scope_id, context_deserializer=context_deserializer, strict_context=strict_context, @@ -2242,7 +2414,11 @@ async def _build_run_state_from_json( state._last_processed_response = None if "session_items" in state_json: - state._session_items = _deserialize_items(state_json.get("session_items", []), agent_map) + state._session_items = _deserialize_items( + state_json.get("session_items", []), + agent_map, + agent_identity_map=agent_identity_map, + ) else: state._session_items = state._merge_generated_items_with_processed() @@ -2254,6 +2430,7 @@ async def _build_run_state_from_json( state._output_guardrail_results = _deserialize_output_guardrail_results( state_json.get("output_guardrail_results", []), agent_map=agent_map, + agent_identity_map=agent_identity_map, fallback_agent=current_agent, ) state._tool_input_guardrail_results = _deserialize_tool_input_guardrail_results( @@ -2270,7 +2447,11 @@ async def _build_run_state_from_json( "interruptions", current_step_data.get("interruptions", []) ) for item_data in interruptions_data: - approval_item = _deserialize_tool_approval_item(item_data, agent_map=agent_map) + approval_item = _deserialize_tool_approval_item( + item_data, + agent_map=agent_map, + agent_identity_map=agent_identity_map, + ) if approval_item is not None: interruptions.append(approval_item) @@ -2294,29 +2475,25 @@ async def _build_run_state_from_json( state._trace_state = TraceState.from_json(trace_data) else: state._trace_state = None + sandbox_data = state_json.get("sandbox") + state._sandbox = dict(sandbox_data) if isinstance(sandbox_data, Mapping) else None return state -def _build_agent_map(initial_agent: Agent[Any]) -> dict[str, Agent[Any]]: - """Build a map of agent names to agents by traversing handoffs. - - Args: - initial_agent: The starting agent. - - Returns: - Dictionary mapping agent names to agent instances. - """ - agent_map: dict[str, Agent[Any]] = {} +def _iter_agent_graph(initial_agent: Agent[Any]) -> Iterator[Agent[Any]]: + """Yield agents reachable from the starting agent in breadth-first order.""" queue: deque[Agent[Any]] = deque([initial_agent]) + seen_agent_ids: set[int] = set() while queue: current = queue.popleft() - if current.name in agent_map: + current_id = id(current) + if current_id in seen_agent_ids: continue - agent_map[current.name] = current + seen_agent_ids.add(current_id) + yield current - # Add handoff agents to the queue for handoff_item in current.handoffs: handoff_agent: Any | None = None handoff_agent_name: str | None = None @@ -2329,8 +2506,6 @@ def _build_agent_map(initial_agent: Agent[Any]) -> dict[str, Agent[Any]]: ) if isinstance(candidate_name, str): handoff_agent_name = candidate_name - if handoff_agent_name in agent_map: - continue handoff_ref = getattr(handoff_item, "_agent_ref", None) handoff_agent = handoff_ref() if callable(handoff_ref) else None @@ -2368,12 +2543,8 @@ def _build_agent_map(initial_agent: Agent[Any]) -> dict[str, Agent[Any]]: candidate_name = getattr(handoff_agent, "name", None) handoff_agent_name = candidate_name if isinstance(candidate_name, str) else None - if ( - handoff_agent is not None - and handoff_agent_name - and handoff_agent_name not in agent_map - ): - queue.append(cast(Any, handoff_agent)) + if handoff_agent is not None and handoff_agent_name: + queue.append(cast(Agent[Any], handoff_agent)) # Include agent-as-tool instances so nested approvals can be restored. tools = getattr(current, "tools", None) @@ -2383,9 +2554,391 @@ def _build_agent_map(initial_agent: Agent[Any]) -> dict[str, Agent[Any]]: continue tool_agent = getattr(tool, "_agent_instance", None) tool_agent_name = getattr(tool_agent, "name", None) - if tool_agent and tool_agent_name and tool_agent_name not in agent_map: + if tool_agent and tool_agent_name: queue.append(tool_agent) + +def _allocate_unique_agent_identity(agent_name: str, used_identities: set[str]) -> str: + """Return a deterministic identity key without colliding with literal agent names.""" + candidate = agent_name + next_index = 1 + while candidate in used_identities: + next_index += 1 + candidate = f"{agent_name}#{next_index}" + used_identities.add(candidate) + return candidate + + +def _identity_type_name(value: Any) -> str: + return f"{type(value).__module__}.{type(value).__qualname__}" + + +def _callable_identity_name(value: Any) -> str: + module = getattr(value, "__module__", type(value).__module__) + qualname = getattr(value, "__qualname__", type(value).__qualname__) + return f"{module}.{qualname}" + + +def _normalize_identity_value(value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, (bytes, bytearray)): + return {"type": "bytes", "length": len(value)} + if callable(value): + return {"callable": _callable_identity_name(value)} + if dataclasses.is_dataclass(value): + return { + "dataclass": _identity_type_name(value), + "value": _normalize_identity_value(dataclasses.asdict(cast(Any, value))), + } + if hasattr(value, "model_dump"): + try: + dumped = value.model_dump(exclude_unset=True) + except TypeError: + dumped = value.model_dump() + return { + "model": _identity_type_name(value), + "value": _normalize_identity_value(dumped), + } + if isinstance(value, Mapping): + return { + str(key): _normalize_identity_value(item) + for key, item in sorted(value.items(), key=lambda pair: str(pair[0])) + } + if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): + return [_normalize_identity_value(item) for item in value] + + value_name = getattr(value, "name", None) + if isinstance(value_name, str): + return {"type": _identity_type_name(value), "name": value_name} + return {"type": _identity_type_name(value)} + + +def _stable_identity_text(value: Any) -> str: + return json.dumps( + _normalize_identity_value(value), + sort_keys=True, + separators=(",", ":"), + ) + + +def _tool_identity_signature(tool: Any) -> dict[str, Any]: + signature: dict[str, Any] = { + "type": _identity_type_name(tool), + "name": getattr(tool, "name", None), + } + namespace = get_function_tool_namespace(tool) + if namespace is not None: + signature["namespace"] = namespace + qualified_name = get_function_tool_qualified_name(tool) + if qualified_name is not None: + signature["qualified_name"] = qualified_name + if hasattr(tool, "environment"): + signature["environment"] = _normalize_identity_value(tool.environment) + if getattr(tool, "_is_agent_tool", False): + nested_agent = getattr(tool, "_agent_instance", None) + signature["agent_tool_target"] = getattr(nested_agent, "name", None) + return signature + + +_THREADING_LOCK_TYPES = (type(threading.Lock()), type(threading.RLock())) + + +def _is_capability_runtime_only_value(value: Any) -> bool: + return isinstance( + value, + ( + BaseSandboxSession, + asyncio.Event, + asyncio.Lock, + asyncio.Semaphore, + asyncio.Condition, + threading.Event, + *_THREADING_LOCK_TYPES, + ), + ) + + +def _normalize_capability_identity_value( + value: Any, + *, + seen: set[int] | None = None, +) -> Any: + if seen is None: + seen = set() + + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, Path): + return value.as_posix() + if isinstance(value, (bytes, bytearray)): + return {"type": "bytes", "length": len(value)} + if callable(value): + return {"callable": _callable_identity_name(value)} + if _is_capability_runtime_only_value(value): + return {"runtime_only": _identity_type_name(value)} + if isinstance( + value, + ( + ApplyPatchTool, + ComputerTool, + FunctionTool, + HostedMCPTool, + LocalShellTool, + ShellTool, + ), + ): + return _tool_identity_signature(value) + + object_id = id(value) + if object_id in seen: + return {"recursive": _identity_type_name(value)} + + if dataclasses.is_dataclass(value): + seen.add(object_id) + try: + merged_fields = { + field.name: getattr(value, field.name) for field in dataclasses.fields(value) + } + if hasattr(value, "__dict__"): + for name, item in vars(value).items(): + if name.startswith("_") or name in merged_fields: + continue + merged_fields[name] = item + return { + "dataclass": _identity_type_name(value), + "value": { + name: _normalize_capability_identity_value( + item, + seen=seen, + ) + for name, item in sorted(merged_fields.items()) + }, + } + finally: + seen.remove(object_id) + + if hasattr(value, "model_dump"): + seen.add(object_id) + try: + try: + dumped = value.model_dump(mode="json", round_trip=True) + except TypeError: + dumped = value.model_dump(mode="json") + return { + "model": _identity_type_name(value), + "value": _normalize_capability_identity_value(dumped, seen=seen), + } + finally: + seen.remove(object_id) + + if isinstance(value, Mapping): + seen.add(object_id) + try: + return { + str(key): _normalize_capability_identity_value(item, seen=seen) + for key, item in sorted(value.items(), key=lambda pair: str(pair[0])) + } + finally: + seen.remove(object_id) + + if isinstance(value, (set, frozenset)): + seen.add(object_id) + try: + normalized_items = [ + _normalize_capability_identity_value(item, seen=seen) for item in value + ] + return sorted(normalized_items, key=_stable_identity_text) + finally: + seen.remove(object_id) + + if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): + seen.add(object_id) + try: + return [_normalize_capability_identity_value(item, seen=seen) for item in value] + finally: + seen.remove(object_id) + + if hasattr(value, "__dict__"): + seen.add(object_id) + try: + return { + "object": _identity_type_name(value), + "value": { + name: _normalize_capability_identity_value(item, seen=seen) + for name, item in sorted(vars(value).items()) + if not name.startswith("_") + }, + } + finally: + seen.remove(object_id) + + value_name = getattr(value, "name", None) + if isinstance(value_name, str): + return {"type": _identity_type_name(value), "name": value_name} + return {"type": _identity_type_name(value)} + + +def _capability_identity_signature(capability: Any) -> dict[str, Any]: + return { + "type": _identity_type_name(capability), + "value": _normalize_capability_identity_value(capability), + } + + +def _handoff_identity_signature(handoff_item: Agent[Any] | Handoff[Any, Any]) -> dict[str, Any]: + if isinstance(handoff_item, Handoff): + tool_name = getattr(handoff_item, "tool_name", None) + if not isinstance(tool_name, str): + tool_name = getattr(handoff_item, "name", None) + agent_name = getattr(handoff_item, "agent_name", None) + return { + "type": _identity_type_name(handoff_item), + "tool_name": tool_name, + "agent_name": agent_name if isinstance(agent_name, str) else None, + "input_filter": _normalize_identity_value(getattr(handoff_item, "input_filter", None)), + "nest_handoff_history": getattr(handoff_item, "nest_handoff_history", None), + } + + return { + "type": _identity_type_name(handoff_item), + "agent_name": getattr(handoff_item, "name", None), + } + + +def _agent_identity_signature(agent: Agent[Any]) -> str: + signature: dict[str, Any] = { + "agent_type": _identity_type_name(agent), + "handoff_description": getattr(agent, "handoff_description", None), + "instructions": _normalize_identity_value(getattr(agent, "instructions", None)), + "prompt": _normalize_identity_value(getattr(agent, "prompt", None)), + "model": _normalize_identity_value(getattr(agent, "model", None)), + "model_settings": _normalize_identity_value(getattr(agent, "model_settings", None)), + "mcp_config": _normalize_capability_identity_value(getattr(agent, "mcp_config", None)), + "hooks": _normalize_capability_identity_value(getattr(agent, "hooks", None)), + "input_guardrails": sorted( + _stable_identity_text(_normalize_capability_identity_value(guardrail)) + for guardrail in getattr(agent, "input_guardrails", []) + ), + "output_guardrails": sorted( + _stable_identity_text(_normalize_capability_identity_value(guardrail)) + for guardrail in getattr(agent, "output_guardrails", []) + ), + "output_type": _normalize_identity_value(getattr(agent, "output_type", None)), + "tool_use_behavior": _normalize_capability_identity_value( + getattr(agent, "tool_use_behavior", None) + ), + "reset_tool_choice": getattr(agent, "reset_tool_choice", None), + "tools": sorted( + _stable_identity_text(_tool_identity_signature(tool)) + for tool in getattr(agent, "tools", []) + ), + "handoffs": sorted( + _stable_identity_text(_handoff_identity_signature(handoff_item)) + for handoff_item in getattr(agent, "handoffs", []) + ), + "mcp_servers": sorted( + _stable_identity_text(server) for server in getattr(agent, "mcp_servers", []) + ), + } + + default_manifest = getattr(agent, "default_manifest", None) + if default_manifest is not None: + signature["default_manifest"] = _normalize_capability_identity_value(default_manifest) + + developer_instructions = getattr(agent, "developer_instructions", None) + if developer_instructions is not None: + signature["developer_instructions"] = developer_instructions + + capabilities = getattr(agent, "capabilities", None) + if isinstance(capabilities, Sequence): + signature["capabilities"] = sorted( + _stable_identity_text(_capability_identity_signature(capability)) + for capability in capabilities + ) + + return _stable_identity_text(signature) + + +def _agent_identity_sort_key( + agent: Agent[Any], + *, + root_agent: Agent[Any], + original_index: int, +) -> tuple[int, str, int]: + return ( + 0 if agent is root_agent else 1, + _agent_identity_signature(agent), + original_index, + ) + + +def _build_agent_identity_map(initial_agent: Agent[Any]) -> dict[str, Agent[Any]]: + """Build a stable identity map that preserves duplicate agent names.""" + ordered_agents = list(_iter_agent_graph(initial_agent)) + original_indices = {id(agent): index for index, agent in enumerate(ordered_agents)} + literal_names = {agent.name for agent in ordered_agents} + agents_by_name: dict[str, list[Agent[Any]]] = {} + for agent in ordered_agents: + agents_by_name.setdefault(agent.name, []).append(agent) + + agent_identity_map: dict[str, Agent[Any]] = {} + used_identities: set[str] = set() + processed_names: set[str] = set() + + for agent in ordered_agents: + agent_name = agent.name + if agent_name in processed_names: + continue + processed_names.add(agent_name) + + group = agents_by_name[agent_name] + sorted_group = sorted( + group, + key=lambda candidate: _agent_identity_sort_key( + candidate, + root_agent=initial_agent, + original_index=original_indices[id(candidate)], + ), + ) + + base_agent = sorted_group[0] + used_identities.add(agent_name) + agent_identity_map[agent_name] = base_agent + + next_index = 2 + for duplicate_agent in sorted_group[1:]: + candidate = f"{agent_name}#{next_index}" + while candidate in used_identities or candidate in literal_names: + next_index += 1 + candidate = f"{agent_name}#{next_index}" + used_identities.add(candidate) + agent_identity_map[candidate] = duplicate_agent + next_index += 1 + + return agent_identity_map + + +def _build_agent_identity_keys_by_id(initial_agent: Agent[Any]) -> dict[int, str]: + """Build stable identity keys for the reachable agent graph.""" + return { + id(agent): identity for identity, agent in _build_agent_identity_map(initial_agent).items() + } + + +def _build_agent_map(initial_agent: Agent[Any]) -> dict[str, Agent[Any]]: + """Build a map of agent names to agents by traversing handoffs. + + Args: + initial_agent: The starting agent. + + Returns: + Dictionary mapping agent names to agent instances. + """ + agent_map: dict[str, Agent[Any]] = {} + for agent in _iter_agent_graph(initial_agent): + agent_map.setdefault(agent.name, agent) + return agent_map @@ -2426,7 +2979,10 @@ def _deserialize_model_responses(responses_data: list[dict[str, Any]]) -> list[M def _deserialize_items( - items_data: list[dict[str, Any]], agent_map: dict[str, Agent[Any]] + items_data: list[dict[str, Any]], + agent_map: dict[str, Agent[Any]], + *, + agent_identity_map: Mapping[str, Agent[Any]] | None = None, ) -> list[RunItem]: """Deserialize run items from JSON data. @@ -2456,7 +3012,11 @@ def _resolve_agent_info( elif isinstance(raw_agent, str): candidate_name = raw_agent - agent_candidate = _resolve_agent_from_data(raw_agent, agent_map) + agent_candidate = _resolve_agent_from_data( + raw_agent, + agent_map, + agent_identity_map, + ) if agent_candidate: return agent_candidate, agent_candidate.name @@ -2537,8 +3097,16 @@ def _resolve_agent_info( result.append(HandoffCallItem(agent=agent, raw_item=raw_item_handoff)) elif item_type == "handoff_output_item": - source_agent = _resolve_agent_from_data(item_data.get("source_agent"), agent_map) - target_agent = _resolve_agent_from_data(item_data.get("target_agent"), agent_map) + source_agent = _resolve_agent_from_data( + item_data.get("source_agent"), + agent_map, + agent_identity_map, + ) + target_agent = _resolve_agent_from_data( + item_data.get("target_agent"), + agent_map, + agent_identity_map, + ) # If we cannot resolve both agents, skip this item gracefully if not source_agent or not target_agent: @@ -2601,12 +3169,15 @@ def _resolve_agent_info( approval_item = _deserialize_tool_approval_item( item_data, agent_map=agent_map, + agent_identity_map=agent_identity_map, fallback_agent=agent, pre_normalized_raw_item=normalized_raw_item, ) if approval_item is not None: result.append(approval_item) + except UserError: + raise except Exception as e: logger.warning(f"Failed to deserialize item of type {item_type}: {e}") continue diff --git a/src/agents/sandbox/__init__.py b/src/agents/sandbox/__init__.py new file mode 100644 index 0000000000..6a02392b43 --- /dev/null +++ b/src/agents/sandbox/__init__.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from ..run_config import SandboxRunConfig +from .capabilities import Capability +from .codex_config import CodexConfig +from .entries import Dir, LocalFile +from .errors import ( + ErrorCode, + ExecTimeoutError, + ExecTransportError, + UniversalComputerError, + WorkspaceArchiveReadError, + WorkspaceArchiveWriteError, + WorkspaceReadNotFoundError, + WorkspaceWriteTypeError, +) +from .manifest import Manifest +from .sandbox_agent import SandboxAgent +from .snapshot import ( + LocalSnapshot, + LocalSnapshotSpec, + SnapshotSpec, + resolve_snapshot, +) +from .types import ExecResult + +__all__ = [ + "Capability", + "CodexConfig", + "Dir", + "ErrorCode", + "ExecResult", + "ExecTimeoutError", + "ExecTransportError", + "LocalFile", + "LocalSnapshot", + "LocalSnapshotSpec", + "Manifest", + "SandboxAgent", + "SandboxRunConfig", + "SnapshotSpec", + "UniversalComputerError", + "WorkspaceArchiveReadError", + "WorkspaceArchiveWriteError", + "WorkspaceReadNotFoundError", + "WorkspaceWriteTypeError", + "resolve_snapshot", +] diff --git a/src/agents/sandbox/app_server/__init__.py b/src/agents/sandbox/app_server/__init__.py new file mode 100644 index 0000000000..aff63176b9 --- /dev/null +++ b/src/agents/sandbox/app_server/__init__.py @@ -0,0 +1,10 @@ +from .client import AppServerClient, AppServerConfig +from .errors import AppServerError, JsonRpcError, TransportClosedError + +__all__ = [ + "AppServerClient", + "AppServerConfig", + "AppServerError", + "JsonRpcError", + "TransportClosedError", +] diff --git a/src/agents/sandbox/app_server/client.py b/src/agents/sandbox/app_server/client.py new file mode 100644 index 0000000000..4150317099 --- /dev/null +++ b/src/agents/sandbox/app_server/client.py @@ -0,0 +1,507 @@ +from __future__ import annotations + +import json +import threading +import uuid +from collections import deque +from collections.abc import Iterable, Iterator +from dataclasses import dataclass +from typing import Any, Callable, TypeVar, cast + +from pydantic import BaseModel +from websockets.exceptions import ConnectionClosed +from websockets.sync.client import ClientConnection, connect + +from .errors import AppServerError, TransportClosedError, map_jsonrpc_error +from .generated.notification_registry import NOTIFICATION_MODELS +from .generated.v2_all import ( + AgentMessageDeltaNotification, + ModelListResponse, + ThreadArchiveResponse, + ThreadCompactStartResponse, + ThreadForkParams as V2ThreadForkParams, + ThreadForkResponse, + ThreadListParams as V2ThreadListParams, + ThreadListResponse, + ThreadReadResponse, + ThreadResumeParams as V2ThreadResumeParams, + ThreadResumeResponse, + ThreadSetNameResponse, + ThreadStartParams as V2ThreadStartParams, + ThreadStartResponse, + ThreadUnarchiveResponse, + TurnCompletedNotification, + TurnInterruptResponse, + TurnStartParams as V2TurnStartParams, + TurnStartResponse, + TurnSteerResponse, +) +from .models import ( + InitializeResponse, + JsonObject, + JsonValue, + Notification, + UnknownNotification, +) +from .retry import retry_on_overload + +ModelT = TypeVar("ModelT", bound=BaseModel) +ApprovalHandler = Callable[[str, JsonObject | None], JsonObject] + + +def _params_dict( + params: ( + V2ThreadStartParams + | V2ThreadResumeParams + | V2ThreadListParams + | V2ThreadForkParams + | V2TurnStartParams + | JsonObject + | None + ), +) -> JsonObject: + if params is None: + return {} + if hasattr(params, "model_dump"): + dumped = params.model_dump( + by_alias=True, + exclude_none=True, + mode="json", + ) + if not isinstance(dumped, dict): + raise TypeError("Expected model_dump() to return dict") + return dumped + if isinstance(params, dict): + return params + raise TypeError(f"Expected generated params model or dict, got {type(params).__name__}") + + +@dataclass(frozen=True) +class AppServerTransportOps: + ws_connect: Callable[..., ClientConnection] + + +def _default_transport_ops() -> AppServerTransportOps: + return AppServerTransportOps(ws_connect=connect) + + +@dataclass(slots=True) +class AppServerConfig: + websocket_url: str | None = None + websocket_headers: dict[str, str] | None = None + client_name: str = "codex_python_sdk" + client_title: str = "Codex Python SDK" + client_version: str = "0.2.0" + experimental_api: bool = True + websocket_open_timeout_s: float = 10.0 + websocket_recv_timeout_s: float | None = None + + +class AppServerClient: + """Synchronous typed JSON-RPC client for a remote `codex app-server` websocket.""" + + def __init__( + self, + config: AppServerConfig | None = None, + approval_handler: ApprovalHandler | None = None, + transport_ops: AppServerTransportOps | None = None, + ) -> None: + self.config = config or AppServerConfig() + self._approval_handler = approval_handler or self._default_approval_handler + self._ops = transport_ops or _default_transport_ops() + self._conn: ClientConnection | None = None + self._lock = threading.Lock() + self._turn_consumer_lock = threading.Lock() + self._active_turn_consumer: str | None = None + self._pending_notifications: deque[Notification] = deque() + + def __enter__(self) -> AppServerClient: + self.start() + return self + + def __exit__(self, _exc_type, _exc, _tb) -> None: + self.close() + + @property + def connected_url(self) -> str | None: + return self.config.websocket_url if self._conn is not None else None + + def start(self) -> None: + if self._conn is not None: + return + + websocket_url = self.config.websocket_url + if not websocket_url: + raise ValueError( + "AppServerConfig.websocket_url is required for remote app-server clients." + ) + + try: + self._conn = self._ops.ws_connect( + websocket_url, + additional_headers=self.config.websocket_headers, + open_timeout=self.config.websocket_open_timeout_s, + max_size=None, + ) + except Exception as exc: + raise TransportClosedError( + f"Failed to connect to app-server websocket `{websocket_url}`: {exc}" + ) from exc + + def close(self) -> None: + conn = self._conn + self._conn = None + self._active_turn_consumer = None + + if conn is not None: + try: + conn.close() + except Exception: + pass + + def initialize(self) -> InitializeResponse: + result = self.request( + "initialize", + { + "clientInfo": { + "name": self.config.client_name, + "title": self.config.client_title, + "version": self.config.client_version, + }, + "capabilities": { + "experimentalApi": self.config.experimental_api, + }, + }, + response_model=InitializeResponse, + ) + self.notify("initialized", None) + return result + + def request( + self, + method: str, + params: JsonObject | None, + *, + response_model: type[ModelT], + ) -> ModelT: + result = self._request_raw(method, params) + if not isinstance(result, dict): + raise AppServerError(f"{method} response must be a JSON object") + return response_model.model_validate(result) + + def _request_raw(self, method: str, params: JsonObject | None = None) -> JsonValue: + request_id = str(uuid.uuid4()) + self._write_message({"id": request_id, "method": method, "params": params or {}}) + + while True: + msg = self._read_message() + method_name = msg.get("method") + + if isinstance(method_name, str) and "id" in msg: + response = self._handle_server_request(msg) + self._write_message({"id": msg["id"], "result": response}) + continue + + if isinstance(method_name, str) and "id" not in msg: + self._pending_notifications.append( + self._coerce_notification(method_name, msg.get("params")) + ) + continue + + if msg.get("id") != request_id: + continue + + if "error" in msg: + err = msg["error"] + if isinstance(err, dict): + code_value = err.get("code", -32000) + code = int(code_value) if isinstance(code_value, (int, float, str)) else -32000 + raise map_jsonrpc_error( + code, + str(err.get("message", "unknown")), + err.get("data"), + ) + raise AppServerError("Malformed JSON-RPC error response") + + return msg.get("result") + + def notify(self, method: str, params: JsonObject | None = None) -> None: + self._write_message({"method": method, "params": params or {}}) + + def next_notification(self) -> Notification: + if self._pending_notifications: + return self._pending_notifications.popleft() + + while True: + msg = self._read_message() + method_name = msg.get("method") + if isinstance(method_name, str) and "id" in msg: + response = self._handle_server_request(msg) + self._write_message({"id": msg["id"], "result": response}) + continue + if isinstance(method_name, str) and "id" not in msg: + return self._coerce_notification(method_name, msg.get("params")) + + def acquire_turn_consumer(self, turn_id: str) -> None: + with self._turn_consumer_lock: + if self._active_turn_consumer is not None: + raise RuntimeError( + "Concurrent turn consumers are not yet supported in the experimental SDK. " + f"Client is already streaming turn {self._active_turn_consumer!r}; " + f"cannot start turn {turn_id!r} until the active consumer finishes." + ) + self._active_turn_consumer = turn_id + + def release_turn_consumer(self, turn_id: str) -> None: + with self._turn_consumer_lock: + if self._active_turn_consumer == turn_id: + self._active_turn_consumer = None + + def thread_start( + self, params: V2ThreadStartParams | JsonObject | None = None + ) -> ThreadStartResponse: + return self.request( + "thread/start", _params_dict(params), response_model=ThreadStartResponse + ) + + def thread_resume( + self, + thread_id: str, + params: V2ThreadResumeParams | JsonObject | None = None, + ) -> ThreadResumeResponse: + payload = {"threadId": thread_id, **_params_dict(params)} + return self.request("thread/resume", payload, response_model=ThreadResumeResponse) + + def thread_list( + self, params: V2ThreadListParams | JsonObject | None = None + ) -> ThreadListResponse: + return self.request("thread/list", _params_dict(params), response_model=ThreadListResponse) + + def thread_read(self, thread_id: str, include_turns: bool = False) -> ThreadReadResponse: + return self.request( + "thread/read", + {"threadId": thread_id, "includeTurns": include_turns}, + response_model=ThreadReadResponse, + ) + + def thread_fork( + self, + thread_id: str, + params: V2ThreadForkParams | JsonObject | None = None, + ) -> ThreadForkResponse: + payload = {"threadId": thread_id, **_params_dict(params)} + return self.request("thread/fork", payload, response_model=ThreadForkResponse) + + def thread_archive(self, thread_id: str) -> ThreadArchiveResponse: + return self.request( + "thread/archive", {"threadId": thread_id}, response_model=ThreadArchiveResponse + ) + + def thread_unarchive(self, thread_id: str) -> ThreadUnarchiveResponse: + return self.request( + "thread/unarchive", {"threadId": thread_id}, response_model=ThreadUnarchiveResponse + ) + + def thread_set_name(self, thread_id: str, name: str) -> ThreadSetNameResponse: + return self.request( + "thread/name/set", + {"threadId": thread_id, "name": name}, + response_model=ThreadSetNameResponse, + ) + + def thread_compact(self, thread_id: str) -> ThreadCompactStartResponse: + return self.request( + "thread/compact/start", + {"threadId": thread_id}, + response_model=ThreadCompactStartResponse, + ) + + def turn_start( + self, + thread_id: str, + input_items: list[JsonObject] | JsonObject | str, + params: V2TurnStartParams | JsonObject | None = None, + ) -> TurnStartResponse: + payload: JsonObject = { + **_params_dict(params), + "threadId": thread_id, + "input": cast(JsonValue, self._normalize_input_items(input_items)), + } + return self.request("turn/start", payload, response_model=TurnStartResponse) + + def turn_interrupt(self, thread_id: str, turn_id: str) -> TurnInterruptResponse: + return self.request( + "turn/interrupt", + {"threadId": thread_id, "turnId": turn_id}, + response_model=TurnInterruptResponse, + ) + + def turn_steer( + self, + thread_id: str, + expected_turn_id: str, + input_items: list[JsonObject] | JsonObject | str, + ) -> TurnSteerResponse: + payload: JsonObject = { + "threadId": thread_id, + "expectedTurnId": expected_turn_id, + "input": cast(JsonValue, self._normalize_input_items(input_items)), + } + return self.request( + "turn/steer", + payload, + response_model=TurnSteerResponse, + ) + + def model_list(self, include_hidden: bool = False) -> ModelListResponse: + return self.request( + "model/list", + {"includeHidden": include_hidden}, + response_model=ModelListResponse, + ) + + def request_with_retry_on_overload( + self, + method: str, + params: JsonObject | None, + *, + response_model: type[ModelT], + max_attempts: int = 3, + initial_delay_s: float = 0.25, + max_delay_s: float = 2.0, + ) -> ModelT: + return retry_on_overload( + lambda: self.request(method, params, response_model=response_model), + max_attempts=max_attempts, + initial_delay_s=initial_delay_s, + max_delay_s=max_delay_s, + ) + + def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification: + while True: + notification = self.next_notification() + if ( + notification.method == "turn/completed" + and isinstance(notification.payload, TurnCompletedNotification) + and notification.payload.turn.id == turn_id + ): + return notification.payload + + def stream_until_methods(self, methods: Iterable[str] | str) -> list[Notification]: + target_methods = {methods} if isinstance(methods, str) else set(methods) + out: list[Notification] = [] + while True: + notification = self.next_notification() + out.append(notification) + if notification.method in target_methods: + return out + + def stream_text( + self, + thread_id: str, + text: str, + params: V2TurnStartParams | JsonObject | None = None, + ) -> Iterator[AgentMessageDeltaNotification]: + started = self.turn_start(thread_id, text, params=params) + turn_id = started.turn.id + while True: + notification = self.next_notification() + if ( + notification.method == "item/agentMessage/delta" + and isinstance(notification.payload, AgentMessageDeltaNotification) + and notification.payload.turn_id == turn_id + ): + yield notification.payload + continue + if ( + notification.method == "turn/completed" + and isinstance(notification.payload, TurnCompletedNotification) + and notification.payload.turn.id == turn_id + ): + break + + def _coerce_notification(self, method: str, params: object) -> Notification: + params_dict = params if isinstance(params, dict) else {} + + model = NOTIFICATION_MODELS.get(method) + if model is None: + return Notification(method=method, payload=UnknownNotification(params=params_dict)) + + try: + payload = model.model_validate(params_dict) + except Exception: + return Notification(method=method, payload=UnknownNotification(params=params_dict)) + return Notification(method=method, payload=cast(Any, payload)) + + def _normalize_input_items( + self, + input_items: list[JsonObject] | JsonObject | str, + ) -> list[JsonObject]: + if isinstance(input_items, str): + return [{"type": "text", "text": input_items}] + if isinstance(input_items, dict): + return [input_items] + return input_items + + def _default_approval_handler(self, method: str, params: JsonObject | None) -> JsonObject: + if method == "item/commandExecution/requestApproval": + return {"decision": "accept"} + if method == "item/fileChange/requestApproval": + return {"decision": "accept"} + return {} + + def _handle_server_request(self, msg: dict[str, JsonValue]) -> JsonObject: + method = msg["method"] + params = msg.get("params") + if not isinstance(method, str): + return {} + return self._approval_handler( + method, + params if isinstance(params, dict) else None, + ) + + def _write_message(self, payload: JsonObject) -> None: + if self._conn is None: + raise TransportClosedError("app-server websocket is not connected") + + try: + with self._lock: + self._conn.send(json.dumps(payload), text=True) + except ConnectionClosed as exc: + raise TransportClosedError( + f"app-server websocket closed while sending. url={self.config.websocket_url!r}" + ) from exc + except Exception as exc: + raise AppServerError(f"Failed to send websocket message: {exc}") from exc + + def _read_message(self) -> dict[str, JsonValue]: + if self._conn is None: + raise TransportClosedError("app-server websocket is not connected") + + try: + frame = self._conn.recv(timeout=self.config.websocket_recv_timeout_s) + except ConnectionClosed as exc: + raise TransportClosedError( + f"app-server websocket closed while receiving. url={self.config.websocket_url!r}" + ) from exc + except Exception as exc: + raise AppServerError(f"Failed to receive websocket message: {exc}") from exc + + if isinstance(frame, bytes): + try: + raw_message = frame.decode("utf-8") + except UnicodeDecodeError as exc: + raise AppServerError( + "Received non-UTF-8 websocket binary frame from app-server" + ) from exc + else: + raw_message = frame + + try: + message = json.loads(raw_message) + except json.JSONDecodeError as exc: + raise AppServerError(f"Invalid JSON-RPC frame: {raw_message!r}") from exc + + if not isinstance(message, dict): + raise AppServerError(f"Invalid JSON-RPC payload: {message!r}") + return message diff --git a/src/agents/sandbox/app_server/errors.py b/src/agents/sandbox/app_server/errors.py new file mode 100644 index 0000000000..104e35f2e9 --- /dev/null +++ b/src/agents/sandbox/app_server/errors.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +from typing import Any + + +class AppServerError(Exception): + """Base exception for SDK errors.""" + + +class JsonRpcError(AppServerError): + """Raw JSON-RPC error wrapper from the server.""" + + def __init__(self, code: int, message: str, data: Any = None): + super().__init__(f"JSON-RPC error {code}: {message}") + self.code = code + self.message = message + self.data = data + + +class TransportClosedError(AppServerError): + """Raised when the app-server transport closes unexpectedly.""" + + +class AppServerRpcError(JsonRpcError): + """Base typed error for JSON-RPC failures.""" + + +class ParseError(AppServerRpcError): + pass + + +class InvalidRequestError(AppServerRpcError): + pass + + +class MethodNotFoundError(AppServerRpcError): + pass + + +class InvalidParamsError(AppServerRpcError): + pass + + +class InternalRpcError(AppServerRpcError): + pass + + +class ServerBusyError(AppServerRpcError): + """Server is overloaded / unavailable and caller should retry.""" + + +class RetryLimitExceededError(ServerBusyError): + """Server exhausted internal retry budget for a retryable operation.""" + + +def _contains_retry_limit_text(message: str) -> bool: + lowered = message.lower() + return "retry limit" in lowered or "too many failed attempts" in lowered + + +def _is_server_overloaded(data: Any) -> bool: + if data is None: + return False + + if isinstance(data, str): + return data.lower() == "server_overloaded" + + if isinstance(data, dict): + direct = data.get("codex_error_info") or data.get("codexErrorInfo") or data.get("errorInfo") + if isinstance(direct, str) and direct.lower() == "server_overloaded": + return True + if isinstance(direct, dict): + for value in direct.values(): + if isinstance(value, str) and value.lower() == "server_overloaded": + return True + for value in data.values(): + if _is_server_overloaded(value): + return True + + if isinstance(data, list): + return any(_is_server_overloaded(value) for value in data) + + return False + + +def map_jsonrpc_error(code: int, message: str, data: Any = None) -> JsonRpcError: + """Map a raw JSON-RPC error into a richer SDK exception class.""" + + if code == -32700: + return ParseError(code, message, data) + if code == -32600: + return InvalidRequestError(code, message, data) + if code == -32601: + return MethodNotFoundError(code, message, data) + if code == -32602: + return InvalidParamsError(code, message, data) + if code == -32603: + return InternalRpcError(code, message, data) + + if -32099 <= code <= -32000: + if _is_server_overloaded(data): + if _contains_retry_limit_text(message): + return RetryLimitExceededError(code, message, data) + return ServerBusyError(code, message, data) + if _contains_retry_limit_text(message): + return RetryLimitExceededError(code, message, data) + return AppServerRpcError(code, message, data) + + return JsonRpcError(code, message, data) + + +def is_retryable_error(exc: BaseException) -> bool: + """True if the exception is a transient overload-style error.""" + + if isinstance(exc, ServerBusyError): + return True + + if isinstance(exc, JsonRpcError): + return _is_server_overloaded(exc.data) + + return False diff --git a/src/agents/sandbox/app_server/generated/__init__.py b/src/agents/sandbox/app_server/generated/__init__.py new file mode 100644 index 0000000000..d7b3f674b2 --- /dev/null +++ b/src/agents/sandbox/app_server/generated/__init__.py @@ -0,0 +1 @@ +"""Auto-generated Python types derived from the app-server schemas.""" diff --git a/src/agents/sandbox/app_server/generated/notification_registry.py b/src/agents/sandbox/app_server/generated/notification_registry.py new file mode 100644 index 0000000000..a714a7f222 --- /dev/null +++ b/src/agents/sandbox/app_server/generated/notification_registry.py @@ -0,0 +1,108 @@ +# Auto-generated by scripts/update_sdk_artifacts.py +# DO NOT EDIT MANUALLY. + +from __future__ import annotations + +from pydantic import BaseModel + +from .v2_all import ( + AccountLoginCompletedNotification, + AccountRateLimitsUpdatedNotification, + AccountUpdatedNotification, + AgentMessageDeltaNotification, + AppListUpdatedNotification, + CommandExecOutputDeltaNotification, + CommandExecutionOutputDeltaNotification, + ConfigWarningNotification, + ContextCompactedNotification, + DeprecationNoticeNotification, + ErrorNotification, + FileChangeOutputDeltaNotification, + FuzzyFileSearchSessionCompletedNotification, + FuzzyFileSearchSessionUpdatedNotification, + HookCompletedNotification, + HookStartedNotification, + ItemCompletedNotification, + ItemGuardianApprovalReviewCompletedNotification, + ItemGuardianApprovalReviewStartedNotification, + ItemStartedNotification, + McpServerOauthLoginCompletedNotification, + McpToolCallProgressNotification, + ModelReroutedNotification, + PlanDeltaNotification, + ReasoningSummaryPartAddedNotification, + ReasoningSummaryTextDeltaNotification, + ReasoningTextDeltaNotification, + ServerRequestResolvedNotification, + SkillsChangedNotification, + TerminalInteractionNotification, + ThreadArchivedNotification, + ThreadClosedNotification, + ThreadNameUpdatedNotification, + ThreadRealtimeClosedNotification, + ThreadRealtimeErrorNotification, + ThreadRealtimeItemAddedNotification, + ThreadRealtimeOutputAudioDeltaNotification, + ThreadRealtimeStartedNotification, + ThreadStartedNotification, + ThreadStatusChangedNotification, + ThreadTokenUsageUpdatedNotification, + ThreadUnarchivedNotification, + TurnCompletedNotification, + TurnDiffUpdatedNotification, + TurnPlanUpdatedNotification, + TurnStartedNotification, + WindowsSandboxSetupCompletedNotification, + WindowsWorldWritableWarningNotification, +) + +NOTIFICATION_MODELS: dict[str, type[BaseModel]] = { + "account/login/completed": AccountLoginCompletedNotification, + "account/rateLimits/updated": AccountRateLimitsUpdatedNotification, + "account/updated": AccountUpdatedNotification, + "app/list/updated": AppListUpdatedNotification, + "command/exec/outputDelta": CommandExecOutputDeltaNotification, + "configWarning": ConfigWarningNotification, + "deprecationNotice": DeprecationNoticeNotification, + "error": ErrorNotification, + "fuzzyFileSearch/sessionCompleted": FuzzyFileSearchSessionCompletedNotification, + "fuzzyFileSearch/sessionUpdated": FuzzyFileSearchSessionUpdatedNotification, + "hook/completed": HookCompletedNotification, + "hook/started": HookStartedNotification, + "item/agentMessage/delta": AgentMessageDeltaNotification, + "item/autoApprovalReview/completed": ItemGuardianApprovalReviewCompletedNotification, + "item/autoApprovalReview/started": ItemGuardianApprovalReviewStartedNotification, + "item/commandExecution/outputDelta": CommandExecutionOutputDeltaNotification, + "item/commandExecution/terminalInteraction": TerminalInteractionNotification, + "item/completed": ItemCompletedNotification, + "item/fileChange/outputDelta": FileChangeOutputDeltaNotification, + "item/mcpToolCall/progress": McpToolCallProgressNotification, + "item/plan/delta": PlanDeltaNotification, + "item/reasoning/summaryPartAdded": ReasoningSummaryPartAddedNotification, + "item/reasoning/summaryTextDelta": ReasoningSummaryTextDeltaNotification, + "item/reasoning/textDelta": ReasoningTextDeltaNotification, + "item/started": ItemStartedNotification, + "mcpServer/oauthLogin/completed": McpServerOauthLoginCompletedNotification, + "model/rerouted": ModelReroutedNotification, + "serverRequest/resolved": ServerRequestResolvedNotification, + "skills/changed": SkillsChangedNotification, + "thread/archived": ThreadArchivedNotification, + "thread/closed": ThreadClosedNotification, + "thread/compacted": ContextCompactedNotification, + "thread/name/updated": ThreadNameUpdatedNotification, + "thread/realtime/closed": ThreadRealtimeClosedNotification, + "thread/realtime/error": ThreadRealtimeErrorNotification, + "thread/realtime/itemAdded": ThreadRealtimeItemAddedNotification, + "thread/realtime/outputAudio/delta": ThreadRealtimeOutputAudioDeltaNotification, + "thread/realtime/started": ThreadRealtimeStartedNotification, + "thread/started": ThreadStartedNotification, + "thread/status/changed": ThreadStatusChangedNotification, + "thread/tokenUsage/updated": ThreadTokenUsageUpdatedNotification, + "thread/unarchived": ThreadUnarchivedNotification, + "turn/completed": TurnCompletedNotification, + "turn/diff/updated": TurnDiffUpdatedNotification, + "turn/plan/updated": TurnPlanUpdatedNotification, + "turn/started": TurnStartedNotification, + "windows/worldWritableWarning": WindowsWorldWritableWarningNotification, + "windowsSandbox/setupCompleted": WindowsSandboxSetupCompletedNotification, +} diff --git a/src/agents/sandbox/app_server/generated/v2_all.py b/src/agents/sandbox/app_server/generated/v2_all.py new file mode 100644 index 0000000000..300c798f3a --- /dev/null +++ b/src/agents/sandbox/app_server/generated/v2_all.py @@ -0,0 +1,6136 @@ +# generated by datamodel-codegen: +# filename: codex_app_server_protocol.v2.schemas.json + +from __future__ import annotations + +from enum import Enum +from typing import Annotated, Any, Literal + +from pydantic import BaseModel, ConfigDict, Field, RootModel + + +class CodexAppServerProtocolV2(BaseModel): + pass + model_config = ConfigDict( + populate_by_name=True, + ) + + +class AbsolutePathBuf(RootModel[str]): + model_config = ConfigDict( + populate_by_name=True, + ) + root: Annotated[ + str, + Field( + description="A path that is guaranteed to be absolute and normalized (though it is not guaranteed to be canonicalized or exist on the filesystem).\n\nIMPORTANT: When deserializing an `AbsolutePathBuf`, a base path must be set using [AbsolutePathBufGuard::new]. If no base path is set, the deserialization will fail unless the path being deserialized is already absolute." + ), + ] + + +class ApiKeyAccount(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + type: Annotated[Literal["apiKey"], Field(title="ApiKeyAccountType")] + + +class AccountLoginCompletedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + error: str | None = None + login_id: Annotated[str | None, Field(alias="loginId")] = None + success: bool + + +class AgentMessageDeltaNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + delta: str + item_id: Annotated[str, Field(alias="itemId")] + thread_id: Annotated[str, Field(alias="threadId")] + turn_id: Annotated[str, Field(alias="turnId")] + + +class AnalyticsConfig(BaseModel): + model_config = ConfigDict( + extra="allow", + populate_by_name=True, + ) + enabled: bool | None = None + + +class AppBranding(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + category: str | None = None + developer: str | None = None + is_discoverable_app: Annotated[bool, Field(alias="isDiscoverableApp")] + privacy_policy: Annotated[str | None, Field(alias="privacyPolicy")] = None + terms_of_service: Annotated[str | None, Field(alias="termsOfService")] = None + website: str | None = None + + +class AppReview(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + status: str + + +class AppScreenshot(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + file_id: Annotated[str | None, Field(alias="fileId")] = None + url: str | None = None + user_prompt: Annotated[str, Field(alias="userPrompt")] + + +class AppSummary(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + description: str | None = None + id: str + install_url: Annotated[str | None, Field(alias="installUrl")] = None + name: str + + +class AppToolApproval(Enum): + auto = "auto" + prompt = "prompt" + approve = "approve" + + +class AppToolConfig(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + approval_mode: AppToolApproval | None = None + enabled: bool | None = None + + +class AppToolsConfig(BaseModel): + pass + model_config = ConfigDict( + populate_by_name=True, + ) + + +class ApprovalsReviewer(Enum): + user = "user" + guardian_subagent = "guardian_subagent" + + +class AppsDefaultConfig(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + destructive_enabled: bool | None = True + enabled: bool | None = True + open_world_enabled: bool | None = True + + +class AppsListParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + cursor: Annotated[ + str | None, + Field(description="Opaque pagination cursor returned by a previous call."), + ] = None + force_refetch: Annotated[ + bool | None, + Field( + alias="forceRefetch", + description="When true, bypass app caches and fetch the latest data from sources.", + ), + ] = None + limit: Annotated[ + int | None, + Field( + description="Optional page size; defaults to a reasonable server-side value.", + ge=0, + ), + ] = None + thread_id: Annotated[ + str | None, + Field( + alias="threadId", + description="Optional thread id used to evaluate app feature gating from that thread's config.", + ), + ] = None + + +class AskForApprovalValue(Enum): + untrusted = "untrusted" + on_failure = "on-failure" + on_request = "on-request" + never = "never" + + +class Granular(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + mcp_elicitations: bool + request_permissions: bool | None = False + rules: bool + sandbox_approval: bool + skill_approval: bool | None = False + + +class GranularAskForApproval(BaseModel): + model_config = ConfigDict( + extra="forbid", + populate_by_name=True, + ) + granular: Granular + + +class AskForApproval(RootModel[AskForApprovalValue | GranularAskForApproval]): + model_config = ConfigDict( + populate_by_name=True, + ) + root: AskForApprovalValue | GranularAskForApproval + + +class AuthMode(Enum): + apikey = "apikey" + chatgpt = "chatgpt" + chatgpt_auth_tokens = "chatgptAuthTokens" + + +class ByteRange(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + end: Annotated[int, Field(ge=0)] + start: Annotated[int, Field(ge=0)] + + +class CancelLoginAccountParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + login_id: Annotated[str, Field(alias="loginId")] + + +class CancelLoginAccountStatus(Enum): + canceled = "canceled" + not_found = "notFound" + + +class ClientInfo(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + name: str + title: str | None = None + version: str + + +class CodexErrorInfoValue(Enum): + context_window_exceeded = "contextWindowExceeded" + usage_limit_exceeded = "usageLimitExceeded" + server_overloaded = "serverOverloaded" + internal_server_error = "internalServerError" + unauthorized = "unauthorized" + bad_request = "badRequest" + thread_rollback_failed = "threadRollbackFailed" + sandbox_error = "sandboxError" + other = "other" + + +class HttpConnectionFailed(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + http_status_code: Annotated[int | None, Field(alias="httpStatusCode", ge=0)] = None + + +class HttpConnectionFailedCodexErrorInfo(BaseModel): + model_config = ConfigDict( + extra="forbid", + populate_by_name=True, + ) + http_connection_failed: Annotated[HttpConnectionFailed, Field(alias="httpConnectionFailed")] + + +class ResponseStreamConnectionFailed(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + http_status_code: Annotated[int | None, Field(alias="httpStatusCode", ge=0)] = None + + +class ResponseStreamConnectionFailedCodexErrorInfo(BaseModel): + model_config = ConfigDict( + extra="forbid", + populate_by_name=True, + ) + response_stream_connection_failed: Annotated[ + ResponseStreamConnectionFailed, Field(alias="responseStreamConnectionFailed") + ] + + +class ResponseStreamDisconnected(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + http_status_code: Annotated[int | None, Field(alias="httpStatusCode", ge=0)] = None + + +class ResponseStreamDisconnectedCodexErrorInfo(BaseModel): + model_config = ConfigDict( + extra="forbid", + populate_by_name=True, + ) + response_stream_disconnected: Annotated[ + ResponseStreamDisconnected, Field(alias="responseStreamDisconnected") + ] + + +class ResponseTooManyFailedAttempts(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + http_status_code: Annotated[int | None, Field(alias="httpStatusCode", ge=0)] = None + + +class ResponseTooManyFailedAttemptsCodexErrorInfo(BaseModel): + model_config = ConfigDict( + extra="forbid", + populate_by_name=True, + ) + response_too_many_failed_attempts: Annotated[ + ResponseTooManyFailedAttempts, Field(alias="responseTooManyFailedAttempts") + ] + + +class CodexErrorInfo( + RootModel[ + CodexErrorInfoValue + | HttpConnectionFailedCodexErrorInfo + | ResponseStreamConnectionFailedCodexErrorInfo + | ResponseStreamDisconnectedCodexErrorInfo + | ResponseTooManyFailedAttemptsCodexErrorInfo + ] +): + model_config = ConfigDict( + populate_by_name=True, + ) + root: Annotated[ + CodexErrorInfoValue + | HttpConnectionFailedCodexErrorInfo + | ResponseStreamConnectionFailedCodexErrorInfo + | ResponseStreamDisconnectedCodexErrorInfo + | ResponseTooManyFailedAttemptsCodexErrorInfo, + Field( + description="This translation layer make sure that we expose codex error code in camel case.\n\nWhen an upstream HTTP status is available (for example, from the Responses API or a provider), it is forwarded in `httpStatusCode` on the relevant `codexErrorInfo` variant." + ), + ] + + +class CollabAgentStatus(Enum): + pending_init = "pendingInit" + running = "running" + completed = "completed" + errored = "errored" + shutdown = "shutdown" + not_found = "notFound" + + +class CollabAgentTool(Enum): + spawn_agent = "spawnAgent" + send_input = "sendInput" + resume_agent = "resumeAgent" + wait = "wait" + close_agent = "closeAgent" + + +class CollabAgentToolCallStatus(Enum): + in_progress = "inProgress" + completed = "completed" + failed = "failed" + + +class ReadCommandAction(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + command: str + name: str + path: str + type: Annotated[Literal["read"], Field(title="ReadCommandActionType")] + + +class ListFilesCommandAction(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + command: str + path: str | None = None + type: Annotated[Literal["listFiles"], Field(title="ListFilesCommandActionType")] + + +class SearchCommandAction(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + command: str + path: str | None = None + query: str | None = None + type: Annotated[Literal["search"], Field(title="SearchCommandActionType")] + + +class UnknownCommandAction(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + command: str + type: Annotated[Literal["unknown"], Field(title="UnknownCommandActionType")] + + +class CommandAction( + RootModel[ + ReadCommandAction | ListFilesCommandAction | SearchCommandAction | UnknownCommandAction + ] +): + model_config = ConfigDict( + populate_by_name=True, + ) + root: ReadCommandAction | ListFilesCommandAction | SearchCommandAction | UnknownCommandAction + + +class CommandExecOutputStream(Enum): + stdout = "stdout" + stderr = "stderr" + + +class CommandExecResizeResponse(BaseModel): + pass + model_config = ConfigDict( + populate_by_name=True, + ) + + +class CommandExecResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + exit_code: Annotated[int, Field(alias="exitCode", description="Process exit code.")] + stderr: Annotated[ + str, + Field( + description="Buffered stderr capture.\n\nEmpty when stderr was streamed via `command/exec/outputDelta`." + ), + ] + stdout: Annotated[ + str, + Field( + description="Buffered stdout capture.\n\nEmpty when stdout was streamed via `command/exec/outputDelta`." + ), + ] + + +class CommandExecTerminalSize(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + cols: Annotated[int, Field(description="Terminal width in character cells.", ge=0)] + rows: Annotated[int, Field(description="Terminal height in character cells.", ge=0)] + + +class CommandExecTerminateParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + process_id: Annotated[ + str, + Field( + alias="processId", + description="Client-supplied, connection-scoped `processId` from the original `command/exec` request.", + ), + ] + + +class CommandExecTerminateResponse(BaseModel): + pass + model_config = ConfigDict( + populate_by_name=True, + ) + + +class CommandExecWriteParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + close_stdin: Annotated[ + bool | None, + Field( + alias="closeStdin", + description="Close stdin after writing `deltaBase64`, if present.", + ), + ] = None + delta_base64: Annotated[ + str | None, + Field( + alias="deltaBase64", + description="Optional base64-encoded stdin bytes to write.", + ), + ] = None + process_id: Annotated[ + str, + Field( + alias="processId", + description="Client-supplied, connection-scoped `processId` from the original `command/exec` request.", + ), + ] + + +class CommandExecWriteResponse(BaseModel): + pass + model_config = ConfigDict( + populate_by_name=True, + ) + + +class CommandExecutionOutputDeltaNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + delta: str + item_id: Annotated[str, Field(alias="itemId")] + thread_id: Annotated[str, Field(alias="threadId")] + turn_id: Annotated[str, Field(alias="turnId")] + + +class CommandExecutionStatus(Enum): + in_progress = "inProgress" + completed = "completed" + failed = "failed" + declined = "declined" + + +class MdmConfigLayerSource(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + domain: str + key: str + type: Annotated[Literal["mdm"], Field(title="MdmConfigLayerSourceType")] + + +class SystemConfigLayerSource(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + file: Annotated[ + AbsolutePathBuf, + Field( + description="This is the path to the system config.toml file, though it is not guaranteed to exist." + ), + ] + type: Annotated[Literal["system"], Field(title="SystemConfigLayerSourceType")] + + +class UserConfigLayerSource(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + file: Annotated[ + AbsolutePathBuf, + Field( + description="This is the path to the user's config.toml file, though it is not guaranteed to exist." + ), + ] + type: Annotated[Literal["user"], Field(title="UserConfigLayerSourceType")] + + +class ProjectConfigLayerSource(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + dot_codex_folder: Annotated[AbsolutePathBuf, Field(alias="dotCodexFolder")] + type: Annotated[Literal["project"], Field(title="ProjectConfigLayerSourceType")] + + +class SessionFlagsConfigLayerSource(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + type: Annotated[Literal["sessionFlags"], Field(title="SessionFlagsConfigLayerSourceType")] + + +class LegacyManagedConfigTomlFromFileConfigLayerSource(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + file: AbsolutePathBuf + type: Annotated[ + Literal["legacyManagedConfigTomlFromFile"], + Field(title="LegacyManagedConfigTomlFromFileConfigLayerSourceType"), + ] + + +class LegacyManagedConfigTomlFromMdmConfigLayerSource(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + type: Annotated[ + Literal["legacyManagedConfigTomlFromMdm"], + Field(title="LegacyManagedConfigTomlFromMdmConfigLayerSourceType"), + ] + + +class ConfigLayerSource( + RootModel[ + MdmConfigLayerSource + | SystemConfigLayerSource + | UserConfigLayerSource + | ProjectConfigLayerSource + | SessionFlagsConfigLayerSource + | LegacyManagedConfigTomlFromFileConfigLayerSource + | LegacyManagedConfigTomlFromMdmConfigLayerSource + ] +): + model_config = ConfigDict( + populate_by_name=True, + ) + root: ( + MdmConfigLayerSource + | SystemConfigLayerSource + | UserConfigLayerSource + | ProjectConfigLayerSource + | SessionFlagsConfigLayerSource + | LegacyManagedConfigTomlFromFileConfigLayerSource + | LegacyManagedConfigTomlFromMdmConfigLayerSource + ) + + +class ConfigReadParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + cwd: Annotated[ + str | None, + Field( + description="Optional working directory to resolve project config layers. If specified, return the effective config as seen from that directory (i.e., including any project layers between `cwd` and the project/repo root)." + ), + ] = None + include_layers: Annotated[bool | None, Field(alias="includeLayers")] = False + + +class InputTextContentItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + text: str + type: Annotated[Literal["input_text"], Field(title="InputTextContentItemType")] + + +class InputImageContentItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + image_url: str + type: Annotated[Literal["input_image"], Field(title="InputImageContentItemType")] + + +class OutputTextContentItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + text: str + type: Annotated[Literal["output_text"], Field(title="OutputTextContentItemType")] + + +class ContentItem(RootModel[InputTextContentItem | InputImageContentItem | OutputTextContentItem]): + model_config = ConfigDict( + populate_by_name=True, + ) + root: InputTextContentItem | InputImageContentItem | OutputTextContentItem + + +class ContextCompactedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + thread_id: Annotated[str, Field(alias="threadId")] + turn_id: Annotated[str, Field(alias="turnId")] + + +class CreditsSnapshot(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + balance: str | None = None + has_credits: Annotated[bool, Field(alias="hasCredits")] + unlimited: bool + + +class DeprecationNoticeNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + details: Annotated[ + str | None, + Field(description="Optional extra guidance, such as migration steps or rationale."), + ] = None + summary: Annotated[str, Field(description="Concise summary of what is deprecated.")] + + +class InputTextDynamicToolCallOutputContentItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + text: str + type: Annotated[ + Literal["inputText"], + Field(title="InputTextDynamicToolCallOutputContentItemType"), + ] + + +class InputImageDynamicToolCallOutputContentItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + image_url: Annotated[str, Field(alias="imageUrl")] + type: Annotated[ + Literal["inputImage"], + Field(title="InputImageDynamicToolCallOutputContentItemType"), + ] + + +class DynamicToolCallOutputContentItem( + RootModel[ + InputTextDynamicToolCallOutputContentItem | InputImageDynamicToolCallOutputContentItem + ] +): + model_config = ConfigDict( + populate_by_name=True, + ) + root: InputTextDynamicToolCallOutputContentItem | InputImageDynamicToolCallOutputContentItem + + +class DynamicToolCallStatus(Enum): + in_progress = "inProgress" + completed = "completed" + failed = "failed" + + +class DynamicToolSpec(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + description: str + input_schema: Annotated[Any, Field(alias="inputSchema")] + name: str + + +class ExperimentalFeatureListParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + cursor: Annotated[ + str | None, + Field(description="Opaque pagination cursor returned by a previous call."), + ] = None + limit: Annotated[ + int | None, + Field( + description="Optional page size; defaults to a reasonable server-side value.", + ge=0, + ), + ] = None + + +class ExperimentalFeatureStage(Enum): + beta = "beta" + under_development = "underDevelopment" + stable = "stable" + deprecated = "deprecated" + removed = "removed" + + +class ExternalAgentConfigDetectParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + cwds: Annotated[ + list[str] | None, + Field(description="Zero or more working directories to include for repo-scoped detection."), + ] = None + include_home: Annotated[ + bool | None, + Field( + alias="includeHome", + description="If true, include detection under the user's home (~/.claude, ~/.codex, etc.).", + ), + ] = None + + +class ExternalAgentConfigImportResponse(BaseModel): + pass + model_config = ConfigDict( + populate_by_name=True, + ) + + +class ExternalAgentConfigMigrationItemType(Enum): + agents_md = "AGENTS_MD" + config = "CONFIG" + skills = "SKILLS" + mcp_server_config = "MCP_SERVER_CONFIG" + + +class FeedbackUploadParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + classification: str + extra_log_files: Annotated[list[str] | None, Field(alias="extraLogFiles")] = None + include_logs: Annotated[bool, Field(alias="includeLogs")] + reason: str | None = None + thread_id: Annotated[str | None, Field(alias="threadId")] = None + + +class FeedbackUploadResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + thread_id: Annotated[str, Field(alias="threadId")] + + +class FileChangeOutputDeltaNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + delta: str + item_id: Annotated[str, Field(alias="itemId")] + thread_id: Annotated[str, Field(alias="threadId")] + turn_id: Annotated[str, Field(alias="turnId")] + + +class ForcedLoginMethod(Enum): + chatgpt = "chatgpt" + api = "api" + + +class FsCopyParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + destination_path: Annotated[ + AbsolutePathBuf, + Field(alias="destinationPath", description="Absolute destination path."), + ] + recursive: Annotated[ + bool | None, + Field(description="Required for directory copies; ignored for file copies."), + ] = None + source_path: Annotated[ + AbsolutePathBuf, Field(alias="sourcePath", description="Absolute source path.") + ] + + +class FsCopyResponse(BaseModel): + pass + model_config = ConfigDict( + populate_by_name=True, + ) + + +class FsCreateDirectoryParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + path: Annotated[AbsolutePathBuf, Field(description="Absolute directory path to create.")] + recursive: Annotated[ + bool | None, + Field(description="Whether parent directories should also be created. Defaults to `true`."), + ] = None + + +class FsCreateDirectoryResponse(BaseModel): + pass + model_config = ConfigDict( + populate_by_name=True, + ) + + +class FsGetMetadataParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + path: Annotated[AbsolutePathBuf, Field(description="Absolute path to inspect.")] + + +class FsGetMetadataResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + created_at_ms: Annotated[ + int, + Field( + alias="createdAtMs", + description="File creation time in Unix milliseconds when available, otherwise `0`.", + ), + ] + is_directory: Annotated[ + bool, + Field( + alias="isDirectory", + description="Whether the path currently resolves to a directory.", + ), + ] + is_file: Annotated[ + bool, + Field( + alias="isFile", + description="Whether the path currently resolves to a regular file.", + ), + ] + modified_at_ms: Annotated[ + int, + Field( + alias="modifiedAtMs", + description="File modification time in Unix milliseconds when available, otherwise `0`.", + ), + ] + + +class FsReadDirectoryEntry(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + file_name: Annotated[ + str, + Field( + alias="fileName", + description="Direct child entry name only, not an absolute or relative path.", + ), + ] + is_directory: Annotated[ + bool, + Field( + alias="isDirectory", + description="Whether this entry resolves to a directory.", + ), + ] + is_file: Annotated[ + bool, + Field(alias="isFile", description="Whether this entry resolves to a regular file."), + ] + + +class FsReadDirectoryParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + path: Annotated[AbsolutePathBuf, Field(description="Absolute directory path to read.")] + + +class FsReadDirectoryResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + entries: Annotated[ + list[FsReadDirectoryEntry], + Field(description="Direct child entries in the requested directory."), + ] + + +class FsReadFileParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + path: Annotated[AbsolutePathBuf, Field(description="Absolute path to read.")] + + +class FsReadFileResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + data_base64: Annotated[ + str, Field(alias="dataBase64", description="File contents encoded as base64.") + ] + + +class FsRemoveParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + force: Annotated[ + bool | None, + Field(description="Whether missing paths should be ignored. Defaults to `true`."), + ] = None + path: Annotated[AbsolutePathBuf, Field(description="Absolute path to remove.")] + recursive: Annotated[ + bool | None, + Field(description="Whether directory removal should recurse. Defaults to `true`."), + ] = None + + +class FsRemoveResponse(BaseModel): + pass + model_config = ConfigDict( + populate_by_name=True, + ) + + +class FsWriteFileParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + data_base64: Annotated[ + str, Field(alias="dataBase64", description="File contents encoded as base64.") + ] + path: Annotated[AbsolutePathBuf, Field(description="Absolute path to write.")] + + +class FsWriteFileResponse(BaseModel): + pass + model_config = ConfigDict( + populate_by_name=True, + ) + + +class InputTextFunctionCallOutputContentItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + text: str + type: Annotated[ + Literal["input_text"], Field(title="InputTextFunctionCallOutputContentItemType") + ] + + +class FuzzyFileSearchParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + cancellation_token: Annotated[str | None, Field(alias="cancellationToken")] = None + query: str + roots: list[str] + + +class Indice(RootModel[int]): + model_config = ConfigDict( + populate_by_name=True, + ) + root: Annotated[int, Field(ge=0)] + + +class FuzzyFileSearchResult(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + file_name: str + indices: list[Indice] | None = None + path: str + root: str + score: Annotated[int, Field(ge=0)] + + +class FuzzyFileSearchSessionCompletedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + session_id: Annotated[str, Field(alias="sessionId")] + + +class FuzzyFileSearchSessionUpdatedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + files: list[FuzzyFileSearchResult] + query: str + session_id: Annotated[str, Field(alias="sessionId")] + + +class GetAccountParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + refresh_token: Annotated[ + bool | None, + Field( + alias="refreshToken", + description="When `true`, requests a proactive token refresh before returning.\n\nIn managed auth mode this triggers the normal refresh-token flow. In external auth mode this flag is ignored. Clients should refresh tokens themselves and call `account/login/start` with `chatgptAuthTokens`.", + ), + ] = False + + +class GhostCommit(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: str + parent: str | None = None + preexisting_untracked_dirs: list[str] + preexisting_untracked_files: list[str] + + +class GitInfo(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + branch: str | None = None + origin_url: Annotated[str | None, Field(alias="originUrl")] = None + sha: str | None = None + + +class GuardianApprovalReviewStatus(Enum): + in_progress = "inProgress" + approved = "approved" + denied = "denied" + aborted = "aborted" + + +class GuardianRiskLevel(Enum): + low = "low" + medium = "medium" + high = "high" + + +class HazelnutScope(Enum): + example = "example" + workspace_shared = "workspace-shared" + all_shared = "all-shared" + personal = "personal" + + +class HookEventName(Enum): + session_start = "sessionStart" + stop = "stop" + + +class HookExecutionMode(Enum): + sync = "sync" + async_ = "async" + + +class HookHandlerType(Enum): + command = "command" + prompt = "prompt" + agent = "agent" + + +class HookOutputEntryKind(Enum): + warning = "warning" + stop = "stop" + feedback = "feedback" + context = "context" + error = "error" + + +class HookRunStatus(Enum): + running = "running" + completed = "completed" + failed = "failed" + blocked = "blocked" + stopped = "stopped" + + +class HookScope(Enum): + thread = "thread" + turn = "turn" + + +class ImageDetail(Enum): + auto = "auto" + low = "low" + high = "high" + original = "original" + + +class InitializeCapabilities(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + experimental_api: Annotated[ + bool | None, + Field( + alias="experimentalApi", + description="Opt into receiving experimental API methods and fields.", + ), + ] = False + opt_out_notification_methods: Annotated[ + list[str] | None, + Field( + alias="optOutNotificationMethods", + description="Exact notification method names that should be suppressed for this connection (for example `thread/started`).", + ), + ] = None + + +class InitializeParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + capabilities: InitializeCapabilities | None = None + client_info: Annotated[ClientInfo, Field(alias="clientInfo")] + + +class InputModality(Enum): + text = "text" + image = "image" + + +class ListMcpServerStatusParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + cursor: Annotated[ + str | None, + Field(description="Opaque pagination cursor returned by a previous call."), + ] = None + limit: Annotated[ + int | None, + Field(description="Optional page size; defaults to a server-defined value.", ge=0), + ] = None + + +class ExecLocalShellAction(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + command: list[str] + env: dict[str, Any] | None = None + timeout_ms: Annotated[int | None, Field(ge=0)] = None + type: Annotated[Literal["exec"], Field(title="ExecLocalShellActionType")] + user: str | None = None + working_directory: str | None = None + + +class LocalShellAction(RootModel[ExecLocalShellAction]): + model_config = ConfigDict( + populate_by_name=True, + ) + root: ExecLocalShellAction + + +class LocalShellStatus(Enum): + completed = "completed" + in_progress = "in_progress" + incomplete = "incomplete" + + +class ApiKeyLoginAccountParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + api_key: Annotated[str, Field(alias="apiKey")] + type: Annotated[Literal["apiKey"], Field(title="ApiKeyv2::LoginAccountParamsType")] + + +class ChatgptLoginAccountParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + type: Annotated[Literal["chatgpt"], Field(title="Chatgptv2::LoginAccountParamsType")] + + +class ChatgptAuthTokensLoginAccountParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + access_token: Annotated[ + str, + Field( + alias="accessToken", + description="Access token (JWT) supplied by the client. This token is used for backend API requests and email extraction.", + ), + ] + chatgpt_account_id: Annotated[ + str, + Field( + alias="chatgptAccountId", + description="Workspace/account identifier supplied by the client.", + ), + ] + chatgpt_plan_type: Annotated[ + str | None, + Field( + alias="chatgptPlanType", + description="Optional plan type supplied by the client.\n\nWhen `null`, Codex attempts to derive the plan type from access-token claims. If unavailable, the plan defaults to `unknown`.", + ), + ] = None + type: Annotated[ + Literal["chatgptAuthTokens"], + Field(title="ChatgptAuthTokensv2::LoginAccountParamsType"), + ] + + +class LoginAccountParams( + RootModel[ + ApiKeyLoginAccountParams | ChatgptLoginAccountParams | ChatgptAuthTokensLoginAccountParams + ] +): + model_config = ConfigDict( + populate_by_name=True, + ) + root: Annotated[ + ApiKeyLoginAccountParams | ChatgptLoginAccountParams | ChatgptAuthTokensLoginAccountParams, + Field(title="LoginAccountParams"), + ] + + +class ApiKeyLoginAccountResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + type: Annotated[Literal["apiKey"], Field(title="ApiKeyv2::LoginAccountResponseType")] + + +class ChatgptLoginAccountResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + auth_url: Annotated[ + str, + Field( + alias="authUrl", + description="URL the client should open in a browser to initiate the OAuth flow.", + ), + ] + login_id: Annotated[str, Field(alias="loginId")] + type: Annotated[Literal["chatgpt"], Field(title="Chatgptv2::LoginAccountResponseType")] + + +class ChatgptAuthTokensLoginAccountResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + type: Annotated[ + Literal["chatgptAuthTokens"], + Field(title="ChatgptAuthTokensv2::LoginAccountResponseType"), + ] + + +class LoginAccountResponse( + RootModel[ + ApiKeyLoginAccountResponse + | ChatgptLoginAccountResponse + | ChatgptAuthTokensLoginAccountResponse + ] +): + model_config = ConfigDict( + populate_by_name=True, + ) + root: Annotated[ + ApiKeyLoginAccountResponse + | ChatgptLoginAccountResponse + | ChatgptAuthTokensLoginAccountResponse, + Field(title="LoginAccountResponse"), + ] + + +class LogoutAccountResponse(BaseModel): + pass + model_config = ConfigDict( + populate_by_name=True, + ) + + +class McpAuthStatus(Enum): + unsupported = "unsupported" + not_logged_in = "notLoggedIn" + bearer_token = "bearerToken" + o_auth = "oAuth" + + +class McpServerOauthLoginCompletedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + error: str | None = None + name: str + success: bool + + +class McpServerOauthLoginParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + name: str + scopes: list[str] | None = None + timeout_secs: Annotated[int | None, Field(alias="timeoutSecs")] = None + + +class McpServerOauthLoginResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + authorization_url: Annotated[str, Field(alias="authorizationUrl")] + + +class McpServerRefreshResponse(BaseModel): + pass + model_config = ConfigDict( + populate_by_name=True, + ) + + +class McpToolCallError(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + message: str + + +class McpToolCallProgressNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + item_id: Annotated[str, Field(alias="itemId")] + message: str + thread_id: Annotated[str, Field(alias="threadId")] + turn_id: Annotated[str, Field(alias="turnId")] + + +class McpToolCallResult(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + content: list + structured_content: Annotated[Any | None, Field(alias="structuredContent")] = None + + +class McpToolCallStatus(Enum): + in_progress = "inProgress" + completed = "completed" + failed = "failed" + + +class MergeStrategy(Enum): + replace = "replace" + upsert = "upsert" + + +class MessagePhase(Enum): + commentary = "commentary" + final_answer = "final_answer" + + +class ModeKind(Enum): + plan = "plan" + default = "default" + + +class ModelAvailabilityNux(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + message: str + + +class ModelListParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + cursor: Annotated[ + str | None, + Field(description="Opaque pagination cursor returned by a previous call."), + ] = None + include_hidden: Annotated[ + bool | None, + Field( + alias="includeHidden", + description="When true, include models that are hidden from the default picker list.", + ), + ] = None + limit: Annotated[ + int | None, + Field( + description="Optional page size; defaults to a reasonable server-side value.", + ge=0, + ), + ] = None + + +class ModelRerouteReason(RootModel[Literal["highRiskCyberActivity"]]): + model_config = ConfigDict( + populate_by_name=True, + ) + root: Literal["highRiskCyberActivity"] + + +class ModelReroutedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + from_model: Annotated[str, Field(alias="fromModel")] + reason: ModelRerouteReason + thread_id: Annotated[str, Field(alias="threadId")] + to_model: Annotated[str, Field(alias="toModel")] + turn_id: Annotated[str, Field(alias="turnId")] + + +class ModelUpgradeInfo(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + migration_markdown: Annotated[str | None, Field(alias="migrationMarkdown")] = None + model: str + model_link: Annotated[str | None, Field(alias="modelLink")] = None + upgrade_copy: Annotated[str | None, Field(alias="upgradeCopy")] = None + + +class NetworkAccess(Enum): + restricted = "restricted" + enabled = "enabled" + + +class NetworkRequirements(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + allow_local_binding: Annotated[bool | None, Field(alias="allowLocalBinding")] = None + allow_unix_sockets: Annotated[list[str] | None, Field(alias="allowUnixSockets")] = None + allow_upstream_proxy: Annotated[bool | None, Field(alias="allowUpstreamProxy")] = None + allowed_domains: Annotated[list[str] | None, Field(alias="allowedDomains")] = None + dangerously_allow_all_unix_sockets: Annotated[ + bool | None, Field(alias="dangerouslyAllowAllUnixSockets") + ] = None + dangerously_allow_non_loopback_proxy: Annotated[ + bool | None, Field(alias="dangerouslyAllowNonLoopbackProxy") + ] = None + denied_domains: Annotated[list[str] | None, Field(alias="deniedDomains")] = None + enabled: bool | None = None + http_port: Annotated[int | None, Field(alias="httpPort", ge=0)] = None + socks_port: Annotated[int | None, Field(alias="socksPort", ge=0)] = None + + +class PatchApplyStatus(Enum): + in_progress = "inProgress" + completed = "completed" + failed = "failed" + declined = "declined" + + +class AddPatchChangeKind(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + type: Annotated[Literal["add"], Field(title="AddPatchChangeKindType")] + + +class DeletePatchChangeKind(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + type: Annotated[Literal["delete"], Field(title="DeletePatchChangeKindType")] + + +class UpdatePatchChangeKind(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + move_path: str | None = None + type: Annotated[Literal["update"], Field(title="UpdatePatchChangeKindType")] + + +class PatchChangeKind( + RootModel[AddPatchChangeKind | DeletePatchChangeKind | UpdatePatchChangeKind] +): + model_config = ConfigDict( + populate_by_name=True, + ) + root: AddPatchChangeKind | DeletePatchChangeKind | UpdatePatchChangeKind + + +class Personality(Enum): + none = "none" + friendly = "friendly" + pragmatic = "pragmatic" + + +class PlanDeltaNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + delta: str + item_id: Annotated[str, Field(alias="itemId")] + thread_id: Annotated[str, Field(alias="threadId")] + turn_id: Annotated[str, Field(alias="turnId")] + + +class PlanType(Enum): + free = "free" + go = "go" + plus = "plus" + pro = "pro" + team = "team" + business = "business" + enterprise = "enterprise" + edu = "edu" + unknown = "unknown" + + +class PluginAuthPolicy(Enum): + on_install = "ON_INSTALL" + on_use = "ON_USE" + + +class PluginInstallParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + marketplace_path: Annotated[AbsolutePathBuf, Field(alias="marketplacePath")] + plugin_name: Annotated[str, Field(alias="pluginName")] + + +class PluginInstallPolicy(Enum): + not_available = "NOT_AVAILABLE" + available = "AVAILABLE" + installed_by_default = "INSTALLED_BY_DEFAULT" + + +class PluginInstallResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + apps_needing_auth: Annotated[list[AppSummary], Field(alias="appsNeedingAuth")] + auth_policy: Annotated[PluginAuthPolicy, Field(alias="authPolicy")] + + +class PluginInterface(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + brand_color: Annotated[str | None, Field(alias="brandColor")] = None + capabilities: list[str] + category: str | None = None + composer_icon: Annotated[AbsolutePathBuf | None, Field(alias="composerIcon")] = None + default_prompt: Annotated[str | None, Field(alias="defaultPrompt")] = None + developer_name: Annotated[str | None, Field(alias="developerName")] = None + display_name: Annotated[str | None, Field(alias="displayName")] = None + logo: AbsolutePathBuf | None = None + long_description: Annotated[str | None, Field(alias="longDescription")] = None + privacy_policy_url: Annotated[str | None, Field(alias="privacyPolicyUrl")] = None + screenshots: list[AbsolutePathBuf] + short_description: Annotated[str | None, Field(alias="shortDescription")] = None + terms_of_service_url: Annotated[str | None, Field(alias="termsOfServiceUrl")] = None + website_url: Annotated[str | None, Field(alias="websiteUrl")] = None + + +class PluginListParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + cwds: Annotated[ + list[AbsolutePathBuf] | None, + Field( + description="Optional working directories used to discover repo marketplaces. When omitted, only home-scoped marketplaces and the official curated marketplace are considered." + ), + ] = None + force_remote_sync: Annotated[ + bool | None, + Field( + alias="forceRemoteSync", + description="When true, reconcile the official curated marketplace against the remote plugin state before listing marketplaces.", + ), + ] = None + + +class PluginReadParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + marketplace_path: Annotated[AbsolutePathBuf, Field(alias="marketplacePath")] + plugin_name: Annotated[str, Field(alias="pluginName")] + + +class LocalPluginSource(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + path: AbsolutePathBuf + type: Annotated[Literal["local"], Field(title="LocalPluginSourceType")] + + +class PluginSource(RootModel[LocalPluginSource]): + model_config = ConfigDict( + populate_by_name=True, + ) + root: LocalPluginSource + + +class PluginSummary(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + auth_policy: Annotated[PluginAuthPolicy, Field(alias="authPolicy")] + enabled: bool + id: str + install_policy: Annotated[PluginInstallPolicy, Field(alias="installPolicy")] + installed: bool + interface: PluginInterface | None = None + name: str + source: PluginSource + + +class PluginUninstallParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + plugin_id: Annotated[str, Field(alias="pluginId")] + + +class PluginUninstallResponse(BaseModel): + pass + model_config = ConfigDict( + populate_by_name=True, + ) + + +class ProductSurface(Enum): + chatgpt = "chatgpt" + codex = "codex" + api = "api" + atlas = "atlas" + + +class RateLimitWindow(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + resets_at: Annotated[int | None, Field(alias="resetsAt")] = None + used_percent: Annotated[int, Field(alias="usedPercent")] + window_duration_mins: Annotated[int | None, Field(alias="windowDurationMins")] = None + + +class RestrictedReadOnlyAccess(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + include_platform_defaults: Annotated[bool | None, Field(alias="includePlatformDefaults")] = True + readable_roots: Annotated[list[AbsolutePathBuf] | None, Field(alias="readableRoots")] = [] + type: Annotated[Literal["restricted"], Field(title="RestrictedReadOnlyAccessType")] + + +class FullAccessReadOnlyAccess(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + type: Annotated[Literal["fullAccess"], Field(title="FullAccessReadOnlyAccessType")] + + +class ReadOnlyAccess(RootModel[RestrictedReadOnlyAccess | FullAccessReadOnlyAccess]): + model_config = ConfigDict( + populate_by_name=True, + ) + root: RestrictedReadOnlyAccess | FullAccessReadOnlyAccess + + +class ReasoningEffort(Enum): + none = "none" + minimal = "minimal" + low = "low" + medium = "medium" + high = "high" + xhigh = "xhigh" + + +class ReasoningEffortOption(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + description: str + reasoning_effort: Annotated[ReasoningEffort, Field(alias="reasoningEffort")] + + +class ReasoningTextReasoningItemContent(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + text: str + type: Annotated[Literal["reasoning_text"], Field(title="ReasoningTextReasoningItemContentType")] + + +class TextReasoningItemContent(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + text: str + type: Annotated[Literal["text"], Field(title="TextReasoningItemContentType")] + + +class ReasoningItemContent(RootModel[ReasoningTextReasoningItemContent | TextReasoningItemContent]): + model_config = ConfigDict( + populate_by_name=True, + ) + root: ReasoningTextReasoningItemContent | TextReasoningItemContent + + +class SummaryTextReasoningItemReasoningSummary(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + text: str + type: Annotated[ + Literal["summary_text"], + Field(title="SummaryTextReasoningItemReasoningSummaryType"), + ] + + +class ReasoningItemReasoningSummary(RootModel[SummaryTextReasoningItemReasoningSummary]): + model_config = ConfigDict( + populate_by_name=True, + ) + root: SummaryTextReasoningItemReasoningSummary + + +class ReasoningSummaryValue(Enum): + auto = "auto" + concise = "concise" + detailed = "detailed" + + +class ReasoningSummary(RootModel[ReasoningSummaryValue | Literal["none"]]): + model_config = ConfigDict( + populate_by_name=True, + ) + root: Annotated[ + ReasoningSummaryValue | Literal["none"], + Field( + description="A summary of the reasoning performed by the model. This can be useful for debugging and understanding the model's reasoning process. See https://platform.openai.com/docs/guides/reasoning?api-mode=responses#reasoning-summaries" + ), + ] + + +class ReasoningSummaryPartAddedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + item_id: Annotated[str, Field(alias="itemId")] + summary_index: Annotated[int, Field(alias="summaryIndex")] + thread_id: Annotated[str, Field(alias="threadId")] + turn_id: Annotated[str, Field(alias="turnId")] + + +class ReasoningSummaryTextDeltaNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + delta: str + item_id: Annotated[str, Field(alias="itemId")] + summary_index: Annotated[int, Field(alias="summaryIndex")] + thread_id: Annotated[str, Field(alias="threadId")] + turn_id: Annotated[str, Field(alias="turnId")] + + +class ReasoningTextDeltaNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + content_index: Annotated[int, Field(alias="contentIndex")] + delta: str + item_id: Annotated[str, Field(alias="itemId")] + thread_id: Annotated[str, Field(alias="threadId")] + turn_id: Annotated[str, Field(alias="turnId")] + + +class RemoteSkillSummary(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + description: str + id: str + name: str + + +class RequestId(RootModel[str | int]): + model_config = ConfigDict( + populate_by_name=True, + ) + root: str | int + + +class ResidencyRequirement(RootModel[Literal["us"]]): + model_config = ConfigDict( + populate_by_name=True, + ) + root: Literal["us"] + + +class Resource(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + field_meta: Annotated[Any | None, Field(alias="_meta")] = None + annotations: Any | None = None + description: str | None = None + icons: list | None = None + mime_type: Annotated[str | None, Field(alias="mimeType")] = None + name: str + size: int | None = None + title: str | None = None + uri: str + + +class ResourceTemplate(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + annotations: Any | None = None + description: str | None = None + mime_type: Annotated[str | None, Field(alias="mimeType")] = None + name: str + title: str | None = None + uri_template: Annotated[str, Field(alias="uriTemplate")] + + +class MessageResponseItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + content: list[ContentItem] + end_turn: bool | None = None + id: str | None = None + phase: MessagePhase | None = None + role: str + type: Annotated[Literal["message"], Field(title="MessageResponseItemType")] + + +class ReasoningResponseItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + content: list[ReasoningItemContent] | None = None + encrypted_content: str | None = None + id: str + summary: list[ReasoningItemReasoningSummary] + type: Annotated[Literal["reasoning"], Field(title="ReasoningResponseItemType")] + + +class LocalShellCallResponseItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + action: LocalShellAction + call_id: Annotated[str | None, Field(description="Set when using the Responses API.")] = None + id: Annotated[ + str | None, + Field(description="Legacy id field retained for compatibility with older payloads."), + ] = None + status: LocalShellStatus + type: Annotated[Literal["local_shell_call"], Field(title="LocalShellCallResponseItemType")] + + +class FunctionCallResponseItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + arguments: str + call_id: str + id: str | None = None + name: str + namespace: str | None = None + type: Annotated[Literal["function_call"], Field(title="FunctionCallResponseItemType")] + + +class ToolSearchCallResponseItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + arguments: Any + call_id: str | None = None + execution: str + id: str | None = None + status: str | None = None + type: Annotated[Literal["tool_search_call"], Field(title="ToolSearchCallResponseItemType")] + + +class CustomToolCallResponseItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + call_id: str + id: str | None = None + input: str + name: str + status: str | None = None + type: Annotated[Literal["custom_tool_call"], Field(title="CustomToolCallResponseItemType")] + + +class ToolSearchOutputResponseItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + call_id: str | None = None + execution: str + status: str + tools: list + type: Annotated[Literal["tool_search_output"], Field(title="ToolSearchOutputResponseItemType")] + + +class ImageGenerationCallResponseItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: str + result: str + revised_prompt: str | None = None + status: str + type: Annotated[ + Literal["image_generation_call"], + Field(title="ImageGenerationCallResponseItemType"), + ] + + +class GhostSnapshotResponseItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + ghost_commit: GhostCommit + type: Annotated[Literal["ghost_snapshot"], Field(title="GhostSnapshotResponseItemType")] + + +class CompactionResponseItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + encrypted_content: str + type: Annotated[Literal["compaction"], Field(title="CompactionResponseItemType")] + + +class OtherResponseItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + type: Annotated[Literal["other"], Field(title="OtherResponseItemType")] + + +class SearchResponsesApiWebSearchAction(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + queries: list[str] | None = None + query: str | None = None + type: Annotated[Literal["search"], Field(title="SearchResponsesApiWebSearchActionType")] + + +class OpenPageResponsesApiWebSearchAction(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + type: Annotated[Literal["open_page"], Field(title="OpenPageResponsesApiWebSearchActionType")] + url: str | None = None + + +class FindInPageResponsesApiWebSearchAction(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + pattern: str | None = None + type: Annotated[ + Literal["find_in_page"], + Field(title="FindInPageResponsesApiWebSearchActionType"), + ] + url: str | None = None + + +class OtherResponsesApiWebSearchAction(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + type: Annotated[Literal["other"], Field(title="OtherResponsesApiWebSearchActionType")] + + +class ResponsesApiWebSearchAction( + RootModel[ + SearchResponsesApiWebSearchAction + | OpenPageResponsesApiWebSearchAction + | FindInPageResponsesApiWebSearchAction + | OtherResponsesApiWebSearchAction + ] +): + model_config = ConfigDict( + populate_by_name=True, + ) + root: ( + SearchResponsesApiWebSearchAction + | OpenPageResponsesApiWebSearchAction + | FindInPageResponsesApiWebSearchAction + | OtherResponsesApiWebSearchAction + ) + + +class ReviewDelivery(Enum): + inline = "inline" + detached = "detached" + + +class UncommittedChangesReviewTarget(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + type: Annotated[ + Literal["uncommittedChanges"], Field(title="UncommittedChangesReviewTargetType") + ] + + +class BaseBranchReviewTarget(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + branch: str + type: Annotated[Literal["baseBranch"], Field(title="BaseBranchReviewTargetType")] + + +class CommitReviewTarget(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + sha: str + title: Annotated[ + str | None, + Field(description="Optional human-readable label (e.g., commit subject) for UIs."), + ] = None + type: Annotated[Literal["commit"], Field(title="CommitReviewTargetType")] + + +class CustomReviewTarget(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + instructions: str + type: Annotated[Literal["custom"], Field(title="CustomReviewTargetType")] + + +class ReviewTarget( + RootModel[ + UncommittedChangesReviewTarget + | BaseBranchReviewTarget + | CommitReviewTarget + | CustomReviewTarget + ] +): + model_config = ConfigDict( + populate_by_name=True, + ) + root: ( + UncommittedChangesReviewTarget + | BaseBranchReviewTarget + | CommitReviewTarget + | CustomReviewTarget + ) + + +class SandboxMode(Enum): + read_only = "read-only" + workspace_write = "workspace-write" + danger_full_access = "danger-full-access" + + +class DangerFullAccessSandboxPolicy(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + type: Annotated[Literal["dangerFullAccess"], Field(title="DangerFullAccessSandboxPolicyType")] + + +class ReadOnlySandboxPolicy(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + access: Annotated[ReadOnlyAccess | None, Field()] = {"type": "fullAccess"} + network_access: Annotated[bool | None, Field(alias="networkAccess")] = False + type: Annotated[Literal["readOnly"], Field(title="ReadOnlySandboxPolicyType")] + + +class ExternalSandboxSandboxPolicy(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + network_access: Annotated[NetworkAccess | None, Field(alias="networkAccess")] = "restricted" + type: Annotated[Literal["externalSandbox"], Field(title="ExternalSandboxSandboxPolicyType")] + + +class WorkspaceWriteSandboxPolicy(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + exclude_slash_tmp: Annotated[bool | None, Field(alias="excludeSlashTmp")] = False + exclude_tmpdir_env_var: Annotated[bool | None, Field(alias="excludeTmpdirEnvVar")] = False + network_access: Annotated[bool | None, Field(alias="networkAccess")] = False + read_only_access: Annotated[ReadOnlyAccess | None, Field(alias="readOnlyAccess")] = { + "type": "fullAccess" + } + type: Annotated[Literal["workspaceWrite"], Field(title="WorkspaceWriteSandboxPolicyType")] + writable_roots: Annotated[list[AbsolutePathBuf] | None, Field(alias="writableRoots")] = [] + + +class SandboxPolicy( + RootModel[ + DangerFullAccessSandboxPolicy + | ReadOnlySandboxPolicy + | ExternalSandboxSandboxPolicy + | WorkspaceWriteSandboxPolicy + ] +): + model_config = ConfigDict( + populate_by_name=True, + ) + root: ( + DangerFullAccessSandboxPolicy + | ReadOnlySandboxPolicy + | ExternalSandboxSandboxPolicy + | WorkspaceWriteSandboxPolicy + ) + + +class SandboxWorkspaceWrite(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + exclude_slash_tmp: bool | None = False + exclude_tmpdir_env_var: bool | None = False + network_access: bool | None = False + writable_roots: list[str] | None = [] + + +class ItemAgentMessageDeltaServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["item/agentMessage/delta"], + Field(title="Item/agentMessage/deltaNotificationMethod"), + ] + params: AgentMessageDeltaNotification + + +class ItemPlanDeltaServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[Literal["item/plan/delta"], Field(title="Item/plan/deltaNotificationMethod")] + params: PlanDeltaNotification + + +class ItemCommandExecutionOutputDeltaServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["item/commandExecution/outputDelta"], + Field(title="Item/commandExecution/outputDeltaNotificationMethod"), + ] + params: CommandExecutionOutputDeltaNotification + + +class ItemFileChangeOutputDeltaServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["item/fileChange/outputDelta"], + Field(title="Item/fileChange/outputDeltaNotificationMethod"), + ] + params: FileChangeOutputDeltaNotification + + +class ItemMcpToolCallProgressServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["item/mcpToolCall/progress"], + Field(title="Item/mcpToolCall/progressNotificationMethod"), + ] + params: McpToolCallProgressNotification + + +class McpServerOauthLoginCompletedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["mcpServer/oauthLogin/completed"], + Field(title="McpServer/oauthLogin/completedNotificationMethod"), + ] + params: McpServerOauthLoginCompletedNotification + + +class ItemReasoningSummaryTextDeltaServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["item/reasoning/summaryTextDelta"], + Field(title="Item/reasoning/summaryTextDeltaNotificationMethod"), + ] + params: ReasoningSummaryTextDeltaNotification + + +class ItemReasoningSummaryPartAddedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["item/reasoning/summaryPartAdded"], + Field(title="Item/reasoning/summaryPartAddedNotificationMethod"), + ] + params: ReasoningSummaryPartAddedNotification + + +class ItemReasoningTextDeltaServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["item/reasoning/textDelta"], + Field(title="Item/reasoning/textDeltaNotificationMethod"), + ] + params: ReasoningTextDeltaNotification + + +class ThreadCompactedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["thread/compacted"], Field(title="Thread/compactedNotificationMethod") + ] + params: ContextCompactedNotification + + +class ModelReroutedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[Literal["model/rerouted"], Field(title="Model/reroutedNotificationMethod")] + params: ModelReroutedNotification + + +class DeprecationNoticeServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["deprecationNotice"], Field(title="DeprecationNoticeNotificationMethod") + ] + params: DeprecationNoticeNotification + + +class FuzzyFileSearchSessionUpdatedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["fuzzyFileSearch/sessionUpdated"], + Field(title="FuzzyFileSearch/sessionUpdatedNotificationMethod"), + ] + params: FuzzyFileSearchSessionUpdatedNotification + + +class FuzzyFileSearchSessionCompletedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["fuzzyFileSearch/sessionCompleted"], + Field(title="FuzzyFileSearch/sessionCompletedNotificationMethod"), + ] + params: FuzzyFileSearchSessionCompletedNotification + + +class AccountLoginCompletedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["account/login/completed"], + Field(title="Account/login/completedNotificationMethod"), + ] + params: AccountLoginCompletedNotification + + +class ServerRequestResolvedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + request_id: Annotated[RequestId, Field(alias="requestId")] + thread_id: Annotated[str, Field(alias="threadId")] + + +class ServiceTier(Enum): + fast = "fast" + flex = "flex" + + +class SessionSourceValue(Enum): + cli = "cli" + vscode = "vscode" + exec = "exec" + app_server = "appServer" + unknown = "unknown" + + +class Settings(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + developer_instructions: str | None = None + model: str + reasoning_effort: ReasoningEffort | None = None + + +class SkillErrorInfo(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + message: str + path: str + + +class SkillInterface(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + brand_color: Annotated[str | None, Field(alias="brandColor")] = None + default_prompt: Annotated[str | None, Field(alias="defaultPrompt")] = None + display_name: Annotated[str | None, Field(alias="displayName")] = None + icon_large: Annotated[str | None, Field(alias="iconLarge")] = None + icon_small: Annotated[str | None, Field(alias="iconSmall")] = None + short_description: Annotated[str | None, Field(alias="shortDescription")] = None + + +class SkillScope(Enum): + user = "user" + repo = "repo" + system = "system" + admin = "admin" + + +class SkillSummary(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + description: str + interface: SkillInterface | None = None + name: str + path: str + short_description: Annotated[str | None, Field(alias="shortDescription")] = None + + +class SkillToolDependency(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + command: str | None = None + description: str | None = None + transport: str | None = None + type: str + url: str | None = None + value: str + + +class SkillsChangedNotification(BaseModel): + pass + model_config = ConfigDict( + populate_by_name=True, + ) + + +class SkillsConfigWriteParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + enabled: bool + path: str + + +class SkillsConfigWriteResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + effective_enabled: Annotated[bool, Field(alias="effectiveEnabled")] + + +class SkillsListExtraRootsForCwd(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + cwd: str + extra_user_roots: Annotated[list[str], Field(alias="extraUserRoots")] + + +class SkillsListParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + cwds: Annotated[ + list[str] | None, + Field(description="When empty, defaults to the current session working directory."), + ] = None + force_reload: Annotated[ + bool | None, + Field( + alias="forceReload", + description="When true, bypass the skills cache and re-scan skills from disk.", + ), + ] = None + per_cwd_extra_user_roots: Annotated[ + list[SkillsListExtraRootsForCwd] | None, + Field( + alias="perCwdExtraUserRoots", + description="Optional per-cwd extra roots to scan as user-scoped skills.", + ), + ] = None + + +class SkillsRemoteReadParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + enabled: bool | None = False + hazelnut_scope: Annotated[HazelnutScope | None, Field(alias="hazelnutScope")] = "example" + product_surface: Annotated[ProductSurface | None, Field(alias="productSurface")] = "codex" + + +class SkillsRemoteReadResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + data: list[RemoteSkillSummary] + + +class SkillsRemoteWriteParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + hazelnut_id: Annotated[str, Field(alias="hazelnutId")] + + +class SkillsRemoteWriteResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: str + path: str + + +class SubAgentSourceValue(Enum): + review = "review" + compact = "compact" + memory_consolidation = "memory_consolidation" + + +class OtherSubAgentSource(BaseModel): + model_config = ConfigDict( + extra="forbid", + populate_by_name=True, + ) + other: str + + +class TerminalInteractionNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + item_id: Annotated[str, Field(alias="itemId")] + process_id: Annotated[str, Field(alias="processId")] + stdin: str + thread_id: Annotated[str, Field(alias="threadId")] + turn_id: Annotated[str, Field(alias="turnId")] + + +class TextElement(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + byte_range: Annotated[ + ByteRange, + Field( + alias="byteRange", + description="Byte range in the parent `text` buffer that this element occupies.", + ), + ] + placeholder: Annotated[ + str | None, + Field( + description="Optional human-readable placeholder for the element, displayed in the UI." + ), + ] = None + + +class TextPosition(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + column: Annotated[ + int, + Field(description="1-based column number (in Unicode scalar values).", ge=0), + ] + line: Annotated[int, Field(description="1-based line number.", ge=0)] + + +class TextRange(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + end: TextPosition + start: TextPosition + + +class ThreadActiveFlag(Enum): + waiting_on_approval = "waitingOnApproval" + waiting_on_user_input = "waitingOnUserInput" + + +class ThreadArchiveParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + thread_id: Annotated[str, Field(alias="threadId")] + + +class ThreadArchiveResponse(BaseModel): + pass + model_config = ConfigDict( + populate_by_name=True, + ) + + +class ThreadArchivedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + thread_id: Annotated[str, Field(alias="threadId")] + + +class ThreadClosedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + thread_id: Annotated[str, Field(alias="threadId")] + + +class ThreadCompactStartParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + thread_id: Annotated[str, Field(alias="threadId")] + + +class ThreadCompactStartResponse(BaseModel): + pass + model_config = ConfigDict( + populate_by_name=True, + ) + + +class ThreadForkParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + approval_policy: Annotated[AskForApproval | None, Field(alias="approvalPolicy")] = None + approvals_reviewer: Annotated[ + ApprovalsReviewer | None, + Field( + alias="approvalsReviewer", + description="Override where approval requests are routed for review on this thread and subsequent turns.", + ), + ] = None + base_instructions: Annotated[str | None, Field(alias="baseInstructions")] = None + config: dict[str, Any] | None = None + cwd: str | None = None + developer_instructions: Annotated[str | None, Field(alias="developerInstructions")] = None + ephemeral: bool | None = None + model: Annotated[ + str | None, + Field(description="Configuration overrides for the forked thread, if any."), + ] = None + model_provider: Annotated[str | None, Field(alias="modelProvider")] = None + sandbox: SandboxMode | None = None + service_tier: Annotated[ServiceTier | None, Field(alias="serviceTier")] = None + thread_id: Annotated[str, Field(alias="threadId")] + + +class ThreadId(RootModel[str]): + model_config = ConfigDict( + populate_by_name=True, + ) + root: str + + +class AgentMessageThreadItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: str + phase: MessagePhase | None = None + text: str + type: Annotated[Literal["agentMessage"], Field(title="AgentMessageThreadItemType")] + + +class PlanThreadItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: str + text: str + type: Annotated[Literal["plan"], Field(title="PlanThreadItemType")] + + +class ReasoningThreadItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + content: list[str] | None = [] + id: str + summary: list[str] | None = [] + type: Annotated[Literal["reasoning"], Field(title="ReasoningThreadItemType")] + + +class CommandExecutionThreadItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + aggregated_output: Annotated[ + str | None, + Field( + alias="aggregatedOutput", + description="The command's output, aggregated from stdout and stderr.", + ), + ] = None + command: Annotated[str, Field(description="The command to be executed.")] + command_actions: Annotated[ + list[CommandAction], + Field( + alias="commandActions", + description="A best-effort parsing of the command to understand the action(s) it will perform. This returns a list of CommandAction objects because a single shell command may be composed of many commands piped together.", + ), + ] + cwd: Annotated[str, Field(description="The command's working directory.")] + duration_ms: Annotated[ + int | None, + Field( + alias="durationMs", + description="The duration of the command execution in milliseconds.", + ), + ] = None + exit_code: Annotated[ + int | None, Field(alias="exitCode", description="The command's exit code.") + ] = None + id: str + process_id: Annotated[ + str | None, + Field( + alias="processId", + description="Identifier for the underlying PTY process (when available).", + ), + ] = None + status: CommandExecutionStatus + type: Annotated[Literal["commandExecution"], Field(title="CommandExecutionThreadItemType")] + + +class McpToolCallThreadItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + arguments: Any + duration_ms: Annotated[ + int | None, + Field( + alias="durationMs", + description="The duration of the MCP tool call in milliseconds.", + ), + ] = None + error: McpToolCallError | None = None + id: str + result: McpToolCallResult | None = None + server: str + status: McpToolCallStatus + tool: str + type: Annotated[Literal["mcpToolCall"], Field(title="McpToolCallThreadItemType")] + + +class DynamicToolCallThreadItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + arguments: Any + content_items: Annotated[ + list[DynamicToolCallOutputContentItem] | None, Field(alias="contentItems") + ] = None + duration_ms: Annotated[ + int | None, + Field( + alias="durationMs", + description="The duration of the dynamic tool call in milliseconds.", + ), + ] = None + id: str + status: DynamicToolCallStatus + success: bool | None = None + tool: str + type: Annotated[Literal["dynamicToolCall"], Field(title="DynamicToolCallThreadItemType")] + + +class ImageViewThreadItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: str + path: str + type: Annotated[Literal["imageView"], Field(title="ImageViewThreadItemType")] + + +class ImageGenerationThreadItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: str + result: str + revised_prompt: Annotated[str | None, Field(alias="revisedPrompt")] = None + status: str + type: Annotated[Literal["imageGeneration"], Field(title="ImageGenerationThreadItemType")] + + +class EnteredReviewModeThreadItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: str + review: str + type: Annotated[Literal["enteredReviewMode"], Field(title="EnteredReviewModeThreadItemType")] + + +class ExitedReviewModeThreadItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: str + review: str + type: Annotated[Literal["exitedReviewMode"], Field(title="ExitedReviewModeThreadItemType")] + + +class ContextCompactionThreadItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: str + type: Annotated[Literal["contextCompaction"], Field(title="ContextCompactionThreadItemType")] + + +class ThreadLoadedListParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + cursor: Annotated[ + str | None, + Field(description="Opaque pagination cursor returned by a previous call."), + ] = None + limit: Annotated[ + int | None, Field(description="Optional page size; defaults to no limit.", ge=0) + ] = None + + +class ThreadLoadedListResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + data: Annotated[ + list[str], + Field(description="Thread ids for sessions currently loaded in memory."), + ] + next_cursor: Annotated[ + str | None, + Field( + alias="nextCursor", + description="Opaque cursor to pass to the next call to continue after the last item. if None, there are no more items to return.", + ), + ] = None + + +class ThreadMetadataGitInfoUpdateParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + branch: Annotated[ + str | None, + Field( + description="Omit to leave the stored branch unchanged, set to `null` to clear it, or provide a non-empty string to replace it." + ), + ] = None + origin_url: Annotated[ + str | None, + Field( + alias="originUrl", + description="Omit to leave the stored origin URL unchanged, set to `null` to clear it, or provide a non-empty string to replace it.", + ), + ] = None + sha: Annotated[ + str | None, + Field( + description="Omit to leave the stored commit unchanged, set to `null` to clear it, or provide a non-empty string to replace it." + ), + ] = None + + +class ThreadMetadataUpdateParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + git_info: Annotated[ + ThreadMetadataGitInfoUpdateParams | None, + Field( + alias="gitInfo", + description="Patch the stored Git metadata for this thread. Omit a field to leave it unchanged, set it to `null` to clear it, or provide a string to replace the stored value.", + ), + ] = None + thread_id: Annotated[str, Field(alias="threadId")] + + +class ThreadNameUpdatedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + thread_id: Annotated[str, Field(alias="threadId")] + thread_name: Annotated[str | None, Field(alias="threadName")] = None + + +class ThreadReadParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + include_turns: Annotated[ + bool | None, + Field( + alias="includeTurns", + description="When true, include turns and their items from rollout history.", + ), + ] = False + thread_id: Annotated[str, Field(alias="threadId")] + + +class ThreadRealtimeAudioChunk(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + data: str + num_channels: Annotated[int, Field(alias="numChannels", ge=0)] + sample_rate: Annotated[int, Field(alias="sampleRate", ge=0)] + samples_per_channel: Annotated[int | None, Field(alias="samplesPerChannel", ge=0)] = None + + +class ThreadRealtimeClosedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + reason: str | None = None + thread_id: Annotated[str, Field(alias="threadId")] + + +class ThreadRealtimeErrorNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + message: str + thread_id: Annotated[str, Field(alias="threadId")] + + +class ThreadRealtimeItemAddedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + item: Any + thread_id: Annotated[str, Field(alias="threadId")] + + +class ThreadRealtimeOutputAudioDeltaNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + audio: ThreadRealtimeAudioChunk + thread_id: Annotated[str, Field(alias="threadId")] + + +class ThreadRealtimeStartedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + session_id: Annotated[str | None, Field(alias="sessionId")] = None + thread_id: Annotated[str, Field(alias="threadId")] + + +class ThreadResumeParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + approval_policy: Annotated[AskForApproval | None, Field(alias="approvalPolicy")] = None + approvals_reviewer: Annotated[ + ApprovalsReviewer | None, + Field( + alias="approvalsReviewer", + description="Override where approval requests are routed for review on this thread and subsequent turns.", + ), + ] = None + base_instructions: Annotated[str | None, Field(alias="baseInstructions")] = None + config: dict[str, Any] | None = None + cwd: str | None = None + developer_instructions: Annotated[str | None, Field(alias="developerInstructions")] = None + model: Annotated[ + str | None, + Field(description="Configuration overrides for the resumed thread, if any."), + ] = None + model_provider: Annotated[str | None, Field(alias="modelProvider")] = None + personality: Personality | None = None + sandbox: SandboxMode | None = None + service_tier: Annotated[ServiceTier | None, Field(alias="serviceTier")] = None + thread_id: Annotated[str, Field(alias="threadId")] + + +class ThreadRollbackParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + num_turns: Annotated[ + int, + Field( + alias="numTurns", + description="The number of turns to drop from the end of the thread. Must be >= 1.\n\nThis only modifies the thread's history and does not revert local file changes that have been made by the agent. Clients are responsible for reverting these changes.", + ge=0, + ), + ] + thread_id: Annotated[str, Field(alias="threadId")] + + +class ThreadSetNameParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + name: str + thread_id: Annotated[str, Field(alias="threadId")] + + +class ThreadSetNameResponse(BaseModel): + pass + model_config = ConfigDict( + populate_by_name=True, + ) + + +class ThreadSortKey(Enum): + created_at = "created_at" + updated_at = "updated_at" + + +class ThreadSourceKind(Enum): + cli = "cli" + vscode = "vscode" + exec = "exec" + app_server = "appServer" + sub_agent = "subAgent" + sub_agent_review = "subAgentReview" + sub_agent_compact = "subAgentCompact" + sub_agent_thread_spawn = "subAgentThreadSpawn" + sub_agent_other = "subAgentOther" + unknown = "unknown" + + +class ThreadStartParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + approval_policy: Annotated[AskForApproval | None, Field(alias="approvalPolicy")] = None + approvals_reviewer: Annotated[ + ApprovalsReviewer | None, + Field( + alias="approvalsReviewer", + description="Override where approval requests are routed for review on this thread and subsequent turns.", + ), + ] = None + base_instructions: Annotated[str | None, Field(alias="baseInstructions")] = None + config: dict[str, Any] | None = None + cwd: str | None = None + developer_instructions: Annotated[str | None, Field(alias="developerInstructions")] = None + ephemeral: bool | None = None + model: str | None = None + model_provider: Annotated[str | None, Field(alias="modelProvider")] = None + personality: Personality | None = None + sandbox: SandboxMode | None = None + service_name: Annotated[str | None, Field(alias="serviceName")] = None + service_tier: Annotated[ServiceTier | None, Field(alias="serviceTier")] = None + + +class NotLoadedThreadStatus(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + type: Annotated[Literal["notLoaded"], Field(title="NotLoadedThreadStatusType")] + + +class IdleThreadStatus(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + type: Annotated[Literal["idle"], Field(title="IdleThreadStatusType")] + + +class SystemErrorThreadStatus(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + type: Annotated[Literal["systemError"], Field(title="SystemErrorThreadStatusType")] + + +class ActiveThreadStatus(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + active_flags: Annotated[list[ThreadActiveFlag], Field(alias="activeFlags")] + type: Annotated[Literal["active"], Field(title="ActiveThreadStatusType")] + + +class ThreadStatus( + RootModel[ + NotLoadedThreadStatus | IdleThreadStatus | SystemErrorThreadStatus | ActiveThreadStatus + ] +): + model_config = ConfigDict( + populate_by_name=True, + ) + root: NotLoadedThreadStatus | IdleThreadStatus | SystemErrorThreadStatus | ActiveThreadStatus + + +class ThreadStatusChangedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + status: ThreadStatus + thread_id: Annotated[str, Field(alias="threadId")] + + +class ThreadUnarchiveParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + thread_id: Annotated[str, Field(alias="threadId")] + + +class ThreadUnarchivedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + thread_id: Annotated[str, Field(alias="threadId")] + + +class ThreadUnsubscribeParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + thread_id: Annotated[str, Field(alias="threadId")] + + +class ThreadUnsubscribeStatus(Enum): + not_loaded = "notLoaded" + not_subscribed = "notSubscribed" + unsubscribed = "unsubscribed" + + +class TokenUsageBreakdown(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + cached_input_tokens: Annotated[int, Field(alias="cachedInputTokens")] + input_tokens: Annotated[int, Field(alias="inputTokens")] + output_tokens: Annotated[int, Field(alias="outputTokens")] + reasoning_output_tokens: Annotated[int, Field(alias="reasoningOutputTokens")] + total_tokens: Annotated[int, Field(alias="totalTokens")] + + +class Tool(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + field_meta: Annotated[Any | None, Field(alias="_meta")] = None + annotations: Any | None = None + description: str | None = None + icons: list | None = None + input_schema: Annotated[Any, Field(alias="inputSchema")] + name: str + output_schema: Annotated[Any | None, Field(alias="outputSchema")] = None + title: str | None = None + + +class TurnDiffUpdatedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + diff: str + thread_id: Annotated[str, Field(alias="threadId")] + turn_id: Annotated[str, Field(alias="turnId")] + + +class TurnError(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + additional_details: Annotated[str | None, Field(alias="additionalDetails")] = None + codex_error_info: Annotated[CodexErrorInfo | None, Field(alias="codexErrorInfo")] = None + message: str + + +class TurnInterruptParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + thread_id: Annotated[str, Field(alias="threadId")] + turn_id: Annotated[str, Field(alias="turnId")] + + +class TurnInterruptResponse(BaseModel): + pass + model_config = ConfigDict( + populate_by_name=True, + ) + + +class TurnPlanStepStatus(Enum): + pending = "pending" + in_progress = "inProgress" + completed = "completed" + + +class TurnStatus(Enum): + completed = "completed" + interrupted = "interrupted" + failed = "failed" + in_progress = "inProgress" + + +class TurnSteerResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + turn_id: Annotated[str, Field(alias="turnId")] + + +class TextUserInput(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + text: str + text_elements: Annotated[ + list[TextElement] | None, + Field( + description="UI-defined spans within `text` used to render or persist special elements." + ), + ] = [] + type: Annotated[Literal["text"], Field(title="TextUserInputType")] + + +class ImageUserInput(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + type: Annotated[Literal["image"], Field(title="ImageUserInputType")] + url: str + + +class LocalImageUserInput(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + path: str + type: Annotated[Literal["localImage"], Field(title="LocalImageUserInputType")] + + +class SkillUserInput(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + name: str + path: str + type: Annotated[Literal["skill"], Field(title="SkillUserInputType")] + + +class MentionUserInput(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + name: str + path: str + type: Annotated[Literal["mention"], Field(title="MentionUserInputType")] + + +class UserInput( + RootModel[ + TextUserInput | ImageUserInput | LocalImageUserInput | SkillUserInput | MentionUserInput + ] +): + model_config = ConfigDict( + populate_by_name=True, + ) + root: TextUserInput | ImageUserInput | LocalImageUserInput | SkillUserInput | MentionUserInput + + +class Verbosity(Enum): + low = "low" + medium = "medium" + high = "high" + + +class SearchWebSearchAction(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + queries: list[str] | None = None + query: str | None = None + type: Annotated[Literal["search"], Field(title="SearchWebSearchActionType")] + + +class OpenPageWebSearchAction(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + type: Annotated[Literal["openPage"], Field(title="OpenPageWebSearchActionType")] + url: str | None = None + + +class FindInPageWebSearchAction(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + pattern: str | None = None + type: Annotated[Literal["findInPage"], Field(title="FindInPageWebSearchActionType")] + url: str | None = None + + +class OtherWebSearchAction(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + type: Annotated[Literal["other"], Field(title="OtherWebSearchActionType")] + + +class WebSearchAction( + RootModel[ + SearchWebSearchAction + | OpenPageWebSearchAction + | FindInPageWebSearchAction + | OtherWebSearchAction + ] +): + model_config = ConfigDict( + populate_by_name=True, + ) + root: ( + SearchWebSearchAction + | OpenPageWebSearchAction + | FindInPageWebSearchAction + | OtherWebSearchAction + ) + + +class WebSearchContextSize(Enum): + low = "low" + medium = "medium" + high = "high" + + +class WebSearchLocation(BaseModel): + model_config = ConfigDict( + extra="forbid", + populate_by_name=True, + ) + city: str | None = None + country: str | None = None + region: str | None = None + timezone: str | None = None + + +class WebSearchMode(Enum): + disabled = "disabled" + cached = "cached" + live = "live" + + +class WebSearchToolConfig(BaseModel): + model_config = ConfigDict( + extra="forbid", + populate_by_name=True, + ) + allowed_domains: list[str] | None = None + context_size: WebSearchContextSize | None = None + location: WebSearchLocation | None = None + + +class WindowsSandboxSetupMode(Enum): + elevated = "elevated" + unelevated = "unelevated" + + +class WindowsSandboxSetupStartParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + cwd: AbsolutePathBuf | None = None + mode: WindowsSandboxSetupMode + + +class WindowsSandboxSetupStartResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + started: bool + + +class WindowsWorldWritableWarningNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + extra_count: Annotated[int, Field(alias="extraCount", ge=0)] + failed_scan: Annotated[bool, Field(alias="failedScan")] + sample_paths: Annotated[list[str], Field(alias="samplePaths")] + + +class WriteStatus(Enum): + ok = "ok" + ok_overridden = "okOverridden" + + +class ChatgptAccount(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + email: str + plan_type: Annotated[PlanType, Field(alias="planType")] + type: Annotated[Literal["chatgpt"], Field(title="ChatgptAccountType")] + + +class Account(RootModel[ApiKeyAccount | ChatgptAccount]): + model_config = ConfigDict( + populate_by_name=True, + ) + root: ApiKeyAccount | ChatgptAccount + + +class AccountUpdatedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + auth_mode: Annotated[AuthMode | None, Field(alias="authMode")] = None + plan_type: Annotated[PlanType | None, Field(alias="planType")] = None + + +class AppConfig(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + default_tools_approval_mode: AppToolApproval | None = None + default_tools_enabled: bool | None = None + destructive_enabled: bool | None = None + enabled: bool | None = True + open_world_enabled: bool | None = None + tools: AppToolsConfig | None = None + + +class AppMetadata(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + categories: list[str] | None = None + developer: str | None = None + first_party_requires_install: Annotated[ + bool | None, Field(alias="firstPartyRequiresInstall") + ] = None + first_party_type: Annotated[str | None, Field(alias="firstPartyType")] = None + review: AppReview | None = None + screenshots: list[AppScreenshot] | None = None + seo_description: Annotated[str | None, Field(alias="seoDescription")] = None + show_in_composer_when_unlinked: Annotated[ + bool | None, Field(alias="showInComposerWhenUnlinked") + ] = None + sub_categories: Annotated[list[str] | None, Field(alias="subCategories")] = None + version: str | None = None + version_id: Annotated[str | None, Field(alias="versionId")] = None + version_notes: Annotated[str | None, Field(alias="versionNotes")] = None + + +class AppsConfig(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + field_default: Annotated[AppsDefaultConfig | None, Field(alias="_default")] = None + + +class CancelLoginAccountResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + status: CancelLoginAccountStatus + + +class InitializeRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["initialize"], Field(title="InitializeRequestMethod")] + params: InitializeParams + + +class ThreadStartRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["thread/start"], Field(title="Thread/startRequestMethod")] + params: ThreadStartParams + + +class ThreadResumeRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["thread/resume"], Field(title="Thread/resumeRequestMethod")] + params: ThreadResumeParams + + +class ThreadForkRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["thread/fork"], Field(title="Thread/forkRequestMethod")] + params: ThreadForkParams + + +class ThreadArchiveRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["thread/archive"], Field(title="Thread/archiveRequestMethod")] + params: ThreadArchiveParams + + +class ThreadUnsubscribeRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["thread/unsubscribe"], Field(title="Thread/unsubscribeRequestMethod")] + params: ThreadUnsubscribeParams + + +class ThreadNameSetRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["thread/name/set"], Field(title="Thread/name/setRequestMethod")] + params: ThreadSetNameParams + + +class ThreadMetadataUpdateRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[ + Literal["thread/metadata/update"], + Field(title="Thread/metadata/updateRequestMethod"), + ] + params: ThreadMetadataUpdateParams + + +class ThreadUnarchiveRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["thread/unarchive"], Field(title="Thread/unarchiveRequestMethod")] + params: ThreadUnarchiveParams + + +class ThreadCompactStartRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[ + Literal["thread/compact/start"], + Field(title="Thread/compact/startRequestMethod"), + ] + params: ThreadCompactStartParams + + +class ThreadRollbackRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["thread/rollback"], Field(title="Thread/rollbackRequestMethod")] + params: ThreadRollbackParams + + +class ThreadLoadedListRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["thread/loaded/list"], Field(title="Thread/loaded/listRequestMethod")] + params: ThreadLoadedListParams + + +class ThreadReadRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["thread/read"], Field(title="Thread/readRequestMethod")] + params: ThreadReadParams + + +class SkillsListRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["skills/list"], Field(title="Skills/listRequestMethod")] + params: SkillsListParams + + +class PluginListRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["plugin/list"], Field(title="Plugin/listRequestMethod")] + params: PluginListParams + + +class PluginReadRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["plugin/read"], Field(title="Plugin/readRequestMethod")] + params: PluginReadParams + + +class SkillsRemoteListRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["skills/remote/list"], Field(title="Skills/remote/listRequestMethod")] + params: SkillsRemoteReadParams + + +class SkillsRemoteExportRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[ + Literal["skills/remote/export"], + Field(title="Skills/remote/exportRequestMethod"), + ] + params: SkillsRemoteWriteParams + + +class AppListRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["app/list"], Field(title="App/listRequestMethod")] + params: AppsListParams + + +class FsReadFileRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["fs/readFile"], Field(title="Fs/readFileRequestMethod")] + params: FsReadFileParams + + +class FsWriteFileRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["fs/writeFile"], Field(title="Fs/writeFileRequestMethod")] + params: FsWriteFileParams + + +class FsCreateDirectoryRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["fs/createDirectory"], Field(title="Fs/createDirectoryRequestMethod")] + params: FsCreateDirectoryParams + + +class FsGetMetadataRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["fs/getMetadata"], Field(title="Fs/getMetadataRequestMethod")] + params: FsGetMetadataParams + + +class FsReadDirectoryRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["fs/readDirectory"], Field(title="Fs/readDirectoryRequestMethod")] + params: FsReadDirectoryParams + + +class FsRemoveRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["fs/remove"], Field(title="Fs/removeRequestMethod")] + params: FsRemoveParams + + +class FsCopyRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["fs/copy"], Field(title="Fs/copyRequestMethod")] + params: FsCopyParams + + +class SkillsConfigWriteRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[ + Literal["skills/config/write"], Field(title="Skills/config/writeRequestMethod") + ] + params: SkillsConfigWriteParams + + +class PluginInstallRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["plugin/install"], Field(title="Plugin/installRequestMethod")] + params: PluginInstallParams + + +class PluginUninstallRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["plugin/uninstall"], Field(title="Plugin/uninstallRequestMethod")] + params: PluginUninstallParams + + +class TurnInterruptRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["turn/interrupt"], Field(title="Turn/interruptRequestMethod")] + params: TurnInterruptParams + + +class ModelListRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["model/list"], Field(title="Model/listRequestMethod")] + params: ModelListParams + + +class ExperimentalFeatureListRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[ + Literal["experimentalFeature/list"], + Field(title="ExperimentalFeature/listRequestMethod"), + ] + params: ExperimentalFeatureListParams + + +class McpServerOauthLoginRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[ + Literal["mcpServer/oauth/login"], + Field(title="McpServer/oauth/loginRequestMethod"), + ] + params: McpServerOauthLoginParams + + +class ConfigMcpServerReloadRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[ + Literal["config/mcpServer/reload"], + Field(title="Config/mcpServer/reloadRequestMethod"), + ] + params: None = None + + +class McpServerStatusListRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[ + Literal["mcpServerStatus/list"], + Field(title="McpServerStatus/listRequestMethod"), + ] + params: ListMcpServerStatusParams + + +class WindowsSandboxSetupStartRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[ + Literal["windowsSandbox/setupStart"], + Field(title="WindowsSandbox/setupStartRequestMethod"), + ] + params: WindowsSandboxSetupStartParams + + +class AccountLoginStartRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[ + Literal["account/login/start"], Field(title="Account/login/startRequestMethod") + ] + params: LoginAccountParams + + +class AccountLoginCancelRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[ + Literal["account/login/cancel"], + Field(title="Account/login/cancelRequestMethod"), + ] + params: CancelLoginAccountParams + + +class AccountLogoutRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["account/logout"], Field(title="Account/logoutRequestMethod")] + params: None = None + + +class AccountRateLimitsReadRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[ + Literal["account/rateLimits/read"], + Field(title="Account/rateLimits/readRequestMethod"), + ] + params: None = None + + +class FeedbackUploadRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["feedback/upload"], Field(title="Feedback/uploadRequestMethod")] + params: FeedbackUploadParams + + +class CommandExecWriteRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["command/exec/write"], Field(title="Command/exec/writeRequestMethod")] + params: CommandExecWriteParams + + +class CommandExecTerminateRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[ + Literal["command/exec/terminate"], + Field(title="Command/exec/terminateRequestMethod"), + ] + params: CommandExecTerminateParams + + +class ConfigReadRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["config/read"], Field(title="Config/readRequestMethod")] + params: ConfigReadParams + + +class ExternalAgentConfigDetectRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[ + Literal["externalAgentConfig/detect"], + Field(title="ExternalAgentConfig/detectRequestMethod"), + ] + params: ExternalAgentConfigDetectParams + + +class ConfigRequirementsReadRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[ + Literal["configRequirements/read"], + Field(title="ConfigRequirements/readRequestMethod"), + ] + params: None = None + + +class AccountReadRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["account/read"], Field(title="Account/readRequestMethod")] + params: GetAccountParams + + +class FuzzyFileSearchRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["fuzzyFileSearch"], Field(title="FuzzyFileSearchRequestMethod")] + params: FuzzyFileSearchParams + + +class CollabAgentState(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + message: str | None = None + status: CollabAgentStatus + + +class CollaborationMode(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + mode: ModeKind + settings: Settings + + +class CollaborationModeMask(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + mode: ModeKind | None = None + model: str | None = None + name: str + reasoning_effort: ReasoningEffort | None = None + + +class CommandExecOutputDeltaNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + cap_reached: Annotated[ + bool, + Field( + alias="capReached", + description="`true` on the final streamed chunk for a stream when `outputBytesCap` truncated later output on that stream.", + ), + ] + delta_base64: Annotated[ + str, Field(alias="deltaBase64", description="Base64-encoded output bytes.") + ] + process_id: Annotated[ + str, + Field( + alias="processId", + description="Client-supplied, connection-scoped `processId` from the original `command/exec` request.", + ), + ] + stream: Annotated[CommandExecOutputStream, Field(description="Output stream for this chunk.")] + + +class CommandExecParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + command: Annotated[ + list[str], Field(description="Command argv vector. Empty arrays are rejected.") + ] + cwd: Annotated[ + str | None, + Field(description="Optional working directory. Defaults to the server cwd."), + ] = None + disable_output_cap: Annotated[ + bool | None, + Field( + alias="disableOutputCap", + description="Disable stdout/stderr capture truncation for this request.\n\nCannot be combined with `outputBytesCap`.", + ), + ] = None + disable_timeout: Annotated[ + bool | None, + Field( + alias="disableTimeout", + description="Disable the timeout entirely for this request.\n\nCannot be combined with `timeoutMs`.", + ), + ] = None + env: Annotated[ + dict[str, Any] | None, + Field( + description="Optional environment overrides merged into the server-computed environment.\n\nMatching names override inherited values. Set a key to `null` to unset an inherited variable." + ), + ] = None + output_bytes_cap: Annotated[ + int | None, + Field( + alias="outputBytesCap", + description="Optional per-stream stdout/stderr capture cap in bytes.\n\nWhen omitted, the server default applies. Cannot be combined with `disableOutputCap`.", + ge=0, + ), + ] = None + process_id: Annotated[ + str | None, + Field( + alias="processId", + description="Optional client-supplied, connection-scoped process id.\n\nRequired for `tty`, `streamStdin`, `streamStdoutStderr`, and follow-up `command/exec/write`, `command/exec/resize`, and `command/exec/terminate` calls. When omitted, buffered execution gets an internal id that is not exposed to the client.", + ), + ] = None + sandbox_policy: Annotated[ + SandboxPolicy | None, + Field( + alias="sandboxPolicy", + description="Optional sandbox policy for this command.\n\nUses the same shape as thread/turn execution sandbox configuration and defaults to the user's configured policy when omitted.", + ), + ] = None + size: Annotated[ + CommandExecTerminalSize | None, + Field( + description="Optional initial PTY size in character cells. Only valid when `tty` is true." + ), + ] = None + stream_stdin: Annotated[ + bool | None, + Field( + alias="streamStdin", + description="Allow follow-up `command/exec/write` requests to write stdin bytes.\n\nRequires a client-supplied `processId`.", + ), + ] = None + stream_stdout_stderr: Annotated[ + bool | None, + Field( + alias="streamStdoutStderr", + description="Stream stdout/stderr via `command/exec/outputDelta` notifications.\n\nStreamed bytes are not duplicated into the final response and require a client-supplied `processId`.", + ), + ] = None + timeout_ms: Annotated[ + int | None, + Field( + alias="timeoutMs", + description="Optional timeout in milliseconds.\n\nWhen omitted, the server default applies. Cannot be combined with `disableTimeout`.", + ), + ] = None + tty: Annotated[ + bool | None, + Field( + description="Enable PTY mode.\n\nThis implies `streamStdin` and `streamStdoutStderr`." + ), + ] = None + + +class CommandExecResizeParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + process_id: Annotated[ + str, + Field( + alias="processId", + description="Client-supplied, connection-scoped `processId` from the original `command/exec` request.", + ), + ] + size: Annotated[CommandExecTerminalSize, Field(description="New PTY size in character cells.")] + + +class ConfigEdit(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + key_path: Annotated[str, Field(alias="keyPath")] + merge_strategy: Annotated[MergeStrategy, Field(alias="mergeStrategy")] + value: Any + + +class ConfigLayer(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + config: Any + disabled_reason: Annotated[str | None, Field(alias="disabledReason")] = None + name: ConfigLayerSource + version: str + + +class ConfigLayerMetadata(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + name: ConfigLayerSource + version: str + + +class ConfigRequirements(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + allowed_approval_policies: Annotated[ + list[AskForApproval] | None, Field(alias="allowedApprovalPolicies") + ] = None + allowed_sandbox_modes: Annotated[ + list[SandboxMode] | None, Field(alias="allowedSandboxModes") + ] = None + allowed_web_search_modes: Annotated[ + list[WebSearchMode] | None, Field(alias="allowedWebSearchModes") + ] = None + enforce_residency: Annotated[ResidencyRequirement | None, Field(alias="enforceResidency")] = ( + None + ) + feature_requirements: Annotated[dict[str, Any] | None, Field(alias="featureRequirements")] = ( + None + ) + + +class ConfigRequirementsReadResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + requirements: Annotated[ + ConfigRequirements | None, + Field( + description="Null if no requirements are configured (e.g. no requirements.toml/MDM entries)." + ), + ] = None + + +class ConfigValueWriteParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + expected_version: Annotated[str | None, Field(alias="expectedVersion")] = None + file_path: Annotated[ + str | None, + Field( + alias="filePath", + description="Path to the config file to write; defaults to the user's `config.toml` when omitted.", + ), + ] = None + key_path: Annotated[str, Field(alias="keyPath")] + merge_strategy: Annotated[MergeStrategy, Field(alias="mergeStrategy")] + value: Any + + +class ConfigWarningNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + details: Annotated[ + str | None, Field(description="Optional extra guidance or error details.") + ] = None + path: Annotated[ + str | None, + Field(description="Optional path to the config file that triggered the warning."), + ] = None + range: Annotated[ + TextRange | None, + Field(description="Optional range for the error location inside the config file."), + ] = None + summary: Annotated[str, Field(description="Concise summary of the warning.")] + + +class ErrorNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + error: TurnError + thread_id: Annotated[str, Field(alias="threadId")] + turn_id: Annotated[str, Field(alias="turnId")] + will_retry: Annotated[bool, Field(alias="willRetry")] + + +class ExperimentalFeature(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + announcement: Annotated[ + str | None, + Field( + description="Announcement copy shown to users when the feature is introduced. Null when this feature is not in beta." + ), + ] = None + default_enabled: Annotated[ + bool, + Field( + alias="defaultEnabled", + description="Whether this feature is enabled by default.", + ), + ] + description: Annotated[ + str | None, + Field( + description="Short summary describing what the feature does. Null when this feature is not in beta." + ), + ] = None + display_name: Annotated[ + str | None, + Field( + alias="displayName", + description="User-facing display name shown in the experimental features UI. Null when this feature is not in beta.", + ), + ] = None + enabled: Annotated[ + bool, + Field(description="Whether this feature is currently enabled in the loaded config."), + ] + name: Annotated[str, Field(description="Stable key used in config.toml and CLI flag toggles.")] + stage: Annotated[ + ExperimentalFeatureStage, + Field(description="Lifecycle stage of this feature flag."), + ] + + +class ExperimentalFeatureListResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + data: list[ExperimentalFeature] + next_cursor: Annotated[ + str | None, + Field( + alias="nextCursor", + description="Opaque cursor to pass to the next call to continue after the last item. If None, there are no more items to return.", + ), + ] = None + + +class ExternalAgentConfigMigrationItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + cwd: Annotated[ + str | None, + Field( + description="Null or empty means home-scoped migration; non-empty means repo-scoped migration." + ), + ] = None + description: str + item_type: Annotated[ExternalAgentConfigMigrationItemType, Field(alias="itemType")] + + +class FileUpdateChange(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + diff: str + kind: PatchChangeKind + path: str + + +class InputImageFunctionCallOutputContentItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + detail: ImageDetail | None = None + image_url: str + type: Annotated[ + Literal["input_image"], + Field(title="InputImageFunctionCallOutputContentItemType"), + ] + + +class FunctionCallOutputContentItem( + RootModel[InputTextFunctionCallOutputContentItem | InputImageFunctionCallOutputContentItem] +): + model_config = ConfigDict( + populate_by_name=True, + ) + root: Annotated[ + InputTextFunctionCallOutputContentItem | InputImageFunctionCallOutputContentItem, + Field( + description="Responses API compatible content items that can be returned by a tool call. This is a subset of ContentItem with the types we support as function call outputs." + ), + ] + + +class GetAccountResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + account: Account | None = None + requires_openai_auth: Annotated[bool, Field(alias="requiresOpenaiAuth")] + + +class GuardianApprovalReview(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + rationale: str | None = None + risk_level: Annotated[GuardianRiskLevel | None, Field(alias="riskLevel")] = None + risk_score: Annotated[int | None, Field(alias="riskScore", ge=0)] = None + status: GuardianApprovalReviewStatus + + +class HookOutputEntry(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + kind: HookOutputEntryKind + text: str + + +class HookRunSummary(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + completed_at: Annotated[int | None, Field(alias="completedAt")] = None + display_order: Annotated[int, Field(alias="displayOrder")] + duration_ms: Annotated[int | None, Field(alias="durationMs")] = None + entries: list[HookOutputEntry] + event_name: Annotated[HookEventName, Field(alias="eventName")] + execution_mode: Annotated[HookExecutionMode, Field(alias="executionMode")] + handler_type: Annotated[HookHandlerType, Field(alias="handlerType")] + id: str + scope: HookScope + source_path: Annotated[str, Field(alias="sourcePath")] + started_at: Annotated[int, Field(alias="startedAt")] + status: HookRunStatus + status_message: Annotated[str | None, Field(alias="statusMessage")] = None + + +class HookStartedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + run: HookRunSummary + thread_id: Annotated[str, Field(alias="threadId")] + turn_id: Annotated[str | None, Field(alias="turnId")] = None + + +class ItemGuardianApprovalReviewCompletedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + action: Any | None = None + review: GuardianApprovalReview + target_item_id: Annotated[str, Field(alias="targetItemId")] + thread_id: Annotated[str, Field(alias="threadId")] + turn_id: Annotated[str, Field(alias="turnId")] + + +class ItemGuardianApprovalReviewStartedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + action: Any | None = None + review: GuardianApprovalReview + target_item_id: Annotated[str, Field(alias="targetItemId")] + thread_id: Annotated[str, Field(alias="threadId")] + turn_id: Annotated[str, Field(alias="turnId")] + + +class McpServerStatus(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + auth_status: Annotated[McpAuthStatus, Field(alias="authStatus")] + name: str + resource_templates: Annotated[list[ResourceTemplate], Field(alias="resourceTemplates")] + resources: list[Resource] + tools: dict[str, Tool] + + +class Model(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + availability_nux: Annotated[ModelAvailabilityNux | None, Field(alias="availabilityNux")] = None + default_reasoning_effort: Annotated[ReasoningEffort, Field(alias="defaultReasoningEffort")] + description: str + display_name: Annotated[str, Field(alias="displayName")] + hidden: bool + id: str + input_modalities: Annotated[list[InputModality] | None, Field(alias="inputModalities")] = [ + "text", + "image", + ] + is_default: Annotated[bool, Field(alias="isDefault")] + model: str + supported_reasoning_efforts: Annotated[ + list[ReasoningEffortOption], Field(alias="supportedReasoningEfforts") + ] + supports_personality: Annotated[bool | None, Field(alias="supportsPersonality")] = False + upgrade: str | None = None + upgrade_info: Annotated[ModelUpgradeInfo | None, Field(alias="upgradeInfo")] = None + + +class ModelListResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + data: list[Model] + next_cursor: Annotated[ + str | None, + Field( + alias="nextCursor", + description="Opaque cursor to pass to the next call to continue after the last item. If None, there are no more items to return.", + ), + ] = None + + +class OverriddenMetadata(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + effective_value: Annotated[Any, Field(alias="effectiveValue")] + message: str + overriding_layer: Annotated[ConfigLayerMetadata, Field(alias="overridingLayer")] + + +class PluginDetail(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + apps: list[AppSummary] + description: str | None = None + marketplace_name: Annotated[str, Field(alias="marketplaceName")] + marketplace_path: Annotated[AbsolutePathBuf, Field(alias="marketplacePath")] + mcp_servers: Annotated[list[str], Field(alias="mcpServers")] + skills: list[SkillSummary] + summary: PluginSummary + + +class PluginMarketplaceEntry(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + name: str + path: AbsolutePathBuf + plugins: list[PluginSummary] + + +class PluginReadResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + plugin: PluginDetail + + +class RateLimitSnapshot(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + credits: CreditsSnapshot | None = None + limit_id: Annotated[str | None, Field(alias="limitId")] = None + limit_name: Annotated[str | None, Field(alias="limitName")] = None + plan_type: Annotated[PlanType | None, Field(alias="planType")] = None + primary: RateLimitWindow | None = None + secondary: RateLimitWindow | None = None + + +class WebSearchCallResponseItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + action: ResponsesApiWebSearchAction | None = None + id: str | None = None + status: str | None = None + type: Annotated[Literal["web_search_call"], Field(title="WebSearchCallResponseItemType")] + + +class ReviewStartParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + delivery: Annotated[ + ReviewDelivery | None, + Field( + description="Where to run the review: inline (default) on the current thread or detached on a new thread (returned in `reviewThreadId`)." + ), + ] = None + target: ReviewTarget + thread_id: Annotated[str, Field(alias="threadId")] + + +class ErrorServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[Literal["error"], Field(title="ErrorNotificationMethod")] + params: ErrorNotification + + +class ThreadStatusChangedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["thread/status/changed"], + Field(title="Thread/status/changedNotificationMethod"), + ] + params: ThreadStatusChangedNotification + + +class ThreadArchivedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[Literal["thread/archived"], Field(title="Thread/archivedNotificationMethod")] + params: ThreadArchivedNotification + + +class ThreadUnarchivedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["thread/unarchived"], Field(title="Thread/unarchivedNotificationMethod") + ] + params: ThreadUnarchivedNotification + + +class ThreadClosedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[Literal["thread/closed"], Field(title="Thread/closedNotificationMethod")] + params: ThreadClosedNotification + + +class SkillsChangedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[Literal["skills/changed"], Field(title="Skills/changedNotificationMethod")] + params: SkillsChangedNotification + + +class ThreadNameUpdatedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["thread/name/updated"], + Field(title="Thread/name/updatedNotificationMethod"), + ] + params: ThreadNameUpdatedNotification + + +class HookStartedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[Literal["hook/started"], Field(title="Hook/startedNotificationMethod")] + params: HookStartedNotification + + +class TurnDiffUpdatedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["turn/diff/updated"], Field(title="Turn/diff/updatedNotificationMethod") + ] + params: TurnDiffUpdatedNotification + + +class ItemAutoApprovalReviewStartedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["item/autoApprovalReview/started"], + Field(title="Item/autoApprovalReview/startedNotificationMethod"), + ] + params: ItemGuardianApprovalReviewStartedNotification + + +class ItemAutoApprovalReviewCompletedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["item/autoApprovalReview/completed"], + Field(title="Item/autoApprovalReview/completedNotificationMethod"), + ] + params: ItemGuardianApprovalReviewCompletedNotification + + +class CommandExecOutputDeltaServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["command/exec/outputDelta"], + Field(title="Command/exec/outputDeltaNotificationMethod"), + ] + params: CommandExecOutputDeltaNotification + + +class ItemCommandExecutionTerminalInteractionServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["item/commandExecution/terminalInteraction"], + Field(title="Item/commandExecution/terminalInteractionNotificationMethod"), + ] + params: TerminalInteractionNotification + + +class ServerRequestResolvedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["serverRequest/resolved"], + Field(title="ServerRequest/resolvedNotificationMethod"), + ] + params: ServerRequestResolvedNotification + + +class AccountUpdatedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[Literal["account/updated"], Field(title="Account/updatedNotificationMethod")] + params: AccountUpdatedNotification + + +class ConfigWarningServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[Literal["configWarning"], Field(title="ConfigWarningNotificationMethod")] + params: ConfigWarningNotification + + +class ThreadRealtimeStartedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["thread/realtime/started"], + Field(title="Thread/realtime/startedNotificationMethod"), + ] + params: ThreadRealtimeStartedNotification + + +class ThreadRealtimeItemAddedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["thread/realtime/itemAdded"], + Field(title="Thread/realtime/itemAddedNotificationMethod"), + ] + params: ThreadRealtimeItemAddedNotification + + +class ThreadRealtimeOutputAudioDeltaServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["thread/realtime/outputAudio/delta"], + Field(title="Thread/realtime/outputAudio/deltaNotificationMethod"), + ] + params: ThreadRealtimeOutputAudioDeltaNotification + + +class ThreadRealtimeErrorServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["thread/realtime/error"], + Field(title="Thread/realtime/errorNotificationMethod"), + ] + params: ThreadRealtimeErrorNotification + + +class ThreadRealtimeClosedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["thread/realtime/closed"], + Field(title="Thread/realtime/closedNotificationMethod"), + ] + params: ThreadRealtimeClosedNotification + + +class WindowsWorldWritableWarningServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["windows/worldWritableWarning"], + Field(title="Windows/worldWritableWarningNotificationMethod"), + ] + params: WindowsWorldWritableWarningNotification + + +class SkillDependencies(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + tools: list[SkillToolDependency] + + +class SkillMetadata(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + dependencies: SkillDependencies | None = None + description: str + enabled: bool + interface: SkillInterface | None = None + name: str + path: str + scope: SkillScope + short_description: Annotated[ + str | None, + Field( + alias="shortDescription", + description="Legacy short_description from SKILL.md. Prefer SKILL.json interface.short_description.", + ), + ] = None + + +class SkillsListEntry(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + cwd: str + errors: list[SkillErrorInfo] + skills: list[SkillMetadata] + + +class SkillsListResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + data: list[SkillsListEntry] + + +class ThreadSpawn(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + agent_nickname: str | None = None + agent_role: str | None = None + depth: int + parent_thread_id: ThreadId + + +class ThreadSpawnSubAgentSource(BaseModel): + model_config = ConfigDict( + extra="forbid", + populate_by_name=True, + ) + thread_spawn: ThreadSpawn + + +class SubAgentSource( + RootModel[SubAgentSourceValue | ThreadSpawnSubAgentSource | OtherSubAgentSource] +): + model_config = ConfigDict( + populate_by_name=True, + ) + root: SubAgentSourceValue | ThreadSpawnSubAgentSource | OtherSubAgentSource + + +class UserMessageThreadItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + content: list[UserInput] + id: str + type: Annotated[Literal["userMessage"], Field(title="UserMessageThreadItemType")] + + +class FileChangeThreadItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + changes: list[FileUpdateChange] + id: str + status: PatchApplyStatus + type: Annotated[Literal["fileChange"], Field(title="FileChangeThreadItemType")] + + +class CollabAgentToolCallThreadItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + agents_states: Annotated[ + dict[str, CollabAgentState], + Field( + alias="agentsStates", + description="Last known status of the target agents, when available.", + ), + ] + id: Annotated[str, Field(description="Unique identifier for this collab tool call.")] + model: Annotated[ + str | None, + Field(description="Model requested for the spawned agent, when applicable."), + ] = None + prompt: Annotated[ + str | None, + Field(description="Prompt text sent as part of the collab tool call, when available."), + ] = None + reasoning_effort: Annotated[ + ReasoningEffort | None, + Field( + alias="reasoningEffort", + description="Reasoning effort requested for the spawned agent, when applicable.", + ), + ] = None + receiver_thread_ids: Annotated[ + list[str], + Field( + alias="receiverThreadIds", + description="Thread ID of the receiving agent, when applicable. In case of spawn operation, this corresponds to the newly spawned agent.", + ), + ] + sender_thread_id: Annotated[ + str, + Field( + alias="senderThreadId", + description="Thread ID of the agent issuing the collab request.", + ), + ] + status: Annotated[ + CollabAgentToolCallStatus, + Field(description="Current status of the collab tool call."), + ] + tool: Annotated[CollabAgentTool, Field(description="Name of the collab tool that was invoked.")] + type: Annotated[ + Literal["collabAgentToolCall"], Field(title="CollabAgentToolCallThreadItemType") + ] + + +class WebSearchThreadItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + action: WebSearchAction | None = None + id: str + query: str + type: Annotated[Literal["webSearch"], Field(title="WebSearchThreadItemType")] + + +class ThreadItem( + RootModel[ + UserMessageThreadItem + | AgentMessageThreadItem + | PlanThreadItem + | ReasoningThreadItem + | CommandExecutionThreadItem + | FileChangeThreadItem + | McpToolCallThreadItem + | DynamicToolCallThreadItem + | CollabAgentToolCallThreadItem + | WebSearchThreadItem + | ImageViewThreadItem + | ImageGenerationThreadItem + | EnteredReviewModeThreadItem + | ExitedReviewModeThreadItem + | ContextCompactionThreadItem + ] +): + model_config = ConfigDict( + populate_by_name=True, + ) + root: ( + UserMessageThreadItem + | AgentMessageThreadItem + | PlanThreadItem + | ReasoningThreadItem + | CommandExecutionThreadItem + | FileChangeThreadItem + | McpToolCallThreadItem + | DynamicToolCallThreadItem + | CollabAgentToolCallThreadItem + | WebSearchThreadItem + | ImageViewThreadItem + | ImageGenerationThreadItem + | EnteredReviewModeThreadItem + | ExitedReviewModeThreadItem + | ContextCompactionThreadItem + ) + + +class ThreadListParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + archived: Annotated[ + bool | None, + Field( + description="Optional archived filter; when set to true, only archived threads are returned. If false or null, only non-archived threads are returned." + ), + ] = None + cursor: Annotated[ + str | None, + Field(description="Opaque pagination cursor returned by a previous call."), + ] = None + cwd: Annotated[ + str | None, + Field( + description="Optional cwd filter; when set, only threads whose session cwd exactly matches this path are returned." + ), + ] = None + limit: Annotated[ + int | None, + Field( + description="Optional page size; defaults to a reasonable server-side value.", + ge=0, + ), + ] = None + model_providers: Annotated[ + list[str] | None, + Field( + alias="modelProviders", + description="Optional provider filter; when set, only sessions recorded under these providers are returned. When present but empty, includes all providers.", + ), + ] = None + search_term: Annotated[ + str | None, + Field( + alias="searchTerm", + description="Optional substring filter for the extracted thread title.", + ), + ] = None + sort_key: Annotated[ + ThreadSortKey | None, + Field(alias="sortKey", description="Optional sort key; defaults to created_at."), + ] = None + source_kinds: Annotated[ + list[ThreadSourceKind] | None, + Field( + alias="sourceKinds", + description="Optional source filter; when set, only sessions from these source kinds are returned. When omitted or empty, defaults to interactive sources.", + ), + ] = None + + +class ThreadTokenUsage(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + last: TokenUsageBreakdown + model_context_window: Annotated[int | None, Field(alias="modelContextWindow")] = None + total: TokenUsageBreakdown + + +class ThreadTokenUsageUpdatedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + thread_id: Annotated[str, Field(alias="threadId")] + token_usage: Annotated[ThreadTokenUsage, Field(alias="tokenUsage")] + turn_id: Annotated[str, Field(alias="turnId")] + + +class ThreadUnsubscribeResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + status: ThreadUnsubscribeStatus + + +class ToolsV2(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + view_image: bool | None = None + web_search: WebSearchToolConfig | None = None + + +class Turn(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + error: Annotated[ + TurnError | None, + Field(description="Only populated when the Turn's status is failed."), + ] = None + id: str + items: Annotated[ + list[ThreadItem], + Field( + description="Only populated on a `thread/resume` or `thread/fork` response. For all other responses and notifications returning a Turn, the items field will be an empty list." + ), + ] + status: TurnStatus + + +class TurnCompletedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + thread_id: Annotated[str, Field(alias="threadId")] + turn: Turn + + +class TurnPlanStep(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + status: TurnPlanStepStatus + step: str + + +class TurnPlanUpdatedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + explanation: str | None = None + plan: list[TurnPlanStep] + thread_id: Annotated[str, Field(alias="threadId")] + turn_id: Annotated[str, Field(alias="turnId")] + + +class TurnStartParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + approval_policy: Annotated[ + AskForApproval | None, + Field( + alias="approvalPolicy", + description="Override the approval policy for this turn and subsequent turns.", + ), + ] = None + approvals_reviewer: Annotated[ + ApprovalsReviewer | None, + Field( + alias="approvalsReviewer", + description="Override where approval requests are routed for review on this turn and subsequent turns.", + ), + ] = None + cwd: Annotated[ + str | None, + Field(description="Override the working directory for this turn and subsequent turns."), + ] = None + effort: Annotated[ + ReasoningEffort | None, + Field(description="Override the reasoning effort for this turn and subsequent turns."), + ] = None + input: list[UserInput] + model: Annotated[ + str | None, + Field(description="Override the model for this turn and subsequent turns."), + ] = None + output_schema: Annotated[ + Any | None, + Field( + alias="outputSchema", + description="Optional JSON Schema used to constrain the final assistant message for this turn.", + ), + ] = None + personality: Annotated[ + Personality | None, + Field(description="Override the personality for this turn and subsequent turns."), + ] = None + sandbox_policy: Annotated[ + SandboxPolicy | None, + Field( + alias="sandboxPolicy", + description="Override the sandbox policy for this turn and subsequent turns.", + ), + ] = None + service_tier: Annotated[ + ServiceTier | None, + Field( + alias="serviceTier", + description="Override the service tier for this turn and subsequent turns.", + ), + ] = None + summary: Annotated[ + ReasoningSummary | None, + Field(description="Override the reasoning summary for this turn and subsequent turns."), + ] = None + thread_id: Annotated[str, Field(alias="threadId")] + + +class TurnStartResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + turn: Turn + + +class TurnStartedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + thread_id: Annotated[str, Field(alias="threadId")] + turn: Turn + + +class TurnSteerParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + expected_turn_id: Annotated[ + str, + Field( + alias="expectedTurnId", + description="Required active turn id precondition. The request fails when it does not match the currently active turn.", + ), + ] + input: list[UserInput] + thread_id: Annotated[str, Field(alias="threadId")] + + +class WindowsSandboxSetupCompletedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + error: str | None = None + mode: WindowsSandboxSetupMode + success: bool + + +class AccountRateLimitsUpdatedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + rate_limits: Annotated[RateLimitSnapshot, Field(alias="rateLimits")] + + +class AppInfo(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + app_metadata: Annotated[AppMetadata | None, Field(alias="appMetadata")] = None + branding: AppBranding | None = None + description: str | None = None + distribution_channel: Annotated[str | None, Field(alias="distributionChannel")] = None + id: str + install_url: Annotated[str | None, Field(alias="installUrl")] = None + is_accessible: Annotated[bool | None, Field(alias="isAccessible")] = False + is_enabled: Annotated[ + bool | None, + Field( + alias="isEnabled", + description="Whether this app is enabled in config.toml. Example: ```toml [apps.bad_app] enabled = false ```", + ), + ] = True + labels: dict[str, Any] | None = None + logo_url: Annotated[str | None, Field(alias="logoUrl")] = None + logo_url_dark: Annotated[str | None, Field(alias="logoUrlDark")] = None + name: str + plugin_display_names: Annotated[list[str] | None, Field(alias="pluginDisplayNames")] = [] + + +class AppListUpdatedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + data: list[AppInfo] + + +class AppsListResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + data: list[AppInfo] + next_cursor: Annotated[ + str | None, + Field( + alias="nextCursor", + description="Opaque cursor to pass to the next call to continue after the last item. If None, there are no more items to return.", + ), + ] = None + + +class ThreadListRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["thread/list"], Field(title="Thread/listRequestMethod")] + params: ThreadListParams + + +class TurnStartRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["turn/start"], Field(title="Turn/startRequestMethod")] + params: TurnStartParams + + +class TurnSteerRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["turn/steer"], Field(title="Turn/steerRequestMethod")] + params: TurnSteerParams + + +class ReviewStartRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["review/start"], Field(title="Review/startRequestMethod")] + params: ReviewStartParams + + +class CommandExecRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["command/exec"], Field(title="Command/execRequestMethod")] + params: CommandExecParams + + +class CommandExecResizeRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[ + Literal["command/exec/resize"], Field(title="Command/exec/resizeRequestMethod") + ] + params: CommandExecResizeParams + + +class ConfigValueWriteRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["config/value/write"], Field(title="Config/value/writeRequestMethod")] + params: ConfigValueWriteParams + + +class ConfigBatchWriteParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + edits: list[ConfigEdit] + expected_version: Annotated[str | None, Field(alias="expectedVersion")] = None + file_path: Annotated[ + str | None, + Field( + alias="filePath", + description="Path to the config file to write; defaults to the user's `config.toml` when omitted.", + ), + ] = None + reload_user_config: Annotated[ + bool | None, + Field( + alias="reloadUserConfig", + description="When true, hot-reload the updated user config into all loaded threads after writing.", + ), + ] = None + + +class ConfigWriteResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + file_path: Annotated[ + AbsolutePathBuf, + Field( + alias="filePath", + description="Canonical path to the config file that was written.", + ), + ] + overridden_metadata: Annotated[OverriddenMetadata | None, Field(alias="overriddenMetadata")] = ( + None + ) + status: WriteStatus + version: str + + +class ExternalAgentConfigDetectResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + items: list[ExternalAgentConfigMigrationItem] + + +class ExternalAgentConfigImportParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + migration_items: Annotated[ + list[ExternalAgentConfigMigrationItem], Field(alias="migrationItems") + ] + + +class FunctionCallOutputBody(RootModel[str | list[FunctionCallOutputContentItem]]): + model_config = ConfigDict( + populate_by_name=True, + ) + root: str | list[FunctionCallOutputContentItem] + + +class FunctionCallOutputPayload(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + body: FunctionCallOutputBody + success: bool | None = None + + +class GetAccountRateLimitsResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + rate_limits: Annotated[ + RateLimitSnapshot, + Field( + alias="rateLimits", + description="Backward-compatible single-bucket view; mirrors the historical payload.", + ), + ] + rate_limits_by_limit_id: Annotated[ + dict[str, Any] | None, + Field( + alias="rateLimitsByLimitId", + description="Multi-bucket view keyed by metered `limit_id` (for example, `codex`).", + ), + ] = None + + +class HookCompletedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + run: HookRunSummary + thread_id: Annotated[str, Field(alias="threadId")] + turn_id: Annotated[str | None, Field(alias="turnId")] = None + + +class ItemCompletedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + item: ThreadItem + thread_id: Annotated[str, Field(alias="threadId")] + turn_id: Annotated[str, Field(alias="turnId")] + + +class ItemStartedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + item: ThreadItem + thread_id: Annotated[str, Field(alias="threadId")] + turn_id: Annotated[str, Field(alias="turnId")] + + +class ListMcpServerStatusResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + data: list[McpServerStatus] + next_cursor: Annotated[ + str | None, + Field( + alias="nextCursor", + description="Opaque cursor to pass to the next call to continue after the last item. If None, there are no more items to return.", + ), + ] = None + + +class PluginListResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + marketplaces: list[PluginMarketplaceEntry] + remote_sync_error: Annotated[str | None, Field(alias="remoteSyncError")] = None + + +class ProfileV2(BaseModel): + model_config = ConfigDict( + extra="allow", + populate_by_name=True, + ) + approval_policy: AskForApproval | None = None + approvals_reviewer: Annotated[ + ApprovalsReviewer | None, + Field( + description="[UNSTABLE] Optional profile-level override for where approval requests are routed for review. If omitted, the enclosing config default is used." + ), + ] = None + chatgpt_base_url: str | None = None + model: str | None = None + model_provider: str | None = None + model_reasoning_effort: ReasoningEffort | None = None + model_reasoning_summary: ReasoningSummary | None = None + model_verbosity: Verbosity | None = None + service_tier: ServiceTier | None = None + tools: ToolsV2 | None = None + web_search: WebSearchMode | None = None + + +class FunctionCallOutputResponseItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + call_id: str + output: FunctionCallOutputPayload + type: Annotated[ + Literal["function_call_output"], + Field(title="FunctionCallOutputResponseItemType"), + ] + + +class CustomToolCallOutputResponseItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + call_id: str + output: FunctionCallOutputPayload + type: Annotated[ + Literal["custom_tool_call_output"], + Field(title="CustomToolCallOutputResponseItemType"), + ] + + +class ResponseItem( + RootModel[ + MessageResponseItem + | ReasoningResponseItem + | LocalShellCallResponseItem + | FunctionCallResponseItem + | ToolSearchCallResponseItem + | FunctionCallOutputResponseItem + | CustomToolCallResponseItem + | CustomToolCallOutputResponseItem + | ToolSearchOutputResponseItem + | WebSearchCallResponseItem + | ImageGenerationCallResponseItem + | GhostSnapshotResponseItem + | CompactionResponseItem + | OtherResponseItem + ] +): + model_config = ConfigDict( + populate_by_name=True, + ) + root: ( + MessageResponseItem + | ReasoningResponseItem + | LocalShellCallResponseItem + | FunctionCallResponseItem + | ToolSearchCallResponseItem + | FunctionCallOutputResponseItem + | CustomToolCallResponseItem + | CustomToolCallOutputResponseItem + | ToolSearchOutputResponseItem + | WebSearchCallResponseItem + | ImageGenerationCallResponseItem + | GhostSnapshotResponseItem + | CompactionResponseItem + | OtherResponseItem + ) + + +class ReviewStartResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + review_thread_id: Annotated[ + str, + Field( + alias="reviewThreadId", + description="Identifies the thread where the review runs.\n\nFor inline reviews, this is the original thread id. For detached reviews, this is the id of the new review thread.", + ), + ] + turn: Turn + + +class ThreadTokenUsageUpdatedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["thread/tokenUsage/updated"], + Field(title="Thread/tokenUsage/updatedNotificationMethod"), + ] + params: ThreadTokenUsageUpdatedNotification + + +class TurnStartedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[Literal["turn/started"], Field(title="Turn/startedNotificationMethod")] + params: TurnStartedNotification + + +class TurnCompletedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[Literal["turn/completed"], Field(title="Turn/completedNotificationMethod")] + params: TurnCompletedNotification + + +class HookCompletedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[Literal["hook/completed"], Field(title="Hook/completedNotificationMethod")] + params: HookCompletedNotification + + +class TurnPlanUpdatedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["turn/plan/updated"], Field(title="Turn/plan/updatedNotificationMethod") + ] + params: TurnPlanUpdatedNotification + + +class ItemStartedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[Literal["item/started"], Field(title="Item/startedNotificationMethod")] + params: ItemStartedNotification + + +class ItemCompletedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[Literal["item/completed"], Field(title="Item/completedNotificationMethod")] + params: ItemCompletedNotification + + +class AccountRateLimitsUpdatedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["account/rateLimits/updated"], + Field(title="Account/rateLimits/updatedNotificationMethod"), + ] + params: AccountRateLimitsUpdatedNotification + + +class AppListUpdatedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["app/list/updated"], Field(title="App/list/updatedNotificationMethod") + ] + params: AppListUpdatedNotification + + +class WindowsSandboxSetupCompletedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["windowsSandbox/setupCompleted"], + Field(title="WindowsSandbox/setupCompletedNotificationMethod"), + ] + params: WindowsSandboxSetupCompletedNotification + + +class SubAgentSessionSource(BaseModel): + model_config = ConfigDict( + extra="forbid", + populate_by_name=True, + ) + sub_agent: Annotated[SubAgentSource, Field(alias="subAgent")] + + +class SessionSource(RootModel[SessionSourceValue | SubAgentSessionSource]): + model_config = ConfigDict( + populate_by_name=True, + ) + root: SessionSourceValue | SubAgentSessionSource + + +class Thread(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + agent_nickname: Annotated[ + str | None, + Field( + alias="agentNickname", + description="Optional random unique nickname assigned to an AgentControl-spawned sub-agent.", + ), + ] = None + agent_role: Annotated[ + str | None, + Field( + alias="agentRole", + description="Optional role (agent_role) assigned to an AgentControl-spawned sub-agent.", + ), + ] = None + cli_version: Annotated[ + str, + Field( + alias="cliVersion", + description="Version of the CLI that created the thread.", + ), + ] + created_at: Annotated[ + int, + Field( + alias="createdAt", + description="Unix timestamp (in seconds) when the thread was created.", + ), + ] + cwd: Annotated[str, Field(description="Working directory captured for the thread.")] + ephemeral: Annotated[ + bool, + Field( + description="Whether the thread is ephemeral and should not be materialized on disk." + ), + ] + git_info: Annotated[ + GitInfo | None, + Field( + alias="gitInfo", + description="Optional Git metadata captured when the thread was created.", + ), + ] = None + id: str + model_provider: Annotated[ + str, + Field( + alias="modelProvider", + description="Model provider used for this thread (for example, 'openai').", + ), + ] + name: Annotated[str | None, Field(description="Optional user-facing thread title.")] = None + path: Annotated[str | None, Field(description="[UNSTABLE] Path to the thread on disk.")] = None + preview: Annotated[ + str, + Field(description="Usually the first user message in the thread, if available."), + ] + source: Annotated[ + SessionSource, + Field( + description="Origin of the thread (CLI, VSCode, codex exec, codex app-server, etc.)." + ), + ] + status: Annotated[ThreadStatus, Field(description="Current runtime status for the thread.")] + turns: Annotated[ + list[Turn], + Field( + description="Only populated on `thread/resume`, `thread/rollback`, `thread/fork`, and `thread/read` (when `includeTurns` is true) responses. For all other responses and notifications returning a Thread, the turns field will be an empty list." + ), + ] + updated_at: Annotated[ + int, + Field( + alias="updatedAt", + description="Unix timestamp (in seconds) when the thread was last updated.", + ), + ] + + +class ThreadForkResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + approval_policy: Annotated[AskForApproval, Field(alias="approvalPolicy")] + approvals_reviewer: Annotated[ + ApprovalsReviewer, + Field( + alias="approvalsReviewer", + description="Reviewer currently used for approval requests on this thread.", + ), + ] + cwd: str + model: str + model_provider: Annotated[str, Field(alias="modelProvider")] + reasoning_effort: Annotated[ReasoningEffort | None, Field(alias="reasoningEffort")] = None + sandbox: SandboxPolicy + service_tier: Annotated[ServiceTier | None, Field(alias="serviceTier")] = None + thread: Thread + + +class ThreadListResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + data: list[Thread] + next_cursor: Annotated[ + str | None, + Field( + alias="nextCursor", + description="Opaque cursor to pass to the next call to continue after the last item. if None, there are no more items to return.", + ), + ] = None + + +class ThreadMetadataUpdateResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + thread: Thread + + +class ThreadReadResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + thread: Thread + + +class ThreadResumeResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + approval_policy: Annotated[AskForApproval, Field(alias="approvalPolicy")] + approvals_reviewer: Annotated[ + ApprovalsReviewer, + Field( + alias="approvalsReviewer", + description="Reviewer currently used for approval requests on this thread.", + ), + ] + cwd: str + model: str + model_provider: Annotated[str, Field(alias="modelProvider")] + reasoning_effort: Annotated[ReasoningEffort | None, Field(alias="reasoningEffort")] = None + sandbox: SandboxPolicy + service_tier: Annotated[ServiceTier | None, Field(alias="serviceTier")] = None + thread: Thread + + +class ThreadRollbackResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + thread: Annotated[ + Thread, + Field( + description="The updated thread after applying the rollback, with `turns` populated.\n\nThe ThreadItems stored in each Turn are lossy since we explicitly do not persist all agent interactions, such as command executions. This is the same behavior as `thread/resume`." + ), + ] + + +class ThreadStartResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + approval_policy: Annotated[AskForApproval, Field(alias="approvalPolicy")] + approvals_reviewer: Annotated[ + ApprovalsReviewer, + Field( + alias="approvalsReviewer", + description="Reviewer currently used for approval requests on this thread.", + ), + ] + cwd: str + model: str + model_provider: Annotated[str, Field(alias="modelProvider")] + reasoning_effort: Annotated[ReasoningEffort | None, Field(alias="reasoningEffort")] = None + sandbox: SandboxPolicy + service_tier: Annotated[ServiceTier | None, Field(alias="serviceTier")] = None + thread: Thread + + +class ThreadStartedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + thread: Thread + + +class ThreadUnarchiveResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + thread: Thread + + +class ExternalAgentConfigImportRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[ + Literal["externalAgentConfig/import"], + Field(title="ExternalAgentConfig/importRequestMethod"), + ] + params: ExternalAgentConfigImportParams + + +class ConfigBatchWriteRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["config/batchWrite"], Field(title="Config/batchWriteRequestMethod")] + params: ConfigBatchWriteParams + + +class ClientRequest( + RootModel[ + InitializeRequest + | ThreadStartRequest + | ThreadResumeRequest + | ThreadForkRequest + | ThreadArchiveRequest + | ThreadUnsubscribeRequest + | ThreadNameSetRequest + | ThreadMetadataUpdateRequest + | ThreadUnarchiveRequest + | ThreadCompactStartRequest + | ThreadRollbackRequest + | ThreadListRequest + | ThreadLoadedListRequest + | ThreadReadRequest + | SkillsListRequest + | PluginListRequest + | PluginReadRequest + | SkillsRemoteListRequest + | SkillsRemoteExportRequest + | AppListRequest + | FsReadFileRequest + | FsWriteFileRequest + | FsCreateDirectoryRequest + | FsGetMetadataRequest + | FsReadDirectoryRequest + | FsRemoveRequest + | FsCopyRequest + | SkillsConfigWriteRequest + | PluginInstallRequest + | PluginUninstallRequest + | TurnStartRequest + | TurnSteerRequest + | TurnInterruptRequest + | ReviewStartRequest + | ModelListRequest + | ExperimentalFeatureListRequest + | McpServerOauthLoginRequest + | ConfigMcpServerReloadRequest + | McpServerStatusListRequest + | WindowsSandboxSetupStartRequest + | AccountLoginStartRequest + | AccountLoginCancelRequest + | AccountLogoutRequest + | AccountRateLimitsReadRequest + | FeedbackUploadRequest + | CommandExecRequest + | CommandExecWriteRequest + | CommandExecTerminateRequest + | CommandExecResizeRequest + | ConfigReadRequest + | ExternalAgentConfigDetectRequest + | ExternalAgentConfigImportRequest + | ConfigValueWriteRequest + | ConfigBatchWriteRequest + | ConfigRequirementsReadRequest + | AccountReadRequest + | FuzzyFileSearchRequest + ] +): + model_config = ConfigDict( + populate_by_name=True, + ) + root: Annotated[ + InitializeRequest + | ThreadStartRequest + | ThreadResumeRequest + | ThreadForkRequest + | ThreadArchiveRequest + | ThreadUnsubscribeRequest + | ThreadNameSetRequest + | ThreadMetadataUpdateRequest + | ThreadUnarchiveRequest + | ThreadCompactStartRequest + | ThreadRollbackRequest + | ThreadListRequest + | ThreadLoadedListRequest + | ThreadReadRequest + | SkillsListRequest + | PluginListRequest + | PluginReadRequest + | SkillsRemoteListRequest + | SkillsRemoteExportRequest + | AppListRequest + | FsReadFileRequest + | FsWriteFileRequest + | FsCreateDirectoryRequest + | FsGetMetadataRequest + | FsReadDirectoryRequest + | FsRemoveRequest + | FsCopyRequest + | SkillsConfigWriteRequest + | PluginInstallRequest + | PluginUninstallRequest + | TurnStartRequest + | TurnSteerRequest + | TurnInterruptRequest + | ReviewStartRequest + | ModelListRequest + | ExperimentalFeatureListRequest + | McpServerOauthLoginRequest + | ConfigMcpServerReloadRequest + | McpServerStatusListRequest + | WindowsSandboxSetupStartRequest + | AccountLoginStartRequest + | AccountLoginCancelRequest + | AccountLogoutRequest + | AccountRateLimitsReadRequest + | FeedbackUploadRequest + | CommandExecRequest + | CommandExecWriteRequest + | CommandExecTerminateRequest + | CommandExecResizeRequest + | ConfigReadRequest + | ExternalAgentConfigDetectRequest + | ExternalAgentConfigImportRequest + | ConfigValueWriteRequest + | ConfigBatchWriteRequest + | ConfigRequirementsReadRequest + | AccountReadRequest + | FuzzyFileSearchRequest, + Field(description="Request from the client to the server.", title="ClientRequest"), + ] + + +class Config(BaseModel): + model_config = ConfigDict( + extra="allow", + populate_by_name=True, + ) + analytics: AnalyticsConfig | None = None + approval_policy: AskForApproval | None = None + approvals_reviewer: Annotated[ + ApprovalsReviewer | None, + Field( + description="[UNSTABLE] Optional default for where approval requests are routed for review." + ), + ] = None + compact_prompt: str | None = None + developer_instructions: str | None = None + forced_chatgpt_workspace_id: str | None = None + forced_login_method: ForcedLoginMethod | None = None + instructions: str | None = None + model: str | None = None + model_auto_compact_token_limit: int | None = None + model_context_window: int | None = None + model_provider: str | None = None + model_reasoning_effort: ReasoningEffort | None = None + model_reasoning_summary: ReasoningSummary | None = None + model_verbosity: Verbosity | None = None + profile: str | None = None + profiles: dict[str, ProfileV2] | None = {} + review_model: str | None = None + sandbox_mode: SandboxMode | None = None + sandbox_workspace_write: SandboxWorkspaceWrite | None = None + service_tier: ServiceTier | None = None + tools: ToolsV2 | None = None + web_search: WebSearchMode | None = None + + +class ConfigReadResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + config: Config + layers: list[ConfigLayer] | None = None + origins: dict[str, ConfigLayerMetadata] + + +class RawResponseItemCompletedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + item: ResponseItem + thread_id: Annotated[str, Field(alias="threadId")] + turn_id: Annotated[str, Field(alias="turnId")] + + +class ThreadStartedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[Literal["thread/started"], Field(title="Thread/startedNotificationMethod")] + params: ThreadStartedNotification + + +class ServerNotification( + RootModel[ + ErrorServerNotification + | ThreadStartedServerNotification + | ThreadStatusChangedServerNotification + | ThreadArchivedServerNotification + | ThreadUnarchivedServerNotification + | ThreadClosedServerNotification + | SkillsChangedServerNotification + | ThreadNameUpdatedServerNotification + | ThreadTokenUsageUpdatedServerNotification + | TurnStartedServerNotification + | HookStartedServerNotification + | TurnCompletedServerNotification + | HookCompletedServerNotification + | TurnDiffUpdatedServerNotification + | TurnPlanUpdatedServerNotification + | ItemStartedServerNotification + | ItemAutoApprovalReviewStartedServerNotification + | ItemAutoApprovalReviewCompletedServerNotification + | ItemCompletedServerNotification + | ItemAgentMessageDeltaServerNotification + | ItemPlanDeltaServerNotification + | CommandExecOutputDeltaServerNotification + | ItemCommandExecutionOutputDeltaServerNotification + | ItemCommandExecutionTerminalInteractionServerNotification + | ItemFileChangeOutputDeltaServerNotification + | ServerRequestResolvedServerNotification + | ItemMcpToolCallProgressServerNotification + | McpServerOauthLoginCompletedServerNotification + | AccountUpdatedServerNotification + | AccountRateLimitsUpdatedServerNotification + | AppListUpdatedServerNotification + | ItemReasoningSummaryTextDeltaServerNotification + | ItemReasoningSummaryPartAddedServerNotification + | ItemReasoningTextDeltaServerNotification + | ThreadCompactedServerNotification + | ModelReroutedServerNotification + | DeprecationNoticeServerNotification + | ConfigWarningServerNotification + | FuzzyFileSearchSessionUpdatedServerNotification + | FuzzyFileSearchSessionCompletedServerNotification + | ThreadRealtimeStartedServerNotification + | ThreadRealtimeItemAddedServerNotification + | ThreadRealtimeOutputAudioDeltaServerNotification + | ThreadRealtimeErrorServerNotification + | ThreadRealtimeClosedServerNotification + | WindowsWorldWritableWarningServerNotification + | WindowsSandboxSetupCompletedServerNotification + | AccountLoginCompletedServerNotification + ] +): + model_config = ConfigDict( + populate_by_name=True, + ) + root: Annotated[ + ErrorServerNotification + | ThreadStartedServerNotification + | ThreadStatusChangedServerNotification + | ThreadArchivedServerNotification + | ThreadUnarchivedServerNotification + | ThreadClosedServerNotification + | SkillsChangedServerNotification + | ThreadNameUpdatedServerNotification + | ThreadTokenUsageUpdatedServerNotification + | TurnStartedServerNotification + | HookStartedServerNotification + | TurnCompletedServerNotification + | HookCompletedServerNotification + | TurnDiffUpdatedServerNotification + | TurnPlanUpdatedServerNotification + | ItemStartedServerNotification + | ItemAutoApprovalReviewStartedServerNotification + | ItemAutoApprovalReviewCompletedServerNotification + | ItemCompletedServerNotification + | ItemAgentMessageDeltaServerNotification + | ItemPlanDeltaServerNotification + | CommandExecOutputDeltaServerNotification + | ItemCommandExecutionOutputDeltaServerNotification + | ItemCommandExecutionTerminalInteractionServerNotification + | ItemFileChangeOutputDeltaServerNotification + | ServerRequestResolvedServerNotification + | ItemMcpToolCallProgressServerNotification + | McpServerOauthLoginCompletedServerNotification + | AccountUpdatedServerNotification + | AccountRateLimitsUpdatedServerNotification + | AppListUpdatedServerNotification + | ItemReasoningSummaryTextDeltaServerNotification + | ItemReasoningSummaryPartAddedServerNotification + | ItemReasoningTextDeltaServerNotification + | ThreadCompactedServerNotification + | ModelReroutedServerNotification + | DeprecationNoticeServerNotification + | ConfigWarningServerNotification + | FuzzyFileSearchSessionUpdatedServerNotification + | FuzzyFileSearchSessionCompletedServerNotification + | ThreadRealtimeStartedServerNotification + | ThreadRealtimeItemAddedServerNotification + | ThreadRealtimeOutputAudioDeltaServerNotification + | ThreadRealtimeErrorServerNotification + | ThreadRealtimeClosedServerNotification + | WindowsWorldWritableWarningServerNotification + | WindowsSandboxSetupCompletedServerNotification + | AccountLoginCompletedServerNotification, + Field( + description="Notification sent from the server to the client.", + title="ServerNotification", + ), + ] diff --git a/src/agents/sandbox/app_server/generated/v2_types.py b/src/agents/sandbox/app_server/generated/v2_types.py new file mode 100644 index 0000000000..2ddf81fb7b --- /dev/null +++ b/src/agents/sandbox/app_server/generated/v2_types.py @@ -0,0 +1,23 @@ +"""Stable aliases over full v2 autogenerated models (datamodel-code-generator).""" + +from .v2_all import ( + ModelListResponse, + ThreadCompactStartResponse, + ThreadItem, + ThreadListResponse, + ThreadReadResponse, + ThreadTokenUsageUpdatedNotification, + TurnCompletedNotification as TurnCompletedNotificationPayload, + TurnSteerResponse, +) + +__all__ = [ + "ModelListResponse", + "ThreadCompactStartResponse", + "ThreadListResponse", + "ThreadReadResponse", + "ThreadTokenUsageUpdatedNotification", + "TurnCompletedNotificationPayload", + "TurnSteerResponse", + "ThreadItem", +] diff --git a/src/agents/sandbox/app_server/models.py b/src/agents/sandbox/app_server/models.py new file mode 100644 index 0000000000..70c61d44cd --- /dev/null +++ b/src/agents/sandbox/app_server/models.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TypeAlias + +from pydantic import BaseModel + +from .generated.v2_all import ( + AccountLoginCompletedNotification, + AccountRateLimitsUpdatedNotification, + AccountUpdatedNotification, + AgentMessageDeltaNotification, + AppListUpdatedNotification, + CommandExecutionOutputDeltaNotification, + ConfigWarningNotification, + ContextCompactedNotification, + DeprecationNoticeNotification, + ErrorNotification, + FileChangeOutputDeltaNotification, + ItemCompletedNotification, + ItemStartedNotification, + McpServerOauthLoginCompletedNotification, + McpToolCallProgressNotification, + PlanDeltaNotification, + RawResponseItemCompletedNotification, + ReasoningSummaryPartAddedNotification, + ReasoningSummaryTextDeltaNotification, + ReasoningTextDeltaNotification, + TerminalInteractionNotification, + ThreadNameUpdatedNotification, + ThreadStartedNotification, + ThreadTokenUsageUpdatedNotification, + TurnCompletedNotification, + TurnDiffUpdatedNotification, + TurnPlanUpdatedNotification, + TurnStartedNotification, + WindowsWorldWritableWarningNotification, +) + +JsonScalar: TypeAlias = str | int | float | bool | None +JsonValue: TypeAlias = JsonScalar | dict[str, "JsonValue"] | list["JsonValue"] +JsonObject: TypeAlias = dict[str, JsonValue] + + +@dataclass(slots=True) +class UnknownNotification: + params: JsonObject + + +NotificationPayload: TypeAlias = ( + AccountLoginCompletedNotification + | AccountRateLimitsUpdatedNotification + | AccountUpdatedNotification + | AgentMessageDeltaNotification + | AppListUpdatedNotification + | CommandExecutionOutputDeltaNotification + | ConfigWarningNotification + | ContextCompactedNotification + | DeprecationNoticeNotification + | ErrorNotification + | FileChangeOutputDeltaNotification + | ItemCompletedNotification + | ItemStartedNotification + | McpServerOauthLoginCompletedNotification + | McpToolCallProgressNotification + | PlanDeltaNotification + | RawResponseItemCompletedNotification + | ReasoningSummaryPartAddedNotification + | ReasoningSummaryTextDeltaNotification + | ReasoningTextDeltaNotification + | TerminalInteractionNotification + | ThreadNameUpdatedNotification + | ThreadStartedNotification + | ThreadTokenUsageUpdatedNotification + | TurnCompletedNotification + | TurnDiffUpdatedNotification + | TurnPlanUpdatedNotification + | TurnStartedNotification + | WindowsWorldWritableWarningNotification + | UnknownNotification +) + + +@dataclass(slots=True) +class Notification: + method: str + payload: NotificationPayload + + +class ServerInfo(BaseModel): + name: str | None = None + version: str | None = None + + +class InitializeResponse(BaseModel): + serverInfo: ServerInfo | None = None + userAgent: str | None = None + platformFamily: str | None = None + platformOs: str | None = None diff --git a/src/agents/sandbox/app_server/py.typed b/src/agents/sandbox/app_server/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/agents/sandbox/app_server/retry.py b/src/agents/sandbox/app_server/retry.py new file mode 100644 index 0000000000..b7e4f77403 --- /dev/null +++ b/src/agents/sandbox/app_server/retry.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import random +import time +from typing import Callable, TypeVar + +from .errors import is_retryable_error + +T = TypeVar("T") + + +def retry_on_overload( + op: Callable[[], T], + *, + max_attempts: int = 3, + initial_delay_s: float = 0.25, + max_delay_s: float = 2.0, + jitter_ratio: float = 0.2, +) -> T: + """Retry helper for transient server-overload errors.""" + + if max_attempts < 1: + raise ValueError("max_attempts must be >= 1") + + delay = initial_delay_s + attempt = 0 + while True: + attempt += 1 + try: + return op() + except Exception as exc: + if attempt >= max_attempts: + raise + if not is_retryable_error(exc): + raise + + jitter = delay * jitter_ratio + sleep_for = min(max_delay_s, delay) + random.uniform(-jitter, jitter) + if sleep_for > 0: + time.sleep(sleep_for) + delay = min(max_delay_s, delay * 2) diff --git a/src/agents/sandbox/capabilities/__init__.py b/src/agents/sandbox/capabilities/__init__.py new file mode 100644 index 0000000000..0c13ccb59c --- /dev/null +++ b/src/agents/sandbox/capabilities/__init__.py @@ -0,0 +1,4 @@ +from .capability import Capability +from .skills import Skill, Skills + +__all__ = ["Capability", "Skill", "Skills"] diff --git a/src/agents/sandbox/capabilities/capability.py b/src/agents/sandbox/capabilities/capability.py new file mode 100644 index 0000000000..a6d85e4eb6 --- /dev/null +++ b/src/agents/sandbox/capabilities/capability.py @@ -0,0 +1,80 @@ +import asyncio +import copy +import threading +from dataclasses import dataclass +from typing import Any + +from ...tool import Tool +from ..manifest import Manifest +from ..session.base_sandbox_session import BaseSandboxSession + + +@dataclass +class Capability: + type: str + + def clone(self) -> "Capability": + """Return a per-run copy of this capability.""" + cloned = copy.copy(self) + if hasattr(self, "__dict__"): + for name, value in self.__dict__.items(): + setattr(cloned, name, _clone_capability_value(value)) + return cloned + + def bind(self, session: BaseSandboxSession) -> None: + """Bind a live session to this plugin (default no-op).""" + _ = session + return + + def tools(self) -> list[Tool]: + return [] + + def process_manifest(self, manifest: Manifest) -> Manifest: + return manifest + + async def instructions(self, manifest: Manifest) -> str | None: + """Return a deterministic instruction fragment for this plugin.""" + _ = manifest + return None + + +def _clone_capability_value(value: Any) -> Any: + if getattr(type(value), "__module__", "").startswith("agents.tool"): + return value + if isinstance( + value, + ( + BaseSandboxSession, + asyncio.Event, + asyncio.Lock, + asyncio.Semaphore, + asyncio.Condition, + threading.Event, + type(threading.Lock()), + type(threading.RLock()), + ), + ): + return value + if isinstance(value, list): + return [_clone_capability_value(item) for item in value] + if isinstance(value, dict): + return { + _clone_capability_value(key): _clone_capability_value(item) + for key, item in value.items() + } + if isinstance(value, set): + return {_clone_capability_value(item) for item in value} + if isinstance(value, tuple): + return tuple(_clone_capability_value(item) for item in value) + if isinstance(value, bytearray): + return bytearray(value) + if hasattr(value, "__dict__"): + cloned = copy.copy(value) + for name, nested in value.__dict__.items(): + setattr(cloned, name, _clone_capability_value(nested)) + return cloned + try: + return copy.deepcopy(value) + except Exception: + return value + return value diff --git a/src/agents/sandbox/capabilities/skills.py b/src/agents/sandbox/capabilities/skills.py new file mode 100644 index 0000000000..81582277df --- /dev/null +++ b/src/agents/sandbox/capabilities/skills.py @@ -0,0 +1,270 @@ +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from pathlib import Path +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field, field_validator + +from ..entries import BaseEntry, Dir, File, LocalFile +from ..errors import SkillsConfigError +from ..manifest import Manifest +from .capability import Capability + +_SKILLS_ROOT = Path(".agents/skills") + + +def _validate_relative_path( + value: str | Path, + *, + field_name: str, + context: Mapping[str, object] | None = None, +) -> Path: + rel = value if isinstance(value, Path) else Path(value) + if rel.is_absolute(): + raise SkillsConfigError( + message=f"{field_name} must be a relative path", + context={ + "field": field_name, + "path": str(rel), + "reason": "absolute", + **(context or {}), + }, + ) + if ".." in rel.parts: + raise SkillsConfigError( + message=f"{field_name} must not escape the skills root", + context={ + "field": field_name, + "path": str(rel), + "reason": "escape_root", + **(context or {}), + }, + ) + if rel.parts in [(), (".",)]: + raise SkillsConfigError( + message=f"{field_name} must be non-empty", + context={"field": field_name, "path": str(rel), "reason": "empty", **(context or {})}, + ) + return rel + + +def _manifest_entry_paths(manifest: Manifest) -> set[Path]: + return {key if isinstance(key, Path) else Path(key) for key in manifest.entries} + + +def _get_manifest_entry_by_path(manifest: Manifest, path: Path) -> BaseEntry | None: + for key, entry in manifest.entries.items(): + normalized = key if isinstance(key, Path) else Path(key) + if normalized == path: + return entry + return None + + +class Skill(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + name: str + description: str + content: str | bytes | BaseEntry + + compatibility: str | None = Field(default=None) + scripts: dict[str | Path, BaseEntry] = Field(default_factory=dict) + references: dict[str | Path, BaseEntry] = Field(default_factory=dict) + assets: dict[str | Path, BaseEntry] = Field(default_factory=dict) + + deferred: bool = Field(default=False) + + @field_validator("content", mode="before") + @classmethod + def _parse_content(cls, value: object) -> object: + if isinstance(value, Mapping): + return BaseEntry.parse(value) + return value + + @field_validator("scripts", "references", "assets", mode="before") + @classmethod + def _parse_entry_map(cls, value: object) -> dict[str | Path, BaseEntry]: + if value is None: + return {} + if not isinstance(value, Mapping): + raise TypeError(f"Artifact mapping must be a mapping, got {type(value).__name__}") + return {key: BaseEntry.parse(entry) for key, entry in value.items()} + + def model_post_init(self, context: Any, /) -> None: + _ = context + skill_context: dict[str, object] = {"skill_name": self.name} + _validate_relative_path(self.name, field_name="name", context=skill_context) + + content_artifact = self.content_artifact() + if not isinstance(content_artifact, (File, LocalFile)): + raise SkillsConfigError( + message="skill content must be file-like", + context={ + "field": "content", + "skill_name": self.name, + "content_type": content_artifact.type, + }, + ) + + self.scripts = self._normalize_entry_map(self.scripts, field_name="scripts") + self.references = self._normalize_entry_map(self.references, field_name="references") + self.assets = self._normalize_entry_map(self.assets, field_name="assets") + + def _normalize_entry_map( + self, + entries: Mapping[str | Path, BaseEntry], + *, + field_name: str, + ) -> dict[str | Path, BaseEntry]: + normalized: dict[str | Path, BaseEntry] = {} + seen_paths: set[str] = set() + for key, artifact in entries.items(): + rel = _validate_relative_path( + key, + field_name=field_name, + context={"skill_name": self.name, "entry_path": str(key)}, + ) + rel_str = rel.as_posix() + if rel_str in seen_paths: + raise SkillsConfigError( + message=f"duplicate entry path in skill {field_name}", + context={ + "skill_name": self.name, + "field": field_name, + "entry_path": rel_str, + }, + ) + seen_paths.add(rel_str) + normalized[rel_str] = artifact + + return normalized + + def content_artifact(self) -> BaseEntry: + if isinstance(self.content, bytes): + return File(content=self.content) + if isinstance(self.content, str): + return File(content=self.content.encode("utf-8")) + return self.content + + def as_dir_entry(self) -> Dir: + children: dict[str | Path, BaseEntry] = {"SKILL.md": self.content_artifact()} + if self.scripts: + children["scripts"] = Dir(children=self.scripts) + if self.references: + children["references"] = Dir(children=self.references) + if self.assets: + children["assets"] = Dir(children=self.assets) + return Dir(children=children) + + +class Skills(Capability): + """Mount skills into a Codex auto-discovery root inside the sandbox.""" + + type: str = "skills" + skills: list[Skill] + from_: BaseEntry | None + + def __init__( + self, + *, + skills: Sequence[Skill | Mapping[str, object]] | None = None, + from_: BaseEntry | Mapping[str, object] | None = None, + ) -> None: + super().__init__(type="skills") + self.skills = [self._coerce_skill(skill) for skill in (skills or [])] + self.from_ = self._coerce_entry(from_) + self._validate() + + @staticmethod + def _coerce_skill(skill: Skill | Mapping[str, object]) -> Skill: + if isinstance(skill, Skill): + return skill + return Skill.model_validate(dict(skill)) + + @staticmethod + def _coerce_entry(entry: BaseEntry | Mapping[str, object] | None) -> BaseEntry | None: + if entry is None or isinstance(entry, BaseEntry): + return entry + return BaseEntry.parse(entry) + + def _validate(self) -> None: + if not self.skills and self.from_ is None: + raise SkillsConfigError( + message="skills capability requires `skills` or `from_`", + context={"field": "skills"}, + ) + if self.skills and self.from_ is not None: + raise SkillsConfigError( + message="skills capability does not allow both `skills` and `from_` together", + context={"field": "skills", "has_from": True}, + ) + + if self.from_ is not None and not self.from_.is_dir: + raise SkillsConfigError( + message="`from_` must be a directory-like artifact", + context={"field": "from_", "artifact_type": self.from_.type}, + ) + + seen_names: set[Path] = set() + for skill in self.skills: + rel = _validate_relative_path( + skill.name, + field_name="skills[].name", + context={"skill_name": skill.name}, + ) + if rel in seen_names: + raise SkillsConfigError( + message=f"duplicate skill name: {skill.name}", + context={"field": "skills[].name", "skill_name": skill.name}, + ) + seen_names.add(rel) + + def process_manifest(self, manifest: Manifest) -> Manifest: + skills_root = _SKILLS_ROOT + existing_paths = _manifest_entry_paths(manifest) + + if self.from_ is not None: + if skills_root in existing_paths: + existing_entry = _get_manifest_entry_by_path(manifest, skills_root) + if existing_entry is None: + raise SkillsConfigError( + message="skills root path lookup failed", + context={"path": str(skills_root), "source": "from_"}, + ) + if existing_entry.is_dir: + return manifest + raise SkillsConfigError( + message="skills root path already exists in manifest", + context={ + "path": str(skills_root), + "source": "from_", + "existing_type": existing_entry.type, + }, + ) + manifest.entries[skills_root] = self.from_ + existing_paths.add(skills_root) + + for skill in self.skills: + relative_path = skills_root / Path(skill.name) + rendered_skill = skill.as_dir_entry() + if relative_path in existing_paths: + existing_entry = _get_manifest_entry_by_path(manifest, relative_path) + if existing_entry is None: + raise SkillsConfigError( + message="skill path lookup failed", + context={"path": str(relative_path), "skill_name": skill.name}, + ) + if existing_entry == rendered_skill: + continue + raise SkillsConfigError( + message="skill path already exists in manifest", + context={"path": str(relative_path), "skill_name": skill.name}, + ) + manifest.entries[relative_path] = rendered_skill + existing_paths.add(relative_path) + + return manifest + + +__all__ = ["Skill", "Skills"] diff --git a/src/agents/sandbox/codex_config.py b/src/agents/sandbox/codex_config.py new file mode 100644 index 0000000000..d63ab891ae --- /dev/null +++ b/src/agents/sandbox/codex_config.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import TypeVar + +from .entries import Codex, resolve_workspace_path +from .errors import InvalidManifestPathError +from .manifest import Manifest +from .session.sandbox_session_state import SandboxSessionState + +DEFAULT_CODEX_PATH = ".codex_bin/codex" +# TODO: this should eventually be sourced from the pinned version of codex-python-sdk +DEFAULT_CODEX_VERSION = "0.114.0" +SandboxSessionStateT = TypeVar("SandboxSessionStateT", bound=SandboxSessionState) + + +@dataclass(frozen=True) +class CodexConfig: + path: str | Path = DEFAULT_CODEX_PATH + version: str = DEFAULT_CODEX_VERSION + + +def normalize_codex_config(codex: bool | CodexConfig) -> CodexConfig | None: + if isinstance(codex, CodexConfig): + return codex + if codex: + return CodexConfig() + return None + + +def apply_codex_to_manifest( + manifest: Manifest | None, + codex: bool | CodexConfig, +) -> Manifest: + normalized = normalize_codex_config(codex) + base_manifest = manifest.model_copy(deep=True) if manifest is not None else Manifest() + if normalized is None: + return base_manifest + + codex_path = _manifest_codex_path(manifest=base_manifest, configured_path=normalized.path) + entries = dict(base_manifest.entries) + entries.setdefault(codex_path, Codex(version=normalized.version)) + return base_manifest.model_copy(update={"entries": entries}) + + +def manifest_has_codex_entry( + manifest: Manifest | None, + codex: bool | CodexConfig, +) -> bool: + normalized = normalize_codex_config(codex) + if normalized is None or manifest is None: + return normalized is None + + codex_path = _manifest_codex_path(manifest=manifest, configured_path=normalized.path) + return codex_path in {manifest._coerce_rel_path(path) for path in manifest.entries} + + +def apply_codex_to_session_state( + state: SandboxSessionStateT, + codex: bool | CodexConfig, +) -> SandboxSessionStateT: + return state.model_copy(update={"manifest": apply_codex_to_manifest(state.manifest, codex)}) + + +def _manifest_codex_path(*, manifest: Manifest, configured_path: str | Path) -> Path: + configured_str = str(configured_path) + if configured_str == "~": + return Path(".") + if configured_str.startswith("~/"): + home_relative = Path(configured_str.removeprefix("~/")) + manifest._validate_rel_path(home_relative) + return home_relative + + raw_path = Path(configured_path) + if not raw_path.is_absolute(): + manifest._validate_rel_path(raw_path) + return raw_path + + candidate_roots = [Path(manifest.root)] + default_root = Path(str(Manifest.model_fields["root"].default)) + if default_root not in candidate_roots: + candidate_roots.append(default_root) + + for root in candidate_roots: + try: + resolved = resolve_workspace_path( + root, + raw_path, + allow_absolute_within_root=True, + ) + except InvalidManifestPathError: + continue + return resolved.relative_to(root) + + raise InvalidManifestPathError(rel=raw_path, reason="absolute") diff --git a/src/agents/sandbox/entries/__init__.py b/src/agents/sandbox/entries/__init__.py new file mode 100644 index 0000000000..fd663e95a0 --- /dev/null +++ b/src/agents/sandbox/entries/__init__.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from .artifacts import Dir, File, GitRepo, LocalDir, LocalFile +from .base import BaseEntry, resolve_workspace_path +from .codex import Codex +from .mounts import ( + AzureBlobMount, + FuseMountPattern, + GCSMount, + Mount, + MountPattern, + MountPatternBase, + MountpointMountPattern, + RcloneMountPattern, + S3Mount, +) + +__all__ = [ + "AzureBlobMount", + "BaseEntry", + "Codex", + "Dir", + "File", + "FuseMountPattern", + "GCSMount", + "GitRepo", + "LocalDir", + "LocalFile", + "Mount", + "MountPattern", + "MountPatternBase", + "MountpointMountPattern", + "RcloneMountPattern", + "S3Mount", + "resolve_workspace_path", +] diff --git a/src/agents/sandbox/entries/artifacts.py b/src/agents/sandbox/entries/artifacts.py new file mode 100644 index 0000000000..a9109bb0a0 --- /dev/null +++ b/src/agents/sandbox/entries/artifacts.py @@ -0,0 +1,326 @@ +from __future__ import annotations + +import io +import re +import uuid +from collections.abc import Awaitable, Callable, Mapping +from pathlib import Path +from typing import TYPE_CHECKING, Literal + +from pydantic import Field, field_serializer, field_validator + +from ..errors import ( + GitCloneError, + GitCopyError, + GitMissingInImageError, + LocalChecksumError, + LocalDirReadError, + LocalFileReadError, +) +from ..materialization import MaterializedFile, gather_in_order +from ..types import ExecResult +from ..util.checksums import sha256_file +from .base import BaseEntry + +if TYPE_CHECKING: + from ..session.base_sandbox_session import BaseSandboxSession + +_COMMIT_REF_RE = re.compile(r"[0-9a-fA-F]{7,40}") + + +class Dir(BaseEntry): + type: Literal["dir"] = "dir" + is_dir: bool = True + children: dict[str | Path, BaseEntry] = Field(default_factory=dict) + + @field_validator("children", mode="before") + @classmethod + def _parse_children(cls, value: object) -> dict[str | Path, BaseEntry]: + if value is None: + return {} + if not isinstance(value, Mapping): + raise TypeError(f"Artifact mapping must be a mapping, got {type(value).__name__}") + return {key: BaseEntry.parse(entry) for key, entry in value.items()} + + @field_serializer("children", when_used="json") + def _serialize_children(self, children: Mapping[str | Path, BaseEntry]) -> dict[str, object]: + out: dict[str, object] = {} + for key, entry in children.items(): + key_str = key.as_posix() if isinstance(key, Path) else str(key) + out[key_str] = entry.model_dump(mode="json") + return out + + def model_post_init(self, context: object, /) -> None: + _ = context + self.permissions.directory = True + + async def apply( + self, + session: BaseSandboxSession, + dest: Path, + base_dir: Path, + ) -> list[MaterializedFile]: + await session.mkdir(dest, parents=True) + await self._apply_metadata(session, dest) + return await session._apply_entry_batch( + [(dest / Path(rel_dest), artifact) for rel_dest, artifact in self.children.items()], + base_dir=base_dir, + ) + + +class File(BaseEntry): + type: Literal["file"] = "file" + content: bytes + + async def apply( + self, + session: BaseSandboxSession, + dest: Path, + base_dir: Path, + ) -> list[MaterializedFile]: + await session.write(dest, io.BytesIO(self.content)) + await self._apply_metadata(session, dest) + return [] + + +class LocalFile(BaseEntry): + type: Literal["local_file"] = "local_file" + src: Path + + async def apply( + self, + session: BaseSandboxSession, + dest: Path, + base_dir: Path, + ) -> list[MaterializedFile]: + src = (base_dir / self.src).resolve() + try: + checksum = sha256_file(src) + except OSError as e: + raise LocalChecksumError(src=src, cause=e) from e + await session.mkdir(Path(dest).parent, parents=True) + try: + with src.open("rb") as f: + await session.write(dest, f) + except OSError as e: + raise LocalFileReadError(src=src, cause=e) from e + await self._apply_metadata(session, dest) + return [MaterializedFile(path=dest, sha256=checksum)] + + +class LocalDir(BaseEntry): + type: Literal["local_dir"] = "local_dir" + is_dir: bool = True + src: Path | None = Field(default=None) + + def model_post_init(self, context: object, /) -> None: + _ = context + self.permissions.directory = True + + async def apply( + self, + session: BaseSandboxSession, + dest: Path, + base_dir: Path, + ) -> list[MaterializedFile]: + files: list[MaterializedFile] = [] + if self.src: + src_root = (base_dir / self.src).resolve() + if not src_root.exists(): + raise LocalDirReadError(src=src_root, context={"reason": "path_not_found"}) + # Minimal v1: copy all files recursively. + try: + await session.mkdir(dest, parents=True) + files = [] + local_files = [child for child in src_root.rglob("*") if child.is_file()] + + def _make_copy_task(child: Path) -> Callable[[], Awaitable[MaterializedFile]]: + async def _copy() -> MaterializedFile: + return await self._copy_local_dir_file( + session=session, + src_root=src_root, + src=child, + dest_root=dest, + ) + + return _copy + + copied_files = await gather_in_order( + [_make_copy_task(child) for child in local_files] + ) + files.extend(copied_files) + except OSError as e: + raise LocalDirReadError(src=src_root, cause=e) from e + await self._apply_metadata(session, dest) + else: + await session.mkdir(dest, parents=True) + await self._apply_metadata(session, dest) + return files + + async def _copy_local_dir_file( + self, + *, + session: BaseSandboxSession, + src_root: Path, + src: Path, + dest_root: Path, + ) -> MaterializedFile: + rel_child = src.relative_to(src_root) + child_dest = dest_root / rel_child + try: + checksum = sha256_file(src) + except OSError as e: + raise LocalChecksumError(src=src, cause=e) from e + await session.mkdir(child_dest.parent, parents=True) + try: + with src.open("rb") as f: + await session.write(child_dest, f) + except OSError as e: + raise LocalFileReadError(src=src, cause=e) from e + return MaterializedFile(path=child_dest, sha256=checksum) + + +class GitRepo(BaseEntry): + type: Literal["git_repo"] = "git_repo" + is_dir: bool = True + host: str = "github.com" + repo: str # "owner/name" (or any host-specific path) + ref: str # tag/branch/sha + subpath: str | None = None + + def model_post_init(self, context: object, /) -> None: + _ = context + self.permissions.directory = True + + async def apply( + self, + session: BaseSandboxSession, + dest: Path, + base_dir: Path, + ) -> list[MaterializedFile]: + # Ensure git exists in the container. + git_check = await session.exec("command -v git >/dev/null 2>&1") + if not git_check.ok(): + context: dict[str, object] = {"repo": self.repo, "ref": self.ref} + image = getattr(session.state, "image", None) + if image is not None: + context["image"] = image + raise GitMissingInImageError(context=context) + + tmp_dir = f"/tmp/uc-git-{session.state.session_id.hex}-{uuid.uuid4().hex}" + url = f"https://{self.host}/{self.repo}.git" + + _ = await session.exec("rm", "-rf", "--", tmp_dir, shell=False) + clone_error: ExecResult | None = None + if self._looks_like_commit_ref(self.ref): + clone = await self._fetch_commit_ref(session=session, url=url, tmp_dir=tmp_dir) + if not clone.ok(): + clone_error = clone + _ = await session.exec("rm", "-rf", "--", tmp_dir, shell=False) + clone = await self._clone_named_ref(session=session, url=url, tmp_dir=tmp_dir) + else: + clone = await self._clone_named_ref(session=session, url=url, tmp_dir=tmp_dir) + if not clone.ok(): + if clone_error is not None: + clone = clone_error + raise GitCloneError( + url=url, + ref=self.ref, + stderr=clone.stderr.decode("utf-8", errors="replace"), + context={"repo": self.repo, "subpath": self.subpath}, + ) + + git_src_root: str = tmp_dir + if self.subpath is not None: + git_src_root = f"{tmp_dir}/{self.subpath.lstrip('/')}" + + # Copy into destination in the container. + await session.mkdir(dest, parents=True) + copy = await session.exec("cp", "-R", "--", f"{git_src_root}/.", f"{dest}/", shell=False) + if not copy.ok(): + raise GitCopyError( + src_root=git_src_root, + dest=dest, + stderr=copy.stderr.decode("utf-8", errors="replace"), + context={"repo": self.repo, "ref": self.ref, "subpath": self.subpath}, + ) + + _ = await session.exec("rm", "-rf", "--", tmp_dir, shell=False) + await self._apply_metadata(session, dest) + + # Receipt: leave checksums empty for now. (Computing them would + # require reading each file back out of the container.) + return [] + + @staticmethod + def _looks_like_commit_ref(ref: str) -> bool: + return _COMMIT_REF_RE.fullmatch(ref) is not None + + async def _clone_named_ref( + self, + *, + session: BaseSandboxSession, + url: str, + tmp_dir: str, + ) -> ExecResult: + return await session.exec( + "git", + "clone", + "--depth", + "1", + "--no-tags", + "--branch", + self.ref, + url, + tmp_dir, + shell=False, + ) + + async def _fetch_commit_ref( + self, + *, + session: BaseSandboxSession, + url: str, + tmp_dir: str, + ) -> ExecResult: + init = await session.exec("git", "init", tmp_dir, shell=False) + if not init.ok(): + return init + + remote_add = await session.exec( + "git", + "-C", + tmp_dir, + "remote", + "add", + "origin", + url, + shell=False, + ) + if not remote_add.ok(): + return remote_add + + fetch = await session.exec( + "git", + "-C", + tmp_dir, + "fetch", + "--depth", + "1", + "--no-tags", + "origin", + self.ref, + shell=False, + ) + if not fetch.ok(): + return fetch + + return await session.exec( + "git", + "-C", + tmp_dir, + "checkout", + "--detach", + "FETCH_HEAD", + shell=False, + ) diff --git a/src/agents/sandbox/entries/base.py b/src/agents/sandbox/entries/base.py new file mode 100644 index 0000000000..660297a7cc --- /dev/null +++ b/src/agents/sandbox/entries/base.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import abc +import builtins +import inspect +import stat +from collections.abc import Mapping +from pathlib import Path +from typing import TYPE_CHECKING, ClassVar + +from pydantic import BaseModel, Field + +from ..errors import InvalidManifestPathError +from ..materialization import MaterializedFile +from ..types import FileMode, Group, Permissions, User + +if TYPE_CHECKING: + from ..session.base_sandbox_session import BaseSandboxSession + + +def resolve_workspace_path( + workspace_root: Path, + rel: str | Path, + *, + allow_absolute_within_root: bool = False, +) -> Path: + rel = Path(rel) + workspace_root = Path(workspace_root) + + if rel.is_absolute(): + if not allow_absolute_within_root: + raise InvalidManifestPathError(rel=rel, reason="absolute") + resolved_workspace_root = workspace_root.resolve(strict=False) + resolved_rel = rel.resolve(strict=False) + try: + resolved_rel.relative_to(resolved_workspace_root) + except ValueError as exc: + raise InvalidManifestPathError(rel=rel, reason="absolute", cause=exc) from exc + return resolved_rel + + if ".." in rel.parts: + raise InvalidManifestPathError(rel=rel, reason="escape_root") + + resolved = workspace_root / rel if rel.parts else workspace_root + if allow_absolute_within_root and resolved.is_absolute(): + try: + resolved.relative_to(workspace_root) + except ValueError as exc: + raise InvalidManifestPathError(rel=rel, reason="escape_root", cause=exc) from exc + return resolved + + +class BaseEntry(BaseModel, abc.ABC): + type: str + _subclass_registry: ClassVar[dict[str, builtins.type[BaseEntry]]] = {} + + description: str | None = Field(default=None) + ephemeral: bool = Field(default=False) + group: Group | User | None = Field(default=None) + # Whether this entry should be treated as a directory in the sandbox filesystem. + # Concrete subclasses override this (e.g. Dir/Mount types -> True). + is_dir: bool = Field(default=False) + permissions: Permissions = Field( + default_factory=lambda: Permissions( + owner=FileMode.ALL, + group=FileMode.READ | FileMode.EXEC, + other=FileMode.READ | FileMode.EXEC, + ) + ) + + @classmethod + def __pydantic_init_subclass__(cls, **kwargs: object) -> None: + super().__pydantic_init_subclass__(**kwargs) + + type_field = cls.model_fields.get("type") + type_default = type_field.default if type_field is not None else None + if not isinstance(type_default, str) or type_default == "": + if inspect.isabstract(cls): + return + raise TypeError(f"{cls.__name__} must define a non-empty string default for `type`") + + cls._register_subclass(cls, allow_override=False) + + @classmethod + def _register_subclass( + cls, + entry_cls: builtins.type[BaseEntry], + *, + allow_override: bool = False, + ) -> builtins.type[BaseEntry]: + type_field = entry_cls.model_fields.get("type") + type_default = type_field.default if type_field is not None else None + if not isinstance(type_default, str) or type_default == "": + raise ValueError(f"{entry_cls.__name__} must define a string `type` field default") + + existing = BaseEntry._subclass_registry.get(type_default) + if existing is not None and existing is not entry_cls and not allow_override: + raise ValueError( + f"Artifact type `{type_default}` is already registered to {existing.__name__}; " + f"refusing to register {entry_cls.__name__}" + ) + + BaseEntry._subclass_registry[type_default] = entry_cls + return entry_cls + + @classmethod + def registered_types(cls) -> dict[str, builtins.type[BaseEntry]]: + return dict(BaseEntry._subclass_registry) + + @classmethod + def parse(cls, payload: object) -> BaseEntry: + if isinstance(payload, BaseEntry): + return payload + if not isinstance(payload, Mapping): + raise TypeError( + f"Artifact entry must be a BaseEntry or mapping, got {type(payload).__name__}" + ) + + entry_type = payload.get("type") + if not isinstance(entry_type, str): + raise ValueError("Artifact entry mapping must include a string `type` field") + + entry_cls = BaseEntry._subclass_registry.get(entry_type) + if entry_cls is None: + known = ", ".join(sorted(BaseEntry._subclass_registry)) or "" + raise ValueError(f"Unknown artifact type `{entry_type}`. Registered types: {known}") + return entry_cls.model_validate(dict(payload)) + + async def _apply_metadata( + self, + session: BaseSandboxSession, + dest: Path, + ) -> None: + if self.group is not None: + await session._exec_checked_nonzero("chgrp", self.group.name, str(dest)) + + chmod_perms = f"{stat.S_IMODE(self.permissions.to_mode()):o}".zfill(4) + await session._exec_checked_nonzero("chmod", chmod_perms, str(dest)) + + @abc.abstractmethod + async def apply( + self, + session: BaseSandboxSession, + dest: Path, + base_dir: Path, + ) -> list[MaterializedFile]: + raise NotImplementedError diff --git a/src/agents/sandbox/entries/codex.py b/src/agents/sandbox/entries/codex.py new file mode 100644 index 0000000000..1668e974b9 --- /dev/null +++ b/src/agents/sandbox/entries/codex.py @@ -0,0 +1,269 @@ +from __future__ import annotations + +import uuid +from pathlib import Path +from typing import TYPE_CHECKING, Literal + +import httpx + +from ..errors import UnsupportedCodexTargetError +from ..materialization import MaterializedFile +from ..util.iterator_io import IteratorIO +from .base import BaseEntry + +if TYPE_CHECKING: + from ..session.base_sandbox_session import BaseSandboxSession + +_SUPPORTED_CODEX_OPERATING_SYSTEMS = ("linux", "darwin", "windows") +_SUPPORTED_CODEX_ARCHITECTURES = ("x86_64", "aarch64") +_SUPPORTED_CODEX_LINUX_LIBC_VARIANTS = ("gnu", "musl") +_CODEX_ARCH_ALIASES = { + "x86_64": "x86_64", + "amd64": "x86_64", + "aarch64": "aarch64", + "arm64": "aarch64", +} + + +class Codex(BaseEntry): + type: Literal["codex"] = "codex" + ephemeral: bool = True + version: str = "latest" + + async def apply( + self, + session: BaseSandboxSession, + dest: Path, + base_dir: Path, + ) -> list[MaterializedFile]: + _ = base_dir + asset_name = await session.resolve_codex_github_asset_name() + if asset_name.endswith(".exe.tar.gz"): + raise RuntimeError("Windows Codex artifacts are not supported in sandbox manifests.") + archive_url = self._release_asset_url(asset_name) + staging_dir = dest.parent / f".codex-install-{uuid.uuid4().hex}" + archive_path = staging_dir / asset_name + + await session.mkdir(dest.parent, parents=True) + await session.mkdir(staging_dir, parents=True) + try: + with _stream_release_asset(archive_url) as response: + response.raise_for_status() + await session.write( + archive_path, + _IteratorStreamWithLength( + response.iter_bytes(), + content_length=_parse_content_length(response), + ), + ) + + extract_result = await session.exec( + "tar", + "-xzf", + archive_path, + "-C", + staging_dir, + shell=False, + ) + if not extract_result.ok(): + raise RuntimeError(extract_result.stderr.decode("utf-8", errors="replace")) + + extracted_binary = await self._resolve_extracted_binary_path( + session=session, + staging_dir=staging_dir, + ) + await self._copy_extracted_binary_to_destination( + session=session, + extracted_binary=extracted_binary, + dest=dest, + ) + finally: + await session.rm(staging_dir, recursive=True) + + await self._apply_metadata(session, dest) + return [] + + def _release_asset_url(self, asset_name: str) -> str: + if self.version == "latest": + return f"https://github.com/openai/codex/releases/latest/download/{asset_name}" + return ( + f"https://github.com/openai/codex/releases/download/rust-v{self.version}/{asset_name}" + ) + + async def _resolve_extracted_binary_path( + self, + *, + session: BaseSandboxSession, + staging_dir: Path, + ) -> str: + result = await session.exec( + f"find {staging_dir} -type f \\( -name codex -o -name 'codex-*' \\) | head -n 1" + ) + if not result.ok(): + raise RuntimeError("Codex binary not found in extracted archive.") + path = result.stdout.decode("utf-8", errors="replace").strip() + if not path: + raise RuntimeError("Codex binary not found in extracted archive.") + return path + + async def _copy_extracted_binary_to_destination( + self, + *, + session: BaseSandboxSession, + extracted_binary: str, + dest: Path, + ) -> None: + result = await session.exec("cp", extracted_binary, dest, shell=False) + if not result.ok(): + raise RuntimeError(result.stderr.decode("utf-8", errors="replace")) + + +async def resolve_codex_github_asset_name(*, session: BaseSandboxSession) -> str: + """Resolve the Codex GitHub release asset filename for the session target.""" + + target_triple = await resolve_codex_target_triple(session=session) + suffix = ".exe.tar.gz" if target_triple.endswith("windows-msvc") else ".tar.gz" + return f"codex-{target_triple}{suffix}" + + +async def resolve_codex_target_triple(*, session: BaseSandboxSession) -> str: + """Resolve the Codex release target triple for the session target platform.""" + + target_os = await _detect_target_os(session=session) + target_arch = await _detect_target_arch(session=session, target_os=target_os) + + if target_os == "linux": + libc = await _detect_linux_libc_variant(session=session) + return resolve_codex_target_triple_for_target( + target_os=target_os, + target_arch=target_arch, + linux_libc=libc, + ) + + return resolve_codex_target_triple_for_target( + target_os=target_os, + target_arch=target_arch, + ) + + +def resolve_codex_target_triple_for_target( + *, + target_os: str, + target_arch: str, + linux_libc: str | None = None, +) -> str: + normalized_os = target_os.strip().lower() + normalized_arch = target_arch.strip().lower() + canonical_arch = _CODEX_ARCH_ALIASES.get(normalized_arch) + + if normalized_os == "linux": + if canonical_arch is not None: + libc = linux_libc or "gnu" + if libc not in _SUPPORTED_CODEX_LINUX_LIBC_VARIANTS: + raise UnsupportedCodexTargetError( + reason="linux_libc", + target_os=target_os, + target_arch=target_arch, + linux_libc=linux_libc, + supported_operating_systems=_SUPPORTED_CODEX_OPERATING_SYSTEMS, + supported_architectures=_SUPPORTED_CODEX_ARCHITECTURES, + supported_linux_libc_variants=_SUPPORTED_CODEX_LINUX_LIBC_VARIANTS, + ) + return f"{canonical_arch}-unknown-linux-{libc}" + elif normalized_os == "darwin": + if canonical_arch is not None: + return f"{canonical_arch}-apple-darwin" + elif normalized_os == "windows": + if canonical_arch is not None: + return f"{canonical_arch}-pc-windows-msvc" + else: + raise UnsupportedCodexTargetError( + reason="operating_system", + target_os=target_os, + target_arch=target_arch, + supported_operating_systems=_SUPPORTED_CODEX_OPERATING_SYSTEMS, + supported_architectures=_SUPPORTED_CODEX_ARCHITECTURES, + ) + + raise UnsupportedCodexTargetError( + reason="architecture", + target_os=normalized_os, + target_arch=target_arch, + supported_operating_systems=_SUPPORTED_CODEX_OPERATING_SYSTEMS, + supported_architectures=_SUPPORTED_CODEX_ARCHITECTURES, + ) + + +async def _detect_target_os( + *, + session: BaseSandboxSession, +) -> Literal["linux", "darwin", "windows"]: + unix_result = await session.exec("uname", "-s", shell=False) + if unix_result.ok(): + system = unix_result.stdout.decode("utf-8", errors="replace").strip().lower() + if system == "linux": + return "linux" + if system == "darwin": + return "darwin" + + windows_result = await session.exec("cmd", "/c", "echo", "%OS%", shell=False) + if windows_result.ok(): + system = windows_result.stdout.decode("utf-8", errors="replace").strip().lower() + if system == "windows_nt": + return "windows" + + raise RuntimeError("Unable to detect sandbox target operating system.") + + +async def _detect_target_arch(*, session: BaseSandboxSession, target_os: str) -> str: + if target_os == "windows": + result = await session.exec( + "cmd", + "/c", + "echo", + "%PROCESSOR_ARCHITECTURE%", + shell=False, + ) + else: + result = await session.exec("uname", "-m", shell=False) + + if result.ok(): + return result.stdout.decode("utf-8", errors="replace").strip().lower() + + raise RuntimeError(f"Unable to detect sandbox target architecture for {target_os}.") + + +async def _detect_linux_libc_variant(*, session: BaseSandboxSession) -> Literal["gnu", "musl"]: + result = await session.exec("getconf", "GNU_LIBC_VERSION", shell=False) + if result.ok(): + return "gnu" + + result = await session.exec("ldd", "--version", shell=False) + combined = (result.stdout + result.stderr).decode("utf-8", errors="replace").lower() + if "musl" in combined: + return "musl" + if result.ok() and combined: + return "gnu" + + raise RuntimeError("Unable to detect Linux libc variant for Codex release asset.") + + +class _IteratorStreamWithLength(IteratorIO): + def __init__(self, it, *, content_length: int | None) -> None: + super().__init__(it=it) + self.content_length = content_length + + +def _stream_release_asset(url: str): + return httpx.stream("GET", url, follow_redirects=True) + + +def _parse_content_length(response: httpx.Response) -> int | None: + value = response.headers.get("Content-Length") + if value is None: + return None + try: + parsed = int(value) + except ValueError: + return None + return parsed if parsed >= 0 else None diff --git a/src/agents/sandbox/entries/mounts/__init__.py b/src/agents/sandbox/entries/mounts/__init__.py new file mode 100644 index 0000000000..5e68010de6 --- /dev/null +++ b/src/agents/sandbox/entries/mounts/__init__.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from .base import Mount +from .patterns import ( + FuseMountPattern, + MountPattern, + MountPatternBase, + MountpointMountPattern, + RcloneMountPattern, +) +from .providers import AzureBlobMount, GCSMount, S3Mount + +__all__ = [ + "AzureBlobMount", + "FuseMountPattern", + "GCSMount", + "Mount", + "MountPattern", + "MountPatternBase", + "MountpointMountPattern", + "RcloneMountPattern", + "S3Mount", +] diff --git a/src/agents/sandbox/entries/mounts/base.py b/src/agents/sandbox/entries/mounts/base.py new file mode 100644 index 0000000000..ab93c2bc94 --- /dev/null +++ b/src/agents/sandbox/entries/mounts/base.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +import abc +import warnings +from pathlib import Path +from typing import TYPE_CHECKING + +from pydantic import Field + +from ...materialization import MaterializedFile +from ...types import FileMode, Permissions +from ..base import BaseEntry + +if TYPE_CHECKING: + from ...session.base_sandbox_session import BaseSandboxSession + + +class Mount(BaseEntry): + is_dir: bool = True + mount_path: Path | None = None + ephemeral: bool = Field(default=True) + + def model_post_init(self, context: object, /) -> None: + _ = context + default_permissions = Permissions( + owner=FileMode.ALL, + group=FileMode.READ | FileMode.EXEC, + other=FileMode.READ | FileMode.EXEC, + ) + if ( + self.permissions.owner != default_permissions.owner + or self.permissions.group != default_permissions.group + or self.permissions.other != default_permissions.other + ): + warnings.warn( + "Mount permissions are not enforced. " + "Please configure access in the cloud provider instead; " + "mount-level permissions can be unreliable.", + stacklevel=2, + ) + self.permissions.owner = default_permissions.owner + self.permissions.group = default_permissions.group + self.permissions.other = default_permissions.other + self.permissions.directory = True + + async def apply( + self, + session: BaseSandboxSession, + dest: Path, + base_dir: Path, + ) -> list[MaterializedFile]: + _ = base_dir + mount_path = self._resolve_mount_path(session, dest) + await self.mount(session, mount_path) + return [] + + async def unmount( + self, + session: BaseSandboxSession, + dest: Path, + base_dir: Path, + ) -> None: + _ = base_dir + mount_path = self._resolve_mount_path(session, dest) + await self.unmount_path(session, mount_path) + + async def mount(self, session: BaseSandboxSession, path: Path) -> None: + await self._mount(session, path) + + async def unmount_path( + self, + session: BaseSandboxSession, + path: Path, + ) -> None: + await self._unmount(session, path) + + @abc.abstractmethod + async def _mount(self, session: BaseSandboxSession, path: Path) -> None: + raise NotImplementedError + + @abc.abstractmethod + async def _unmount(self, session: BaseSandboxSession, path: Path) -> None: + raise NotImplementedError + + def _resolve_mount_path( + self, + session: BaseSandboxSession, + dest: Path, + ) -> Path: + manifest_root = Path(getattr(session.state.manifest, "root", "/")) + return self._resolve_mount_path_for_root(manifest_root, dest) + + def _resolve_mount_path_for_root( + self, + manifest_root: Path, + dest: Path, + ) -> Path: + if self.mount_path is not None: + mount_path = Path(self.mount_path) + if mount_path.is_absolute(): + return mount_path + # Relative explicit mount paths are interpreted inside the active workspace root so a + # manifest can stay portable across backends with different concrete root prefixes. + return manifest_root / mount_path + + if dest.is_absolute(): + try: + rel_dest = dest.relative_to(manifest_root) + except ValueError: + return dest + # `dest` may already be normalized to an absolute workspace path; re-anchor it to the + # current manifest root instead of nesting the root twice. + return manifest_root / rel_dest + return manifest_root / dest diff --git a/src/agents/sandbox/entries/mounts/patterns.py b/src/agents/sandbox/entries/mounts/patterns.py new file mode 100644 index 0000000000..eef00dd7a0 --- /dev/null +++ b/src/agents/sandbox/entries/mounts/patterns.py @@ -0,0 +1,748 @@ +from __future__ import annotations + +import abc +import io +import re +import shlex +import warnings +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Annotated, Literal, TypeVar + +from pydantic import BaseModel, Field + +from ...errors import ( + MountCommandError, + MountConfigError, + MountToolMissingError, + WorkspaceReadNotFoundError, +) + +if TYPE_CHECKING: + from ...session.base_sandbox_session import BaseSandboxSession + + +@dataclass(frozen=True) +class FuseMountConfig: + account: str + container: str + endpoint: str | None + identity_client_id: str | None + account_key: str | None + mount_type: str + + +@dataclass(frozen=True) +class MountpointMountConfig: + bucket: str + access_key_id: str | None + secret_access_key: str | None + session_token: str | None + prefix: str | None + region: str | None + endpoint_url: str | None + mount_type: str + + +@dataclass(frozen=True) +class RcloneMountConfig: + remote_name: str + remote_path: str + remote_kind: str + mount_type: str + config_text: str | None = None + + +MountPatternConfig = FuseMountConfig | MountpointMountConfig | RcloneMountConfig +MountPatternConfigT = TypeVar("MountPatternConfigT", bound=MountPatternConfig) + + +def _require_mount_config( + config: MountPatternConfig, + expected_type: type[MountPatternConfigT], +) -> MountPatternConfigT: + if not isinstance(config, expected_type): + raise MountConfigError( + message="mount pattern received incompatible runtime config", + context={ + "expected": expected_type.__name__, + "actual": type(config).__name__, + }, + ) + return config + + +class MountPatternBase(BaseModel, abc.ABC): + @abc.abstractmethod + async def apply( + self, + session: BaseSandboxSession, + path: Path, + config: MountPatternConfig, + ) -> None: + raise NotImplementedError + + @abc.abstractmethod + async def unapply( + self, + session: BaseSandboxSession, + path: Path, + config: MountPatternConfig, + ) -> None: + raise NotImplementedError + + +class FuseMountPattern(MountPatternBase): + type: Literal["fuse"] = "fuse" + read_only: bool = Field(default=True) + allow_other: bool = Field(default=True) + log_type: str = Field(default="syslog") + log_level: str = Field(default="log_debug") + cache_type: Literal["block_cache", "file_cache"] = Field(default="block_cache") + cache_path: Path | None = None + cache_size_mb: int | None = None + block_cache_block_size_mb: int = Field(default=16) + block_cache_disk_timeout_sec: int = Field(default=3600) + file_cache_timeout_sec: int = Field(default=120) + file_cache_max_size_mb: int | None = None + attr_cache_timeout_sec: int | None = None + entry_cache_timeout_sec: int | None = None + negative_entry_cache_timeout_sec: int | None = None + + @dataclass(frozen=True) + class BlobfuseConfig: + account: str + container: str + endpoint: str + cache_type: str + cache_size_mb: int + block_cache_block_size_mb: int + block_cache_disk_timeout_sec: int + file_cache_timeout_sec: int + file_cache_max_size_mb: int + cache_dir: Path + allow_other: bool + log_type: str + log_level: str + entry_cache_timeout_sec: int | None + negative_entry_cache_timeout_sec: int | None + attr_cache_timeout_sec: int | None + identity_client_id: str | None + account_key: str | None + + def to_text(self) -> str: + lines: list[str] = [] + if self.allow_other: + lines.append("allow-other: true") + lines.append("") + lines.extend( + [ + "logging:", + f" type: {self.log_type}", + f" level: {self.log_level}", + "", + "components:", + " - libfuse", + f" - {self.cache_type}", + " - attr_cache", + " - azstorage", + "", + ] + ) + + libfuse_lines: list[str] = [] + if self.entry_cache_timeout_sec is not None: + libfuse_lines.append(f" entry-expiration-sec: {self.entry_cache_timeout_sec}") + if self.negative_entry_cache_timeout_sec is not None: + libfuse_lines.append( + f" negative-entry-expiration-sec: {self.negative_entry_cache_timeout_sec}" + ) + if libfuse_lines: + lines.append("libfuse:") + lines.extend(libfuse_lines) + lines.append("") + + if self.cache_type == "block_cache": + lines.extend( + [ + "block_cache:", + f" block-size-mb: {self.block_cache_block_size_mb}", + f" mem-size-mb: {self.cache_size_mb}", + f" path: {self.cache_dir}", + f" disk-size-mb: {self.cache_size_mb}", + f" disk-timeout-sec: {self.block_cache_disk_timeout_sec}", + "", + ] + ) + else: + lines.extend( + [ + "file_cache:", + f" path: {self.cache_dir}", + f" timeout-sec: {self.file_cache_timeout_sec}", + f" max-size-mb: {self.file_cache_max_size_mb}", + "", + ] + ) + + attr_cache_timeout = self.attr_cache_timeout_sec or 7200 + lines.extend( + [ + "attr_cache:", + f" timeout-sec: {attr_cache_timeout}", + "", + "azstorage:", + " type: block", + f" account-name: {self.account}", + f" container: {self.container}", + f" endpoint: {self.endpoint}", + ] + ) + if self.account_key: + lines.extend( + [ + " auth-type: key", + f" account-key: {self.account_key}", + ] + ) + else: + lines.append(" mode: msi") + if self.identity_client_id: + lines.append(f" identity-client-id: {self.identity_client_id}") + lines.append("") + return "\n".join(lines) + + async def apply( + self, + session: BaseSandboxSession, + path: Path, + config: MountPatternConfig, + ) -> None: + fuse_config = _require_mount_config(config, FuseMountConfig) + account = fuse_config.account + container = fuse_config.container + + tool_check = await session.exec("command -v blobfuse2 >/dev/null 2>&1") + if not tool_check.ok(): + raise MountToolMissingError( + tool="blobfuse2", + context={"account": account, "container": container}, + ) + + session_id = getattr(session.state, "session_id", None) + if session_id is None: + raise MountConfigError( + message="mount session is missing session_id", + context={"type": fuse_config.mount_type}, + ) + + mount_path = path + cache_dir = ( + Path(self.cache_path) + if self.cache_path is not None + else Path(f"/tmp/uc-blobfuse-cache/{session_id.hex}") / account / container + ) + config_dir = Path(f"/tmp/uc-blobfuse-config/{session_id.hex}") + config_name = f"{account}_{container}".replace("/", "_") + config_path = config_dir / f"{config_name}.yaml" + + await session.mkdir(mount_path, parents=True) + await session.mkdir(cache_dir, parents=True) + await session.mkdir(config_dir, parents=True) + + endpoint = fuse_config.endpoint or f"https://{account}.blob.core.windows.net" + cache_type = self.cache_type + cache_size_mb = self.cache_size_mb or (50_000 if cache_type == "block_cache" else 4_096) + file_cache_max_size_mb = self.file_cache_max_size_mb or cache_size_mb + blobfuse_config = self.BlobfuseConfig( + account=account, + container=container, + endpoint=endpoint, + cache_type=cache_type, + cache_size_mb=cache_size_mb, + block_cache_block_size_mb=self.block_cache_block_size_mb, + block_cache_disk_timeout_sec=self.block_cache_disk_timeout_sec, + file_cache_timeout_sec=self.file_cache_timeout_sec, + file_cache_max_size_mb=file_cache_max_size_mb, + cache_dir=cache_dir, + allow_other=self.allow_other, + log_type=self.log_type, + log_level=self.log_level, + entry_cache_timeout_sec=self.entry_cache_timeout_sec, + negative_entry_cache_timeout_sec=self.negative_entry_cache_timeout_sec, + attr_cache_timeout_sec=self.attr_cache_timeout_sec, + identity_client_id=fuse_config.identity_client_id, + account_key=fuse_config.account_key, + ) + config_payload = blobfuse_config.to_text().encode("utf-8") + await session.write(config_path, io.BytesIO(config_payload)) + + cmd: list[str] = ["blobfuse2", "mount"] + if self.read_only: + cmd.append("--read-only") + cmd.extend(["--config-file", str(config_path)]) + cmd.append(str(mount_path)) + + result = await session.exec(*cmd, shell=False) + if not result.ok(): + raise MountCommandError( + command=" ".join(cmd), + stderr=result.stderr.decode("utf-8", errors="replace"), + context={"account": account, "container": container}, + ) + + async def unapply( + self, + session: BaseSandboxSession, + path: Path, + config: MountPatternConfig, + ) -> None: + _ = _require_mount_config(config, FuseMountConfig) + # Best-effort unmount; ignore failures for already-unmounted mounts. + await session.exec( + "sh", + "-lc", + f"fusermount3 -u {shlex.quote(str(path))} || umount {shlex.quote(str(path))}", + shell=False, + ) + + +class MountpointMountPattern(MountPatternBase): + type: Literal["mountpoint"] = "mountpoint" + read_only: bool = Field(default=True) + + @dataclass(frozen=True) + class MountpointOptions: + prefix: str | None = None + region: str | None = None + endpoint_url: str | None = None + + options: MountpointOptions = Field(default_factory=MountpointOptions) + + async def apply( + self, + session: BaseSandboxSession, + path: Path, + config: MountPatternConfig, + ) -> None: + mountpoint_config = _require_mount_config(config, MountpointMountConfig) + bucket = mountpoint_config.bucket + + tool_check = await session.exec("command -v mount-s3 >/dev/null 2>&1") + if not tool_check.ok(): + raise MountToolMissingError( + tool="mount-s3", + context={"bucket": bucket}, + ) + + await session.mkdir(path, parents=True) + + cmd: list[str] = ["mount-s3"] + if self.read_only: + cmd.append("--read-only") + if mountpoint_config.region: + cmd.extend(["--region", mountpoint_config.region]) + if mountpoint_config.endpoint_url: + cmd.extend(["--endpoint-url", mountpoint_config.endpoint_url]) + if mountpoint_config.prefix: + cmd.extend(["--prefix", mountpoint_config.prefix]) + cmd.extend([bucket, str(path)]) + + env_parts: list[str] = [] + access_key_id = mountpoint_config.access_key_id + secret_access_key = mountpoint_config.secret_access_key + session_token = mountpoint_config.session_token + if access_key_id and secret_access_key: + env_parts.append(f"AWS_ACCESS_KEY_ID={shlex.quote(access_key_id)}") + env_parts.append(f"AWS_SECRET_ACCESS_KEY={shlex.quote(secret_access_key)}") + if session_token: + env_parts.append(f"AWS_SESSION_TOKEN={shlex.quote(session_token)}") + + joined_cmd = " ".join(shlex.quote(part) for part in cmd) + if env_parts: + joined_cmd = f"{' '.join(env_parts)} {joined_cmd}" + + result = await session.exec("sh", "-lc", joined_cmd, shell=False) + if not result.ok(): + raise MountCommandError( + command=joined_cmd, + stderr=result.stderr.decode("utf-8", errors="replace"), + context={"bucket": bucket}, + ) + + async def unapply( + self, + session: BaseSandboxSession, + path: Path, + config: MountPatternConfig, + ) -> None: + _ = _require_mount_config(config, MountpointMountConfig) + await session.exec( + "sh", + "-lc", + f"fusermount3 -u {shlex.quote(str(path))} || umount {shlex.quote(str(path))}", + shell=False, + ) + + +def _supplement_rclone_config_text( + *, + config_text: str, + remote_name: str, + required_lines: list[str], + mount_type: str | None, +) -> str: + section_pattern = re.compile(rf"^\s*\[{re.escape(remote_name)}\]\s*$", re.MULTILINE) + match = section_pattern.search(config_text) + if not match: + raise MountConfigError( + message="rclone config missing required remote section", + context={"type": mount_type or "mount", "remote_name": remote_name}, + ) + + section_start = match.start() + section_end = match.end() + next_section = re.search(r"^\s*\[.+\]\s*$", config_text[section_end:], re.MULTILINE) + if next_section: + section_body_end = section_end + next_section.start() + else: + section_body_end = len(config_text) + + before = config_text[:section_start] + section_body = config_text[section_start:section_body_end].rstrip("\n") + after = config_text[section_body_end:] + + supplement = "\n".join(required_lines[1:]) # header already present + merged_section = f"{section_body}\n{supplement}\n" + return f"{before}{merged_section}{after}" + + +class RcloneMountPattern(MountPatternBase): + type: Literal["rclone"] = "rclone" + mode: Literal["fuse", "nfs"] = Field(default="fuse") + read_only: bool = Field(default=True) + remote_name: str | None = None + extra_args: list[str] = Field(default_factory=list) + nfs_addr: str | None = None + nfs_mount_options: list[str] | None = None + config_file_path: Path | None = None + + def resolve_remote_name( + self, + *, + session_id: str, + remote_kind: str, + mount_type: str | None = None, + ) -> str: + if self.remote_name: + return self.remote_name + if not remote_kind: + raise MountConfigError( + message="rclone mount requires remote_kind", + context={"type": mount_type or "mount"}, + ) + # Derive a deterministic per-session remote name when the caller did not pin one, so + # multiple mounts can coexist without sharing mutable rclone config sections. + return f"uc_{remote_kind}_{session_id}" + + def _resolve_config_path( + self, + session: BaseSandboxSession, + config_path: Path, + ) -> Path: + manifest_root = Path(getattr(session.state.manifest, "root", "/")) + if config_path.is_absolute(): + return config_path + # Relative config paths are resolved inside the sandbox workspace, not relative to the + # host process that is orchestrating the session. + return manifest_root / config_path + + async def read_config_text( + self, + session: BaseSandboxSession, + remote_name: str, + *, + mount_type: str | None, + ) -> str: + if self.config_file_path is None: + raise MountConfigError( + message="rclone config_file_path is not set", + context={"type": mount_type or "mount"}, + ) + config_path = self._resolve_config_path(session, self.config_file_path) + try: + handle = await session.read(config_path) + except WorkspaceReadNotFoundError: + raise + except FileNotFoundError as e: + raise WorkspaceReadNotFoundError(path=config_path, cause=e) from e + except Exception as e: + raise MountConfigError( + message="failed to read rclone config file", + context={"type": mount_type or "mount", "path": str(config_path)}, + ) from e + + try: + raw_config = handle.read() + finally: + handle.close() + if isinstance(raw_config, bytes): + config_text = raw_config.decode("utf-8", errors="replace") + elif isinstance(raw_config, str): + config_text = raw_config + else: + config_text = str(raw_config) + + if not config_text.strip(): + raise MountConfigError( + message="rclone config file is empty", + context={"type": mount_type or "mount", "path": str(config_path)}, + ) + + section_pattern = rf"^\s*\[{re.escape(remote_name)}\]\s*$" + if not re.search(section_pattern, config_text, re.MULTILINE): + raise MountConfigError( + message="rclone config missing required remote section", + context={ + "type": mount_type or "mount", + "path": str(config_path), + "remote_name": remote_name, + }, + ) + + return config_text + + async def _start_rclone_server( + self, + session: BaseSandboxSession, + *, + config: RcloneMountConfig, + config_path: Path, + nfs_addr: str, + ) -> None: + nfs_check = await session.exec( + "sh", + "-lc", + "/usr/local/bin/rclone serve nfs --help >/dev/null 2>&1" + " || rclone serve nfs --help >/dev/null 2>&1", + shell=False, + ) + if not nfs_check.ok(): + raise MountToolMissingError( + tool="rclone serve nfs", + context={"type": config.mount_type}, + ) + cmd: list[str] = ["rclone", "serve", "nfs", f"{config.remote_name}:{config.remote_path}"] + cmd.extend(["--addr", nfs_addr]) + cmd.extend(["--config", str(config_path)]) + if self.extra_args: + cmd.extend(self.extra_args) + joined_cmd = " ".join(shlex.quote(part) for part in cmd) + # Run in background so we can wait for the server to start. + server_cmd = f"{joined_cmd} &" + result = await session.exec("sh", "-lc", server_cmd, shell=False) + if not result.ok(): + raise MountCommandError( + command=" ".join(cmd), + stderr=result.stderr.decode("utf-8", errors="replace"), + context={"type": config.mount_type}, + ) + + async def _start_rclone_client( + self, + session: BaseSandboxSession, + *, + path: Path, + config: RcloneMountConfig, + config_path: Path, + nfs_addr: str | None = None, + ) -> None: + if self.mode == "fuse": + cmd: list[str] = [ + "rclone", + "mount", + f"{config.remote_name}:{config.remote_path}", + str(path), + ] + if self.read_only: + cmd.append("--read-only") + cmd.extend(["--config", str(config_path), "--daemon"]) + if self.extra_args: + cmd.extend(self.extra_args) + result = await session.exec(*cmd, shell=False) + if not result.ok(): + raise MountCommandError( + command=" ".join(cmd), + stderr=result.stderr.decode("utf-8", errors="replace"), + context={"type": config.mount_type}, + ) + return + + if nfs_addr is None: + raise MountConfigError( + message="nfs_addr required for rclone nfs client", + context={"type": config.mount_type}, + ) + + nfs_supported = await session.exec( + "sh", "-lc", "grep -w nfs /proc/filesystems", shell=False + ) + if not nfs_supported.ok(): + warnings.warn( + "NFS client support not detected; attempting mount anyway. " + "If it fails, use rclone fuse mode or run on a kernel with NFS support.", + stacklevel=2, + ) + + # Default to localhost if no NFS address is provided + host = "127.0.0.1" + port = "2049" + + if ":" in nfs_addr: + host, port = nfs_addr.rsplit(":", 1) + else: + host = nfs_addr + if host in {"0.0.0.0", "::"}: + host = "127.0.0.1" + + mount_options = self.nfs_mount_options or [ + "vers=4.1", + "tcp", + f"port={port}", + "soft", + "timeo=50", + "retrans=1", + ] + option_arg = ",".join(mount_options) + timeout_check = await session.exec( + "sh", "-lc", "command -v timeout >/dev/null 2>&1", shell=False + ) + timeout_prefix = "timeout 10s " if timeout_check.ok() else "" + mount_cmd_string = " ".join( + [ + "for i in 1 2 3; do", + f"{timeout_prefix}mount", + "-v", + "-t", + "nfs", + "-o", + shlex.quote(option_arg), + f"{shlex.quote(host)}:/", + shlex.quote(str(path)), + "&& exit 0; sleep 1; done; exit 1", + ] + ) + mount_cmd = ( + "sh", + "-lc", + mount_cmd_string, + ) + mount_result = await session.exec(*mount_cmd, shell=False) + if not mount_result.ok(): + raise MountCommandError( + command=" ".join(mount_cmd), + stderr=mount_result.stderr.decode("utf-8", errors="replace"), + context={"type": config.mount_type}, + ) + + async def apply( + self, + session: BaseSandboxSession, + path: Path, + config: MountPatternConfig, + ) -> None: + rclone_config = _require_mount_config(config, RcloneMountConfig) + tool_check = await session.exec( + "sh", + "-lc", + "command -v rclone >/dev/null 2>&1 || test -x /usr/local/bin/rclone", + shell=False, + ) + if not tool_check.ok(): + raise MountToolMissingError( + tool="rclone", + context={"type": rclone_config.mount_type}, + ) + + if rclone_config.config_text is None: + raise MountConfigError( + message="rclone mount requires config_text", + context={"type": rclone_config.mount_type}, + ) + + session_id = getattr(session.state, "session_id", None) + if session_id is None: + raise MountConfigError( + message="mount session is missing session_id", + context={"type": rclone_config.mount_type}, + ) + session_id_str = session_id.hex + config_dir = Path(f"/tmp/uc-rclone-config/{session_id_str}") + config_path = config_dir / f"{rclone_config.remote_name}.conf" + await session.mkdir(path, parents=True) + await session.mkdir(config_dir, parents=True) + # Always write an isolated config file for the live mount operation so provider-specific + # augmentation does not mutate a shared source config in the workspace. + await session.write(config_path, io.BytesIO(rclone_config.config_text.encode("utf-8"))) + + if self.mode == "nfs": + nfs_addr = self.nfs_addr or "127.0.0.1:2049" + await self._start_rclone_server( + session, + config=rclone_config, + config_path=config_path, + nfs_addr=nfs_addr, + ) + await self._start_rclone_client( + session, + path=path, + config=rclone_config, + config_path=config_path, + nfs_addr=nfs_addr, + ) + else: + # fuse mode + await self._start_rclone_client( + session, + path=path, + config=rclone_config, + config_path=config_path, + ) + + async def unapply( + self, + session: BaseSandboxSession, + path: Path, + config: MountPatternConfig, + ) -> None: + rclone_config = _require_mount_config(config, RcloneMountConfig) + if self.mode == "fuse": + await session.exec( + "sh", + "-lc", + f"fusermount3 -u {shlex.quote(str(path))} || umount {shlex.quote(str(path))}", + shell=False, + ) + if self.mode == "nfs": + await session.exec( + "sh", + "-lc", + f"umount {shlex.quote(str(path))} >/dev/null 2>&1 || true", + shell=False, + ) + + await session.exec( + "sh", + "-lc", + ( + "pkill -f -- " + f"'rclone (mount|serve nfs) {rclone_config.remote_name}:' >/dev/null 2>&1 || true" + ), + shell=False, + ) + + +MountPattern = Annotated[ + FuseMountPattern | MountpointMountPattern | RcloneMountPattern, + Field(discriminator="type"), +] diff --git a/src/agents/sandbox/entries/mounts/providers.py b/src/agents/sandbox/entries/mounts/providers.py new file mode 100644 index 0000000000..8cba989eb4 --- /dev/null +++ b/src/agents/sandbox/entries/mounts/providers.py @@ -0,0 +1,320 @@ +from __future__ import annotations + +import abc +import uuid +from pathlib import Path +from typing import TYPE_CHECKING, Literal + +from ...errors import MountConfigError +from .base import Mount +from .patterns import ( + FuseMountConfig, + FuseMountPattern, + MountPattern, + MountPatternConfig, + MountpointMountConfig, + MountpointMountPattern, + RcloneMountConfig, + RcloneMountPattern, + _supplement_rclone_config_text, +) + +if TYPE_CHECKING: + from ...session.base_sandbox_session import BaseSandboxSession + + +class _ConfiguredMount(Mount, abc.ABC): + mount_pattern: MountPattern | None = None + + def _require_mount_pattern(self) -> MountPattern: + if self.mount_pattern is None: + raise MountConfigError( + message=f"{self.type} requires mount_pattern", + context={"type": self.type}, + ) + return self.mount_pattern + + @staticmethod + def _require_session_id_hex(session: BaseSandboxSession, mount_type: str) -> str: + session_id = getattr(session.state, "session_id", None) + if not isinstance(session_id, uuid.UUID): + raise MountConfigError( + message="mount session is missing session_id", + context={"type": mount_type}, + ) + return session_id.hex + + async def _build_rclone_config( + self, + *, + session: BaseSandboxSession, + pattern: RcloneMountPattern, + remote_kind: str, + remote_path: str, + required_lines: list[str], + include_config_text: bool, + ) -> RcloneMountConfig: + remote_name = pattern.resolve_remote_name( + session_id=self._require_session_id_hex(session, self.type), + remote_kind=remote_kind, + mount_type=self.type, + ) + config_text: str | None = None + if include_config_text: + if pattern.config_file_path is not None: + config_text = await pattern.read_config_text( + session, + remote_name, + mount_type=self.type, + ) + config_text = _supplement_rclone_config_text( + config_text=config_text, + remote_name=remote_name, + required_lines=required_lines, + mount_type=self.type, + ) + else: + config_text = "\n".join(required_lines) + "\n" + return RcloneMountConfig( + remote_name=remote_name, + remote_path=remote_path, + remote_kind=remote_kind, + mount_type=self.type, + config_text=config_text, + ) + + @abc.abstractmethod + async def _build_mount_config( + self, + session: BaseSandboxSession, + pattern: MountPattern, + *, + include_config_text: bool, + ) -> MountPatternConfig: + raise NotImplementedError + + async def _mount(self, session: BaseSandboxSession, path: Path) -> None: + pattern = self._require_mount_pattern() + config = await self._build_mount_config(session, pattern, include_config_text=True) + await pattern.apply(session, path, config) + + async def _unmount(self, session: BaseSandboxSession, path: Path) -> None: + pattern = self._require_mount_pattern() + config = await self._build_mount_config(session, pattern, include_config_text=False) + await pattern.unapply(session, path, config) + + +class AzureBlobMount(_ConfiguredMount): + type: Literal["azure_blob_mount"] = "azure_blob_mount" + account: str # AZURE_STORAGE_ACCOUNT + container: str # AZURE_STORAGE_CONTAINER + endpoint: str | None = None + identity_client_id: str | None = None # AZURE_CLIENT_ID + account_key: str | None = None # AZURE_STORAGE_ACCOUNT_KEY + + def model_post_init(self, context: object, /) -> None: + super().model_post_init(context) + pattern = self._require_mount_pattern() + if not isinstance(pattern, (RcloneMountPattern, FuseMountPattern)): + raise MountConfigError( + message="invalid mount_pattern type", + context={"type": self.type}, + ) + + async def _build_mount_config( + self, + session: BaseSandboxSession, + pattern: MountPattern, + *, + include_config_text: bool, + ) -> MountPatternConfig: + if isinstance(pattern, RcloneMountPattern): + return await self._build_rclone_config( + session=session, + pattern=pattern, + remote_kind="azureblob", + remote_path=self.container, + required_lines=self._rclone_required_lines( + pattern.resolve_remote_name( + session_id=self._require_session_id_hex(session, self.type), + remote_kind="azureblob", + mount_type=self.type, + ) + ), + include_config_text=include_config_text, + ) + if isinstance(pattern, FuseMountPattern): + return FuseMountConfig( + account=self.account, + container=self.container, + endpoint=self.endpoint, + identity_client_id=self.identity_client_id, + account_key=self.account_key, + mount_type=self.type, + ) + raise MountConfigError( + message="invalid mount_pattern type", + context={"type": self.type}, + ) + + def _rclone_required_lines(self, remote_name: str) -> list[str]: + lines = [ + f"[{remote_name}]", + "type = azureblob", + f"account = {self.account}", + ] + if self.endpoint: + lines.append(f"endpoint = {self.endpoint}") + if self.account_key: + lines.append(f"key = {self.account_key}") + else: + lines.append("use_msi = true") + if self.identity_client_id: + lines.append(f"msi_client_id = {self.identity_client_id}") + return lines + + +class S3Mount(_ConfiguredMount): + type: Literal["s3_mount"] = "s3_mount" + bucket: str + access_key_id: str | None = None + secret_access_key: str | None = None + session_token: str | None = None + + def model_post_init(self, context: object, /) -> None: + super().model_post_init(context) + pattern = self._require_mount_pattern() + if not isinstance(pattern, (RcloneMountPattern, MountpointMountPattern)): + raise MountConfigError( + message="invalid mount_pattern type", + context={"type": self.type}, + ) + + async def _build_mount_config( + self, + session: BaseSandboxSession, + pattern: MountPattern, + *, + include_config_text: bool, + ) -> MountPatternConfig: + if isinstance(pattern, RcloneMountPattern): + return await self._build_rclone_config( + session=session, + pattern=pattern, + remote_kind="s3", + remote_path=self.bucket, + required_lines=self._rclone_required_lines( + pattern.resolve_remote_name( + session_id=self._require_session_id_hex(session, self.type), + remote_kind="s3", + mount_type=self.type, + ) + ), + include_config_text=include_config_text, + ) + if isinstance(pattern, MountpointMountPattern): + options = pattern.options + return MountpointMountConfig( + bucket=self.bucket, + access_key_id=self.access_key_id, + secret_access_key=self.secret_access_key, + session_token=self.session_token, + prefix=options.prefix, + region=options.region, + endpoint_url=options.endpoint_url, + mount_type=self.type, + ) + raise MountConfigError( + message="invalid mount_pattern type", + context={"type": self.type}, + ) + + def _rclone_required_lines(self, remote_name: str) -> list[str]: + lines = [ + f"[{remote_name}]", + "type = s3", + "provider = AWS", + ] + if self.access_key_id and self.secret_access_key: + lines.append("env_auth = false") + lines.append(f"access_key_id = {self.access_key_id}") + lines.append(f"secret_access_key = {self.secret_access_key}") + if self.session_token: + lines.append(f"session_token = {self.session_token}") + else: + lines.append("env_auth = true") + return lines + + +class GCSMount(_ConfiguredMount): + type: Literal["gcs_mount"] = "gcs_mount" + bucket: str + access_id: str | None = None + secret_access_key: str | None = None + + def model_post_init(self, context: object, /) -> None: + super().model_post_init(context) + if self.mount_pattern is None: + # GCS defaults to the S3-compatible mountpoint path so examples can omit the pattern + # unless they specifically need rclone behavior. + self.mount_pattern = MountpointMountPattern() + pattern = self._require_mount_pattern() + if not isinstance(pattern, (RcloneMountPattern, MountpointMountPattern)): + raise MountConfigError( + message="invalid mount_pattern type", + context={"type": self.type}, + ) + + async def _build_mount_config( + self, + session: BaseSandboxSession, + pattern: MountPattern, + *, + include_config_text: bool, + ) -> MountPatternConfig: + if isinstance(pattern, RcloneMountPattern): + return await self._build_rclone_config( + session=session, + pattern=pattern, + remote_kind="gcs", + remote_path=self.bucket, + required_lines=self._rclone_required_lines( + pattern.resolve_remote_name( + session_id=self._require_session_id_hex(session, self.type), + remote_kind="gcs", + mount_type=self.type, + ) + ), + include_config_text=include_config_text, + ) + if isinstance(pattern, MountpointMountPattern): + options = pattern.options + return MountpointMountConfig( + bucket=self.bucket, + access_key_id=self.access_id, + secret_access_key=self.secret_access_key, + session_token=None, + prefix=options.prefix, + region=options.region, + endpoint_url=options.endpoint_url or "https://storage.googleapis.com", + mount_type=self.type, + ) + raise MountConfigError( + message="invalid mount_pattern type", + context={"type": self.type}, + ) + + def _rclone_required_lines(self, remote_name: str) -> list[str]: + lines = [ + f"[{remote_name}]", + "type = s3", + "provider = GCS", + "endpoint = https://storage.googleapis.com", + ] + if self.access_id and self.secret_access_key: + lines.append("env_auth = false") + lines.append(f"access_key_id = {self.access_id}") + lines.append(f"secret_access_key = {self.secret_access_key}") + else: + lines.append("env_auth = true") + return lines diff --git a/src/agents/sandbox/errors.py b/src/agents/sandbox/errors.py new file mode 100644 index 0000000000..7ccc759e76 --- /dev/null +++ b/src/agents/sandbox/errors.py @@ -0,0 +1,739 @@ +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Literal + +from .types import ExecResult + + +class ErrorCode(str, Enum): + """Stable, machine-readable error codes for `UniversalComputerError`.""" + + def __str__(self) -> str: + return str(self.value) + + INVALID_MANIFEST_PATH = "invalid_manifest_path" + INVALID_COMPRESSION_SCHEME = "invalid_compression_scheme" + EXEC_NONZERO = "exec_nonzero" + EXEC_TIMEOUT = "exec_timeout" + EXEC_TRANSPORT_ERROR = "exec_transport_error" + + WORKSPACE_READ_NOT_FOUND = "workspace_read_not_found" + WORKSPACE_ARCHIVE_READ_ERROR = "workspace_archive_read_error" + WORKSPACE_ARCHIVE_WRITE_ERROR = "workspace_archive_write_error" + WORKSPACE_WRITE_TYPE_ERROR = "workspace_write_type_error" + WORKSPACE_STOP_ERROR = "workspace_stop_error" + WORKSPACE_START_ERROR = "workspace_start_error" + WORKSPACE_ROOT_NOT_FOUND = "workspace_root_not_found" + + LOCAL_FILE_READ_ERROR = "local_file_read_error" + LOCAL_DIR_READ_ERROR = "local_dir_read_error" + LOCAL_CHECKSUM_ERROR = "local_checksum_error" + + GIT_MISSING_IN_IMAGE = "git_missing_in_image" + GIT_CLONE_ERROR = "git_clone_error" + GIT_COPY_ERROR = "git_copy_error" + UNSUPPORTED_CODEX_TARGET = "unsupported_codex_target" + + MOUNT_MISSING_TOOL = "mount_missing_tool" + MOUNT_FAILED = "mount_failed" + MOUNT_CONFIG_INVALID = "mount_config_invalid" + SKILLS_CONFIG_INVALID = "skills_config_invalid" + + SNAPSHOT_PERSIST_ERROR = "snapshot_persist_error" + SNAPSHOT_RESTORE_ERROR = "snapshot_restore_error" + SNAPSHOT_NOT_RESTORABLE = "snapshot_not_restorable" + + +OpName = Literal[ + "start", + "stop", + "exec", + "read", + "write", + "shutdown", + "running", + "persist_workspace", + "hydrate_workspace", + "materialize", + "snapshot_persist", + "snapshot_restore", +] + + +@dataclass(eq=False) +class UniversalComputerError(Exception): + """Base class for structured, user-facing sandbox errors. + + Attributes: + message: Human-readable error message. + error_code: Stable, machine-readable code for programmatic handling. + op: The operation where the error occurred. + context: Structured metadata to aid debugging. + cause: Optional underlying exception. + """ + + message: str + error_code: ErrorCode + op: OpName + context: dict[str, object] + cause: BaseException | None = None + + def __post_init__(self) -> None: + super().__init__(self.message) + if self.cause is not None: + self.__cause__ = self.cause + + @property + def code(self) -> str: + """Backward-compatible alias for `error_code`.""" + + return str(self.error_code) + + +class ConfigurationError(UniversalComputerError): + """Raised when validating user-provided configuration and inputs.""" + + +class SandboxError(UniversalComputerError): + """Raised for sandbox failures (e.g., Docker/IO/transport).""" + + +class ArtifactError(UniversalComputerError): + """Raised while materializing input artifacts (local files, git repos).""" + + +class SnapshotError(UniversalComputerError): + """Raised for snapshot persist/restore errors.""" + + +def _as_context(context: Mapping[str, object] | None) -> dict[str, object]: + return dict(context or {}) + + +def _format_command(command: Sequence[str | Path]) -> str: + return " ".join(str(p) for p in command) + + +class UnsupportedCodexTargetError(ArtifactError): + """Codex release assets do not support the detected target platform.""" + + reason: Literal["operating_system", "architecture", "linux_libc"] + target_os: str + target_arch: str + linux_libc: str | None + supported_operating_systems: tuple[str, ...] + supported_architectures: tuple[str, ...] + supported_linux_libc_variants: tuple[str, ...] + + def __init__( + self, + *, + reason: Literal["operating_system", "architecture", "linux_libc"], + target_os: str, + target_arch: str, + linux_libc: str | None = None, + supported_operating_systems: Sequence[str], + supported_architectures: Sequence[str], + supported_linux_libc_variants: Sequence[str] = (), + context: Mapping[str, object] | None = None, + cause: BaseException | None = None, + ) -> None: + supported_os = tuple(supported_operating_systems) + supported_arches = tuple(supported_architectures) + supported_libc_variants = tuple(supported_linux_libc_variants) + + if reason == "operating_system": + message = ( + f"Unsupported Codex target operating system: {target_os}. " + f"Available operating systems: {', '.join(supported_os)}." + ) + elif reason == "architecture": + message = ( + f"Unsupported Codex target architecture for {target_os}: {target_arch}. " + f"Available architectures: {', '.join(supported_arches)}." + ) + else: + message = ( + "Unsupported Linux libc variant for Codex target resolution: " + f"{linux_libc}. Available libc variants: {', '.join(supported_libc_variants)}." + ) + + super().__init__( + message=message, + error_code=ErrorCode.UNSUPPORTED_CODEX_TARGET, + op="materialize", + context={ + "reason": reason, + "target_os": target_os, + "target_arch": target_arch, + "linux_libc": linux_libc, + "supported_operating_systems": supported_os, + "supported_architectures": supported_arches, + "supported_linux_libc_variants": supported_libc_variants, + **_as_context(context), + }, + cause=cause, + ) + self.reason = reason + self.target_os = target_os + self.target_arch = target_arch + self.linux_libc = linux_libc + self.supported_operating_systems = supported_os + self.supported_architectures = supported_arches + self.supported_linux_libc_variants = supported_libc_variants + + +class InvalidManifestPathError(ConfigurationError): + """Manifest path was invalid (absolute or escaped the workspace root).""" + + def __init__( + self, + *, + rel: str | Path, + reason: Literal["absolute", "escape_root"], + context: Mapping[str, object] | None = None, + cause: BaseException | None = None, + ) -> None: + msg = ( + f"manifest path must be relative: {rel}" + if reason == "absolute" + else f"manifest path must not escape root: {rel}" + ) + super().__init__( + message=msg, + error_code=ErrorCode.INVALID_MANIFEST_PATH, + op="materialize", + context={"rel": str(rel), "reason": reason, **_as_context(context)}, + cause=cause, + ) + + +class InvalidCompressionSchemeError(ConfigurationError): + """Compression scheme was missing or unsupported for a workspace write.""" + + def __init__( + self, + *, + path: Path, + scheme: str | None, + context: Mapping[str, object] | None = None, + cause: BaseException | None = None, + ) -> None: + msg = ( + "could not determine compression scheme" + if not scheme + else "compression scheme must be one of 'zip' 'tar'" + ) + super().__init__( + message=msg, + error_code=ErrorCode.INVALID_COMPRESSION_SCHEME, + op="write", + context={"path": str(path), "scheme": scheme, **_as_context(context)}, + cause=cause, + ) + + +class ExecFailureError(SandboxError): + """Base class for exec()-related failures.""" + + command: tuple[str, ...] + + def __init__( + self, + *, + message: str, + error_code: ErrorCode, + command: Sequence[str | Path], + context: Mapping[str, object] | None = None, + cause: BaseException | None = None, + ) -> None: + cmd = tuple(str(c) for c in command) + super().__init__( + message=message, + error_code=error_code, + op="exec", + context={"command": cmd, "command_str": _format_command(cmd), **_as_context(context)}, + cause=cause, + ) + self.command = cmd + + +class ExecNonZeroError(ExecFailureError): + """exec() returned a non-zero exit status.""" + + exit_code: int + stdout: bytes + stderr: bytes + + def __init__( + self, + exec_result: ExecResult, + *, + command: Sequence[str | Path], + context: Mapping[str, object] | None = None, + cause: BaseException | None = None, + ) -> None: + super().__init__( + message=exec_result.stderr.decode("utf-8", errors="replace"), + error_code=ErrorCode.EXEC_NONZERO, + command=command, + context={ + "exit_code": exec_result.exit_code, + **_as_context(context), + }, + cause=cause, + ) + self.exit_code = exec_result.exit_code + self.stdout = exec_result.stdout + self.stderr = exec_result.stderr + + +class ExecTimeoutError(ExecFailureError): + """exec() exceeded its timeout.""" + + timeout_s: float | None + + def __init__( + self, + *, + command: Sequence[str | Path], + timeout_s: float | None, + context: Mapping[str, object] | None = None, + cause: BaseException | None = None, + ) -> None: + super().__init__( + message="command timed out", + error_code=ErrorCode.EXEC_TIMEOUT, + command=command, + context={"timeout_s": timeout_s, **_as_context(context)}, + cause=cause, + ) + self.timeout_s = timeout_s + + +class ExecTransportError(ExecFailureError): + """exec() failed due to a transport-level error (e.g., Docker API).""" + + def __init__( + self, + *, + command: Sequence[str | Path], + context: Mapping[str, object] | None = None, + cause: BaseException | None = None, + ) -> None: + super().__init__( + message="exec transport error", + error_code=ErrorCode.EXEC_TRANSPORT_ERROR, + command=command, + context=_as_context(context), + cause=cause, + ) + + +class WorkspaceIOError(SandboxError): + """Base class for workspace read/write errors.""" + + +class WorkspaceReadNotFoundError(WorkspaceIOError): + """Workspace read failed because the path does not exist.""" + + def __init__( + self, + *, + path: Path, + context: Mapping[str, object] | None = None, + cause: BaseException | None = None, + ) -> None: + super().__init__( + message=f"file not found: {path}", + error_code=ErrorCode.WORKSPACE_READ_NOT_FOUND, + op="read", + context={"path": str(path), **_as_context(context)}, + cause=cause, + ) + + +class WorkspaceArchiveReadError(WorkspaceIOError): + """Workspace read failed while reading or decoding the archive stream.""" + + def __init__( + self, + *, + path: Path, + context: Mapping[str, object] | None = None, + cause: BaseException | None = None, + ) -> None: + super().__init__( + message=f"failed to read archive for path: {path}", + error_code=ErrorCode.WORKSPACE_ARCHIVE_READ_ERROR, + op="read", + context={"path": str(path), **_as_context(context)}, + cause=cause, + ) + + +class WorkspaceArchiveWriteError(WorkspaceIOError): + """Workspace write failed while creating or sending the archive stream.""" + + def __init__( + self, + *, + path: Path, + context: Mapping[str, object] | None = None, + cause: BaseException | None = None, + ) -> None: + super().__init__( + message=f"failed to write archive for path: {path}", + error_code=ErrorCode.WORKSPACE_ARCHIVE_WRITE_ERROR, + op="write", + context={"path": str(path), **_as_context(context)}, + cause=cause, + ) + + +class WorkspaceWriteTypeError(WorkspaceIOError): + """Workspace write payload was not a binary file-like object.""" + + def __init__( + self, + *, + path: Path, + actual_type: str, + context: Mapping[str, object] | None = None, + cause: BaseException | None = None, + ) -> None: + super().__init__( + message="write() expects a binary file-like object", + error_code=ErrorCode.WORKSPACE_WRITE_TYPE_ERROR, + op="write", + context={"path": str(path), "actual_type": actual_type, **_as_context(context)}, + cause=cause, + ) + + +class WorkspaceStopError(SandboxError): + """SandboxSession stop failed (typically during snapshot persistence).""" + + def __init__( + self, + *, + path: Path, + context: Mapping[str, object] | None = None, + cause: BaseException | None = None, + ) -> None: + super().__init__( + message="failed to stop session", + error_code=ErrorCode.WORKSPACE_STOP_ERROR, + op="stop", + context={"path": str(path), **_as_context(context)}, + cause=cause, + ) + + +class WorkspaceStartError(SandboxError): + """SandboxSession start failed (typically while ensuring the workspace root exists).""" + + def __init__( + self, + *, + path: Path, + context: Mapping[str, object] | None = None, + cause: BaseException | None = None, + ) -> None: + super().__init__( + message="failed to start session", + error_code=ErrorCode.WORKSPACE_START_ERROR, + op="start", + context={"path": str(path), **_as_context(context)}, + cause=cause, + ) + + +class WorkspaceRootNotFoundError(SandboxError): + """Workspace root is missing on disk (e.g. deleted mid-session).""" + + def __init__( + self, + *, + path: Path, + context: Mapping[str, object] | None = None, + cause: BaseException | None = None, + ) -> None: + super().__init__( + message=f"workspace root not found: {path}", + error_code=ErrorCode.WORKSPACE_ROOT_NOT_FOUND, + op="exec", + context={"path": str(path), **_as_context(context)}, + cause=cause, + ) + + +class LocalArtifactError(ArtifactError): + """Base class for errors while reading local artifacts.""" + + +class LocalFileReadError(LocalArtifactError): + """Failed to read a local file artifact from disk.""" + + def __init__( + self, + *, + src: Path, + context: Mapping[str, object] | None = None, + cause: BaseException | None = None, + ) -> None: + super().__init__( + message=f"failed to read local file artifact: {src}", + error_code=ErrorCode.LOCAL_FILE_READ_ERROR, + op="materialize", + context={"src": str(src), **_as_context(context)}, + cause=cause, + ) + + +class LocalDirReadError(LocalArtifactError): + """Failed to read a local directory artifact from disk.""" + + def __init__( + self, + *, + src: Path, + context: Mapping[str, object] | None = None, + cause: BaseException | None = None, + ) -> None: + super().__init__( + message=f"failed to read local dir artifact: {src}", + error_code=ErrorCode.LOCAL_DIR_READ_ERROR, + op="materialize", + context={"src": str(src), **_as_context(context)}, + cause=cause, + ) + + +class LocalChecksumError(LocalArtifactError): + """Failed to compute a checksum for a local artifact.""" + + def __init__( + self, + *, + src: Path, + context: Mapping[str, object] | None = None, + cause: BaseException | None = None, + ) -> None: + super().__init__( + message=f"failed to checksum local artifact: {src}", + error_code=ErrorCode.LOCAL_CHECKSUM_ERROR, + op="materialize", + context={"src": str(src), **_as_context(context)}, + cause=cause, + ) + + +class GitArtifactError(ArtifactError): + """Base class for errors while materializing git_repo artifacts.""" + + +class GitMissingInImageError(GitArtifactError): + """Container image is missing git, so git_repo artifacts cannot be materialized.""" + + def __init__( + self, + *, + context: Mapping[str, object] | None = None, + cause: BaseException | None = None, + ) -> None: + super().__init__( + message="git is required in the container image to materialize git_repo artifacts", + error_code=ErrorCode.GIT_MISSING_IN_IMAGE, + op="materialize", + context=_as_context(context), + cause=cause, + ) + + +class GitCloneError(GitArtifactError): + """Failed to clone a git repository while materializing an artifact.""" + + def __init__( + self, + *, + url: str, + ref: str, + stderr: str | None = None, + context: Mapping[str, object] | None = None, + cause: BaseException | None = None, + ) -> None: + super().__init__( + message=f"git clone failed for {url}@{ref}", + error_code=ErrorCode.GIT_CLONE_ERROR, + op="materialize", + context={"url": url, "ref": ref, "stderr": stderr, **_as_context(context)}, + cause=cause, + ) + + +class GitCopyError(GitArtifactError): + """Failed to copy files from a cloned repo into the workspace.""" + + def __init__( + self, + *, + src_root: str, + dest: Path, + stderr: str | None = None, + context: Mapping[str, object] | None = None, + cause: BaseException | None = None, + ) -> None: + super().__init__( + message="copy from git repo failed", + error_code=ErrorCode.GIT_COPY_ERROR, + op="materialize", + context={ + "src_root": src_root, + "dest": str(dest), + "stderr": stderr, + **_as_context(context), + }, + cause=cause, + ) + + +class MountArtifactError(ArtifactError): + """Base class for mount-related errors while materializing artifacts.""" + + +class MountToolMissingError(MountArtifactError): + """Required mount tool is missing in the sandbox.""" + + def __init__( + self, + *, + tool: str, + context: Mapping[str, object] | None = None, + cause: BaseException | None = None, + ) -> None: + super().__init__( + message=f"required mount tool missing: {tool}", + error_code=ErrorCode.MOUNT_MISSING_TOOL, + op="materialize", + context={"tool": tool, **_as_context(context)}, + cause=cause, + ) + + +class MountConfigError(MountArtifactError): + """Mount configuration was invalid or incomplete.""" + + def __init__( + self, + *, + message: str, + context: Mapping[str, object] | None = None, + ) -> None: + super().__init__( + message=message, + error_code=ErrorCode.MOUNT_CONFIG_INVALID, + op="materialize", + context=_as_context(context), + ) + + +class MountCommandError(MountArtifactError): + """Mount command failed to execute successfully.""" + + def __init__( + self, + *, + command: str, + stderr: str | None, + context: Mapping[str, object] | None = None, + cause: BaseException | None = None, + ) -> None: + super().__init__( + message="mount command failed", + error_code=ErrorCode.MOUNT_FAILED, + op="materialize", + context={"command": command, "stderr": stderr, **_as_context(context)}, + cause=cause, + ) + + +class SkillsConfigError(ConfigurationError): + """Skills capability configuration was invalid.""" + + def __init__( + self, + *, + message: str, + context: Mapping[str, object] | None = None, + cause: BaseException | None = None, + ) -> None: + super().__init__( + message=message, + error_code=ErrorCode.SKILLS_CONFIG_INVALID, + op="materialize", + context=_as_context(context), + cause=cause, + ) + + +class SnapshotPersistError(SnapshotError): + """Failed to persist snapshot bytes to durable storage.""" + + def __init__( + self, + *, + snapshot_id: str, + path: Path, + context: Mapping[str, object] | None = None, + cause: BaseException | None = None, + ) -> None: + super().__init__( + message="failed to persist snapshot", + error_code=ErrorCode.SNAPSHOT_PERSIST_ERROR, + op="snapshot_persist", + context={"snapshot_id": snapshot_id, "path": str(path), **_as_context(context)}, + cause=cause, + ) + + +class SnapshotRestoreError(SnapshotError): + """Failed to restore snapshot bytes from durable storage.""" + + def __init__( + self, + *, + snapshot_id: str, + path: Path, + context: Mapping[str, object] | None = None, + cause: BaseException | None = None, + ) -> None: + super().__init__( + message="failed to restore snapshot", + error_code=ErrorCode.SNAPSHOT_RESTORE_ERROR, + op="snapshot_restore", + context={"snapshot_id": snapshot_id, "path": str(path), **_as_context(context)}, + cause=cause, + ) + + +class SnapshotNotRestorableError(SnapshotError): + """Snapshot cannot be restored because the underlying storage is missing.""" + + def __init__( + self, + *, + snapshot_id: str, + path: Path, + context: Mapping[str, object] | None = None, + ) -> None: + super().__init__( + message="snapshot is not restorable", + error_code=ErrorCode.SNAPSHOT_NOT_RESTORABLE, + op="snapshot_restore", + context={"snapshot_id": snapshot_id, "path": str(path), **_as_context(context)}, + ) diff --git a/src/agents/sandbox/files.py b/src/agents/sandbox/files.py new file mode 100644 index 0000000000..a49b1d346e --- /dev/null +++ b/src/agents/sandbox/files.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from enum import Enum + +from pydantic import BaseModel + +from .types import Permissions + + +class EntryKind(str, Enum): + DIRECTORY = "directory" + FILE = "file" + SYMLINK = "symlink" + OTHER = "other" + + +class FileEntry(BaseModel): + path: str + permissions: Permissions + owner: str + group: str + size: int + kind: EntryKind = EntryKind.FILE + + def is_dir(self) -> bool: + return self.kind == EntryKind.DIRECTORY diff --git a/src/agents/sandbox/manifest.py b/src/agents/sandbox/manifest.py new file mode 100644 index 0000000000..398e2b47a3 --- /dev/null +++ b/src/agents/sandbox/manifest.py @@ -0,0 +1,202 @@ +import abc +import asyncio +from collections.abc import Iterator, Mapping +from pathlib import Path +from typing import Literal + +from pydantic import BaseModel, Field, field_serializer, field_validator +from typing_extensions import assert_never + +from .entries import BaseEntry, Dir, Mount, resolve_workspace_path +from .errors import InvalidManifestPathError +from .manifest_render import render_manifest_description +from .types import Group, User + + +# TODO (sdcoffey) env val from secret store +class EnvValue(BaseModel, abc.ABC): + @abc.abstractmethod + async def resolve(self) -> str: ... + + +class StrEnvValue(EnvValue): + value: str + + async def resolve(self) -> str: + return self.value + + +class EnvEntry(BaseModel): + description: str | None = None + ephemeral: bool = Field(default=False) + value: EnvValue + + +class Environment(BaseModel): + value: dict[str, str | EnvValue | EnvEntry] = Field(default_factory=dict) + + def normalized(self) -> dict[str, EnvEntry]: + result: dict[str, EnvEntry] = {} + for key, value in self.value.items(): + match value: + case str(): + result[key] = EnvEntry(value=StrEnvValue(value=value)) + case EnvValue(): + result[key] = EnvEntry(value=value) + case EnvEntry(): + result[key] = value + case _: + assert_never(value) + + return result + + async def resolve(self) -> dict[str, str]: + normalized = self.normalized() + keys = normalized.keys() + values = await asyncio.gather(*[normalized[key].value.resolve() for key in keys]) + return dict(zip(keys, values)) + + +class Manifest(BaseModel): + version: Literal[1] = 1 + root: str = Field(default="/workspace") + entries: dict[str | Path, BaseEntry] = Field(default_factory=dict) + environment: Environment = Field(default_factory=Environment) + users: list[User] = Field(default_factory=list) + groups: list[Group] = Field(default_factory=list) + + @field_validator("entries", mode="before") + @classmethod + def _parse_entries(cls, value: object) -> dict[str | Path, BaseEntry]: + if value is None: + return {} + if not isinstance(value, Mapping): + raise TypeError(f"Artifact mapping must be a mapping, got {type(value).__name__}") + return {key: BaseEntry.parse(entry) for key, entry in value.items()} + + @field_serializer("entries", when_used="json") + def _serialize_entries(self, entries: Mapping[str | Path, BaseEntry]) -> dict[str, object]: + out: dict[str, object] = {} + for key, entry in entries.items(): + key_str = key.as_posix() if isinstance(key, Path) else str(key) + out[key_str] = entry.model_dump(mode="json") + return out + + def validated_entries(self) -> dict[str | Path, BaseEntry]: + validated: dict[str | Path, BaseEntry] = dict(self.entries) + for _path, _artifact in self.iter_entries(): + pass + return validated + + def ephemeral_entry_paths(self, depth: int | None = 1) -> set[Path]: + _ = depth + return {path for path, artifact in self.iter_entries() if artifact.ephemeral} + + def ephemeral_mount_targets(self) -> list[tuple[Mount, Path]]: + root = Path(self.root) + mounts: list[tuple[Mount, Path]] = [] + for rel_path, artifact in self.iter_entries(): + if not isinstance(artifact, Mount) or not artifact.ephemeral: + continue + dest = resolve_workspace_path(root, rel_path) + mount_path = artifact._resolve_mount_path_for_root(root, dest) + normalized_mount_path = self._normalize_in_workspace_path(root, mount_path) + if normalized_mount_path is not None: + mount_path = normalized_mount_path + mounts.append((artifact, mount_path)) + mounts.sort(key=lambda item: len(item[1].parts), reverse=True) + return mounts + + def ephemeral_persistence_paths(self, depth: int | None = 1) -> set[Path]: + _ = depth + root = Path(self.root) + skip = self.ephemeral_entry_paths(depth=depth) + for _mount, mount_path in self.ephemeral_mount_targets(): + try: + rel_mount_path = mount_path.relative_to(root) + except ValueError: + continue + if rel_mount_path.parts: + skip.add(rel_mount_path) + return skip + + @staticmethod + def _coerce_rel_path(path: str | Path) -> Path: + return path if isinstance(path, Path) else Path(path) + + @staticmethod + def _validate_rel_path(rel: Path) -> None: + if rel.is_absolute(): + raise InvalidManifestPathError(rel=rel, reason="absolute") + if ".." in rel.parts: + raise InvalidManifestPathError(rel=rel, reason="escape_root") + + @staticmethod + def _normalize_rel_path_within_root(rel: Path, *, original: Path) -> Path: + if rel.is_absolute(): + raise InvalidManifestPathError(rel=original, reason="absolute") + + normalized_parts: list[str] = [] + for part in rel.parts: + if part in ("", "."): + continue + if part == "..": + if not normalized_parts: + raise InvalidManifestPathError(rel=original, reason="escape_root") + normalized_parts.pop() + continue + normalized_parts.append(part) + + return Path(*normalized_parts) + + @classmethod + def _normalize_in_workspace_path(cls, root: Path, path: Path) -> Path | None: + if not path.is_absolute(): + normalized_rel = cls._normalize_rel_path_within_root(path, original=path) + return root / normalized_rel if normalized_rel.parts else root + + try: + rel_path = path.relative_to(root) + except ValueError: + return None + + normalized_rel = cls._normalize_rel_path_within_root(rel_path, original=path) + return root / normalized_rel if normalized_rel.parts else root + + def iter_entries(self) -> Iterator[tuple[Path, BaseEntry]]: + stack = [ + (self._coerce_rel_path(path), artifact) + for path, artifact in reversed(list(self.entries.items())) + ] + while stack: + rel_path, artifact = stack.pop() + self._validate_rel_path(rel_path) + yield rel_path, artifact + if not isinstance(artifact, Dir): + continue + + for child_name, child_artifact in reversed(list(artifact.children.items())): + child_rel_path = rel_path / self._coerce_rel_path(child_name) + stack.append((child_rel_path, child_artifact)) + + def describe(self, depth: int | None = 1) -> str: + """ + print a nice fs representation of things inside root with inline descriptions + depth controls how deep the tree is rendered; None renders all levels + eg: + + /workspace (root) + ├── repo/ # /workspace/repo — my repo + │ └── README.md # /workspace/repo/README.md + ├── data/ # /workspace/data + │ └── config.json # /workspace/data/config.json — config + ├── mount-data/ # /workspace/mount-data (mount) + └── notes.txt # /workspace/notes.txt + ... + """ + return render_manifest_description( + root=self.root, + entries=self.validated_entries(), + coerce_rel_path=self._coerce_rel_path, + depth=depth, + ) diff --git a/src/agents/sandbox/manifest_render.py b/src/agents/sandbox/manifest_render.py new file mode 100644 index 0000000000..d33ce320b1 --- /dev/null +++ b/src/agents/sandbox/manifest_render.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Callable + +from .entries import BaseEntry, Dir, Mount + + +def render_manifest_description( + *, + root: str, + entries: dict[str | Path, BaseEntry], + coerce_rel_path: Callable[[str | Path], Path], + depth: int | None = 1, +) -> str: + if depth is not None and depth <= 0: + raise ValueError("depth must be a non-zero positive integer or None") + + root = root.rstrip("/") or "/" + root_path = Path(root) + + def _mount_full_path(entry: str | Path, artifact: Mount) -> Path: + if artifact.mount_path is not None: + mount_path = Path(artifact.mount_path) + return mount_path if mount_path.is_absolute() else root_path / mount_path + return root_path / coerce_rel_path(entry) + + class _Node: + def __init__(self) -> None: + self.children: dict[str, _Node] = {} + self.description: str | None = None + self.is_dir: bool = False + self.full_path: Path | None = None + + def _path_parts(path: Path) -> tuple[str, ...]: + parts = [part for part in path.parts if part not in {"", "."}] + return tuple(parts) + + root_node = _Node() + + def _insert_path( + path: Path, + *, + description: str | None, + is_dir: bool, + full_path: Path | None = None, + max_depth: int | None = None, + ) -> None: + parts = _path_parts(path) + if not parts: + return + node = root_node + limit = len(parts) if max_depth is None else min(len(parts), max_depth) + for index, part in enumerate(parts[:limit]): + node = node.children.setdefault(part, _Node()) + if index < len(parts) - 1: + node.is_dir = True + if node.description is None and description is not None and limit == len(parts): + node.description = description + if full_path is not None and limit == len(parts): + node.full_path = full_path + if is_dir or limit < len(parts): + node.is_dir = True + + def _insert_entry_tree( + path: Path, + artifact: BaseEntry, + *, + full_path: Path | None = None, + ) -> None: + stack: list[tuple[Path, BaseEntry, Path | None]] = [(path, artifact, full_path)] + while stack: + current_path, current_artifact, current_full_path = stack.pop() + _insert_path( + current_path, + description=current_artifact.description, + is_dir=current_artifact.permissions.directory, + full_path=current_full_path, + max_depth=depth, + ) + if not isinstance(current_artifact, Dir): + continue + if depth is not None and len(_path_parts(current_path)) >= depth: + continue + + for child_name, child_artifact in current_artifact.children.items(): + child_rel_path = coerce_rel_path(child_name) + child_path = current_path / child_rel_path + child_full_path = ( + current_full_path / child_rel_path if current_full_path is not None else None + ) + stack.append((child_path, child_artifact, child_full_path)) + + for entry, artifact in entries.items(): + path = coerce_rel_path(entry) + if path.is_absolute(): + path = path.relative_to(path.anchor) + full_path = _mount_full_path(entry, artifact) if isinstance(artifact, Mount) else None + _insert_entry_tree(path, artifact, full_path=full_path) + + def _collect( + node: _Node, + prefix: str, + remaining: int | None, + rel_parts: tuple[str, ...], + ) -> list[tuple[str, str, str, str | None]]: + lines: list[tuple[str, str, str, str | None]] = [] + stack: list[tuple[str, _Node, str, int | None, tuple[str, ...]]] + stack = [("children", node, prefix, remaining, rel_parts)] + while stack: + action, current_node, current_prefix, current_remaining, current_rel_parts = stack.pop() + if action == "line": + child = current_node + name = current_rel_parts[-1] + child_is_dir = child.is_dir or bool(child.children) + display_name = f"{name}/" if child_is_dir else name + if child.full_path is not None: + full_path = str(child.full_path) + else: + full_path = str(root_path / Path(*current_rel_parts)) + lines.append((current_prefix, display_name, full_path, child.description)) + continue + + if current_remaining is not None and current_remaining <= 0: + continue + + names = sorted(current_node.children) + next_remaining = None if current_remaining is None else current_remaining - 1 + for index in range(len(names) - 1, -1, -1): + name = names[index] + child = current_node.children[name] + is_last = index == len(names) - 1 + connector = "└── " if is_last else "├── " + child_parts = current_rel_parts + (name,) + if next_remaining is None or next_remaining > 0: + extension = " " if is_last else "│ " + stack.append( + ( + "children", + child, + current_prefix + extension, + next_remaining, + child_parts, + ) + ) + stack.append( + ("line", child, current_prefix + connector, next_remaining, child_parts) + ) + return lines + + lines: list[str] = [root] + collected = _collect(root_node, "", depth, ()) + if collected: + max_width = max(len(prefix + name) for prefix, name, _, _ in collected) + for prefix, name, full_path_str, description in collected: + spacer = " " * (max_width - len(prefix + name) + 2) + if description: + comment = f"# {full_path_str} — {description}" + else: + comment = f"# {full_path_str}" + lines.append(f"{prefix}{name}{spacer}{comment}") + + return "\n".join(lines) + "\n" diff --git a/src/agents/sandbox/materialization.py b/src/agents/sandbox/materialization.py new file mode 100644 index 0000000000..25528db690 --- /dev/null +++ b/src/agents/sandbox/materialization.py @@ -0,0 +1,68 @@ +import asyncio +from collections.abc import Awaitable, Callable, Sequence +from dataclasses import dataclass +from pathlib import Path +from typing import TypeVar, cast + + +@dataclass(frozen=True) +class MaterializedFile: + path: Path + sha256: str + + +@dataclass(frozen=True) +class MaterializationResult: + files: list[MaterializedFile] + + +_TaskResultT = TypeVar("_TaskResultT") +_MISSING = object() + + +async def gather_in_order( + task_factories: Sequence[Callable[[], Awaitable[_TaskResultT]]], +) -> list[_TaskResultT]: + if not task_factories: + return [] + + results: list[_TaskResultT | object] = [_MISSING] * len(task_factories) + + async def _run(index: int, factory: Callable[[], Awaitable[_TaskResultT]]) -> None: + results[index] = await factory() + + tasks = [ + asyncio.create_task(_run(index, factory)) for index, factory in enumerate(task_factories) + ] + try: + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION) + + first_error: BaseException | None = None + for task in done: + try: + task.result() + except asyncio.CancelledError: + continue + except BaseException as error: + first_error = error + break + + if first_error is not None: + for task in pending: + task.cancel() + await asyncio.gather(*pending, return_exceptions=True) + raise first_error + + if pending: + await asyncio.gather(*pending) + except BaseException: + for task in tasks: + if not task.done(): + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise + + for task in tasks: + task.result() + + return [cast(_TaskResultT, result) for result in results] diff --git a/src/agents/sandbox/py.typed b/src/agents/sandbox/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/agents/sandbox/runtime.py b/src/agents/sandbox/runtime.py new file mode 100644 index 0000000000..4a44bec4e8 --- /dev/null +++ b/src/agents/sandbox/runtime.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Generic + +from ..agent import Agent +from ..exceptions import UserError +from ..items import TResponseInputItem +from ..result import RunResult, RunResultStreaming +from ..run_config import RunConfig +from ..run_context import RunContextWrapper, TContext +from ..run_internal.agent_bindings import ( + AgentBindings, + bind_execution_agent, + bind_public_agent, +) +from ..run_state import RunState +from .runtime_agent_preparation import clone_capabilities, prepare_sandbox_agent +from .runtime_session_manager import SandboxRuntimeSessionManager +from .sandbox_agent import SandboxAgent +from .session.base_sandbox_session import BaseSandboxSession + + +@dataclass +class _SandboxPreparedAgent(Generic[TContext]): + bindings: AgentBindings[TContext] + input: str | list[TResponseInputItem] + + +class SandboxRuntime(Generic[TContext]): + def __init__( + self, + *, + starting_agent: Agent[TContext], + run_config: RunConfig | None, + run_state: RunState[TContext] | None, + ) -> None: + self._sandbox_config = run_config.sandbox if run_config is not None else None + self._session_manager = SandboxRuntimeSessionManager( + starting_agent=starting_agent, + sandbox_config=self._sandbox_config, + run_state=run_state, + ) + self._prepared_agents: dict[int, Agent[TContext]] = {} + self._prepared_sessions: dict[int, BaseSandboxSession] = {} + + @property + def enabled(self) -> bool: + return self._session_manager.enabled + + @property + def current_session(self) -> BaseSandboxSession | None: + return self._session_manager.current_session + + def apply_result_metadata(self, result: RunResult | RunResultStreaming) -> None: + session = self.current_session + result._sandbox_session = session + if isinstance(result, RunResultStreaming): + + async def _cleanup_and_store() -> None: + try: + payload = await self.cleanup() + result._sandbox_resume_state = payload + finally: + result._sandbox_session = None + + result._sandbox_cleanup = _cleanup_and_store + + def assert_agent_supported(self, agent: Agent[TContext]) -> None: + if isinstance(agent, SandboxAgent) and self._sandbox_config is None: + raise UserError("SandboxAgent execution requires `RunConfig(sandbox=...)`") + + async def prepare_agent( + self, + *, + current_agent: Agent[TContext], + current_input: str | list[TResponseInputItem], + context_wrapper: RunContextWrapper[TContext], + is_resumed_state: bool, + ) -> _SandboxPreparedAgent[TContext]: + self.assert_agent_supported(current_agent) + if not isinstance(current_agent, SandboxAgent): + return _SandboxPreparedAgent( + bindings=bind_public_agent(current_agent), + input=current_input, + ) + + self._session_manager.acquire_agent(current_agent) + prepared_agent = self._prepared_agents.get(id(current_agent)) + prepared_capabilities = clone_capabilities(current_agent.capabilities) + session = await self._session_manager.ensure_session( + agent=current_agent, + capabilities=prepared_capabilities, + is_resumed_state=is_resumed_state, + ) + if prepared_agent is not None and self._prepared_sessions.get(id(current_agent)) is session: + return _SandboxPreparedAgent( + bindings=bind_execution_agent( + public_agent=current_agent, + execution_agent=prepared_agent, + ), + input=current_input, + ) + + prepared_agent = prepare_sandbox_agent( + agent=current_agent, + session=session, + capabilities=prepared_capabilities, + ) + self._prepared_agents[id(current_agent)] = prepared_agent + self._prepared_sessions[id(current_agent)] = session + return _SandboxPreparedAgent( + bindings=bind_execution_agent( + public_agent=current_agent, + execution_agent=prepared_agent, + ), + input=current_input, + ) + + async def cleanup(self) -> dict[str, object] | None: + try: + return await self._session_manager.cleanup() + finally: + self._prepared_agents.clear() + self._prepared_sessions.clear() diff --git a/src/agents/sandbox/runtime_agent_preparation.py b/src/agents/sandbox/runtime_agent_preparation.py new file mode 100644 index 0000000000..281f792c78 --- /dev/null +++ b/src/agents/sandbox/runtime_agent_preparation.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import inspect +from collections.abc import Awaitable, Callable +from typing import cast + +from .._public_agent import get_public_agent, set_public_agent +from ..agent import Agent +from ..run_context import RunContextWrapper, TContext +from .capabilities import Capability +from .manifest import Manifest +from .sandbox_agent import SandboxAgent +from .session.base_sandbox_session import BaseSandboxSession + + +def clone_capabilities(capabilities: list[Capability]) -> list[Capability]: + return [capability.clone() for capability in capabilities] + + +def prepare_sandbox_agent( + *, + agent: SandboxAgent[TContext], + session: BaseSandboxSession, + capabilities: list[Capability], +) -> Agent[TContext]: + manifest = session.state.manifest + for capability in capabilities: + capability.bind(session) + + capability_tools = [tool for capability in capabilities for tool in capability.tools()] + prepared_agent = agent.clone( + instructions=build_sandbox_instructions( + agent.instructions, + agent.developer_instructions, + capabilities, + manifest, + ), + tools=[*agent.tools, *capability_tools], + capabilities=capabilities, + ) + set_public_agent(prepared_agent, agent) + return prepared_agent + + +def build_sandbox_instructions( + base_instructions: str + | Callable[[RunContextWrapper[TContext], Agent[TContext]], Awaitable[str | None] | str | None] + | None, + developer_instructions: str | None, + capabilities: list[Capability], + manifest: Manifest | None, +) -> Callable[[RunContextWrapper[TContext], Agent[TContext]], Awaitable[str | None]]: + async def _instructions( + run_context: RunContextWrapper[TContext], + current_agent: Agent[TContext], + ) -> str | None: + parts: list[str] = [] + public_agent = cast(Agent[TContext], get_public_agent(current_agent)) + + base = await resolve_base_instructions( + instructions=base_instructions, + run_context=run_context, + agent=public_agent, + ) + if base: + parts.append(base) + + if developer_instructions: + parts.append(developer_instructions) + + if manifest is not None: + for capability in capabilities: + fragment = await capability.instructions(manifest) + if fragment: + parts.append(fragment) + + return "\n\n".join(parts) if parts else None + + return _instructions + + +async def resolve_base_instructions( + *, + instructions: str + | Callable[[RunContextWrapper[TContext], Agent[TContext]], Awaitable[str | None] | str | None] + | None, + run_context: RunContextWrapper[TContext], + agent: Agent[TContext], +) -> str | None: + if isinstance(instructions, str): + return instructions + if callable(instructions): + result = instructions(run_context, agent) + if inspect.isawaitable(result): + return await result + return result + return None diff --git a/src/agents/sandbox/runtime_session_manager.py b/src/agents/sandbox/runtime_session_manager.py new file mode 100644 index 0000000000..096c5fc54d --- /dev/null +++ b/src/agents/sandbox/runtime_session_manager.py @@ -0,0 +1,851 @@ +from __future__ import annotations + +import asyncio +import copy +import threading +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Generic, cast + +from ..agent import Agent +from ..exceptions import UserError +from ..run_config import SandboxRunConfig +from ..run_context import TContext +from ..run_state import ( + RunState, + _allocate_unique_agent_identity, + _build_agent_identity_keys_by_id, +) +from .capabilities import Capability +from .codex_config import manifest_has_codex_entry +from .entries import BaseEntry, Dir, Mount, resolve_workspace_path +from .manifest import Manifest +from .sandbox_agent import SandboxAgent +from .session.base_sandbox_session import BaseSandboxSession +from .session.sandbox_client import BaseSandboxClient +from .session.sandbox_session import SandboxSession +from .session.sandbox_session_state import SandboxSessionState +from .snapshot import NoopSnapshotSpec, SnapshotSpec +from .snapshot_defaults import resolve_default_local_snapshot_spec + + +class _SandboxSessionResources: + def __init__( + self, + *, + session: BaseSandboxSession, + client: BaseSandboxClient[Any] | None, + owns_session: bool, + ) -> None: + self._session = session + self._client = client + self._owns_session = owns_session + self._cleanup_lock = asyncio.Lock() + self._cleaned = False + self._started = False + + @property + def session(self) -> BaseSandboxSession: + return self._session + + @property + def state(self) -> SandboxSessionState: + return self._session.state + + async def ensure_started(self) -> None: + if self._started and await self._session.running(): + return + if not self._owns_session and await self._session.running(): + self._started = True + return + await self._session.start() + self._started = True + + async def cleanup(self) -> None: + if not self._owns_session: + return + async with self._cleanup_lock: + if self._cleaned: + return + self._cleaned = True + + cleanup_error: BaseException | None = None + try: + await self._session.stop() + except BaseException as exc: # pragma: no cover + cleanup_error = exc + try: + await self._session.shutdown() + except BaseException as exc: # pragma: no cover + if cleanup_error is None: + cleanup_error = exc + finally: + try: + if self._client is not None and isinstance(self._session, SandboxSession): + await self._client.delete(self._session) + except BaseException as exc: # pragma: no cover + if cleanup_error is None: + cleanup_error = exc + finally: + try: + await self._session._aclose_dependencies() + except BaseException as exc: # pragma: no cover + if cleanup_error is None: + cleanup_error = exc + if cleanup_error is not None: + raise cleanup_error + + +@dataclass +class _SandboxConcurrencyGuard: + lock: threading.Lock = field(default_factory=threading.Lock) + active_runs: int = 0 + + +@dataclass(frozen=True) +class _LiveSessionManifestUpdate: + processed_manifest: Manifest | None + entries_to_apply: list[tuple[Path, BaseEntry]] + + +class SandboxRuntimeSessionManager(Generic[TContext]): + def __init__( + self, + *, + starting_agent: Agent[TContext], + sandbox_config: SandboxRunConfig | None, + run_state: RunState[TContext] | None, + ) -> None: + self._sandbox_config = sandbox_config + self._run_state = run_state + resume_identity_root = starting_agent + if ( + run_state is not None + and run_state._starting_agent is not None + and run_state._current_agent is not None + and run_state._starting_agent is not run_state._current_agent + ): + resume_identity_root = run_state._starting_agent + self._stable_resume_keys_by_agent_id = _build_agent_identity_keys_by_id( + resume_identity_root + ) + self._resources_by_agent: dict[int, _SandboxSessionResources] = {} + self._current_agent_id: int | None = None + self._acquired_agents: dict[int, SandboxAgent[TContext]] = {} + self._resume_keys_by_agent_id: dict[int, str] = {} + self._resume_source_key_by_agent_id: dict[int, str] = {} + self._available_resumed_keys_by_name: dict[str, list[str]] | None = None + self._claimed_resumed_keys: set[str] = set() + + @staticmethod + def _resume_agent_base_key(agent: Agent[Any]) -> str: + return agent.name + + @staticmethod + def _serialize_session_entry( + *, + agent: Agent[Any], + session_state: dict[str, object], + ) -> dict[str, object]: + return { + "agent_name": agent.name, + "session_state": session_state, + } + + @property + def enabled(self) -> bool: + return self._sandbox_config is not None + + @property + def current_session(self) -> BaseSandboxSession | None: + if self._current_agent_id is None: + return None + resources = self._resources_by_agent.get(self._current_agent_id) + if resources is None: + return None + return resources.session + + def acquire_agent(self, agent: SandboxAgent[TContext]) -> None: + agent_id = id(agent) + if agent_id in self._acquired_agents: + return + + guard = getattr(agent, "_sandbox_concurrency_guard", None) + if guard is None: + guard = _SandboxConcurrencyGuard() + agent._sandbox_concurrency_guard = guard + with guard.lock: + if guard.active_runs > 0: + raise RuntimeError( + f"SandboxAgent {agent.name!r} cannot be reused concurrently across runs" + ) + guard.active_runs += 1 + self._acquired_agents[agent_id] = agent + self._ensure_resume_key(agent) + + async def ensure_session( + self, + *, + agent: SandboxAgent[TContext], + capabilities: list[Capability], + is_resumed_state: bool, + ) -> BaseSandboxSession: + agent_id = id(agent) + resources = self._resources_by_agent.get(agent_id) + if resources is None: + resources = await self._create_resources( + agent=agent, + capabilities=capabilities, + is_resumed_state=is_resumed_state, + ) + self._resources_by_agent[agent_id] = resources + self._current_agent_id = agent_id + + await resources.ensure_started() + return resources.session + + def serialize_resume_state(self) -> dict[str, object] | None: + existing_payload = ( + copy.deepcopy(self._run_state._sandbox) + if self._run_state is not None and isinstance(self._run_state._sandbox, dict) + else None + ) + if self._sandbox_config is None: + return existing_payload + if self._current_agent_id is None: + return existing_payload + if self._sandbox_config.client is None: + return existing_payload + resources = self._resources_by_agent.get(self._current_agent_id) + if resources is None: + return existing_payload + + client = self._resolve_client() + current_agent = self._acquired_agents.get(self._current_agent_id) + if current_agent is None: + return existing_payload + + sessions_by_agent = self._serialize_sessions_by_agent(client) + return { + "backend_id": client.backend_id, + "current_agent_key": self._ensure_resume_key(current_agent), + "current_agent_name": current_agent.name, + "session_state": client.serialize_session_state(resources.state), + "sessions_by_agent": sessions_by_agent, + } + + async def cleanup(self) -> dict[str, object] | None: + cleanup_error: BaseException | None = None + resume_state: dict[str, object] | None = None + try: + for resources in list(self._resources_by_agent.values()): + try: + await resources.cleanup() + except BaseException as exc: # pragma: no cover + if cleanup_error is None: + cleanup_error = exc + if cleanup_error is None: + resume_state = self.serialize_resume_state() + finally: + self._resources_by_agent.clear() + self._current_agent_id = None + self._release_agents() + if cleanup_error is not None: + raise cleanup_error + return resume_state + + async def _create_resources( + self, + *, + agent: SandboxAgent[TContext], + capabilities: list[Capability], + is_resumed_state: bool, + ) -> _SandboxSessionResources: + sandbox_config = self._require_sandbox_config() + if sandbox_config.session is not None: + self._validate_injected_session(agent=agent, session=sandbox_config.session) + running = await sandbox_config.session.running() + manifest_update = self._process_live_session_manifest( + capabilities=capabilities, + session=sandbox_config.session, + running=running, + ) + if manifest_update.entries_to_apply: + await sandbox_config.session._apply_entry_batch( + manifest_update.entries_to_apply, + base_dir=sandbox_config.session._manifest_base_dir(), + ) + if manifest_update.processed_manifest is not None: + sandbox_config.session.state = sandbox_config.session.state.model_copy( + update={"manifest": manifest_update.processed_manifest} + ) + return _SandboxSessionResources( + session=sandbox_config.session, + client=None, + owns_session=False, + ) + + client = self._resolve_client() + explicit_state = sandbox_config.session_state + resume_from_run_state = False + resumed_payload = self._resume_state_payload_for_agent( + client=client, + agent=agent, + agent_id=id(agent), + ) + if resumed_payload is not None: + explicit_state = client.deserialize_session_state(resumed_payload) + resume_from_run_state = True + + if explicit_state is not None: + explicit_state = self._process_resumed_state_manifest( + capabilities=capabilities, + session_state=explicit_state, + ) + return _SandboxSessionResources( + session=await client.resume(explicit_state, codex=agent.codex), + client=client, + owns_session=True, + ) + + effective_manifest = self._resolve_manifest( + agent=agent, + resume_from_run_state=resume_from_run_state, + ) + if effective_manifest is not None: + effective_manifest = self._process_manifest(capabilities, effective_manifest) + + options = sandbox_config.options + if options is None and not client.supports_default_options: + raise ValueError( + "Sandbox execution requires `run_config.sandbox.options` when creating a session" + ) + + session = await client.create( + snapshot=self._resolve_snapshot_spec(sandbox_config.snapshot), + manifest=effective_manifest, + codex=agent.codex, + options=options, + ) + return _SandboxSessionResources(session=session, client=client, owns_session=True) + + def _resume_state_payload_for_agent( + self, + *, + client: BaseSandboxClient[Any], + agent: SandboxAgent[TContext], + agent_id: int, + ) -> dict[str, object] | None: + if self._run_state is None or self._run_state._sandbox is None: + return None + + resumed = self._run_state._sandbox + backend_id = resumed.get("backend_id") + if backend_id != client.backend_id: + raise ValueError( + "RunState sandbox backend does not match the configured sandbox client" + ) + + sessions_by_agent = resumed.get("sessions_by_agent") + if isinstance(sessions_by_agent, dict): + resume_key = self._assign_resumed_agent_key(agent) + if resume_key is not None: + payload = self._session_payload_from_entry(sessions_by_agent.get(resume_key)) + if payload is not None: + self._remember_resume_source_key(agent_id, resume_key) + return payload + + payload = self._session_payload_from_entry(sessions_by_agent.get(str(agent_id))) + if payload is not None: + self._remember_resume_source_key(agent_id, str(agent_id)) + return payload + + current_agent_key = resumed.get("current_agent_key") + current_agent_name = resumed.get("current_agent_name") + current_agent_id = resumed.get("current_agent_id") + payload = resumed.get("session_state") + if payload is None: + return None + if not isinstance(payload, dict): + raise ValueError("RunState sandbox payload is missing `session_state`") + if isinstance(current_agent_key, str): + resume_key = self._assign_resumed_agent_key(agent) + if resume_key != current_agent_key: + return None + self._remember_resume_source_key(agent_id, current_agent_key) + return payload + if current_agent_name is None and self._run_state._current_agent is not None: + current_agent_name = self._run_state._current_agent.name + if isinstance(current_agent_name, str): + if current_agent_name != self._resume_agent_base_key(agent): + return None + self._remember_resume_source_key(agent_id, current_agent_name) + return payload + if current_agent_id is None or current_agent_id == agent_id: + if current_agent_id is not None: + self._remember_resume_source_key(agent_id, str(current_agent_id)) + return payload + return None + + def _resolve_client(self) -> BaseSandboxClient[Any]: + sandbox_config = self._require_sandbox_config() + if sandbox_config.client is None: + raise ValueError( + "Sandbox execution requires `run_config.sandbox.client` " + "unless a live session is provided" + ) + return sandbox_config.client + + def _require_sandbox_config(self) -> SandboxRunConfig: + if self._sandbox_config is None: + raise ValueError("Sandbox runtime is disabled for this run") + return self._sandbox_config + + @staticmethod + def _resolve_snapshot_spec(snapshot: SnapshotSpec | None) -> SnapshotSpec: + if snapshot is not None: + return snapshot + try: + return resolve_default_local_snapshot_spec() + except OSError: + return NoopSnapshotSpec() + + def _resolve_manifest( + self, + *, + agent: SandboxAgent[TContext], + resume_from_run_state: bool, + ) -> Manifest | None: + sandbox_config = self._require_sandbox_config() + if sandbox_config.session is not None: + return cast(Manifest | None, getattr(sandbox_config.session.state, "manifest", None)) + if sandbox_config.session_state is not None: + return cast(Manifest | None, getattr(sandbox_config.session_state, "manifest", None)) + if resume_from_run_state: + return None + if sandbox_config.manifest is not None: + return sandbox_config.manifest + return agent.default_manifest + + @staticmethod + def _process_manifest( + capabilities: list[Capability], + manifest: Manifest | None, + ) -> Manifest | None: + if manifest is None: + return None + processed_manifest = manifest.model_copy(deep=True) + for capability in capabilities: + processed_manifest = capability.process_manifest(processed_manifest) + return processed_manifest + + @classmethod + def _process_live_session_manifest( + cls, + *, + capabilities: list[Capability], + session: BaseSandboxSession, + running: bool, + ) -> _LiveSessionManifestUpdate: + current_manifest = session.state.manifest + processed_manifest = cls._process_manifest(capabilities, current_manifest) + if processed_manifest is None or processed_manifest == current_manifest: + return _LiveSessionManifestUpdate(processed_manifest=None, entries_to_apply=[]) + + entries_to_apply: list[tuple[Path, BaseEntry]] = [] + if running: + cls._validate_running_live_session_manifest_update( + current_manifest=current_manifest, + processed_manifest=processed_manifest, + ) + entries_to_apply = cls._diff_live_session_entries( + current_entries=current_manifest.entries, + processed_entries=processed_manifest.entries, + ) + entries_to_apply = [ + ( + resolve_workspace_path(Path(processed_manifest.root), rel_path), + artifact, + ) + for rel_path, artifact in entries_to_apply + ] + + return _LiveSessionManifestUpdate( + processed_manifest=processed_manifest, + entries_to_apply=entries_to_apply, + ) + + @classmethod + def _validate_running_live_session_manifest_update( + cls, + *, + current_manifest: Manifest, + processed_manifest: Manifest, + ) -> None: + if processed_manifest.root != current_manifest.root: + raise ValueError( + "Running injected sandbox sessions do not support capability changes to " + "`manifest.root`; use a fresh session or a session_state resume flow." + ) + if processed_manifest.environment != current_manifest.environment: + raise ValueError( + "Running injected sandbox sessions do not support capability changes to " + "`manifest.environment`; use a fresh session or a session_state resume flow." + ) + if ( + processed_manifest.users != current_manifest.users + or processed_manifest.groups != current_manifest.groups + ): + raise ValueError( + "Running injected sandbox sessions do not support capability changes to " + "`manifest.users` or `manifest.groups`; use a fresh session or a " + "session_state resume flow." + ) + + @classmethod + def _diff_live_session_entries( + cls, + *, + current_entries: dict[str | Path, BaseEntry], + processed_entries: dict[str | Path, BaseEntry], + parent_rel: Path = Path(), + ) -> list[tuple[Path, BaseEntry]]: + current_by_name = { + Manifest._coerce_rel_path(name): entry for name, entry in current_entries.items() + } + processed_by_name = { + Manifest._coerce_rel_path(name): entry for name, entry in processed_entries.items() + } + + removed = sorted(current_by_name.keys() - processed_by_name.keys()) + if removed: + removed_paths = ", ".join((parent_rel / rel).as_posix() for rel in removed) + raise ValueError( + "Running injected sandbox sessions do not support removing manifest entries: " + f"{removed_paths}." + ) + + entries_to_apply: list[tuple[Path, BaseEntry]] = [] + for rel_name, processed_entry in processed_by_name.items(): + rel_path = parent_rel / rel_name + current_entry = current_by_name.get(rel_name) + if current_entry is None: + cls._validate_running_live_session_entry_addition( + rel_path=rel_path, + entry=processed_entry, + ) + entries_to_apply.append((rel_path, processed_entry.model_copy(deep=True))) + continue + + delta_entry = cls._diff_live_session_entry( + rel_path=rel_path, + current_entry=current_entry, + processed_entry=processed_entry, + ) + if delta_entry is not None: + entries_to_apply.append((rel_path, delta_entry)) + + return entries_to_apply + + @classmethod + def _diff_live_session_entry( + cls, + *, + rel_path: Path, + current_entry: BaseEntry, + processed_entry: BaseEntry, + ) -> BaseEntry | None: + if current_entry == processed_entry: + return None + + if type(current_entry) is not type(processed_entry) or ( + current_entry.is_dir != processed_entry.is_dir + ): + raise ValueError( + "Running injected sandbox sessions do not support replacing manifest entry " + f"types at {rel_path.as_posix()}; use a fresh session or a session_state " + "resume flow." + ) + + if isinstance(current_entry, Mount): + raise ValueError( + "Running injected sandbox sessions do not support capability changes to mount " + f"entries at {rel_path.as_posix()}; use a fresh session or a session_state " + "resume flow." + ) + + if isinstance(current_entry, Dir) and isinstance(processed_entry, Dir): + changed_children = dict( + cls._diff_live_session_entries( + current_entries=current_entry.children, + processed_entries=processed_entry.children, + parent_rel=Path(), + ) + ) + metadata_changed = current_entry.model_dump( + exclude={"children"} + ) != processed_entry.model_dump(exclude={"children"}) + if not metadata_changed and not changed_children: + return None + return processed_entry.model_copy(update={"children": changed_children}, deep=True) + + return processed_entry.model_copy(deep=True) + + @staticmethod + def _validate_running_live_session_entry_addition( + *, + rel_path: Path, + entry: BaseEntry, + ) -> None: + if SandboxRuntimeSessionManager._entry_contains_mount(entry): + raise ValueError( + "Running injected sandbox sessions do not support capability-added mount " + f"entries at {rel_path.as_posix()}; use a fresh session or a session_state " + "resume flow." + ) + + @staticmethod + def _entry_contains_mount(entry: BaseEntry) -> bool: + if isinstance(entry, Mount): + return True + if isinstance(entry, Dir): + return any( + SandboxRuntimeSessionManager._entry_contains_mount(child) + for child in entry.children.values() + ) + return False + + @classmethod + def _process_resumed_state_manifest( + cls, + *, + capabilities: list[Capability], + session_state: SandboxSessionState, + ) -> SandboxSessionState: + processed_manifest = cls._process_manifest(capabilities, session_state.manifest) + if processed_manifest is None: + return session_state + return session_state.model_copy(update={"manifest": processed_manifest}) + + @staticmethod + def _validate_injected_session( + *, + agent: SandboxAgent[TContext], + session: BaseSandboxSession, + ) -> None: + if manifest_has_codex_entry(session.state.manifest, agent.codex): + return + if not agent.codex: + return + raise UserError( + "Injected sandbox sessions are used as-is and are not auto-provisioned with Codex. " + f"Session for SandboxAgent {agent.name!r} is missing Codex. " + "Create the session with `client.create(..., codex=True)` or set `codex=False` " + "on the SandboxAgent." + ) + + def _release_agents(self) -> None: + if not self._acquired_agents: + return + + released = list(self._acquired_agents.values()) + self._acquired_agents.clear() + self._resume_keys_by_agent_id.clear() + self._resume_source_key_by_agent_id.clear() + self._available_resumed_keys_by_name = None + self._claimed_resumed_keys.clear() + for agent in released: + guard = getattr(agent, "_sandbox_concurrency_guard", None) + if guard is None: + continue + with guard.lock: + guard.active_runs = max(0, guard.active_runs - 1) + + def _ensure_resume_key(self, agent: SandboxAgent[TContext]) -> str: + agent_id = id(agent) + existing = self._resume_keys_by_agent_id.get(agent_id) + if existing is not None: + return existing + + stable_key = self._stable_resume_key_for_agent(agent) + if stable_key is not None and stable_key not in self._used_resume_keys(): + self._resume_keys_by_agent_id[agent_id] = stable_key + return stable_key + + resumed_key = self._assign_resumed_agent_key(agent) + if resumed_key is not None: + return resumed_key + + key = _allocate_unique_agent_identity( + self._resume_agent_base_key(agent), + self._used_resume_keys(), + ) + self._resume_keys_by_agent_id[agent_id] = key + return key + + def _stable_resume_key_for_agent(self, agent: Agent[Any]) -> str | None: + return self._stable_resume_keys_by_agent_id.get(id(agent)) + + def _assign_resumed_agent_key(self, agent: SandboxAgent[TContext]) -> str | None: + agent_id = id(agent) + existing = self._resume_keys_by_agent_id.get(agent_id) + if existing is not None: + return existing + if self._run_state is None or self._run_state._sandbox is None: + return None + + resumed = self._run_state._sandbox + current_key = resumed.get("current_agent_key") + stable_key = self._stable_resume_key_for_agent(agent) + sessions_by_agent = resumed.get("sessions_by_agent") + if ( + isinstance(stable_key, str) + and stable_key not in self._claimed_resumed_keys + and self._entry_matches_agent_name(sessions_by_agent, stable_key, agent.name) + ): + self._claimed_resumed_keys.add(stable_key) + self._resume_keys_by_agent_id[agent_id] = stable_key + return stable_key + + base = self._resume_agent_base_key(agent) + if ( + isinstance(current_key, str) + and current_key not in self._claimed_resumed_keys + and self._run_state._current_agent is agent + and self._entry_matches_agent_name( + sessions_by_agent, + current_key, + base, + ) + ): + self._claimed_resumed_keys.add(current_key) + self._resume_keys_by_agent_id[agent_id] = current_key + return current_key + + available = self._resumed_keys_by_name().get(base, []) + for key in available: + if key in self._claimed_resumed_keys: + continue + if ( + isinstance(current_key, str) + and key == current_key + and self._run_state._current_agent is not agent + ): + continue + self._claimed_resumed_keys.add(key) + self._resume_keys_by_agent_id[agent_id] = key + return key + return None + + def _resumed_keys_by_name(self) -> dict[str, list[str]]: + cached = self._available_resumed_keys_by_name + if cached is not None: + return cached + + grouped: dict[str, list[str]] = {} + if self._run_state is not None and self._run_state._sandbox is not None: + sessions_by_agent = self._run_state._sandbox.get("sessions_by_agent") + if isinstance(sessions_by_agent, dict): + for key, entry in sessions_by_agent.items(): + if not isinstance(key, str): + continue + agent_name = self._agent_name_from_entry(key=key, entry=entry) + if agent_name is None: + continue + grouped.setdefault(agent_name, []).append(key) + + self._available_resumed_keys_by_name = grouped + return grouped + + def _legacy_session_entries(self) -> dict[str, object]: + if self._run_state is None or self._run_state._sandbox is None: + return {} + + resumed = self._run_state._sandbox + sessions_by_agent = resumed.get("sessions_by_agent") + if isinstance(sessions_by_agent, dict): + return { + key: copy.deepcopy(entry) + for key, entry in sessions_by_agent.items() + if isinstance(key, str) + } + + payload = resumed.get("session_state") + if not isinstance(payload, dict): + return {} + + current_key = resumed.get("current_agent_key") + if isinstance(current_key, str): + return {current_key: copy.deepcopy(payload)} + + current_agent_name = resumed.get("current_agent_name") + if current_agent_name is None and self._run_state._current_agent is not None: + current_agent_name = self._run_state._current_agent.name + if isinstance(current_agent_name, str): + return {current_agent_name: copy.deepcopy(payload)} + + current_agent_id = resumed.get("current_agent_id") + if current_agent_id is not None: + return {str(current_agent_id): copy.deepcopy(payload)} + return {} + + def _serialize_sessions_by_agent( + self, + client: BaseSandboxClient[Any], + ) -> dict[str, object]: + sessions_by_agent = self._legacy_session_entries() + for agent_id, agent_resources in self._resources_by_agent.items(): + agent = self._acquired_agents.get(agent_id) + if agent is None: + continue + resume_key = self._ensure_resume_key(agent) + source_key = self._resume_source_key_by_agent_id.get(agent_id) + if source_key is not None and source_key != resume_key: + sessions_by_agent.pop(source_key, None) + sessions_by_agent[resume_key] = self._serialize_session_entry( + agent=agent, + session_state=client.serialize_session_state(agent_resources.state), + ) + return sessions_by_agent + + def _used_resume_keys(self) -> set[str]: + used = set(self._legacy_session_entries()) + used.update(self._resume_keys_by_agent_id.values()) + return used + + def _remember_resume_source_key(self, agent_id: int, key: str) -> None: + self._resume_source_key_by_agent_id[agent_id] = key + + @staticmethod + def _entry_matches_agent_name( + sessions_by_agent: object, + key: str, + agent_name: str, + ) -> bool: + if not isinstance(sessions_by_agent, dict): + return False + entry = sessions_by_agent.get(key) + return ( + SandboxRuntimeSessionManager._agent_name_from_entry(key=key, entry=entry) == agent_name + ) + + @staticmethod + def _agent_name_from_entry(*, key: str, entry: object) -> str | None: + if isinstance(entry, dict): + entry_name = entry.get("agent_name") + session_state = entry.get("session_state") + if isinstance(entry_name, str) and isinstance(session_state, dict): + return entry_name + return key + return None + + @staticmethod + def _session_payload_from_entry(entry: object) -> dict[str, object] | None: + if entry is None: + return None + if not isinstance(entry, dict): + raise ValueError("RunState sandbox payload has an invalid `sessions_by_agent` item") + session_state = entry.get("session_state") + if isinstance(session_state, dict): + return session_state + return entry diff --git a/src/agents/sandbox/sandbox_agent.py b/src/agents/sandbox/sandbox_agent.py new file mode 100644 index 0000000000..d2d7ab21f8 --- /dev/null +++ b/src/agents/sandbox/sandbox_agent.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from dataclasses import dataclass, field + +from ..agent import Agent +from ..run_context import TContext +from .capabilities import Capability +from .codex_config import CodexConfig +from .manifest import Manifest + + +@dataclass +class SandboxAgent(Agent[TContext]): + """An `Agent` with sandbox-specific configuration. + + Runtime transport details such as the sandbox client, client options, and live session are + provided at run time through `RunConfig(sandbox=...)`, not stored on the agent itself. + """ + + default_manifest: Manifest | None = None + """Default sandbox manifest for new sessions created by `Runner` sandbox execution.""" + + developer_instructions: str | None = None + """Additional deterministic instructions appended after the base agent instructions.""" + + capabilities: list[Capability] = field(default_factory=list) + """Sandbox capabilities that can mutate the manifest, add instructions, and expose tools.""" + + codex: bool | CodexConfig = True + """Whether to provision Codex for runtime-created or resumed sandbox sessions.""" + + _sandbox_concurrency_guard: object | None = field(default=None, init=False, repr=False) diff --git a/src/agents/sandbox/sandboxes/__init__.py b/src/agents/sandbox/sandboxes/__init__.py new file mode 100644 index 0000000000..fa4d02611b --- /dev/null +++ b/src/agents/sandbox/sandboxes/__init__.py @@ -0,0 +1,41 @@ +""" +Sandbox implementations for the sandbox package. + +This subpackage contains concrete session/client implementations for different +execution environments (e.g. Docker, local Unix). +""" + +from .unix_local import ( + UnixLocalSandboxClient, + UnixLocalSandboxSession, + UnixLocalSandboxSessionState, +) + +try: + from .docker import ( # noqa: F401 + DockerSandboxClient, + DockerSandboxClientOptions, + DockerSandboxSession, + DockerSandboxSessionState, + ) + + _HAS_DOCKER = True +except Exception: # pragma: no cover + # Docker is an optional extra; keep base imports working without it. + _HAS_DOCKER = False + +__all__ = [ + "UnixLocalSandboxClient", + "UnixLocalSandboxSession", + "UnixLocalSandboxSessionState", +] + +if _HAS_DOCKER: + __all__.extend( + [ + "DockerSandboxClient", + "DockerSandboxClientOptions", + "DockerSandboxSession", + "DockerSandboxSessionState", + ] + ) diff --git a/src/agents/sandbox/sandboxes/docker.py b/src/agents/sandbox/sandboxes/docker.py new file mode 100644 index 0000000000..3f47ff855c --- /dev/null +++ b/src/agents/sandbox/sandboxes/docker.py @@ -0,0 +1,790 @@ +import asyncio +import io +import queue +import tarfile +import tempfile +import threading +import uuid +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from pathlib import Path +from typing import Final, cast + +import docker.errors # type: ignore[import-untyped] +from docker import DockerClient as DockerSDKClient +from docker.models.containers import Container # type: ignore[import-untyped] +from docker.utils import parse_repository_tag # type: ignore[import-untyped] +from typing_extensions import Buffer + +from ..codex_config import CodexConfig, apply_codex_to_manifest, apply_codex_to_session_state +from ..entries import ( + FuseMountPattern, + Mount, + MountpointMountPattern, + RcloneMountPattern, + resolve_workspace_path, +) +from ..errors import ( + ExecTimeoutError, + ExecTransportError, + WorkspaceArchiveReadError, + WorkspaceArchiveWriteError, + WorkspaceReadNotFoundError, +) +from ..manifest import Manifest +from ..session import SandboxSession, SandboxSessionState +from ..session.base_sandbox_session import BaseSandboxSession +from ..session.dependencies import Dependencies +from ..session.manager import Instrumentation +from ..session.sandbox_client import BaseSandboxClient +from ..session.workspace_payloads import coerce_write_payload +from ..snapshot import SnapshotSpec, resolve_snapshot +from ..types import ExecResult +from ..util.iterator_io import IteratorIO +from ..util.retry import ( + TRANSIENT_HTTP_STATUS_CODES, + exception_chain_has_status_code, + retry_async, +) +from ..util.tar_utils import should_skip_tar_member + +_DOCKER_EXECUTOR: Final = ThreadPoolExecutor( + max_workers=8, + thread_name_prefix="agents-docker-sandbox", +) + + +class _QueueWriter(io.RawIOBase): + def __init__(self, chunks: queue.Queue[bytes | BaseException | None]) -> None: + self._chunks = chunks + self._closed = False + + def writable(self) -> bool: + return True + + def write(self, b: Buffer, /) -> int: + if self._closed: + raise ValueError("I/O operation on closed file.") + payload = bytes(b) + if payload: + self._chunks.put(payload) + return len(payload) + + def close(self) -> None: + if self._closed: + return + self._closed = True + super().close() + + +class _StreamingTarPipe: + def __init__(self) -> None: + self._chunks: queue.Queue[bytes | BaseException | None] = queue.Queue() + self.writer = _QueueWriter(self._chunks) + self.reader = IteratorIO(self._iter_chunks()) + + def _iter_chunks(self): + while True: + item = self._chunks.get() + if item is None: + return + if isinstance(item, BaseException): + raise item + yield item + + def set_error(self, error: BaseException) -> None: + self._chunks.put(error) + + def close(self) -> None: + try: + self.writer.close() + finally: + self._chunks.put(None) + + +class DockerSandboxSessionState(SandboxSessionState): + image: str + container_id: str + workspace_root_ready: bool = False + + +@dataclass(frozen=True) +class DockerSandboxClientOptions: + image: str + + +class DockerSandboxSession(BaseSandboxSession): + _docker_client: DockerSDKClient + _container: Container + _workspace_root_ready: bool + _resume_workspace_probe_pending: bool + _resume_preserves_system_state: bool + + state: DockerSandboxSessionState + _ARCHIVE_STAGING_DIR: Path = Path("/tmp/uc-docker-archive") + + def __init__( + self, + *, + docker_client: DockerSDKClient, + container: Container, + state: DockerSandboxSessionState, + ) -> None: + self._docker_client = docker_client + self._container = container + self.state = state + self._workspace_root_ready = state.workspace_root_ready + self._resume_workspace_probe_pending = False + self._resume_preserves_system_state = False + + @classmethod + def from_state( + cls, + state: DockerSandboxSessionState, + *, + container: Container, + docker_client: DockerSDKClient, + ) -> "DockerSandboxSession": + return cls(docker_client=docker_client, container=container, state=state) + + @property + def container_id(self) -> str: + return self.state.container_id + + def _archive_stage_path(self, *, name_hint: str) -> Path: + # Unique name avoids clashes across concurrent reads/writes. + return self._ARCHIVE_STAGING_DIR / f"{uuid.uuid4().hex}_{name_hint}" + + async def _stage_workspace_copy(self) -> tuple[Path, Path]: + root = Path(self.state.manifest.root) + root_name = root.name or "workspace" + staging_parent = self._archive_stage_path(name_hint="workspace") + staging_workspace = staging_parent / root_name + + await self._exec_checked( + "mkdir", + "-p", + str(staging_parent), + error_cls=WorkspaceArchiveReadError, + error_path=root, + ) + await self._exec_checked( + "cp", + "-R", + "--", + str(root), + str(staging_workspace), + error_cls=WorkspaceArchiveReadError, + error_path=root, + ) + return staging_parent, staging_workspace + + async def _rm_best_effort(self, path: Path) -> None: + try: + await self.exec("rm", "-rf", "--", str(path), shell=False) + except Exception: + pass + + async def _exec_checked( + self, + *cmd: str | Path, + error_cls: type[WorkspaceArchiveReadError] | type[WorkspaceArchiveWriteError], + error_path: Path, + ) -> ExecResult: + res = await self.exec(*cmd, shell=False) + if not res.ok(): + raise error_cls( + path=error_path, + context={ + "command": [str(c) for c in cmd], + "stdout": res.stdout.decode("utf-8", errors="replace"), + "stderr": res.stderr.decode("utf-8", errors="replace"), + }, + ) + return res + + async def start(self) -> None: + self._container.reload() + if not await self.running(): + self._container.start() + await super().start() + self._workspace_root_ready = True + self.state.workspace_root_ready = True + self._resume_workspace_probe_pending = False + + async def _exec_run( + self, + *, + cmd: list[str], + workdir: str | None, + timeout: float | None, + command_for_errors: tuple[str | Path, ...], + kill_on_timeout: bool, + ) -> ExecResult: + loop = asyncio.get_running_loop() + future = loop.run_in_executor( + _DOCKER_EXECUTOR, + lambda: self._container.exec_run(cmd=cmd, demux=True, workdir=workdir), + ) + try: + exec_result = await asyncio.wait_for(future, timeout=timeout) + except asyncio.TimeoutError as e: + if kill_on_timeout: + # Best-effort: kill processes matching the command line. + # If this fails, the caller still gets a timeout error. + try: + pattern = " ".join(str(c) for c in command_for_errors).replace("'", "'\\''") + self._container.exec_run( + cmd=[ + "sh", + "-lc", + f"pkill -f -- '{pattern}' >/dev/null 2>&1 || true", + ], + demux=True, + ) + except Exception: + pass + raise ExecTimeoutError(command=command_for_errors, timeout_s=timeout, cause=e) from e + except Exception as e: + raise ExecTransportError(command=command_for_errors, cause=e) from e + + stdout, stderr = exec_result.output + return ExecResult( + stdout=stdout or b"", + stderr=stderr or b"", + exit_code=exec_result.exit_code or 0, + ) + + async def _recover_workspace_root_ready(self, *, timeout: float | None) -> None: + if self._workspace_root_ready or not self._resume_workspace_probe_pending: + return + + root = self.state.manifest.root + probe_command = ("test", "-d", "--", root) + try: + result = await self._exec_run( + cmd=[str(c) for c in probe_command], + workdir=None, + timeout=timeout, + command_for_errors=probe_command, + kill_on_timeout=False, + ) + except (ExecTimeoutError, ExecTransportError): + return + finally: + self._resume_workspace_probe_pending = False + + if result.ok(): + self._workspace_root_ready = True + self.state.workspace_root_ready = True + + async def _exec_internal( + self, *command: str | Path, timeout: float | None = None + ) -> ExecResult: + # `docker-py` is synchronous and can block indefinitely (e.g. hung + # process, daemon issues). Run in a worker thread so we can enforce a + # timeout without requiring `timeout(1)` in the container image. + # Use a shared bounded executor so repeated timeouts do not leak one + # new thread per command. + cmd: list[str] = [str(c) for c in command] + await self._recover_workspace_root_ready(timeout=timeout) + # The workspace root is created during `apply_manifest()`, so the first + # bootstrap commands must not force Docker to chdir there yet. + workdir = self.state.manifest.root if self._workspace_root_ready else None + return await self._exec_run( + cmd=cmd, + workdir=workdir, + timeout=timeout, + command_for_errors=command, + kill_on_timeout=True, + ) + + async def read(self, path: Path) -> io.IOBase: + workspace_path = resolve_workspace_path( + Path(self.state.manifest.root), + path, + allow_absolute_within_root=True, + ) + + # Docker's archive APIs (put/get) can be flaky for paths that exist *inside* the container + # but are not visible to the Docker daemon (notably FUSE mounts like mount-s3/rclone). + # Mirror `write()`: always stage into a daemon-visible directory, then `get_archive` there. + staging_path = self._archive_stage_path(name_hint=workspace_path.name) + + await self._exec_checked( + "mkdir", + "-p", + str(self._ARCHIVE_STAGING_DIR), + error_cls=WorkspaceArchiveReadError, + error_path=path, + ) + + cp_res = await self.exec("cp", "--", str(workspace_path), str(staging_path), shell=False) + if not cp_res.ok(): + # Best-effort: treat stage failure as not-found. (It can also be permissions, but we + # don't have a dedicated error type for that yet.) + raise WorkspaceReadNotFoundError( + path=path, + context={ + "command": ["cp", "--", str(workspace_path), str(staging_path)], + "stdout": cp_res.stdout.decode("utf-8", errors="replace"), + "stderr": cp_res.stderr.decode("utf-8", errors="replace"), + }, + ) + + try: + stream, _ = self._container.get_archive(str(staging_path)) + except docker.errors.NotFound as e: + raise WorkspaceReadNotFoundError(path=path, cause=e) from e + except docker.errors.APIError as e: + raise WorkspaceArchiveReadError(path=path, cause=e) from e + finally: + # Best-effort cleanup. + await self._rm_best_effort(staging_path) + + # `get_archive` returns a tar stream. For a single-file read we buffer + # the tar bytes so tarfile can operate in non-streaming mode (seeking + # is required by some reads). + try: + raw = b"".join(stream) + with tarfile.open(fileobj=io.BytesIO(raw), mode="r:*") as tar: + members = tar.getmembers() + if not members: + raise WorkspaceReadNotFoundError(path=path) + extracted = tar.extractfile(members[0]) + if extracted is None: + raise WorkspaceReadNotFoundError(path=path) + return io.BytesIO(extracted.read()) + except WorkspaceReadNotFoundError: + raise + except (tarfile.TarError, OSError) as e: + raise WorkspaceArchiveReadError(path=path, cause=e) from e + + async def write(self, path: Path, data: io.IOBase) -> None: + payload = coerce_write_payload(path=path, data=data) + + path = resolve_workspace_path( + Path(self.state.manifest.root), + path, + allow_absolute_within_root=True, + ) + + parent = path.parent + await self.mkdir(parent, parents=True) + + # Docker's archive APIs (put/get) can be flaky for paths that exist *inside* the container + # but are not visible to the Docker daemon (notably FUSE mounts like mount-s3/rclone). + # To make writes robust across normal dirs and mountpoints, always stage the payload in + # a daemon-visible directory and then copy into place from inside the container. + staging_path = self._archive_stage_path(name_hint=path.name) + staging_name = staging_path.name + + await self._exec_checked( + "mkdir", + "-p", + str(self._ARCHIVE_STAGING_DIR), + error_cls=WorkspaceArchiveWriteError, + error_path=self._ARCHIVE_STAGING_DIR, + ) + + try: + tar_stream: io.IOBase + if payload.content_length is None: + tar_stream = self._buffered_single_file_tar( + staging_name=staging_name, + stream=payload.stream, + ) + else: + tar_stream = self._streaming_single_file_tar( + staging_name=staging_name, + stream=payload.stream, + content_length=payload.content_length, + ) + ok = self._container.put_archive(str(self._ARCHIVE_STAGING_DIR), tar_stream) + except docker.errors.APIError as e: + raise WorkspaceArchiveWriteError(path=self._ARCHIVE_STAGING_DIR, cause=e) from e + if not ok: + raise WorkspaceArchiveWriteError( + path=self._ARCHIVE_STAGING_DIR, + context={"reason": "put_archive_returned_false"}, + ) + + # Copy into place using a process inside the container, which can see mounts. + cp_res = await self.exec("cp", "--", str(staging_path), str(path), shell=False) + if not cp_res.ok(): + raise WorkspaceArchiveWriteError( + path=parent, + context={ + "command": ["cp", "--", str(staging_path), str(path)], + "stdout": cp_res.stdout.decode("utf-8", errors="replace"), + "stderr": cp_res.stderr.decode("utf-8", errors="replace"), + }, + ) + + # Best-effort cleanup. Ignore failures (e.g. concurrent cleanup). + await self._rm_best_effort(staging_path) + + @staticmethod + def _buffered_single_file_tar(*, staging_name: str, stream: io.IOBase) -> io.BytesIO: + payload = stream.read() + if isinstance(payload, bytearray): + payload = bytes(payload) + if not isinstance(payload, bytes): + raise TypeError(f"expected bytes payload, got {type(payload).__name__}") + + info = tarfile.TarInfo(name=staging_name) + info.size = len(payload) + + tar_buf = io.BytesIO() + with tarfile.open(fileobj=tar_buf, mode="w") as tar: + tar.addfile(info, io.BytesIO(payload)) + tar_buf.seek(0) + return tar_buf + + @staticmethod + def _streaming_single_file_tar( + *, + staging_name: str, + stream: io.IOBase, + content_length: int, + ) -> io.IOBase: + pipe = _StreamingTarPipe() + + def _produce() -> None: + info = tarfile.TarInfo(name=staging_name) + info.size = content_length + try: + with tarfile.open(fileobj=pipe.writer, mode="w|") as tar: + tar.addfile(info, stream) + except BaseException as exc: + pipe.set_error(exc) + finally: + pipe.close() + + threading.Thread( + target=_produce, + name=f"docker-put-archive-{staging_name}", + daemon=True, + ).start() + return pipe.reader + + async def running(self) -> bool: + # docker-py caches container attributes; refresh to avoid stale status, + # especially right after start/stop. + try: + self._container.reload() + except docker.errors.APIError: + # Best-effort: if we can't reload, fall back to last known status. + pass + return cast(str, self._container.status) == "running" + + async def stop(self) -> None: + # Persistence-only. Container teardown is handled in `shutdown()`. + await super().stop() + + async def shutdown(self) -> None: + # Best-effort: stop the container if it exists. + try: + self._container.reload() + except Exception: + pass + try: + if await self.running(): + self._container.stop() + except Exception: + # If the container is already gone/stopped, ignore. + pass + + def should_provision_manifest_accounts_on_resume(self) -> bool: + return not self._resume_preserves_system_state + + async def exists(self) -> bool: + try: + self._docker_client.containers.get(self.state.container_id) + return True + except docker.errors.NotFound: + return False + + @retry_async( + retry_if=lambda exc, self: exception_chain_has_status_code(exc, TRANSIENT_HTTP_STATUS_CODES) + ) + async def persist_workspace(self) -> io.IOBase: + def _error_context_summary(error: WorkspaceArchiveReadError) -> dict[str, str]: + summary = {"message": error.message} + if error.cause is not None: + summary["cause_type"] = type(error.cause).__name__ + summary["cause"] = str(error.cause) + return summary + + skip = self._persist_workspace_skip_relpaths() + root = Path(self.state.manifest.root) + unmounted_mounts: list[tuple[Mount, Path]] = [] + unmount_error: WorkspaceArchiveReadError | None = None + for mount_entry, mount_path in self.state.manifest.ephemeral_mount_targets(): + try: + await mount_entry.unmount_path(self, mount_path) + except Exception as e: + unmount_error = WorkspaceArchiveReadError(path=root, cause=e) + break + unmounted_mounts.append((mount_entry, mount_path)) + + snapshot_error: WorkspaceArchiveReadError | None = None + archive: io.IOBase | None = None + staging_parent: Path | None = None + if unmount_error is None: + try: + try: + staging_parent, staging_workspace = await self._stage_workspace_copy() + for rel_path in skip: + await self._rm_best_effort(staging_workspace / rel_path) + + bits, _ = self._container.get_archive(str(staging_workspace)) + root_name = root.name or "workspace" + if not skip: + archive = IteratorIO(it=bits) + else: + in_stream = IteratorIO(it=bits) + out_stream = tempfile.SpooledTemporaryFile( + max_size=16 * 1024 * 1024, mode="w+b" + ) + try: + with ( + tarfile.open(fileobj=in_stream, mode="r|*") as in_tar, + tarfile.open(fileobj=out_stream, mode="w") as out_tar, + ): + for member in in_tar: + if should_skip_tar_member( + member.name, skip_rel_paths=skip, root_name=root_name + ): + continue + fileobj = in_tar.extractfile(member) if member.isreg() else None + out_tar.addfile(member, fileobj) + if fileobj is not None: + fileobj.close() + except (tarfile.TarError, OSError) as e: + out_stream.close() + raise WorkspaceArchiveReadError(path=root, cause=e) from e + + out_stream.seek(0) + archive = cast(io.IOBase, out_stream) + except docker.errors.NotFound as e: + snapshot_error = WorkspaceArchiveReadError(path=root, cause=e) + except docker.errors.APIError as e: + snapshot_error = WorkspaceArchiveReadError(path=root, cause=e) + except WorkspaceArchiveReadError as e: + snapshot_error = e + finally: + if staging_parent is not None: + await self._rm_best_effort(staging_parent) + + remount_error: WorkspaceArchiveReadError | None = None + for mount_entry, mount_path in reversed(unmounted_mounts): + try: + await mount_entry.mount(self, mount_path) + except Exception as e: + current_error = WorkspaceArchiveReadError(path=root, cause=e) + if remount_error is None: + remount_error = current_error + if unmount_error is not None: + remount_error.context["earlier_unmount_error"] = _error_context_summary( + unmount_error + ) + else: + additional_remount_errors = remount_error.context.setdefault( + "additional_remount_errors", [] + ) + assert isinstance(additional_remount_errors, list) + additional_remount_errors.append(_error_context_summary(current_error)) + + if remount_error is not None: + if snapshot_error is not None: + remount_error.context["snapshot_error_before_remount_corruption"] = ( + _error_context_summary(snapshot_error) + ) + raise remount_error + if unmount_error is not None: + raise unmount_error + if snapshot_error is not None: + raise snapshot_error + + assert archive is not None + return archive + + async def hydrate_workspace(self, data: io.IOBase) -> None: + root = self.state.manifest.root + hydration_target = Path(root).parent + try: + ok = self._container.put_archive(str(hydration_target), data) + except docker.errors.APIError as e: + raise WorkspaceArchiveWriteError(path=Path(root), cause=e) from e + if not ok: + raise WorkspaceArchiveWriteError( + path=Path(root), context={"reason": "put_archive_returned_false"} + ) + + +class DockerSandboxClient(BaseSandboxClient[DockerSandboxClientOptions]): + backend_id = "docker" + docker_client: DockerSDKClient + _instrumentation: Instrumentation + + def __init__( + self, + docker_client: DockerSDKClient, + *, + instrumentation: Instrumentation | None = None, + dependencies: Dependencies | None = None, + ) -> None: + super().__init__() + self.docker_client = docker_client + self._instrumentation = instrumentation or Instrumentation() + self._dependencies = dependencies + + async def create( + self, + *, + snapshot: SnapshotSpec | None = None, + manifest: Manifest | None = None, + codex: bool | CodexConfig = False, + options: DockerSandboxClientOptions, + ) -> SandboxSession: + manifest = apply_codex_to_manifest(manifest, codex) + image = options.image + + container = await self._create_container(image, manifest=manifest) + container.start() + + session_id = uuid.uuid4() + container_id = container.id + assert container_id is not None + snapshot_id = str(session_id) + snapshot_instance = resolve_snapshot(snapshot, snapshot_id) + state = DockerSandboxSessionState( + session_id=session_id, + manifest=manifest or Manifest(), + image=image, + snapshot=snapshot_instance, + container_id=container_id, + ) + + inner = DockerSandboxSession( + docker_client=self.docker_client, + container=container, + state=state, + ) + return self._wrap_session(inner, instrumentation=self._instrumentation) + + async def delete(self, session: SandboxSession) -> SandboxSession: + inner = session._inner + if not isinstance(inner, DockerSandboxSession): + raise TypeError("DockerSandboxClient.delete expects a DockerSandboxSession") + try: + container = self.docker_client.containers.get(inner.state.container_id) + except docker.errors.NotFound: + return session + # Ensure teardown happens before removal. + try: + await inner.shutdown() + except Exception: + pass + try: + container.remove() + except docker.errors.NotFound: + return session + return session + + async def resume( + self, + state: SandboxSessionState, + *, + codex: bool | CodexConfig = False, + ) -> SandboxSession: + if not isinstance(state, DockerSandboxSessionState): + raise TypeError("DockerSandboxClient.resume expects a DockerSandboxSessionState") + state = apply_codex_to_session_state(state, codex) + container = self.get_container(state.container_id) + reused_existing_container = container is not None + if container is None: + container = await self._create_container(state.image, manifest=state.manifest) + container_id = container.id + assert container_id is not None + state.container_id = container_id + state.workspace_root_ready = False + + # Use the existing container (or the one we just created). + inner = DockerSandboxSession( + container=container, docker_client=self.docker_client, state=state + ) + inner._resume_workspace_probe_pending = True + inner._resume_preserves_system_state = reused_existing_container + return self._wrap_session(inner, instrumentation=self._instrumentation) + + def deserialize_session_state(self, payload: dict[str, object]) -> SandboxSessionState: + return DockerSandboxSessionState.model_validate(payload) + + async def _create_container(self, image: str, *, manifest: Manifest | None = None) -> Container: + # create image if it does not exist + if not self.image_exists(image): + repo, tag = parse_repository_tag(image) + self.docker_client.images.pull(repo, tag=tag or None, all_tags=False) + + assert self.image_exists(image) + environment: dict[str, str] | None = None + if manifest: + environment = await manifest.environment.resolve() + create_kwargs: dict[str, object] = { + "entrypoint": ["tail"], + "image": image, + "detach": True, + "command": ["-f", "/dev/null"], + "environment": environment, + } + if _manifest_requires_fuse(manifest): + create_kwargs.update( + devices=["/dev/fuse"], + cap_add=["SYS_ADMIN"], + security_opt=["apparmor:unconfined"], + ) + elif _manifest_requires_sys_admin(manifest): + create_kwargs.update( + cap_add=["SYS_ADMIN"], + security_opt=["apparmor:unconfined"], + ) + return self.docker_client.containers.create(**create_kwargs) + + def image_exists(self, image: str) -> bool: + try: + self.docker_client.images.get(image) + return True + except docker.errors.ImageNotFound: + return False + + def get_container(self, container_id: str) -> Container | None: + try: + return self.docker_client.containers.get(container_id) + except docker.errors.NotFound: + return None + + +def _manifest_requires_fuse(manifest: Manifest | None) -> bool: + if manifest is None: + return False + for _path, artifact in manifest.iter_entries(): + if isinstance(artifact, Mount): + mount_pattern = getattr(artifact, "mount_pattern", None) + if isinstance(mount_pattern, (FuseMountPattern, MountpointMountPattern)): + return True + if isinstance(mount_pattern, RcloneMountPattern) and mount_pattern.mode == "fuse": + return True + return False + + +def _manifest_requires_sys_admin(manifest: Manifest | None) -> bool: + if manifest is None: + return False + for _path, artifact in manifest.iter_entries(): + if isinstance(artifact, Mount): + mount_pattern = getattr(artifact, "mount_pattern", None) + if isinstance(mount_pattern, RcloneMountPattern) and mount_pattern.mode == "nfs": + return True + return False diff --git a/src/agents/sandbox/sandboxes/unix_local.py b/src/agents/sandbox/sandboxes/unix_local.py new file mode 100644 index 0000000000..fd4475532f --- /dev/null +++ b/src/agents/sandbox/sandboxes/unix_local.py @@ -0,0 +1,597 @@ +import asyncio +import io +import logging +import os +import shlex +import shutil +import signal +import sys +import tarfile +import tempfile +import uuid +from collections.abc import Mapping, Sequence +from pathlib import Path +from typing import Literal, cast + +from ..codex_config import CodexConfig, apply_codex_to_manifest, apply_codex_to_session_state +from ..entries import resolve_workspace_path +from ..errors import ( + ExecNonZeroError, + ExecTimeoutError, + ExecTransportError, + InvalidManifestPathError, + WorkspaceArchiveReadError, + WorkspaceArchiveWriteError, + WorkspaceReadNotFoundError, + WorkspaceRootNotFoundError, + WorkspaceStartError, + WorkspaceStopError, +) +from ..files import EntryKind, FileEntry +from ..manifest import Manifest +from ..materialization import MaterializationResult +from ..session import SandboxSession, SandboxSessionState +from ..session.base_sandbox_session import BaseSandboxSession +from ..session.dependencies import Dependencies +from ..session.manager import Instrumentation +from ..session.sandbox_client import BaseSandboxClient +from ..session.workspace_payloads import coerce_write_payload +from ..snapshot import SnapshotSpec, resolve_snapshot +from ..types import ExecResult, Permissions, User +from ..util.tar_utils import ( + UnsafeTarMemberError, + safe_extract_tarfile, + should_skip_tar_member, +) + +_DEFAULT_WORKSPACE_PREFIX = "uc-local-" +_DEFAULT_MANIFEST_ROOT = cast(str, Manifest.model_fields["root"].default) + +logger = logging.getLogger(__name__) + + +class UnixLocalSandboxSessionState(SandboxSessionState): + workspace_root_owned: bool = False + + +class UnixLocalSandboxSession(BaseSandboxSession): + """ + Unix-only session implementation that runs commands on the host and uses the host filesystem + as the workspace (rooted at `self.state.manifest.root`). + """ + + state: UnixLocalSandboxSessionState + _running: bool + + def __init__(self, *, state: UnixLocalSandboxSessionState) -> None: + self.state = state + self._running = False + + @classmethod + def from_state(cls, state: UnixLocalSandboxSessionState) -> "UnixLocalSandboxSession": + return cls(state=state) + + async def start(self) -> None: + workspace = Path(self.state.manifest.root) + try: + workspace.mkdir(parents=True, exist_ok=True) + except OSError as e: + raise WorkspaceStartError(path=workspace, cause=e) from e + + self._running = True + await super().start() + + async def stop(self) -> None: + try: + await super().stop() + except Exception as e: + raise WorkspaceStopError(path=Path(self.state.manifest.root), cause=e) from e + + async def apply_manifest(self, *, only_ephemeral: bool = False) -> MaterializationResult: + if self.state.manifest.users or self.state.manifest.groups: + raise ValueError( + "UnixLocalSandboxSession does not support manifest users or groups because " + "provisioning would run on the host machine" + ) + return await super().apply_manifest(only_ephemeral=only_ephemeral) + + async def provision_manifest_accounts(self) -> None: + if self.state.manifest.users or self.state.manifest.groups: + raise ValueError( + "UnixLocalSandboxSession does not support manifest users or groups because " + "provisioning would run on the host machine" + ) + + async def shutdown(self) -> None: + # Best-effort: mark session not running. We intentionally do not delete the workspace + # directory here; cleanup is handled by the Client.delete(). + self._running = False + + def _prepare_exec_command( + self, + *command: str | Path, + shell: bool | list[str], + user: str | User | None, + ) -> list[str]: + if shell is True: + shell = ["sh", "-c"] + return super()._prepare_exec_command(*command, shell=shell, user=user) + + async def _exec_internal( + self, *command: str | Path, timeout: float | None = None + ) -> ExecResult: + env, cwd = await self._resolved_exec_context() + workspace_root = Path(cwd).resolve() + command_parts = self._workspace_relative_command_parts(command, workspace_root) + process_cwd, command_parts = self._shell_workspace_process_context( + command_parts=command_parts, + workspace_root=workspace_root, + cwd=cwd, + ) + exec_command = self._confined_exec_command( + command_parts=command_parts, + workspace_root=workspace_root, + env=env, + ) + + try: + proc = await asyncio.create_subprocess_exec( + *exec_command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=process_cwd, + env=env, + start_new_session=True, + ) + + try: + stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout) + except asyncio.TimeoutError as e: + try: + # process tree cleanup + os.killpg(proc.pid, signal.SIGKILL) + except Exception: + pass + raise ExecTimeoutError(command=command, timeout_s=timeout, cause=e) from e + except ExecTimeoutError: + raise + except Exception as e: + raise ExecTransportError(command=command, cause=e) from e + + return ExecResult( + stdout=stdout or b"", stderr=stderr or b"", exit_code=proc.returncode or 0 + ) + + async def _resolved_exec_context(self) -> tuple[dict[str, str], str]: + env = os.environ.copy() + env.update(await self.state.manifest.environment.resolve()) + + workspace = Path(self.state.manifest.root) + if not workspace.exists(): + raise WorkspaceRootNotFoundError(path=workspace) + + env["HOME"] = str(workspace) + return env, str(workspace) + + def _confined_exec_command( + self, + *, + command_parts: list[str], + workspace_root: Path, + env: Mapping[str, str], + ) -> list[str]: + if sys.platform != "darwin": + return command_parts + + sandbox_exec = shutil.which("sandbox-exec") + if not sandbox_exec: + raise ExecTransportError( + command=command_parts, + context={ + "reason": "unix_local_confinement_unavailable", + "platform": sys.platform, + "workspace_root": str(workspace_root), + }, + ) + + profile = self._darwin_exec_profile( + workspace_root, + extra_read_paths=self._darwin_additional_read_paths( + command_parts=command_parts, + env=env, + ), + ) + return [sandbox_exec, "-p", profile, *command_parts] + + @staticmethod + def _workspace_relative_command_parts( + command: tuple[str | Path, ...], + workspace_root: Path, + ) -> list[str]: + command_parts = [str(part) for part in command] + rewritten = [command_parts[0]] + for part in command_parts[1:]: + path_part = Path(part) + if not path_part.is_absolute(): + rewritten.append(part) + continue + try: + relative = path_part.relative_to(workspace_root) + except ValueError: + rewritten.append(part) + continue + rewritten.append("." if not relative.parts else relative.as_posix()) + return rewritten + + @staticmethod + def _darwin_allowable_read_roots(path: Path, *, host_home: Path) -> list[Path]: + candidates: set[Path] = set() + normalized = path.expanduser() + try: + resolved = normalized.resolve(strict=False) + except OSError: + resolved = normalized + + if normalized.is_dir(): + candidates.add(normalized) + else: + candidates.add(normalized.parent) + + if resolved.is_dir(): + candidates.add(resolved) + else: + candidates.add(resolved.parent) + + resolved_text = resolved.as_posix() + if resolved_text == "/opt/homebrew" or resolved_text.startswith("/opt/homebrew/"): + candidates.add(Path("/opt/homebrew")) + if resolved_text == "/usr/local" or resolved_text.startswith("/usr/local/"): + candidates.add(Path("/usr/local")) + if resolved_text == "/Library/Frameworks" or resolved_text.startswith( + "/Library/Frameworks/" + ): + candidates.add(Path("/Library/Frameworks")) + + try: + relative_to_home = resolved.relative_to(host_home) + except ValueError: + relative_to_home = None + if relative_to_home is not None and relative_to_home.parts: + first_segment = relative_to_home.parts[0] + if first_segment.startswith("."): + candidates.add(host_home / first_segment) + elif len(relative_to_home.parts) >= 2 and relative_to_home.parts[:2] == ( + "Library", + "Python", + ): + candidates.add(host_home / "Library" / "Python") + + return sorted( + candidates, key=lambda candidate: (len(candidate.parts), candidate.as_posix()) + ) + + def _darwin_additional_read_paths( + self, + *, + command_parts: list[str], + env: Mapping[str, str], + ) -> list[Path]: + host_home = Path.home().resolve() + allowed: list[Path] = [] + seen: set[str] = set() + + def _append(path: str | Path | None) -> None: + if path is None: + return + candidate = Path(path).expanduser() + if not candidate.is_absolute(): + return + for root in self._darwin_allowable_read_roots(candidate, host_home=host_home): + key = root.as_posix() + if key in seen: + continue + seen.add(key) + allowed.append(root) + + for path_entry in env.get("PATH", "").split(os.pathsep): + if path_entry: + _append(path_entry) + + executable = shutil.which(command_parts[0], path=env.get("PATH")) + _append(executable) + return allowed + + def _darwin_exec_profile( + self, + workspace_root: Path, + *, + extra_read_paths: Sequence[Path] = (), + ) -> str: + def _literal(path: Path | str) -> str: + escaped = str(path).replace("\\", "\\\\").replace('"', '\\"') + return f'"{escaped}"' + + denied_paths = [ + Path("/Users"), + Path("/Volumes"), + Path("/Applications"), + Path("/Library"), + Path("/opt"), + Path("/etc"), + Path("/private/etc"), + Path("/tmp"), + Path("/private/tmp"), + Path("/private"), + Path("/var"), + Path("/usr"), + ] + allow_rules = [ + f"(allow file-read-data file-read-metadata (subpath {_literal(workspace_root)}))", + f"(allow file-write* (subpath {_literal(workspace_root)}))", + *[ + f"(allow file-read-data file-read-metadata (subpath {_literal(path)}))" + for path in extra_read_paths + ], + '(allow file-read-data file-read-metadata (subpath "/usr/bin"))', + '(allow file-read-data file-read-metadata (subpath "/usr/lib"))', + '(allow file-read-data file-read-metadata (subpath "/bin"))', + '(allow file-read-data file-read-metadata (subpath "/System"))', + '(allow file-read-data file-read-metadata (literal "/private/var/select/sh"))', + '(allow file-write* (literal "/dev/null"))', + ] + deny_rules = "\n".join( + f"(deny file-read-data (subpath {_literal(path)}))\n" + f"(deny file-write* (subpath {_literal(path)}))" + for path in denied_paths + ) + return "\n".join( + [ + "(version 1)", + "(allow default)", + deny_rules, + *allow_rules, + ] + ) + + @staticmethod + def _shell_workspace_process_context( + *, + command_parts: list[str], + workspace_root: Path, + cwd: str, + ) -> tuple[str, list[str]]: + if len(command_parts) < 3 or command_parts[0] != "sh" or command_parts[1] != "-c": + return cwd, command_parts + + workspace_cd = f"cd {shlex.quote(str(workspace_root))} && {command_parts[2]}" + rewritten = [*command_parts] + rewritten[2] = workspace_cd + return "/", rewritten + + def _resolve_workspace_path(self, path: Path) -> Path: + workspace_root = Path(self.state.manifest.root).resolve() + confined = resolve_workspace_path( + workspace_root, + path, + allow_absolute_within_root=True, + ) + resolved = confined.resolve(strict=False) + try: + resolved.relative_to(workspace_root) + except ValueError as exc: + reason: Literal["absolute", "escape_root"] = ( + "absolute" if path.is_absolute() else "escape_root" + ) + raise InvalidManifestPathError(rel=path, reason=reason, cause=exc) from exc + return resolved + + def normalize_path(self, path: Path | str) -> Path: + if isinstance(path, str): + path = Path(path) + return self._resolve_workspace_path(path) + + async def ls(self, path: Path | str) -> list[FileEntry]: + normalized = self.normalize_path(path) + command = ("ls", "-la", "--", str(normalized)) + try: + with os.scandir(normalized) as entries: + listed: list[FileEntry] = [] + for entry in entries: + stat_result = entry.stat(follow_symlinks=False) + if entry.is_symlink(): + kind = EntryKind.SYMLINK + elif entry.is_dir(follow_symlinks=False): + kind = EntryKind.DIRECTORY + elif entry.is_file(follow_symlinks=False): + kind = EntryKind.FILE + else: + kind = EntryKind.OTHER + listed.append( + FileEntry( + path=entry.path, + permissions=Permissions.from_mode(stat_result.st_mode), + owner=str(stat_result.st_uid), + group=str(stat_result.st_gid), + size=stat_result.st_size, + kind=kind, + ) + ) + return listed + except OSError as e: + raise ExecNonZeroError( + ExecResult(stdout=b"", stderr=str(e).encode("utf-8"), exit_code=1), + command=command, + cause=e, + ) from e + + async def mkdir(self, path: Path | str, *, parents: bool = False) -> None: + normalized = self.normalize_path(path) + try: + normalized.mkdir(parents=parents, exist_ok=True) + except OSError as e: + raise WorkspaceArchiveWriteError(path=normalized, cause=e) from e + + async def rm(self, path: Path | str, *, recursive: bool = False) -> None: + normalized = self.normalize_path(path) + try: + if normalized.is_dir() and not normalized.is_symlink(): + if recursive: + shutil.rmtree(normalized) + else: + normalized.rmdir() + else: + normalized.unlink() + except FileNotFoundError as e: + if recursive: + return + raise ExecNonZeroError( + ExecResult(stdout=b"", stderr=str(e).encode("utf-8"), exit_code=1), + command=("rm", "-rf" if recursive else "--", str(normalized)), + cause=e, + ) from e + except OSError as e: + raise WorkspaceArchiveWriteError(path=normalized, cause=e) from e + + async def read(self, path: Path) -> io.IOBase: + workspace_path = self._resolve_workspace_path(path) + try: + return workspace_path.open("rb") + except FileNotFoundError as e: + raise WorkspaceReadNotFoundError(path=path, cause=e) from e + except OSError as e: + raise WorkspaceArchiveReadError(path=path, cause=e) from e + + async def write(self, path: Path, data: io.IOBase) -> None: + payload = coerce_write_payload(path=path, data=data) + + workspace_path = self._resolve_workspace_path(path) + try: + workspace_path.parent.mkdir(parents=True, exist_ok=True) + with workspace_path.open("wb") as f: + shutil.copyfileobj(payload.stream, f) + except OSError as e: + raise WorkspaceArchiveWriteError(path=workspace_path, cause=e) from e + + async def running(self) -> bool: + return self._running + + async def persist_workspace(self) -> io.IOBase: + root = Path(self.state.manifest.root) + if not root.exists(): + raise WorkspaceArchiveReadError( + path=root, context={"reason": "workspace_root_not_found"} + ) + + skip = self._persist_workspace_skip_relpaths() + buf = io.BytesIO() + try: + with tarfile.open(fileobj=buf, mode="w") as tar: + tar.add( + root, + arcname=".", + filter=lambda ti: ( + None + if should_skip_tar_member( + ti.name, + skip_rel_paths=skip, + root_name=None, + ) + else ti + ), + ) + except (tarfile.TarError, OSError) as e: + raise WorkspaceArchiveReadError(path=root, cause=e) from e + + buf.seek(0) + return buf + + async def hydrate_workspace(self, data: io.IOBase) -> None: + root = Path(self.state.manifest.root) + try: + root.mkdir(parents=True, exist_ok=True) + with tarfile.open(fileobj=data, mode="r:*") as tar: + safe_extract_tarfile(tar, root=root) + except UnsafeTarMemberError as e: + raise WorkspaceArchiveWriteError( + path=root, context={"reason": e.reason, "member": e.member}, cause=e + ) from e + except (tarfile.TarError, OSError) as e: + raise WorkspaceArchiveWriteError(path=root, cause=e) from e + + +class UnixLocalSandboxClient(BaseSandboxClient[None]): + backend_id = "unix_local" + supports_default_options = True + _instrumentation: Instrumentation + + def __init__( + self, + *, + instrumentation: Instrumentation | None = None, + dependencies: Dependencies | None = None, + ) -> None: + self._instrumentation = instrumentation or Instrumentation() + self._dependencies = dependencies + + async def create( + self, + *, + snapshot: SnapshotSpec | None = None, + manifest: Manifest | None = None, + codex: bool | CodexConfig = False, + options: None = None, + ) -> SandboxSession: + if options is not None: + raise ValueError("UnixLocalSandboxClient.create does not accept options") + manifest = apply_codex_to_manifest(manifest, codex) + # For local execution, runner-created sessions should always get an isolated temp root + # unless the caller explicitly chose a custom host path. + workspace_root_owned = False + if manifest is None or manifest.root == _DEFAULT_MANIFEST_ROOT: + workspace_dir = tempfile.mkdtemp(prefix=_DEFAULT_WORKSPACE_PREFIX) + workspace_root_owned = True + if manifest is None: + manifest = Manifest(root=workspace_dir) + else: + manifest = manifest.model_copy(update={"root": workspace_dir}, deep=True) + + session_id = uuid.uuid4() + snapshot_id = str(session_id) + snapshot_instance = resolve_snapshot(snapshot, snapshot_id) + state = UnixLocalSandboxSessionState( + session_id=session_id, + manifest=manifest, + snapshot=snapshot_instance, + workspace_root_owned=workspace_root_owned, + ) + inner = UnixLocalSandboxSession.from_state(state) + return self._wrap_session(inner, instrumentation=self._instrumentation) + + async def delete(self, session: SandboxSession) -> SandboxSession: + """Best-effort cleanup of the on-disk workspace directory.""" + inner = session._inner + if not isinstance(inner, UnixLocalSandboxSession): + raise TypeError("UnixLocalSandboxClient.delete expects a UnixLocalSandboxSession") + if not inner.state.workspace_root_owned: + return session + try: + shutil.rmtree(Path(inner.state.manifest.root), ignore_errors=False) + except FileNotFoundError: + pass + except Exception: + pass + return session + + async def resume( + self, + state: SandboxSessionState, + *, + codex: bool | CodexConfig = False, + ) -> SandboxSession: + if not isinstance(state, UnixLocalSandboxSessionState): + raise TypeError("UnixLocalSandboxClient.resume expects a UnixLocalSandboxSessionState") + inner = UnixLocalSandboxSession.from_state(apply_codex_to_session_state(state, codex)) + return self._wrap_session(inner, instrumentation=self._instrumentation) + + def deserialize_session_state(self, payload: dict[str, object]) -> SandboxSessionState: + return UnixLocalSandboxSessionState.model_validate(payload) diff --git a/src/agents/sandbox/session/__init__.py b/src/agents/sandbox/session/__init__.py new file mode 100644 index 0000000000..738e21b799 --- /dev/null +++ b/src/agents/sandbox/session/__init__.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +__all__ = [ + "BaseSandboxClient", + "BaseSandboxSession", + "CallbackSink", + "ChainedSink", + "ClientOptionsT", + "Dependencies", + "DependenciesBindingError", + "DependenciesError", + "DependenciesMissingDependencyError", + "DependencyKey", + "EventPayloadPolicy", + "EventSink", + "HttpProxySink", + "Instrumentation", + "JsonlOutboxSink", + "SandboxSession", + "SandboxSessionState", + "UCEvent", + "UCFinishEvent", + "UCStartEvent", + "WorkspaceJsonlSink", + "event_to_json_line", + "validate_uc_event", +] + +if TYPE_CHECKING: + from .base_sandbox_session import BaseSandboxSession + from .dependencies import ( + Dependencies, + DependenciesBindingError, + DependenciesError, + DependenciesMissingDependencyError, + DependencyKey, + ) + from .events import ( + EventPayloadPolicy, + UCEvent, + UCFinishEvent, + UCStartEvent, + validate_uc_event, + ) + from .manager import Instrumentation + from .sandbox_client import BaseSandboxClient, ClientOptionsT + from .sandbox_session import SandboxSession + from .sandbox_session_state import SandboxSessionState + from .sinks import ( + CallbackSink, + ChainedSink, + EventSink, + HttpProxySink, + JsonlOutboxSink, + WorkspaceJsonlSink, + ) + from .utils import event_to_json_line + + +def __getattr__(name: str) -> object: + if name == "BaseSandboxSession": + from .base_sandbox_session import BaseSandboxSession + + return BaseSandboxSession + if name in { + "Dependencies", + "DependenciesBindingError", + "DependenciesError", + "DependenciesMissingDependencyError", + "DependencyKey", + }: + from . import dependencies as dependencies_module + + return getattr(dependencies_module, name) + if name in { + "EventPayloadPolicy", + "UCEvent", + "UCFinishEvent", + "UCStartEvent", + "validate_uc_event", + }: + from . import events as events_module + + return getattr(events_module, name) + if name == "Instrumentation": + from .manager import Instrumentation + + return Instrumentation + if name in {"BaseSandboxClient", "ClientOptionsT"}: + from . import sandbox_client as sandbox_client_module + + return getattr(sandbox_client_module, name) + if name == "SandboxSession": + from .sandbox_session import SandboxSession + + return SandboxSession + if name == "SandboxSessionState": + from .sandbox_session_state import SandboxSessionState + + return SandboxSessionState + if name in { + "CallbackSink", + "ChainedSink", + "EventSink", + "HttpProxySink", + "JsonlOutboxSink", + "WorkspaceJsonlSink", + }: + from . import sinks as sinks_module + + return getattr(sinks_module, name) + if name == "event_to_json_line": + from .utils import event_to_json_line + + return event_to_json_line + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/agents/sandbox/session/archive_extraction.py b/src/agents/sandbox/session/archive_extraction.py new file mode 100644 index 0000000000..6bf5dc09ac --- /dev/null +++ b/src/agents/sandbox/session/archive_extraction.py @@ -0,0 +1,322 @@ +from __future__ import annotations + +import io +import shutil +import tarfile +import tempfile +import zipfile +from collections.abc import Awaitable, Callable, Iterator +from contextlib import contextmanager +from pathlib import Path, PurePosixPath +from typing import Literal, cast + +from ..errors import ExecNonZeroError, WorkspaceArchiveWriteError +from ..files import EntryKind, FileEntry +from ..util.tar_utils import UnsafeTarMemberError, safe_tar_member_rel_path + + +class UnsafeZipMemberError(ValueError): + """Raised when a zip member would escape or violate archive extraction rules.""" + + def __init__(self, *, member: str, reason: str) -> None: + super().__init__(f"unsafe zip member {member!r}: {reason}") + self.member = member + self.reason = reason + + +class WorkspaceArchiveExtractor: + def __init__( + self, + *, + mkdir: Callable[[Path], Awaitable[None]], + write: Callable[[Path, io.IOBase], Awaitable[None]], + ls: Callable[[Path], Awaitable[list[FileEntry]]], + ) -> None: + self._mkdir = mkdir + self._write = write + self._ls = ls + + async def extract_tar_archive( + self, + *, + archive_path: Path, + destination_root: Path, + data: io.IOBase, + ) -> None: + child_entry_cache: dict[Path, dict[str, EntryKind]] = {} + try: + with tarfile.open(fileobj=data, mode="r:*") as archive: + for member in archive.getmembers(): + rel_path = safe_tar_member_rel_path(member) + if rel_path is None: + continue + + await self._ensure_no_symlink_extract_parents( + destination_root=destination_root, + rel_path=rel_path, + member_name=member.name, + error_type="tar", + child_entry_cache=child_entry_cache, + ) + dest = destination_root / rel_path + if member.isdir(): + await self._mkdir(dest) + self._record_extract_entry( + child_entry_cache=child_entry_cache, + destination_root=destination_root, + path=dest, + kind=EntryKind.DIRECTORY, + ) + continue + + fileobj = archive.extractfile(member) + if fileobj is None: + raise UnsafeTarMemberError( + member=member.name, + reason="missing file payload", + ) + try: + await self._mkdir(dest.parent) + self._record_extract_entry( + child_entry_cache=child_entry_cache, + destination_root=destination_root, + path=dest.parent, + kind=EntryKind.DIRECTORY, + ) + await self._write(dest, cast(io.IOBase, fileobj)) + self._record_extract_entry( + child_entry_cache=child_entry_cache, + destination_root=destination_root, + path=dest, + kind=EntryKind.FILE, + ) + finally: + fileobj.close() + except UnsafeTarMemberError as e: + raise WorkspaceArchiveWriteError( + path=archive_path, + context={"member": e.member, "reason": e.reason}, + cause=e, + ) from e + except (tarfile.TarError, OSError) as e: + raise WorkspaceArchiveWriteError(path=archive_path, cause=e) from e + + async def extract_zip_archive( + self, + *, + archive_path: Path, + destination_root: Path, + data: io.IOBase, + ) -> None: + child_entry_cache: dict[Path, dict[str, EntryKind]] = {} + try: + with zipfile_compatible_stream(data) as zip_data: + with zipfile.ZipFile(zip_data) as archive: + for member in archive.infolist(): + rel_path = safe_zip_member_rel_path(member) + if rel_path is None: + continue + + await self._ensure_no_symlink_extract_parents( + destination_root=destination_root, + rel_path=rel_path, + member_name=member.filename, + error_type="zip", + child_entry_cache=child_entry_cache, + ) + dest = destination_root / rel_path + if member.is_dir(): + await self._mkdir(dest) + self._record_extract_entry( + child_entry_cache=child_entry_cache, + destination_root=destination_root, + path=dest, + kind=EntryKind.DIRECTORY, + ) + continue + + await self._mkdir(dest.parent) + self._record_extract_entry( + child_entry_cache=child_entry_cache, + destination_root=destination_root, + path=dest.parent, + kind=EntryKind.DIRECTORY, + ) + with archive.open(member, mode="r") as member_data: + await self._write(dest, cast(io.IOBase, member_data)) + self._record_extract_entry( + child_entry_cache=child_entry_cache, + destination_root=destination_root, + path=dest, + kind=EntryKind.FILE, + ) + except UnsafeZipMemberError as e: + raise WorkspaceArchiveWriteError( + path=archive_path, + context={"member": e.member, "reason": e.reason}, + cause=e, + ) from e + except ValueError as e: + raise WorkspaceArchiveWriteError(path=archive_path, cause=e) from e + except (zipfile.BadZipFile, OSError) as e: + raise WorkspaceArchiveWriteError(path=archive_path, cause=e) from e + + async def _ensure_no_symlink_extract_parents( + self, + *, + destination_root: Path, + rel_path: Path, + member_name: str, + error_type: Literal["tar", "zip"], + child_entry_cache: dict[Path, dict[str, EntryKind]], + ) -> None: + symlink_component = await self._find_symlink_component( + base_dir=destination_root, + rel_path=rel_path, + child_entry_cache=child_entry_cache, + ) + if symlink_component is None: + return + + reason = f"symlink in parent path: {symlink_component.as_posix()}" + if error_type == "tar": + raise UnsafeTarMemberError(member=member_name, reason=reason) + raise UnsafeZipMemberError(member=member_name, reason=reason) + + async def _find_symlink_component( + self, + *, + base_dir: Path, + rel_path: Path, + child_entry_cache: dict[Path, dict[str, EntryKind]], + ) -> Path | None: + current_dir = base_dir + traversed = Path() + + for part in rel_path.parts: + entry_kind = await self._lookup_child_entry_kind( + current_dir, + part, + child_entry_cache=child_entry_cache, + ) + if entry_kind is None: + return None + + traversed /= part + if entry_kind == EntryKind.SYMLINK: + return traversed + + current_dir = current_dir / part + + return None + + async def _lookup_child_entry_kind( + self, + parent_dir: Path, + child_name: str, + *, + child_entry_cache: dict[Path, dict[str, EntryKind]], + ) -> EntryKind | None: + cached_entries = child_entry_cache.get(parent_dir) + if cached_entries is None: + try: + entries = await self._ls(parent_dir) + except ExecNonZeroError: + return None + cached_entries = {Path(entry.path).name: entry.kind for entry in entries} + child_entry_cache[parent_dir] = cached_entries + + return cached_entries.get(child_name) + + @staticmethod + def _record_extract_entry( + *, + child_entry_cache: dict[Path, dict[str, EntryKind]], + destination_root: Path, + path: Path, + kind: EntryKind, + ) -> None: + try: + rel_path = path.relative_to(destination_root) + except ValueError: + return + + if not rel_path.parts: + return + + current_dir = destination_root + for index, part in enumerate(rel_path.parts): + child_kind = kind if index == len(rel_path.parts) - 1 else EntryKind.DIRECTORY + cached_entries = child_entry_cache.get(current_dir) + if cached_entries is not None: + cached_entries[part] = child_kind + current_dir = current_dir / part + + +def _supports_zip_random_access(stream: io.IOBase) -> bool: + try: + position = stream.tell() + stream.seek(position, io.SEEK_SET) + except (AttributeError, OSError, TypeError, ValueError): + return False + return True + + +@contextmanager +def zipfile_compatible_stream(stream: io.IOBase) -> Iterator[io.IOBase]: + if _supports_zip_random_access(stream): + yield _ZipFileStreamAdapter(stream) + return + + spool = tempfile.SpooledTemporaryFile(max_size=16 * 1024 * 1024, mode="w+b") + try: + shutil.copyfileobj(stream, spool) + spool.seek(0) + yield _ZipFileStreamAdapter(cast(io.IOBase, spool)) + finally: + spool.close() + + +def safe_zip_member_rel_path(member: zipfile.ZipInfo) -> Path | None: + if member.filename in ("", ".", "./"): + return None + + rel = PurePosixPath(member.filename) + if rel.is_absolute(): + raise UnsafeZipMemberError(member=member.filename, reason="absolute path") + if ".." in rel.parts: + raise UnsafeZipMemberError(member=member.filename, reason="parent traversal") + + mode = (member.external_attr >> 16) & 0o170000 + if mode == 0o120000: + raise UnsafeZipMemberError(member=member.filename, reason="link member not allowed") + + return Path(*rel.parts) + + +class _ZipFileStreamAdapter(io.IOBase): + # Python 3.10's zipfile._SharedFile reads `file.seekable` directly, so this + # adapter keeps ZIP-compatible random-access streams working across versions. + def __init__(self, stream: io.IOBase) -> None: + self._stream = stream + + def seekable(self) -> bool: + return True + + def readable(self) -> bool: + return True + + def tell(self) -> int: + return int(self._stream.tell()) + + def seek(self, offset: int, whence: int = io.SEEK_SET) -> int: + return int(self._stream.seek(offset, whence)) + + def read(self, size: int = -1) -> bytes: + data = self._stream.read(size) + if isinstance(data, bytes): + return data + raise TypeError(f"expected bytes from wrapped stream, got {type(data).__name__}") + + def close(self) -> None: + return diff --git a/src/agents/sandbox/session/base_sandbox_session.py b/src/agents/sandbox/session/base_sandbox_session.py new file mode 100644 index 0000000000..d13ba228f5 --- /dev/null +++ b/src/agents/sandbox/session/base_sandbox_session.py @@ -0,0 +1,495 @@ +import abc +import io +import shlex +import shutil +import tempfile +from collections.abc import Sequence +from pathlib import Path +from typing import Literal, cast + +from typing_extensions import Self + +from ..entries import BaseEntry, resolve_workspace_path +from ..entries.codex import ( + Codex, + resolve_codex_github_asset_name as resolve_codex_github_asset_name_for_session, + resolve_codex_target_triple as resolve_codex_target_triple_for_session, +) +from ..errors import ( + ExecNonZeroError, + InvalidCompressionSchemeError, +) +from ..files import FileEntry +from ..manifest import Manifest +from ..materialization import MaterializationResult, MaterializedFile +from ..snapshot import NoopSnapshot +from ..types import ExecResult, User +from ..util.parse_utils import parse_ls_la +from .archive_extraction import ( + WorkspaceArchiveExtractor, + safe_zip_member_rel_path, +) +from .dependencies import Dependencies +from .manifest_application import ManifestApplier +from .sandbox_session_state import SandboxSessionState + + +class BaseSandboxSession(abc.ABC): + state: SandboxSessionState + _dependencies: Dependencies | None = None + _dependencies_closed: bool = False + _runtime_persist_workspace_skip_relpaths: set[Path] | None = None + + async def start(self) -> None: + if await self.state.snapshot.restorable(): + # Ensure the snapshot is the single source of truth on resume. + await self._clear_workspace_root_on_resume() + await self.hydrate_workspace(await self.state.snapshot.restore()) + if self.should_provision_manifest_accounts_on_resume(): + await self.provision_manifest_accounts() + # Reapply only ephemeral manifest entries on resume so persisted workspace state wins + # for durable files while temporary scaffolding is rebuilt for the new process. + await self.apply_manifest(only_ephemeral=True) + await self._materialize_missing_codex_entries_on_resume() + else: + await self.apply_manifest() + + async def stop(self) -> None: + """ + Persist/snapshot the workspace. + + Note: `stop()` is intentionally persistence-only. Sandboxes that need to tear down + sandbox resources (Docker containers, remote sessions, etc.) should implement + `shutdown()` instead. + """ + if isinstance(self.state.snapshot, NoopSnapshot): + return + await self.state.snapshot.persist(await self.persist_workspace()) + + @abc.abstractmethod + async def shutdown(self) -> None: + """ + Tear down sandbox resources (best-effort). + + Default is a no-op. Sandbox-specific sessions (e.g. Docker) should override. + """ + + async def __aenter__(self) -> Self: + await self.start() + return self + + async def aclose(self) -> None: + """Run the session cleanup lifecycle outside of ``async with``. + + This performs the same session-owned cleanup as ``__aexit__()``: persist/snapshot the + workspace via ``stop()``, tear down session resources via ``shutdown()``, and close + session-scoped dependencies. If the session came from a sandbox client, call the client's + ``delete()`` separately for backend-specific deletion such as removing a Docker container + or deleting a temporary host workspace. + """ + try: + await self.stop() + await self.shutdown() + finally: + await self._aclose_dependencies() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: object | None, + ) -> None: + await self.aclose() + + @property + def dependencies(self) -> Dependencies: + dependencies = self._dependencies + if dependencies is None: + dependencies = Dependencies() + self._dependencies = dependencies + self._dependencies_closed = False + return dependencies + + def set_dependencies(self, dependencies: Dependencies | None) -> None: + if dependencies is None: + return + self._dependencies = dependencies + self._dependencies_closed = False + + async def _aclose_dependencies(self) -> None: + dependencies = self._dependencies + if dependencies is None or self._dependencies_closed: + return + self._dependencies_closed = True + await dependencies.aclose() + + def _register_persist_workspace_skip_relpath(self, path: Path | str) -> Path: + rel_path = Manifest._coerce_rel_path(path) + Manifest._validate_rel_path(rel_path) + if rel_path in (Path(""), Path(".")): + raise ValueError("Persist workspace skip paths must target a concrete relative path.") + + if self._runtime_persist_workspace_skip_relpaths is None: + self._runtime_persist_workspace_skip_relpaths = set() + self._runtime_persist_workspace_skip_relpaths.add(rel_path) + return rel_path + + def _persist_workspace_skip_relpaths(self) -> set[Path]: + skip_paths = set(self.state.manifest.ephemeral_persistence_paths()) + if self._runtime_persist_workspace_skip_relpaths: + skip_paths.update(self._runtime_persist_workspace_skip_relpaths) + return skip_paths + + async def exec( + self, + *command: str | Path, + timeout: float | None = None, + shell: bool | list[str] = True, + user: str | User | None = None, + ) -> ExecResult: + """Execute a command inside the session. + + :param command: Command and args (will be stringified). + :param timeout: Optional wall-clock timeout in seconds. + :param shell: Whether to run this command in a shell. If ``True`` is provided, + the command will be run prefixed by ``sh -lc``. A custom shell prefix may be used + by providing a list. + + :returns: An ``ExecResult`` containing stdout/stderr and exit code. + + :raises TimeoutError: If the sandbox cannot complete within `timeout`. + """ + + sanitized_command = self._prepare_exec_command(*command, shell=shell, user=user) + return await self._exec_internal(*sanitized_command, timeout=timeout) + + async def resolve_codex_github_asset_name(self) -> str: + """Resolve the Codex GitHub release asset filename for the session target.""" + + return await resolve_codex_github_asset_name_for_session(session=self) + + async def resolve_codex_target_triple(self) -> str: + """Resolve the Codex release target triple for the session target platform.""" + + return await resolve_codex_target_triple_for_session(session=self) + + def _prepare_exec_command( + self, + *command: str | Path, + shell: bool | list[str], + user: str | User | None, + ) -> list[str]: + sanitized_command = [str(c) for c in command] + + if shell: + joined = ( + sanitized_command[0] + if len(sanitized_command) == 1 + else shlex.join(sanitized_command) + ) + if isinstance(shell, list): + sanitized_command = shell + [joined] + else: + sanitized_command = ["sh", "-lc", joined] + + if user: + if isinstance(user, User): + user = user.name + + assert isinstance(user, str) + + sanitized_command = ["sudo", "-u", user, "--"] + sanitized_command + + return sanitized_command + + @abc.abstractmethod + async def _exec_internal( + self, + *command: str | Path, + timeout: float | None = None, + ) -> ExecResult: ... + + @abc.abstractmethod + async def read(self, path: Path) -> io.IOBase: + """Read a file from the session's workspace. + + :param path: Absolute path in the container or path relative to the + workspace root. + :returns: A readable file-like object. + :raises: FileNotFoundError: If the path does not exist. + """ + + @abc.abstractmethod + async def write(self, path: Path, data: io.IOBase) -> None: + """Write a file into the session's workspace. + + :param path: Absolute path in the container or path relative to the + workspace root. + :param data: A file-like object positioned at the start of the payload. + """ + + @abc.abstractmethod + async def running(self) -> bool: + """ + :returns: whether the underlying sandbox is currently running. + """ + + @abc.abstractmethod + async def persist_workspace(self) -> io.IOBase: + """Serialize the session's workspace into a byte stream. + + :returns: A readable tar binary stream representing the full workspace. + """ + + @abc.abstractmethod + async def hydrate_workspace(self, data: io.IOBase) -> None: + """Populate the session's workspace from a serialized byte stream. + + :param data: A readable tar binary stream as produced by `persist_workspace`. + """ + + async def ls(self, path: Path | str) -> list[FileEntry]: + """List directory contents. + + :param path: Path to list. + :returns: A list of `FileEntry` objects. + """ + path = self.normalize_path(path) + + cmd = ("ls", "-la", "--", str(path)) + result = await self.exec(*cmd, shell=False) + if not result.ok(): + raise ExecNonZeroError(result, command=cmd) + + return parse_ls_la(result.stdout.decode("utf-8", errors="replace"), base=str(path)) + + async def rm(self, path: Path | str, *, recursive: bool = False) -> None: + """Remove a file or directory. + + :param path: Path to remove. + :param recursive: If true, remove directories recursively. + """ + path = self.normalize_path(path) + + cmd: list[str] = ["rm"] + if recursive: + cmd.append("-rf") + cmd.extend(["--", str(path)]) + + result = await self.exec(*cmd, shell=False) + if not result.ok(): + raise ExecNonZeroError(result, command=cmd) + + async def mkdir(self, path: Path | str, *, parents: bool = False) -> None: + """Create a directory. + + :param path: Directory to create on the remote. + :param parents: If true, create missing parents. + """ + path = self.normalize_path(path) + + cmd: list[str] = ["mkdir"] + if parents: + cmd.append("-p") + cmd.append(str(path)) + + result = await self.exec(*cmd, shell=False) + if not result.ok(): + raise ExecNonZeroError(result, command=cmd) + + async def extract( + self, + path: Path | str, + data: io.IOBase, + *, + compression_scheme: Literal["tar", "zip"] | None = None, + ) -> None: + """ + Write a compressed archive to a destination on the remote. + Optionally extract the archive once written. + + :param path: Path on the host machine to extract to + :param data: a file-like io stream. + :param compression_scheme: either "tar" or "zip". If not provided, + it will try to infer from the path. + """ + if isinstance(path, str): + path = Path(path) + + if compression_scheme is None: + suffix = path.suffix.removeprefix(".") + compression_scheme = cast(Literal["tar", "zip"], suffix) if suffix else None + + if compression_scheme is None or compression_scheme not in ["zip", "tar"]: + raise InvalidCompressionSchemeError(path=path, scheme=compression_scheme) + + normalized_path = self.normalize_path(path) + destination_root = normalized_path.parent + + # Materialize the archive into a local spool once because both `write()` and the + # extraction step consume the stream, and zip extraction may require seeking. + spool = tempfile.SpooledTemporaryFile(max_size=16 * 1024 * 1024, mode="w+b") + try: + shutil.copyfileobj(data, spool) + spool.seek(0) + await self.write(normalized_path, spool) + spool.seek(0) + + if compression_scheme == "tar": + await self._extract_tar_archive( + archive_path=normalized_path, + destination_root=destination_root, + data=spool, + ) + else: + await self._extract_zip_archive( + archive_path=normalized_path, + destination_root=destination_root, + data=spool, + ) + finally: + spool.close() + + def normalize_path(self, path: Path | str) -> Path: + if isinstance(path, str): + path = Path(path) + + root = Path(self.state.manifest.root) + return resolve_workspace_path(root, path, allow_absolute_within_root=True) + + def describe(self) -> str: + return self.state.manifest.describe() + + async def _materialize_missing_codex_entries_on_resume(self) -> None: + missing_codex_entries: dict[str | Path, BaseEntry] = {} + for rel_path, artifact in self.state.manifest.iter_entries(): + if not isinstance(artifact, Codex): + continue + exists = await self.exec("test", "-e", str(self.normalize_path(rel_path)), shell=False) + if exists.ok(): + continue + missing_codex_entries[rel_path] = artifact.model_copy(deep=True) + + if not missing_codex_entries: + return + + codex_manifest = Manifest( + root=self.state.manifest.root, + entries=missing_codex_entries, + ) + await ManifestApplier( + mkdir=lambda path: self.mkdir(path, parents=True), + exec_checked_nonzero=self._exec_checked_nonzero, + apply_entry=lambda artifact, dest, base_dir: artifact.apply(self, dest, base_dir), + ).apply_manifest(codex_manifest, base_dir=self._manifest_base_dir()) + + async def _extract_tar_archive( + self, + *, + archive_path: Path, + destination_root: Path, + data: io.IOBase, + ) -> None: + extractor = WorkspaceArchiveExtractor( + mkdir=lambda path: self.mkdir(path, parents=True), + write=self.write, + ls=lambda path: self.ls(path), + ) + await extractor.extract_tar_archive( + archive_path=archive_path, + destination_root=destination_root, + data=data, + ) + + async def _extract_zip_archive( + self, + *, + archive_path: Path, + destination_root: Path, + data: io.IOBase, + ) -> None: + extractor = WorkspaceArchiveExtractor( + mkdir=lambda path: self.mkdir(path, parents=True), + write=self.write, + ls=lambda path: self.ls(path), + ) + await extractor.extract_zip_archive( + archive_path=archive_path, + destination_root=destination_root, + data=data, + ) + + @staticmethod + def _safe_zip_member_rel_path(member) -> Path | None: + return safe_zip_member_rel_path(member) + + async def apply_manifest(self, *, only_ephemeral: bool = False) -> MaterializationResult: + applier = ManifestApplier( + mkdir=lambda path: self.mkdir(path, parents=True), + exec_checked_nonzero=self._exec_checked_nonzero, + apply_entry=lambda artifact, dest, base_dir: artifact.apply(self, dest, base_dir), + ) + return await applier.apply_manifest( + self.state.manifest, + only_ephemeral=only_ephemeral, + base_dir=self._manifest_base_dir(), + ) + + async def provision_manifest_accounts(self) -> None: + applier = ManifestApplier( + mkdir=lambda path: self.mkdir(path, parents=True), + exec_checked_nonzero=self._exec_checked_nonzero, + apply_entry=lambda artifact, dest, base_dir: artifact.apply(self, dest, base_dir), + ) + await applier.provision_accounts(self.state.manifest) + + def should_provision_manifest_accounts_on_resume(self) -> bool: + return True + + async def _apply_entry_batch( + self, + entries: Sequence[tuple[Path, BaseEntry]], + *, + base_dir: Path, + ) -> list[MaterializedFile]: + applier = ManifestApplier( + mkdir=lambda path: self.mkdir(path, parents=True), + exec_checked_nonzero=self._exec_checked_nonzero, + apply_entry=lambda artifact, dest, current_base_dir: artifact.apply( + self, + dest, + current_base_dir, + ), + ) + return await applier._apply_entry_batch(entries, base_dir=base_dir) + + def _manifest_base_dir(self) -> Path: + return Path.cwd() + + async def _exec_checked_nonzero(self, *command: str | Path) -> ExecResult: + result = await self.exec(*command, shell=False) + if not result.ok(): + raise ExecNonZeroError(result, command=command) + return result + + async def _clear_workspace_root_on_resume(self) -> None: + """ + Best-effort cleanup step for snapshot resume. + + We intentionally clear *contents* of the workspace root rather than deleting the root + directory itself. Some sandboxes configure their process working directory to the workspace + root (e.g. Modal sandboxes), and deleting the directory can make subsequent exec() calls + fail with "failed to find initial working directory". + """ + + root = Path(self.state.manifest.root) + try: + entries = await self.ls(root) + except ExecNonZeroError: + # If the root doesn't exist (or isn't listable), treat it as empty and let hydrate/apply + # create it as needed. + return + + for entry in entries: + # `parse_ls_la` filters "." and ".." already; remove everything else recursively. + await self.rm(Path(entry.path), recursive=True) diff --git a/src/agents/sandbox/session/dependencies.py b/src/agents/sandbox/session/dependencies.py new file mode 100644 index 0000000000..cb1cec7552 --- /dev/null +++ b/src/agents/sandbox/session/dependencies.py @@ -0,0 +1,201 @@ +from __future__ import annotations + +import inspect +from collections.abc import Awaitable, Callable, Mapping +from dataclasses import dataclass +from typing import cast + +from typing_extensions import Self + +DependencyKey = str + + +class DependenciesError(RuntimeError): + pass + + +class DependenciesBindingError(DependenciesError, ValueError): + pass + + +class DependenciesMissingDependencyError(DependenciesError, LookupError): + pass + + +FactoryFn = Callable[["Dependencies"], object | Awaitable[object]] + + +@dataclass(slots=True) +class _ValueBinding: + value: object + + +@dataclass(slots=True) +class _FactoryBinding: + factory: FactoryFn + cache: bool + owns_result: bool + + +_Binding = _ValueBinding | _FactoryBinding + + +async def _close_best_effort(value: object) -> None: + close = getattr(value, "aclose", None) + if close is not None: + try: + result = close() + if inspect.isawaitable(result): + await cast(Awaitable[object], result) + return + except Exception: + return + + close = getattr(value, "close", None) + if close is None: + return + try: + result = close() + if inspect.isawaitable(result): + await cast(Awaitable[object], result) + except Exception: + return + + +class Dependencies: + """Session-scoped dependency container for manifest entry materialization. + + Sandbox clients hold a configured template of bindings and clone it for each created or resumed + session. That gives each session its own cache and owned-resource lifecycle while still letting + callers register shared runtime-only objects such as service clients or lazy factories. + """ + + def __init__(self) -> None: + self._bindings: dict[DependencyKey, _Binding] = {} + self._cache: dict[DependencyKey, object] = {} + self._owned_results: list[object] = [] + self._closed = False + + @classmethod + def with_values( + cls, + values: Mapping[DependencyKey, object], + ) -> Dependencies: + dependencies = cls() + for key, value in values.items(): + dependencies.bind_value(key, value) + return dependencies + + def bind_value( + self, + key: DependencyKey, + value: object, + *, + overwrite: bool = False, + ) -> Self: + if not key: + raise ValueError("Dependency key must be non-empty") + self._bind(key, _ValueBinding(value=value), overwrite=overwrite) + return self + + def clone(self) -> Dependencies: + cloned = Dependencies() + for key, binding in self._bindings.items(): + if isinstance(binding, _ValueBinding): + cloned._bindings[key] = _ValueBinding(value=binding.value) + else: + cloned._bindings[key] = _FactoryBinding( + factory=binding.factory, + cache=binding.cache, + owns_result=binding.owns_result, + ) + return cloned + + def bind_factory( + self, + key: DependencyKey, + factory: FactoryFn, + *, + cache: bool = True, + overwrite: bool = False, + owns_result: bool = False, + ) -> Self: + if not key: + raise ValueError("Dependency key must be non-empty") + self._bind( + key, + _FactoryBinding( + factory=factory, + cache=cache, + owns_result=owns_result, + ), + overwrite=overwrite, + ) + return self + + def _bind( + self, + key: DependencyKey, + binding: _Binding, + *, + overwrite: bool, + ) -> None: + if not overwrite and key in self._bindings: + raise DependenciesBindingError(f"Dependency `{key}` is already bound") + self._bindings[key] = binding + self._cache.pop(key, None) + + async def get(self, key: DependencyKey) -> object | None: + binding = self._bindings.get(key) + if binding is None: + return None + return await self._resolve(key, binding) + + async def require( + self, + key: DependencyKey, + *, + consumer: str | None = None, + ) -> object: + value = await self.get(key) + if value is not None: + return value + + consumer_part = f" for {consumer}" if consumer else "" + raise DependenciesMissingDependencyError( + f"Missing dependency `{key}`{consumer_part}. " + "Bind it on a Dependencies instance and pass it as " + "`dependencies=` when constructing the sandbox client." + ) + + async def _resolve(self, key: DependencyKey, binding: _Binding) -> object: + if isinstance(binding, _ValueBinding): + return binding.value + + assert isinstance(binding, _FactoryBinding) + if binding.cache and key in self._cache: + return self._cache[key] + + produced = binding.factory(self) + value = ( + await cast(Awaitable[object], produced) if inspect.isawaitable(produced) else produced + ) + + if binding.cache: + self._cache[key] = value + if binding.owns_result: + self._owned_results.append(value) + return value + + async def aclose(self) -> None: + if self._closed: + return + self._closed = True + + seen_ids: set[int] = set() + for value in reversed(self._owned_results): + value_id = id(value) + if value_id in seen_ids: + continue + seen_ids.add(value_id) + await _close_best_effort(value) diff --git a/src/agents/sandbox/session/events.py b/src/agents/sandbox/session/events.py new file mode 100644 index 0000000000..3d9e3145d7 --- /dev/null +++ b/src/agents/sandbox/session/events.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import uuid +from datetime import datetime, timezone +from typing import Annotated, Literal + +from pydantic import BaseModel, Field, TypeAdapter + +from ..errors import ErrorCode, OpName + +EventPhase = Literal["start", "finish"] + + +def _utcnow() -> datetime: + return datetime.now(tz=timezone.utc) + + +class EventPayloadPolicy(BaseModel): + """Controls how much potentially sensitive/large data is included in events.""" + + # Exec output can be noisy and sensitive; default off. + include_exec_output: bool = Field(default=False) + + # When enabled, bound output sizes. + max_stdout_chars: int = Field(default=8_000, ge=0) + max_stderr_chars: int = Field(default=8_000, ge=0) + + # For write events, we only include a best-effort byte count (never file bytes). + include_write_len: bool = Field(default=True) + + +class UCEventBase(BaseModel): + """Shared fields for all instrumentation events.""" + + version: int = Field(default=1) + + event_id: uuid.UUID = Field(default_factory=uuid.uuid4) + ts: datetime = Field(default_factory=_utcnow) + + session_id: uuid.UUID + seq: int + + op: OpName + phase: EventPhase + + span_id: uuid.UUID + parent_span_id: uuid.UUID | None = None + + # Operation-specific metadata (paths, argv, timings, etc.) + data: dict[str, object] = Field(default_factory=dict) + + +class UCStartEvent(UCEventBase): + """The start event for an operation.""" + + phase: Literal["start"] = Field(default="start") + + +class UCFinishEvent(UCEventBase): + """The finish event for an operation.""" + + phase: Literal["finish"] = Field(default="finish") + + ok: bool + duration_ms: float + + error_code: ErrorCode | None = None + error_type: str | None = None + error_message: str | None = None + + # Optional exec outputs (truncated / opt-in via policy). + stdout: str | None = None + stderr: str | None = None + + # Raw exec outputs (bytes) for per-sink/per-op policy application. + # These are excluded from serialization (JSONL / HTTP) by default. + stdout_bytes: bytes | None = Field(default=None, exclude=True) + stderr_bytes: bytes | None = Field(default=None, exclude=True) + + +# Discriminated union keyed by `phase`. +UCEvent = Annotated[UCStartEvent | UCFinishEvent, Field(discriminator="phase")] +_UC_EVENT_ADAPTER: TypeAdapter[UCEvent] = TypeAdapter(UCEvent) + + +def validate_uc_event(obj: object) -> UCEvent: + """Parse an event payload (e.g. from JSON) into the correct phase-specific model.""" + + return _UC_EVENT_ADAPTER.validate_python(obj) diff --git a/src/agents/sandbox/session/manager.py b/src/agents/sandbox/session/manager.py new file mode 100644 index 0000000000..91a6dbf025 --- /dev/null +++ b/src/agents/sandbox/session/manager.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import asyncio +import logging +from collections.abc import Sequence + +from ..errors import OpName +from .events import EventPayloadPolicy, UCEvent, UCFinishEvent +from .sinks import ChainedSink, EventSink +from .utils import _safe_decode + +logger = logging.getLogger(__name__) + + +class Instrumentation: + def __init__( + self, + *, + sinks: Sequence[EventSink] | None = None, + payload_policy: EventPayloadPolicy | None = None, + payload_policy_by_op: dict[OpName, EventPayloadPolicy] | None = None, + ) -> None: + self._sinks: list[EventSink] = list(sinks or []) + self.payload_policy = payload_policy or EventPayloadPolicy() + self.payload_policy_by_op = payload_policy_by_op or {} + self._tasks: set[asyncio.Task[None]] = set() + + @property + def sinks(self) -> list[EventSink]: + return list(self._sinks) + + def add_sink(self, sink: EventSink) -> None: + self._sinks.append(sink) + + async def emit(self, event: UCEvent) -> None: + for sink in self._sinks: + if isinstance(sink, ChainedSink): + for inner in sink.sinks: + policy = self._policy_for(event.op, inner) + per_sink_event = self._apply_policy(event, policy) + # ChainedSink promises in-order delivery; ensure each sink completes + # before moving on, regardless of inner sink.mode. + await self._deliver_chained(inner, per_sink_event) + else: + policy = self._policy_for(event.op, sink) + per_sink_event = self._apply_policy(event, policy) + await self._deliver(sink, per_sink_event) + + async def flush(self) -> None: + pending = tuple(self._tasks) + if not pending: + return + await asyncio.gather(*pending, return_exceptions=True) + + def _policy_for(self, op: OpName, sink: EventSink) -> EventPayloadPolicy: + # Merge semantics: default -> per-op overrides -> per-sink overrides. + effective = self.payload_policy.model_copy(deep=True) + + op_policy = self.payload_policy_by_op.get(op) + if op_policy is not None: + effective = effective.model_copy(update=self._overrides(op_policy)) + + sink_policy = getattr(sink, "payload_policy", None) + if sink_policy is not None: + effective = effective.model_copy(update=self._overrides(sink_policy)) + + return effective + + def _overrides(self, policy: EventPayloadPolicy) -> dict[str, object]: + # Only override fields explicitly set by the user. + return {name: getattr(policy, name) for name in policy.model_fields_set} + + def _apply_policy(self, event: UCEvent, policy: EventPayloadPolicy) -> UCEvent: + # Clone per sink so we can redact/augment fields without affecting other sinks. + out = event.model_copy(deep=True) + + # Generic stream-length metadata redaction. + if not policy.include_write_len and "bytes" in out.data: + out.data.pop("bytes", None) + + # Exec output redaction/formatting. + if isinstance(out, UCFinishEvent): + if not policy.include_exec_output: + out.stdout = None + out.stderr = None + out.stdout_bytes = None + out.stderr_bytes = None + else: + if out.stdout_bytes is not None: + out.stdout = _safe_decode(out.stdout_bytes, max_chars=policy.max_stdout_chars) + if out.stderr_bytes is not None: + out.stderr = _safe_decode(out.stderr_bytes, max_chars=policy.max_stderr_chars) + + return out + + async def _deliver(self, sink: EventSink, event: UCEvent) -> None: + async def _run() -> None: + await sink.handle(event) + + if sink.mode == "sync": + try: + await _run() + except Exception: + self._handle_sink_error(sink, event) + elif sink.mode == "async": + if sink.on_error == "raise": + await _run() + return + + async def _task() -> None: + try: + await _run() + except Exception: + self._handle_sink_error(sink, event) + + task = asyncio.create_task(_task()) + # Track background deliveries so the task is kept alive and can be discarded once done. + self._tasks.add(task) + task.add_done_callback(self._tasks.discard) + elif sink.mode == "best_effort": + + async def _task() -> None: + try: + await _run() + except Exception: + self._handle_sink_error(sink, event, force_no_raise=True) + + task = asyncio.create_task(_task()) + # Same bookkeeping as async mode, but failures are always swallowed after logging. + self._tasks.add(task) + task.add_done_callback(self._tasks.discard) + else: + raise AssertionError(f"unknown sink.mode: {sink.mode!r}") + + async def _deliver_chained(self, sink: EventSink, event: UCEvent) -> None: + """ + Deliver an event to a sink as part of a ChainedSink group. + + The ChainedSink contract is "run in order", which implies later sinks should not + observe side effects before earlier sinks complete. To uphold that, we always + await completion here (ignoring sink.mode scheduling). + """ + try: + await sink.handle(event) + except Exception: + force_no_raise = sink.mode == "best_effort" + self._handle_sink_error(sink, event, force_no_raise=force_no_raise) + + def _handle_sink_error( + self, sink: EventSink, event: UCEvent, *, force_no_raise: bool = False + ) -> None: + if force_no_raise or sink.on_error in ("log", "ignore"): + if sink.on_error == "log": + logger.exception("instrumentation sink failed (ignored): %s", type(sink).__name__) + return + raise RuntimeError( + "instrumentation sink failed: " + f"{type(sink).__name__} while handling event {event.event_id}" + ) diff --git a/src/agents/sandbox/session/manifest_application.py b/src/agents/sandbox/session/manifest_application.py new file mode 100644 index 0000000000..039713f15c --- /dev/null +++ b/src/agents/sandbox/session/manifest_application.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +from collections.abc import Awaitable, Callable, Sequence +from pathlib import Path + +from ..entries import BaseEntry, Dir, Mount, resolve_workspace_path +from ..manifest import Manifest +from ..materialization import MaterializationResult, MaterializedFile, gather_in_order +from ..types import ExecResult, User + + +class ManifestApplier: + def __init__( + self, + *, + mkdir: Callable[[Path], Awaitable[None]], + exec_checked_nonzero: Callable[..., Awaitable[ExecResult]], + apply_entry: Callable[[BaseEntry, Path, Path], Awaitable[list[MaterializedFile]]], + ) -> None: + self._mkdir = mkdir + self._exec_checked_nonzero = exec_checked_nonzero + self._apply_entry = apply_entry + + async def apply_manifest( + self, + manifest: Manifest, + *, + only_ephemeral: bool = False, + base_dir: Path | None = None, + ) -> MaterializationResult: + base_dir = Path("/") if base_dir is None else base_dir + + await self._mkdir(Path(manifest.root)) + + if not only_ephemeral: + await self.provision_accounts(manifest) + + entries_to_apply: list[tuple[Path, BaseEntry]] = [] + if only_ephemeral: + for rel_dest, artifact in self._ephemeral_entries(manifest): + dest = resolve_workspace_path(Path(manifest.root), rel_dest) + entries_to_apply.append((dest, artifact)) + else: + for raw_rel_dest, artifact in manifest.validated_entries().items(): + dest = resolve_workspace_path( + Path(manifest.root), + Manifest._coerce_rel_path(raw_rel_dest), + ) + entries_to_apply.append((dest, artifact)) + + return MaterializationResult( + files=await self._apply_entry_batch(entries_to_apply, base_dir=base_dir), + ) + + async def provision_accounts(self, manifest: Manifest) -> None: + all_users: set[User] = set(manifest.users) + for group in manifest.groups: + all_users |= set(group.users) + await self._exec_checked_nonzero("groupadd", group.name) + + for user in all_users: + await self._exec_checked_nonzero( + "useradd", + "-U", + "-M", + "-s", + "/usr/sbin/nologin", + user.name, + ) + + for group in manifest.groups: + for user in group.users: + await self._exec_checked_nonzero("usermod", "-aG", group.name, user.name) + + def _ephemeral_entries(self, manifest: Manifest) -> list[tuple[Path, BaseEntry]]: + entries: list[tuple[Path, BaseEntry]] = [] + for rel_dest, artifact in manifest.entries.items(): + self._collect_ephemeral_entries( + rel_dest=Manifest._coerce_rel_path(rel_dest), + artifact=artifact, + out=entries, + ) + return entries + + def _collect_ephemeral_entries( + self, + *, + rel_dest: Path, + artifact: BaseEntry, + out: list[tuple[Path, BaseEntry]], + ) -> None: + manifest_rel = Manifest._coerce_rel_path(rel_dest) + Manifest._validate_rel_path(manifest_rel) + if artifact.ephemeral: + out.append((manifest_rel, self._prune_to_ephemeral(artifact))) + return + if isinstance(artifact, Dir): + for child_name, child_artifact in artifact.children.items(): + self._collect_ephemeral_entries( + rel_dest=manifest_rel / Manifest._coerce_rel_path(child_name), + artifact=child_artifact, + out=out, + ) + + def _prune_to_ephemeral(self, artifact: BaseEntry) -> BaseEntry: + if not isinstance(artifact, Dir): + return artifact + if artifact.ephemeral: + return artifact.model_copy(deep=True) + + pruned_children: dict[str | Path, BaseEntry] = {} + for child_name, child_artifact in artifact.children.items(): + if child_artifact.ephemeral: + pruned_children[child_name] = self._prune_to_ephemeral(child_artifact) + continue + if isinstance(child_artifact, Dir): + nested = self._prune_to_ephemeral(child_artifact) + if isinstance(nested, Dir) and nested.children: + pruned_children[child_name] = nested + + return artifact.model_copy(update={"children": pruned_children}, deep=True) + + @staticmethod + def _paths_overlap(left: Path, right: Path) -> bool: + return left == right or left in right.parents or right in left.parents + + async def _apply_entry_batch( + self, + entries: Sequence[tuple[Path, BaseEntry]], + *, + base_dir: Path, + ) -> list[MaterializedFile]: + files: list[MaterializedFile] = [] + parallel_batch: list[tuple[Path, BaseEntry]] = [] + + async def _flush_parallel_batch() -> None: + nonlocal files + if not parallel_batch: + return + + def _make_apply_task( + dest: Path, + artifact: BaseEntry, + ) -> Callable[[], Awaitable[list[MaterializedFile]]]: + async def _apply() -> list[MaterializedFile]: + return await self._apply_entry(artifact, dest, base_dir) + + return _apply + + batch = list(parallel_batch) + parallel_batch.clear() + batch_files = await gather_in_order( + [_make_apply_task(dest, artifact) for dest, artifact in batch] + ) + for entry_files in batch_files: + files.extend(entry_files) + + for dest, artifact in entries: + if isinstance(artifact, Mount) or any( + self._paths_overlap(dest, queued_dest) for queued_dest, _ in parallel_batch + ): + await _flush_parallel_batch() + files.extend(await self._apply_entry(artifact, dest, base_dir)) + continue + + parallel_batch.append((dest, artifact)) + + await _flush_parallel_batch() + return files diff --git a/src/agents/sandbox/session/sandbox_client.py b/src/agents/sandbox/session/sandbox_client.py new file mode 100644 index 0000000000..25ee8af777 --- /dev/null +++ b/src/agents/sandbox/session/sandbox_client.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import abc +from typing import Generic, TypeVar + +from ..codex_config import CodexConfig +from ..manifest import Manifest +from ..snapshot import SnapshotSpec +from .base_sandbox_session import BaseSandboxSession +from .dependencies import Dependencies +from .manager import Instrumentation +from .sandbox_session import SandboxSession +from .sandbox_session_state import SandboxSessionState + +ClientOptionsT = TypeVar("ClientOptionsT") + + +class BaseSandboxClient(abc.ABC, Generic[ClientOptionsT]): + backend_id: str + supports_default_options: bool = False + _dependencies: Dependencies | None = None + + def _resolve_dependencies(self) -> Dependencies | None: + if self._dependencies is None: + return None + # Sessions get clones instead of the shared template so per-session factory caches and + # owned resources do not leak across unrelated sandboxes. + return self._dependencies.clone() + + def _wrap_session( + self, + inner: BaseSandboxSession, + *, + instrumentation: Instrumentation | None = None, + ) -> SandboxSession: + # Always return the instrumented wrapper so callers get consistent events and dependency + # lifecycle handling regardless of which backend created the inner session. + return SandboxSession( + inner, + instrumentation=instrumentation, + dependencies=self._resolve_dependencies(), + ) + + @abc.abstractmethod + async def create( + self, + *, + snapshot: SnapshotSpec | None = None, + manifest: Manifest | None = None, + codex: bool | CodexConfig = False, + options: ClientOptionsT, + ) -> SandboxSession: + """Create a new session. + + Args: + snapshot: Snapshot spec used to create a snapshot instance for + the session. If omitted, the session uses a no-op snapshot. + manifest: Optional manifest to materialize into the workspace when + the session starts. + codex: Whether to provision Codex into the workspace, or a custom + Codex provisioning config. + options: Sandbox-specific settings. For example, Docker expects + ``DockerSandboxClientOptions(image="...")``. + Returns: + A `SandboxSession` that can be entered with `async with` or closed explicitly with + `await session.aclose()`. + """ + + @abc.abstractmethod + async def delete(self, session: SandboxSession) -> SandboxSession: + """Delete a session and release sandbox resources.""" + + @abc.abstractmethod + async def resume( + self, + state: SandboxSessionState, + *, + codex: bool | CodexConfig = False, + ) -> SandboxSession: + """Resume a session from a previously persisted `SandboxSessionState`. + + The returned session should hydrate its workspace from `state.snapshot` + during `SandboxSession.start()`. + """ + + def serialize_session_state(self, state: SandboxSessionState) -> dict[str, object]: + """Serialize backend-specific sandbox state into a JSON-compatible payload.""" + return state.model_dump(mode="json") + + @abc.abstractmethod + def deserialize_session_state(self, payload: dict[str, object]) -> SandboxSessionState: + """Deserialize backend-specific sandbox state from a JSON-compatible payload.""" diff --git a/src/agents/sandbox/session/sandbox_session.py b/src/agents/sandbox/session/sandbox_session.py new file mode 100644 index 0000000000..4e30967242 --- /dev/null +++ b/src/agents/sandbox/session/sandbox_session.py @@ -0,0 +1,398 @@ +from __future__ import annotations + +import io +import time +import uuid +from collections.abc import Coroutine +from contextvars import Token +from functools import wraps +from pathlib import Path +from typing import Callable, TypeVar, cast + +from ..errors import OpName, UniversalComputerError +from ..types import ExecResult, User +from .base_sandbox_session import BaseSandboxSession +from .dependencies import Dependencies +from .events import UCFinishEvent, UCStartEvent +from .manager import Instrumentation +from .sandbox_session_state import SandboxSessionState +from .sinks import ChainedSink, SandboxSessionBoundSink +from .utils import ( + _best_effort_stream_len, + current_span_id, +) + +T = TypeVar("T") +F = TypeVar("F", bound=Callable[..., Coroutine[object, object, object]]) + + +def instrumented_op( + op: OpName, + *, + data: Callable[..., dict[str, object] | None] | None = None, + finish_data: ( + Callable[[dict[str, object] | None, object], dict[str, object] | None] | None + ) = None, + ok: Callable[[object], bool] | None = None, + outputs: Callable[[object], tuple[bytes | None, bytes | None]] | None = None, +) -> Callable[[F], F]: + """Decorator to emit UCEvents around a SandboxSession operation.""" + + def _decorator(fn: F) -> F: + @wraps(fn) + async def _wrapped(self: SandboxSession, *args: object, **kwargs: object) -> object: + start_data = data(self, *args, **kwargs) if data is not None else None + finish_cb: Callable[[object], dict[str, object]] | None + if finish_data is None: + finish_cb = None + else: + fd = finish_data + + def _finish_cb(res: object) -> dict[str, object]: + return dict(fd(start_data, res) or {}) + + finish_cb = _finish_cb + + return await self._annotate( + op=op, + start_data=start_data, + run=lambda: fn(self, *args, **kwargs), + finish_data=finish_cb, + ok=ok, + outputs=outputs, + ) + + return cast(F, _wrapped) + + return _decorator + + +def _exec_start_data( + _self: SandboxSession, + *command: str | Path, + timeout: float | None = None, + shell: bool | list[str] = True, + user: str | User | None = None, +) -> dict[str, object]: + user_value: str | None + if isinstance(user, User): + user_value = user.name + else: + user_value = user + return { + "command": [str(c) for c in command], + "timeout_s": timeout, + "shell": shell, + "user": user_value, + } + + +def _exec_finish_data(start_data: dict[str, object] | None, result: object) -> dict[str, object]: + out = dict(start_data or {}) + out["exit_code"] = cast(ExecResult, result).exit_code + return out + + +def _write_start_data(self: SandboxSession, path: Path, data: io.IOBase) -> dict[str, object]: + out: dict[str, object] = {"path": str(path)} + n = _best_effort_stream_len(data) + if n is not None: + out["bytes"] = n + return out + + +def _running_finish_data( + _start_data: dict[str, object] | None, + result: object, +) -> dict[str, object]: + return {"alive": bool(result)} + + +def _snapshot_tar_path(self: SandboxSession) -> str | None: + """ + Best-effort path to the persisted workspace tar on the *host*. + + Today Snapshot is a LocalSnapshot whose persist() writes `/.tar`. + We keep this best-effort (instead of importing LocalSnapshot) to avoid coupling. + """ + + snap = getattr(self.state, "snapshot", None) + base_path = getattr(snap, "base_path", None) + snap_id = getattr(snap, "id", None) + if isinstance(base_path, Path) and isinstance(snap_id, str) and snap_id: + return str(Path(str(base_path / snap_id) + ".tar")) + return None + + +def _persist_start_data(self: SandboxSession) -> dict[str, object]: + out: dict[str, object] = {"workspace_root": str(self.state.manifest.root)} + tar_path = _snapshot_tar_path(self) + if tar_path is not None: + out["tar_path"] = tar_path + return out + + +def _persist_finish_data( + start_data: dict[str, object] | None, + result: object, +) -> dict[str, object]: + out = dict(start_data or {}) + n = _best_effort_stream_len(cast(io.IOBase, result)) + if n is not None: + out["bytes"] = n + return out + + +def _hydrate_start_data(self: SandboxSession, data: io.IOBase) -> dict[str, object]: + out: dict[str, object] = {"untar_dir": str(self.state.manifest.root)} + n = _best_effort_stream_len(data) + if n is not None: + out["bytes"] = n + return out + + +class SandboxSession(BaseSandboxSession): + """A SandboxSession wrapper that emits UCEvent objects around core operations.""" + + _inner: BaseSandboxSession + _instrumentation: Instrumentation + _seq: int + + def __init__( + self, + inner: BaseSandboxSession, + *, + instrumentation: Instrumentation | None = None, + dependencies: Dependencies | None = None, + ) -> None: + self._inner = inner + self._inner.set_dependencies(dependencies) + self._instrumentation = instrumentation or Instrumentation() + self._seq = 0 + + self._bind_session_to_sinks() + + def _bind_session_to_sinks(self) -> None: + # Bind sinks to the *inner* session to avoid recursive instrumentation loops. + for sink in self._instrumentation.sinks: + sinks: list[object] + if isinstance(sink, ChainedSink): + sinks = list(sink.sinks) + else: + sinks = [sink] + for s in sinks: + if isinstance(s, SandboxSessionBoundSink): + s.bind(self._inner) + + @property + def state(self) -> SandboxSessionState: + return self._inner.state + + @state.setter + def state(self, value: SandboxSessionState) -> None: # pragma: no cover + self._inner.state = value + + @property + def dependencies(self) -> Dependencies: + return self._inner.dependencies + + async def _aclose_dependencies(self) -> None: + await self._inner._aclose_dependencies() + + async def aclose(self) -> None: + try: + await super().aclose() + finally: + await self._instrumentation.flush() + + def _next_seq(self) -> int: + self._seq += 1 + return self._seq + + async def _emit_start_event( + self, + *, + op: OpName, + span_id: uuid.UUID, + parent_span_id: uuid.UUID | None, + data: dict[str, object] | None = None, + ) -> None: + await self._instrumentation.emit( + UCStartEvent( + session_id=self.state.session_id, + seq=self._next_seq(), + op=op, + span_id=span_id, + parent_span_id=parent_span_id, + data=data or {}, + ) + ) + + async def _annotate( + self, + *, + op: OpName, + start_data: dict[str, object] | None, + run: Callable[[], Coroutine[object, object, T]], + finish_data: Callable[[T], dict[str, object]] | None = None, + ok: Callable[[T], bool] | None = None, + outputs: Callable[[T], tuple[bytes | None, bytes | None]] | None = None, + ) -> T: + span_id = uuid.uuid4() + parent = current_span_id.get() + token = current_span_id.set(span_id) + + try: + await self._emit_start_event( + op=op, span_id=span_id, parent_span_id=parent, data=start_data + ) + except Exception: + current_span_id.reset(token) + raise + + t0 = time.monotonic() + try: + value = await run() + except Exception as e: + await self._emit_finish_event( + op=op, + span_id=span_id, + parent_span_id=parent, + start_t=t0, + token=token, + ok=False, + exc=e, + data=start_data, + stdout=None, + stderr=None, + ) + raise + + data_finish = finish_data(value) if finish_data is not None else start_data + ok_value = ok(value) if ok is not None else True + stdout, stderr = outputs(value) if outputs is not None else (None, None) + await self._emit_finish_event( + op=op, + span_id=span_id, + parent_span_id=parent, + start_t=t0, + token=token, + ok=ok_value, + exc=None, + data=data_finish, + stdout=stdout, + stderr=stderr, + ) + return value + + async def _emit_finish_event( + self, + *, + op: OpName, + span_id: uuid.UUID, + parent_span_id: uuid.UUID | None, + start_t: float, + token: Token[uuid.UUID | None], + ok: bool, + exc: BaseException | None, + data: dict[str, object] | None, + stdout: bytes | None, + stderr: bytes | None, + ) -> None: + duration_ms = (time.monotonic() - start_t) * 1000.0 + event = UCFinishEvent( + session_id=self.state.session_id, + seq=self._next_seq(), + op=op, + span_id=span_id, + parent_span_id=parent_span_id, + data=data or {}, + ok=ok, + duration_ms=duration_ms, + ) + + if exc is not None: + event.error_type = type(exc).__name__ + event.error_message = str(exc) + if isinstance(exc, UniversalComputerError): + event.error_code = exc.error_code + + # Preserve raw bytes so Instrumentation can apply per-op/per-sink policies later. + # Decoding here would force one global formatting decision before sink-specific redaction + # and truncation rules have a chance to run. + event.stdout_bytes = stdout + event.stderr_bytes = stderr + + try: + await self._instrumentation.emit(event) + finally: + current_span_id.reset(token) + + @instrumented_op("start") + async def start(self) -> None: + await self._inner.start() + + @instrumented_op("stop") + async def stop(self) -> None: + await self._inner.stop() + + @instrumented_op("shutdown") + async def shutdown(self) -> None: + await self._inner.shutdown() + + @instrumented_op( + "exec", + data=_exec_start_data, + finish_data=_exec_finish_data, + ok=lambda result: cast(ExecResult, result).ok(), + outputs=lambda result: ( + cast(ExecResult, result).stdout, + cast(ExecResult, result).stderr, + ), + ) + async def exec( + self, + *command: str | Path, + timeout: float | None = None, + shell: bool | list[str] = True, + user: str | User | None = None, + ) -> ExecResult: + return await self._inner.exec(*command, timeout=timeout, shell=shell, user=user) + + async def _exec_internal( + self, + *command: str | Path, + timeout: float | None = None, + ) -> ExecResult: + raise NotImplementedError("this should never be invoked") + + @instrumented_op("read", data=lambda _self, path: {"path": str(path)}) + async def read(self, path: Path) -> io.IOBase: + return await self._inner.read(path) + + @instrumented_op("write", data=_write_start_data) + async def write(self, path: Path, data: io.IOBase) -> None: + await self._inner.write(path, data) + + @instrumented_op( + "running", + finish_data=_running_finish_data, + ok=lambda _alive: True, + ) + async def running(self) -> bool: + return await self._inner.running() + + @instrumented_op( + "persist_workspace", + data=_persist_start_data, + finish_data=_persist_finish_data, + ) + async def persist_workspace(self) -> io.IOBase: + return await self._inner.persist_workspace() + + @instrumented_op( + "hydrate_workspace", + data=_hydrate_start_data, + ) + async def hydrate_workspace(self, data: io.IOBase) -> None: + await self._inner.hydrate_workspace(data) diff --git a/src/agents/sandbox/session/sandbox_session_state.py b/src/agents/sandbox/session/sandbox_session_state.py new file mode 100644 index 0000000000..97bd9929a6 --- /dev/null +++ b/src/agents/sandbox/session/sandbox_session_state.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import uuid + +from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator + +from ..manifest import Manifest +from ..snapshot import SnapshotBase + + +class SandboxSessionState(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + session_id: uuid.UUID = Field(default_factory=uuid.uuid4) + snapshot: SnapshotBase + manifest: Manifest + + @field_validator("snapshot", mode="before") + @classmethod + def _coerce_snapshot(cls, value: object) -> SnapshotBase: + return SnapshotBase.parse(value) + + @field_serializer("snapshot", when_used="json") + def _serialize_snapshot(self, snapshot: SnapshotBase) -> object: + # Ensure subclass fields (e.g. LocalSnapshot.base_path) are preserved in JSON. + return snapshot.model_dump(mode="json") diff --git a/src/agents/sandbox/session/sinks.py b/src/agents/sandbox/session/sinks.py new file mode 100644 index 0000000000..f586e7068f --- /dev/null +++ b/src/agents/sandbox/session/sinks.py @@ -0,0 +1,350 @@ +from __future__ import annotations + +import abc +import asyncio +import io +import logging +from pathlib import Path +from types import ModuleType +from typing import Callable, Literal, Protocol, runtime_checkable +from urllib.error import HTTPError, URLError +from urllib.request import Request, urlopen + +from ..errors import WorkspaceReadNotFoundError +from .base_sandbox_session import BaseSandboxSession +from .events import EventPayloadPolicy, UCEvent +from .utils import event_to_json_line + +logger = logging.getLogger(__name__) + +DeliveryMode = Literal["sync", "async", "best_effort"] +OnErrorPolicy = Literal["raise", "log", "ignore"] + + +def _unwrap_session_wrapper(session: BaseSandboxSession) -> BaseSandboxSession: + """ + Defensive unwrapping: if a sink is accidentally bound to a SandboxSession wrapper, + unwrap to the underlying session to avoid recursive event loops. + """ + + # Avoid importing session.sandbox_session.SandboxSession here + # (would create a dependency cycle). + cls = type(session) + if not ( + cls.__name__ == "SandboxSession" + and cls.__module__ == "agents.sandbox.session.sandbox_session" + ): + return session + inner = getattr(session, "_inner", None) + return inner if isinstance(inner, BaseSandboxSession) else session + + +class EventSink(abc.ABC): + """Consumes UCEvent objects (e.g., callback, file outbox, proxy HTTP).""" + + name: str | None = None + mode: DeliveryMode + on_error: OnErrorPolicy + payload_policy: EventPayloadPolicy | None + + @abc.abstractmethod + async def handle(self, event: UCEvent) -> None: ... + + +@runtime_checkable +class SandboxSessionBoundSink(Protocol): + """Optional interface for sinks that need access to the underlying SandboxSession.""" + + def bind(self, session: BaseSandboxSession) -> None: ... + + +class CallbackSink(EventSink): + """Deliver events to a user-provided callable. + + Supports sync or async callables. + """ + + def __init__( + self, + callback: Callable[[UCEvent, BaseSandboxSession], object], + *, + mode: DeliveryMode = "sync", + on_error: OnErrorPolicy = "raise", + payload_policy: EventPayloadPolicy | None = None, + name: str | None = None, + ) -> None: + self._callback = callback + self.mode = mode + self.on_error = on_error + self.payload_policy = payload_policy + self._session: BaseSandboxSession | None = None + self.name = name + + def bind(self, session: BaseSandboxSession) -> None: + self._session = _unwrap_session_wrapper(session) + + async def handle(self, event: UCEvent) -> None: + if self._session is None: + raise RuntimeError( + "CallbackSink requires a bound session; use SandboxSession / " + "a sandbox client with instrumentation (or call bind(session))." + ) + out = self._callback(event, self._session) + if asyncio.iscoroutine(out): + await out + + +class JsonlOutboxSink(EventSink): + """Append events to a JSONL file on the host filesystem.""" + + def __init__( + self, + path: Path, + *, + mode: DeliveryMode = "best_effort", + on_error: OnErrorPolicy = "log", + payload_policy: EventPayloadPolicy | None = None, + ) -> None: + self.path = path + self.mode = mode + self.on_error = on_error + self.payload_policy = payload_policy + + async def handle(self, event: UCEvent) -> None: + line = event_to_json_line(event) + await asyncio.to_thread(self._append_line, line) + + def _append_line(self, line: str) -> None: + self.path.parent.mkdir(parents=True, exist_ok=True) + fcntl_mod: ModuleType | None + try: + import fcntl as fcntl_mod + except Exception: + # Not available on all platforms (e.g. Windows) + fcntl_mod = None + + with self.path.open("a", encoding="utf-8") as f: + if fcntl_mod is not None: + try: + fcntl_mod.flock(f.fileno(), fcntl_mod.LOCK_EX) + except Exception: + pass + f.write(line) + f.flush() + if fcntl_mod is not None: + try: + # Nice to have release here; the OS releases the lock + # automatically when the file is closed. + fcntl_mod.flock(f.fileno(), fcntl_mod.LOCK_UN) + except Exception: + pass + + +class WorkspaceJsonlSink(EventSink): + """ + Append events to a JSONL file inside the session workspace (under manifest.root). + + This sink still runs in the client process, but writes into the session via + `SandboxSession.write()`, so it works across sandboxes (Docker/Modal) + without requiring host-mounted volumes. + """ + + def __init__( + self, + *, + workspace_relpath: Path = Path("logs/events-{session_id}.jsonl"), + ephemeral: bool = False, + mode: DeliveryMode = "best_effort", + on_error: OnErrorPolicy = "log", + payload_policy: EventPayloadPolicy | None = None, + flush_every: int = 1, + ) -> None: + """ + Args: + workspace_relpath: Relative path under the session workspace root. + This also supports lightweight templating which is expanded on `bind()`: + - `"{session_id}"` (UUID string, e.g. "550e8400-e29b-41d4-a716-446655440000") + - `"{session_id_hex}"` (UUID hex, e.g. "550e8400e29b41d4a716446655440000") + + Example: + Path("logs/events-{session_id}.jsonl") + """ + self.workspace_relpath = workspace_relpath + self.ephemeral = ephemeral + self.mode = mode + self.on_error = on_error + self.payload_policy = payload_policy + self._session: BaseSandboxSession | None = None + self._resolved_workspace_relpath: Path | None = None + self._buf = bytearray() + self._seen = 0 + self._lock = asyncio.Lock() + self._flush_every = max(1, int(flush_every)) + self._existing_outbox_loaded = False + + def _resolve_relpath(self) -> Path: + rel = self.workspace_relpath + if self._session is None: + return rel + template = str(rel) + try: + rendered = template.format( + session_id=self._session.state.session_id, + session_id_hex=self._session.state.session_id.hex, + ) + except Exception: + # If formatting fails for any reason, fall back to the literal path. + rendered = template + return Path(rendered) + + def bind(self, session: BaseSandboxSession) -> None: + self._session = _unwrap_session_wrapper(session) + self._resolved_workspace_relpath = self._resolve_relpath() + if self.ephemeral: + relpath = self._resolved_workspace_relpath or self.workspace_relpath + self._session._register_persist_workspace_skip_relpath(relpath) + + def _buffer_event(self, event: UCEvent) -> bool: + self._buf.extend(event_to_json_line(event).encode("utf-8")) + self._seen += 1 + + if self._seen % self._flush_every == 0: + return True + if event.op == "persist_workspace" and event.phase == "start": + return True + if event.op == "stop": + return True + if event.op == "shutdown" and event.phase == "start": + return True + if event.op == "shutdown" and event.phase == "finish": + return False + + return False + + async def _can_flush_to_workspace(self) -> bool: + if self._session is None: + return False + + # `SandboxSession.start()` emits the `start` event before the underlying sandbox + # is fully running, so writes may still fail during early startup or late teardown. + try: + return await self._session.running() + except Exception: + return False + + async def _flush_buffer(self) -> None: + if self._session is None: + return + + await self._ensure_existing_outbox_loaded() + relpath = self._resolved_workspace_relpath or self.workspace_relpath + await self._session.write(relpath, io.BytesIO(bytes(self._buf))) + + async def _ensure_existing_outbox_loaded(self) -> None: + if self._session is None or self._existing_outbox_loaded: + return + + relpath = self._resolved_workspace_relpath or self.workspace_relpath + try: + existing = await self._session.read(relpath) + except (FileNotFoundError, WorkspaceReadNotFoundError): + self._existing_outbox_loaded = True + return + + try: + payload = existing.read() + finally: + existing.close() + + if isinstance(payload, str): + payload = payload.encode("utf-8") + if payload: + self._buf = bytearray(payload) + self._buf + self._existing_outbox_loaded = True + + async def handle(self, event: UCEvent) -> None: + # If unbound (e.g., Instrumentation.emit used without a SandboxSession wrapper), + # no-op. + if self._session is None: + return + + async with self._lock: + if not self._buffer_event(event): + return + + if not await self._can_flush_to_workspace(): + return + + await self._flush_buffer() + + +class HttpProxySink(EventSink): + """POST events as JSON to a proxy endpoint (local daemon or remote service).""" + + def __init__( + self, + endpoint: str, + *, + headers: dict[str, str] | None = None, + timeout_s: float = 5.0, + spool_path: Path | None = None, + mode: DeliveryMode = "best_effort", + on_error: OnErrorPolicy = "log", + payload_policy: EventPayloadPolicy | None = None, + ) -> None: + self.endpoint = endpoint + self.headers = headers or {} + self.timeout_s = timeout_s + self.spool_path = spool_path + self.mode = mode + self.on_error = on_error + self.payload_policy = payload_policy + + async def handle(self, event: UCEvent) -> None: + payload = event.model_dump_json().encode("utf-8") + spool_line = event_to_json_line(event) if self.spool_path is not None else None + await asyncio.to_thread(self._post, payload, spool_line) + + def _post(self, body: bytes, spool_line: str | None) -> None: + # TODO: thinking about using proxy instead of direct http call + req = Request( + self.endpoint, + data=body, + headers={"content-type": "application/json", **self.headers}, + method="POST", + ) + try: + with urlopen(req, timeout=self.timeout_s) as resp: + _ = resp.read(1) # ensure request completes + except (HTTPError, URLError) as e: + if spool_line is not None and self.spool_path is not None: + try: + self.spool_path.parent.mkdir(parents=True, exist_ok=True) + with self.spool_path.open("a", encoding="utf-8") as f: + f.write(spool_line) + f.flush() + except Exception: + pass + raise RuntimeError(f"http proxy sink POST failed: {e}") from e + + +class ChainedSink(EventSink): + """ + Groups multiple sinks that should run in order. + + Note: Instrumentation unwraps this group and applies per-op/per-sink payload policies to each + inner sink individually (so grouping does not disable per-sink policy behavior). + """ + + def __init__(self, *sinks: EventSink) -> None: + self.sinks = list(sinks) + # These are not used directly when Instrumentation unwraps the group, but keep the object + # conforming to EventSink. + self.mode = "sync" + self.on_error = "raise" + self.payload_policy = None + + async def handle(self, event: UCEvent) -> None: + # Fallback behavior if used directly (without Instrumentation unwrapping). + for sink in self.sinks: + await sink.handle(event) diff --git a/src/agents/sandbox/session/utils.py b/src/agents/sandbox/session/utils.py new file mode 100644 index 0000000000..c9f91e488d --- /dev/null +++ b/src/agents/sandbox/session/utils.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import io +import json +import uuid +from contextvars import ContextVar + +from .events import UCEvent + + +def _safe_decode(b: bytes, *, max_chars: int) -> str: + # Decode bytes as UTF-8 with replacement to keep event JSON valid. + # Truncation is on decoded string length, not raw bytes. + s = b.decode("utf-8", errors="replace") + if len(s) > max_chars: + return s[:max_chars] + "…" + return s + + +def _best_effort_stream_len(stream: io.IOBase) -> int | None: + # Avoid consuming the stream. This only works for seekable streams. + try: + pos = stream.tell() + stream.seek(0, io.SEEK_END) + end = stream.tell() + stream.seek(pos, io.SEEK_SET) + return int(end - pos) + except Exception: + return None + + +def event_to_json_line(event: UCEvent) -> str: + payload = event.model_dump(mode="json") + return json.dumps(payload, separators=(",", ":"), sort_keys=True) + "\n" + + +current_span_id: ContextVar[uuid.UUID | None] = ContextVar("uc_current_span_id", default=None) diff --git a/src/agents/sandbox/session/workspace_payloads.py b/src/agents/sandbox/session/workspace_payloads.py new file mode 100644 index 0000000000..5141707861 --- /dev/null +++ b/src/agents/sandbox/session/workspace_payloads.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import io +from dataclasses import dataclass +from pathlib import Path + +from ..errors import WorkspaceWriteTypeError + + +@dataclass(frozen=True) +class WritePayload: + stream: io.IOBase + content_length: int | None = None + + +class _BinaryReadAdapter(io.IOBase): + def __init__(self, *, path: Path, stream: io.IOBase) -> None: + self._path = path + self._stream = stream + + def readable(self) -> bool: + return True + + def read(self, size: int = -1) -> bytes: + chunk = self._stream.read(size) + if chunk is None: + return b"" + if isinstance(chunk, bytes): + return chunk + if isinstance(chunk, bytearray): + return bytes(chunk) + raise WorkspaceWriteTypeError(path=self._path, actual_type=type(chunk).__name__) + + def readinto(self, b: bytearray) -> int: + data = self.read(len(b)) + n = len(data) + b[:n] = data + return n + + def seek(self, offset: int, whence: int = io.SEEK_SET) -> int: + return int(self._stream.seek(offset, whence)) + + def tell(self) -> int: + return int(self._stream.tell()) + + +def coerce_write_payload(*, path: Path, data: io.IOBase) -> WritePayload: + stream = _BinaryReadAdapter(path=path, stream=data) + return WritePayload(stream=stream, content_length=_best_effort_content_length(data)) + + +def _best_effort_content_length(stream: io.IOBase) -> int | None: + for attr in ("content_length", "length"): + value = getattr(stream, attr, None) + if isinstance(value, int) and value >= 0: + return value + + headers = getattr(stream, "headers", None) + if headers is not None: + content_length = None + get = getattr(headers, "get", None) + if callable(get): + content_length = get("Content-Length") + if isinstance(content_length, str): + try: + parsed = int(content_length) + except ValueError: + parsed = None + if parsed is not None and parsed >= 0: + return parsed + + try: + pos = stream.tell() + stream.seek(0, io.SEEK_END) + end = stream.tell() + stream.seek(pos, io.SEEK_SET) + return int(end - pos) + except Exception: + return None diff --git a/src/agents/sandbox/snapshot.py b/src/agents/sandbox/snapshot.py new file mode 100644 index 0000000000..e7f89c30f6 --- /dev/null +++ b/src/agents/sandbox/snapshot.py @@ -0,0 +1,138 @@ +import abc +import io +import shutil +from pathlib import Path +from typing import Annotated, ClassVar, Literal + +from pydantic import BaseModel, Field, PrivateAttr + +from .errors import ( + SnapshotNotRestorableError, + SnapshotPersistError, + SnapshotRestoreError, +) + +SnapshotClass = type["SnapshotBase"] + + +class SnapshotBase(BaseModel, abc.ABC): + type: str + id: str + _subclass_registry: ClassVar[dict[str, SnapshotClass]] = {} + + @classmethod + def __pydantic_init_subclass__(cls, **kwargs: object) -> None: + super().__pydantic_init_subclass__(**kwargs) + + type_field = cls.model_fields.get("type") + type_default = type_field.default if type_field is not None else None + if not isinstance(type_default, str) or type_default == "": + raise TypeError(f"{cls.__name__} must define a non-empty string default for `type`") + + existing = SnapshotBase._subclass_registry.get(type_default) + if existing is not None and existing is not cls: + raise TypeError( + f"snapshot type `{type_default}` is already registered by {existing.__name__}" + ) + SnapshotBase._subclass_registry[type_default] = cls + + @classmethod + def parse(cls, payload: object) -> "SnapshotBase": + if isinstance(payload, SnapshotBase): + return payload + + if isinstance(payload, dict): + snapshot_type = payload.get("type") + if isinstance(snapshot_type, str): + snapshot_class = cls._snapshot_class_for_type(snapshot_type) + if snapshot_class is not None: + return snapshot_class.model_validate(payload) + + raise ValueError(f"unknown snapshot type `{snapshot_type}`") + + raise TypeError("snapshot payload must be a SnapshotBase or object payload") + + @classmethod + def _snapshot_class_for_type(cls, snapshot_type: str) -> SnapshotClass | None: + return SnapshotBase._subclass_registry.get(snapshot_type) + + @abc.abstractmethod + async def persist(self, data: io.IOBase) -> None: ... + + @abc.abstractmethod + async def restore(self) -> io.IOBase: ... + + @abc.abstractmethod + async def restorable(self) -> bool: ... + + +class LocalSnapshot(SnapshotBase): + type: Literal["local"] = "local" + + base_path: Path + _checksum: str | None = PrivateAttr(default=None) + + async def persist(self, data: io.IOBase) -> None: + path = self._path() + try: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("wb") as f: + shutil.copyfileobj(data, f) + except OSError as e: + raise SnapshotPersistError(snapshot_id=self.id, path=path, cause=e) from e + + async def restore(self) -> io.IOBase: + path = self._path() + try: + return path.open("rb") + except OSError as e: + raise SnapshotRestoreError(snapshot_id=self.id, path=path, cause=e) from e + + async def restorable(self) -> bool: + return self._path().exists() + + def _path(self) -> Path: + return Path(str(self.base_path / self.id) + ".tar") + + +class NoopSnapshot(SnapshotBase): + type: Literal["noop"] = "noop" + + async def persist(self, data: io.IOBase) -> None: + _ = data + return + + async def restore(self) -> io.IOBase: + raise SnapshotNotRestorableError(snapshot_id=self.id, path=Path("")) + + async def restorable(self) -> bool: + return False + + +class SnapshotSpec(BaseModel, abc.ABC): + type: str + + @abc.abstractmethod + def build(self, snapshot_id: str) -> SnapshotBase: ... + + +class LocalSnapshotSpec(SnapshotSpec): + type: Literal["local"] = "local" + base_path: Path + + def build(self, snapshot_id: str) -> SnapshotBase: + return LocalSnapshot(id=snapshot_id, base_path=self.base_path) + + +class NoopSnapshotSpec(SnapshotSpec): + type: Literal["noop"] = "noop" + + def build(self, snapshot_id: str) -> SnapshotBase: + return NoopSnapshot(id=snapshot_id) + + +SnapshotSpecUnion = Annotated[LocalSnapshotSpec | NoopSnapshotSpec, Field(discriminator="type")] + + +def resolve_snapshot(spec: SnapshotSpec | None, snapshot_id: str) -> SnapshotBase: + return (spec or NoopSnapshotSpec()).build(snapshot_id) diff --git a/src/agents/sandbox/snapshot_defaults.py b/src/agents/sandbox/snapshot_defaults.py new file mode 100644 index 0000000000..afe7b9c8a8 --- /dev/null +++ b/src/agents/sandbox/snapshot_defaults.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import os +import sys +import time +from collections.abc import Mapping +from pathlib import Path + +from .snapshot import LocalSnapshotSpec + +_DEFAULT_LOCAL_SNAPSHOT_TTL_SECONDS = 60 * 60 * 24 * 30 +_DEFAULT_LOCAL_SNAPSHOT_SUBDIR = Path("openai-agents-python") / "sandbox" / "snapshots" + + +def default_local_snapshot_base_dir( + *, + home: Path | None = None, + env: Mapping[str, str] | None = None, + platform: str | None = None, + os_name: str | None = None, +) -> Path: + resolved_home = home or Path.home() + resolved_env = env or os.environ + resolved_platform = platform or sys.platform + resolved_os_name = os_name or os.name + + if resolved_platform == "darwin": + base = resolved_home / "Library" / "Application Support" + elif resolved_os_name == "nt": + local_app_data = resolved_env.get("LOCALAPPDATA") or resolved_env.get("APPDATA") + base = Path(local_app_data) if local_app_data else resolved_home / "AppData" / "Local" + else: + xdg_state_home = resolved_env.get("XDG_STATE_HOME") + base = Path(xdg_state_home) if xdg_state_home else resolved_home / ".local" / "state" + + return base / _DEFAULT_LOCAL_SNAPSHOT_SUBDIR + + +def cleanup_stale_default_local_snapshots( + base_path: Path, + *, + now: float | None = None, + max_age_seconds: int = _DEFAULT_LOCAL_SNAPSHOT_TTL_SECONDS, +) -> None: + # This is intentionally limited to stale files in the SDK-managed default directory. + # We do not delete snapshots during normal session teardown because pause/resume may still + # need them. If we add explicit artifact cleanup later, it should be a separate opt-in path + # that can also account for backend-specific remote artifacts. + if max_age_seconds < 0 or not base_path.exists(): + return + + cutoff = (time.time() if now is None else now) - max_age_seconds + try: + candidates = list(base_path.glob("*.tar")) + except OSError: + return + + for candidate in candidates: + try: + if not candidate.is_file(): + continue + if candidate.stat().st_mtime >= cutoff: + continue + candidate.unlink(missing_ok=True) + except OSError: + continue + + +def resolve_default_local_snapshot_spec( + *, + home: Path | None = None, + env: Mapping[str, str] | None = None, + platform: str | None = None, + os_name: str | None = None, + now: float | None = None, +) -> LocalSnapshotSpec: + base_path = default_local_snapshot_base_dir( + home=home, + env=env, + platform=platform, + os_name=os_name, + ) + base_path.mkdir(parents=True, exist_ok=True, mode=0o700) + if (os_name or os.name) != "nt": + try: + base_path.chmod(0o700) + except OSError: + pass + return LocalSnapshotSpec(base_path=base_path) diff --git a/src/agents/sandbox/types.py b/src/agents/sandbox/types.py new file mode 100644 index 0000000000..2c3ad15d04 --- /dev/null +++ b/src/agents/sandbox/types.py @@ -0,0 +1,147 @@ +import stat +from enum import IntEnum + +from pydantic import BaseModel, Field +from typing_extensions import Self + + +class User(BaseModel): + name: str + + def __hash__(self) -> int: + return hash(self.name) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, User): + return NotImplemented + return self.name == other.name + + +class Group(BaseModel): + name: str + users: list[User] + + def __hash__(self) -> int: + return hash(self.name) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Group): + return NotImplemented + return self.name == other.name + + +class Permissions(BaseModel): + owner: int = Field(default=0o7) + group: int = Field(default=0) + other: int = Field(default=0) + directory: bool = Field(default=False) + + def to_mode(self) -> int: + mode = 0 + for perms, shift in [(self.owner, 6), (self.group, 3), (self.other, 0)]: + mode |= int(perms) << shift + if self.directory: + mode |= stat.S_IFDIR + return mode + + @classmethod + def from_mode(cls, mode: int) -> "Permissions": + return cls( + owner=(mode >> 6) & 0b111, + group=(mode >> 3) & 0b111, + other=(mode >> 0) & 0b111, + directory=bool(mode & stat.S_IFDIR), + ) + + @classmethod + def from_str(cls, perms: str) -> "Permissions": + if len(perms) == 11 and perms[-1] in {"@", "+"}: + perms = perms[:-1] + if len(perms) != 10: + raise ValueError(f"invalid permissions string length: {perms!r}") + + directory = perms[0] == "d" + if perms[0] not in {"d", "-"}: + raise ValueError(f"invalid permissions type: {perms!r}") + + def parse_triplet(triplet: str) -> int: + if len(triplet) != 3: + raise ValueError(f"invalid permissions triplet: {triplet!r}") + mask = 0 + if triplet[0] == "r": + mask |= FileMode.READ + elif triplet[0] != "-": + raise ValueError(f"invalid read flag: {triplet!r}") + if triplet[1] == "w": + mask |= FileMode.WRITE + elif triplet[1] != "-": + raise ValueError(f"invalid write flag: {triplet!r}") + if triplet[2] == "x": + mask |= FileMode.EXEC + elif triplet[2] != "-": + raise ValueError(f"invalid exec flag: {triplet!r}") + return int(mask) + + owner = parse_triplet(perms[1:4]) + group = parse_triplet(perms[4:7]) + other = parse_triplet(perms[7:10]) + return cls( + owner=owner, + group=group, + other=other, + directory=directory, + ) + + def owner_can(self, mode: int) -> Self: + self.owner = mode + return self + + def group_can(self, mode: int) -> Self: + self.group = mode + return self + + def others_can(self, mode: int) -> Self: + self.other = mode + return self + + def __repr__(self) -> str: + def fmt(perms: int) -> str: + return "".join( + c if perms & p else "-" + for p, c in [(FileMode.READ, "r"), (FileMode.WRITE, "w"), (FileMode.EXEC, "x")] + ) + + return ("d" if self.directory else "-") + "".join( + fmt(perms) for perms in (self.owner, self.group, self.other) + ) + + def __str__(self) -> str: + return repr(self) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Permissions): + return NotImplemented + return self.to_mode() == other.to_mode() + + +class FileMode(IntEnum): + ALL = 0o7 + NONE = 0 + + READ = 1 << 2 + WRITE = 1 << 1 + EXEC = 1 + + +class ExecResult: + stdout: bytes + stderr: bytes + exit_code: int + + def __init__(self, *, stdout: bytes, stderr: bytes, exit_code: int) -> None: + self.stdout = stdout + self.stderr = stderr + self.exit_code = exit_code + + def ok(self) -> bool: + return self.exit_code == 0 diff --git a/src/agents/sandbox/util/__init__.py b/src/agents/sandbox/util/__init__.py new file mode 100644 index 0000000000..13c9850a70 --- /dev/null +++ b/src/agents/sandbox/util/__init__.py @@ -0,0 +1,42 @@ +from .deep_merge import deep_merge +from .github import clone_repo, ensure_git_available +from .parse_utils import parse_ls_la +from .retry import ( + DEFAULT_TRANSIENT_RETRY_BACKOFF, + DEFAULT_TRANSIENT_RETRY_INTERVAL_S, + DEFAULT_TRANSIENT_RETRY_MAX_ATTEMPT, + TRANSIENT_HTTP_STATUS_CODES, + BackoffStrategy, + exception_chain_contains_type, + exception_chain_has_status_code, + iter_exception_chain, + retry_async, +) +from .tar_utils import ( + UnsafeTarMemberError, + safe_extract_tarfile, + safe_tar_member_rel_path, + should_skip_tar_member, + validate_tarfile, +) + +__all__ = [ + "DEFAULT_TRANSIENT_RETRY_BACKOFF", + "DEFAULT_TRANSIENT_RETRY_INTERVAL_S", + "DEFAULT_TRANSIENT_RETRY_MAX_ATTEMPT", + "BackoffStrategy", + "TRANSIENT_HTTP_STATUS_CODES", + "exception_chain_contains_type", + "exception_chain_has_status_code", + "iter_exception_chain", + "retry_async", + "deep_merge", + "clone_repo", + "ensure_git_available", + "parse_ls_la", + "UnsafeTarMemberError", + "safe_extract_tarfile", + "safe_tar_member_rel_path", + "should_skip_tar_member", + "validate_tarfile", +] diff --git a/src/agents/sandbox/util/checksums.py b/src/agents/sandbox/util/checksums.py new file mode 100644 index 0000000000..0b47f3d173 --- /dev/null +++ b/src/agents/sandbox/util/checksums.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +import hashlib +from pathlib import Path + + +def sha256_file(path: Path) -> str: + digest = hashlib.sha256() + with path.open("rb") as handle: + while True: + chunk = handle.read(1024 * 1024) + if not chunk: + break + digest.update(chunk) + return digest.hexdigest() diff --git a/src/agents/sandbox/util/deep_merge.py b/src/agents/sandbox/util/deep_merge.py new file mode 100644 index 0000000000..d8aa96b160 --- /dev/null +++ b/src/agents/sandbox/util/deep_merge.py @@ -0,0 +1,21 @@ +from typing import TypeGuard + + +def _is_string_object_dict(value: object) -> TypeGuard[dict[str, object]]: + return isinstance(value, dict) and all(isinstance(key, str) for key in value) + + +def deep_merge(dict1: dict[str, object], dict2: dict[str, object]) -> dict[str, object]: + """ + Recursively merge dict2 into dict1 and return a new dict. + If both values for a key are dicts, merge them. + Otherwise, dict2's value overwrites dict1's. + """ + result = dict1.copy() + for key, value in dict2.items(): + existing = result.get(key) + if _is_string_object_dict(existing) and _is_string_object_dict(value): + result[key] = deep_merge(existing, value) + else: + result[key] = value + return result diff --git a/src/agents/sandbox/util/github.py b/src/agents/sandbox/util/github.py new file mode 100644 index 0000000000..4a35462158 --- /dev/null +++ b/src/agents/sandbox/util/github.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +import shutil +import subprocess +from pathlib import Path + + +def ensure_git_available() -> None: + if shutil.which("git") is None: + raise RuntimeError("git is required to use github_repo artifacts") + + +def clone_repo(*, repo: str, ref: str, dest: Path) -> None: + """Shallow clone a GitHub repo at a ref (tag/branch/sha).""" + + ensure_git_available() + url = f"https://github.com/{repo}.git" + dest.parent.mkdir(parents=True, exist_ok=True) + + # Use a shallow clone for tags/branches; fall back to a pinned checkout for SHAs. + try: + subprocess.run( + [ + "git", + "clone", + "--depth", + "1", + "--no-tags", + "--branch", + ref, + url, + str(dest), + ], + check=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + return + except subprocess.CalledProcessError: + pass + + subprocess.run( + ["git", "clone", "--no-checkout", url, str(dest)], + check=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + subprocess.run( + ["git", "-C", str(dest), "checkout", ref], + check=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) diff --git a/src/agents/sandbox/util/iterator_io.py b/src/agents/sandbox/util/iterator_io.py new file mode 100644 index 0000000000..cf4119543e --- /dev/null +++ b/src/agents/sandbox/util/iterator_io.py @@ -0,0 +1,69 @@ +import io +from collections.abc import Iterator + + +class IteratorIO(io.IOBase): + def __init__(self, it: Iterator[bytes]): + self._it = it + self._buffer = bytearray() + self._closed = False + + def readable(self) -> bool: + return True + + def read(self, size: int = -1) -> bytes: + if self._closed: + return b"" + + if size < 0: + # Read all remaining data. + chunks: list[bytes] = [] + if self._buffer: + chunks.append(bytes(self._buffer)) + self._buffer.clear() + for chunk in self._it: + if chunk: + chunks.append(chunk) + self._closed = True + return b"".join(chunks) + + if size == 0: + return b"" + + # Fill buffer until we can satisfy the request or iterator is exhausted. + while len(self._buffer) < size and not self._closed: + try: + chunk = next(self._it) + if not chunk: + continue + self._buffer.extend(chunk) + except StopIteration: + self._closed = True + + out = bytes(self._buffer[:size]) + del self._buffer[:size] + return out + + def readinto(self, b: bytearray) -> int: + if self._closed: + return 0 + + # Fill buffer until we have something or iterator is exhausted + while not self._buffer: + try: + chunk = next(self._it) + if not chunk: + continue + self._buffer.extend(chunk) + except StopIteration: + self._closed = True + return 0 + + n = min(len(b), len(self._buffer)) + b[:n] = self._buffer[:n] + del self._buffer[:n] + return n + + def close(self) -> None: + self._closed = True + super().close() diff --git a/src/agents/sandbox/util/parse_utils.py b/src/agents/sandbox/util/parse_utils.py new file mode 100644 index 0000000000..4b1ced206a --- /dev/null +++ b/src/agents/sandbox/util/parse_utils.py @@ -0,0 +1,64 @@ +from ..files import EntryKind, FileEntry +from ..types import Permissions + + +def parse_ls_la(output: str, *, base: str) -> list[FileEntry]: + entries: list[FileEntry] = [] + for raw_line in output.splitlines(): + line = raw_line.strip("\n") + if not line or line.startswith("total"): + continue + + # Typical coreutils format: + # drwxr-xr-x 2 root root 4096 Jan 1 00:00 dirname + # -rw-r--r-- 1 root root 123 Jan 1 00:00 file.txt + # lrwxrwxrwx 1 root root 12 Jan 1 00:00 link -> target + parts = line.split(maxsplit=8) + if len(parts) < 9: + continue + + permissions_str = parts[0] + owner = parts[2] + group = parts[3] + try: + size = int(parts[4]) + except ValueError: + continue + + kind_map: dict[str, EntryKind] = { + "d": EntryKind.DIRECTORY, + "-": EntryKind.FILE, + "l": EntryKind.SYMLINK, + } + kind: EntryKind = kind_map.get(permissions_str[:1], EntryKind.OTHER) + + # Permissions only track rwx bits and directory-ness; for symlink/other entries we + # preserve rwx bits by normalizing the leading type marker to "-". + if permissions_str[:1] not in {"d", "-"} and len(permissions_str) >= 2: + permissions_str = "-" + permissions_str[1:] + + name = parts[8] + if " -> " in name: + name = name.split(" -> ", 1)[0] + + if name in {".", ".."}: + continue + + permissions = Permissions.from_str(permissions_str) + entry_path = ( + name + if name.startswith("/") + else (f"{base.rstrip('/')}/{name}" if base != "/" else f"/{name}") + ) + entries.append( + FileEntry( + path=entry_path, + permissions=permissions, + owner=owner, + group=group, + size=size, + kind=kind, + ) + ) + + return entries diff --git a/src/agents/sandbox/util/retry.py b/src/agents/sandbox/util/retry.py new file mode 100644 index 0000000000..889058bd6d --- /dev/null +++ b/src/agents/sandbox/util/retry.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import asyncio +import functools +import inspect +from collections.abc import Callable, Coroutine, Iterable +from enum import Enum +from typing import ParamSpec, TypeVar, cast + +P = ParamSpec("P") +T = TypeVar("T") + + +class BackoffStrategy(str, Enum): + def __str__(self) -> str: + return str(self.value) + + FIXED = "fixed" + LINEAR = "linear" + EXPONENTIAL = "exponential" + + +DEFAULT_TRANSIENT_RETRY_INTERVAL_S = 0.25 +DEFAULT_TRANSIENT_RETRY_MAX_ATTEMPT = 3 +DEFAULT_TRANSIENT_RETRY_BACKOFF = BackoffStrategy.EXPONENTIAL +TRANSIENT_HTTP_STATUS_CODES: frozenset[int] = frozenset({500, 502, 503, 504}) + + +def iter_exception_chain(exc: BaseException) -> Iterable[BaseException]: + seen: set[int] = set() + current: BaseException | None = exc + while current is not None and id(current) not in seen: + yield current + seen.add(id(current)) + current = cast( + BaseException | None, + getattr(current, "__cause__", None) or getattr(current, "__context__", None), + ) + + +def exception_chain_contains_type( + exc: BaseException, + error_types: tuple[type[BaseException], ...], +) -> bool: + if not error_types: + return False + return any(isinstance(candidate, error_types) for candidate in iter_exception_chain(exc)) + + +def exception_chain_has_status_code( + exc: BaseException, + status_codes: set[int] | frozenset[int], +) -> bool: + for candidate in iter_exception_chain(exc): + for value in ( + getattr(candidate, "status_code", None), + getattr(candidate, "http_code", None), + getattr(getattr(candidate, "response", None), "status_code", None), + ): + if isinstance(value, int) and value in status_codes: + return True + return False + + +def retry_async( + *, + interval: float = DEFAULT_TRANSIENT_RETRY_INTERVAL_S, + max_attempt: int = DEFAULT_TRANSIENT_RETRY_MAX_ATTEMPT, + backoff: BackoffStrategy = DEFAULT_TRANSIENT_RETRY_BACKOFF, + retry_if: Callable[..., bool], + on_retry: Callable[..., object] | None = None, +) -> Callable[ + [Callable[P, Coroutine[object, object, T]]], + Callable[P, Coroutine[object, object, T]], +]: + """Retry an async function when `retry_if` marks the exception as transient. + + `backoff=BackoffStrategy.FIXED` keeps a constant delay equal to `interval`. + `backoff=BackoffStrategy.LINEAR` scales delay as `interval * attempt`. + `backoff=BackoffStrategy.EXPONENTIAL` doubles the delay on each retry attempt. + """ + + if max_attempt < 1: + raise ValueError("max_attempt must be >= 1") + if interval < 0: + raise ValueError("interval must be >= 0") + if backoff not in { + BackoffStrategy.FIXED, + BackoffStrategy.LINEAR, + BackoffStrategy.EXPONENTIAL, + }: + raise ValueError( + "backoff must be BackoffStrategy.FIXED, " + "BackoffStrategy.LINEAR, or BackoffStrategy.EXPONENTIAL" + ) + + def decorator( + fn: Callable[P, Coroutine[object, object, T]], + ) -> Callable[P, Coroutine[object, object, T]]: + @functools.wraps(fn) + async def wrapped(*args: P.args, **kwargs: P.kwargs) -> T: + for attempt in range(1, max_attempt + 1): + try: + return await fn(*args, **kwargs) + except Exception as exc: + if attempt >= max_attempt or not retry_if(exc, *args, **kwargs): + raise + + if backoff is BackoffStrategy.EXPONENTIAL: + delay_s = interval * (2 ** (attempt - 1)) + elif backoff is BackoffStrategy.LINEAR: + delay_s = interval * attempt + else: + delay_s = interval + + if on_retry is not None: + hook_result = on_retry(exc, attempt, max_attempt, delay_s, *args, **kwargs) + if inspect.isawaitable(hook_result): + await hook_result + + await asyncio.sleep(delay_s) + + raise AssertionError("unreachable") + + return cast(Callable[P, Coroutine[object, object, T]], wrapped) + + return decorator diff --git a/src/agents/sandbox/util/tar_utils.py b/src/agents/sandbox/util/tar_utils.py new file mode 100644 index 0000000000..6240ad9186 --- /dev/null +++ b/src/agents/sandbox/util/tar_utils.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +import os +import shutil +import tarfile +from collections.abc import Iterable +from pathlib import Path, PurePosixPath + + +class UnsafeTarMemberError(ValueError): + def __init__(self, *, member: str, reason: str) -> None: + super().__init__(f"unsafe tar member {member!r}: {reason}") + self.member = member + self.reason = reason + + +def safe_tar_member_rel_path(member: tarfile.TarInfo) -> Path | None: + if member.name in ("", ".", "./"): + return None + rel = PurePosixPath(member.name) + if rel.is_absolute(): + raise UnsafeTarMemberError(member=member.name, reason="absolute path") + if ".." in rel.parts: + raise UnsafeTarMemberError(member=member.name, reason="parent traversal") + if member.issym() or member.islnk(): + raise UnsafeTarMemberError(member=member.name, reason="link member not allowed") + if not (member.isdir() or member.isreg()): + raise UnsafeTarMemberError(member=member.name, reason="unsupported member type") + return Path(*rel.parts) + + +def _normalize_rel(prefix: str | Path) -> Path: + rel = prefix if isinstance(prefix, Path) else Path(prefix) + posix = rel.as_posix() + parts = [p for p in Path(posix).parts if p not in ("", ".")] + if parts[:1] == ["/"]: + parts = parts[1:] + return Path(*parts) + + +def _is_within(path: Path, prefix: Path) -> bool: + if prefix == Path(): + return True + if path == prefix: + return True + return path.parts[: len(prefix.parts)] == prefix.parts + + +def should_skip_tar_member( + member_name: str, + *, + skip_rel_paths: Iterable[str | Path], + root_name: str | None, +) -> bool: + """ + Decide whether a tar member should be excluded based on workspace-relative prefixes. + + `member_name` is the raw name from the tar, which may include `.` or the workspace root + directory name depending on how the tar was produced. + """ + + raw_parts = [p for p in Path(member_name).parts if p not in ("", ".")] + if raw_parts[:1] == ["/"]: + raw_parts = raw_parts[1:] + if not raw_parts: + rel_variants = [Path()] + else: + rel_variants = [Path(*raw_parts)] + if root_name and raw_parts and raw_parts[0] == root_name: + rel_variants.append(Path(*raw_parts[1:])) + + prefixes = [_normalize_rel(p) for p in skip_rel_paths] + return any(_is_within(rel, prefix) for rel in rel_variants for prefix in prefixes) + + +def _ensure_no_symlink_parents(*, root: Path, dest: Path) -> None: + """ + Ensure that no existing parent directory in `dest` is a symlink. + + This helps prevent writing outside `root` via pre-existing symlink components. + """ + + root_resolved = root.resolve() + dest_resolved = dest.resolve() + if not (dest_resolved == root_resolved or dest_resolved.is_relative_to(root_resolved)): + raise UnsafeTarMemberError(member=str(dest), reason="path escapes root after resolution") + + rel = dest.relative_to(root) + cur = root + for part in rel.parts[:-1]: + cur = cur / part + if cur.exists() and cur.is_symlink(): + raise UnsafeTarMemberError(member=str(rel.as_posix()), reason="symlink in parent path") + + +def validate_tarfile(tar: tarfile.TarFile) -> None: + for member in tar.getmembers(): + safe_tar_member_rel_path(member) + + +def safe_extract_tarfile(tar: tarfile.TarFile, *, root: Path) -> None: + """ + Safely extract a tar archive into `root`. + + This rejects: + - absolute member paths + - paths containing `..` + - symlinks / hardlinks + - non-regular-file and non-directory members (devices, fifos, etc.) + + It also ensures extraction doesn't traverse through existing symlink parents. + """ + + root.mkdir(parents=True, exist_ok=True) + root_resolved = root.resolve() + + validate_tarfile(tar) + for member in tar.getmembers(): + name = member.name + rel_path = safe_tar_member_rel_path(member) + if rel_path is None: + continue + + dest = root_resolved / rel_path + _ensure_no_symlink_parents(root=root_resolved, dest=dest) + + if member.isdir(): + dest.mkdir(parents=True, exist_ok=True) + continue + + # Regular file + fileobj = tar.extractfile(member) + if fileobj is None: + raise UnsafeTarMemberError(member=name, reason="missing file payload") + + dest.parent.mkdir(parents=True, exist_ok=True) + _ensure_no_symlink_parents(root=root_resolved, dest=dest) + + flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC + if hasattr(os, "O_NOFOLLOW"): + flags |= os.O_NOFOLLOW + fd = os.open(dest, flags, 0o600) + try: + with os.fdopen(fd, "wb") as out: + shutil.copyfileobj(fileobj, out) + finally: + try: + fileobj.close() + except Exception: + pass diff --git a/tests/extensions/experiemental/codex/test_codex_tool.py b/tests/extensions/experiemental/codex/test_codex_tool.py index b9a78c7d0f..a133c3c775 100644 --- a/tests/extensions/experiemental/codex/test_codex_tool.py +++ b/tests/extensions/experiemental/codex/test_codex_tool.py @@ -27,6 +27,7 @@ from agents.lifecycle import RunHooks from agents.run_config import RunConfig from agents.run_context import RunContextWrapper +from agents.run_internal.agent_bindings import bind_public_agent from agents.run_internal.run_steps import ToolRunFunction from agents.run_internal.tool_execution import execute_function_tool_calls from agents.tool_context import ToolContext @@ -920,7 +921,7 @@ async def _error_tool() -> str: with pytest.raises(UserError, match="Error running tool error_tool: boom"): await execute_function_tool_calls( - agent=agent, + bindings=bind_public_agent(agent), tool_runs=tool_runs, hooks=RunHooks(), context_wrapper=context_wrapper, diff --git a/tests/extensions/test_sandbox_e2b.py b/tests/extensions/test_sandbox_e2b.py new file mode 100644 index 0000000000..0ab0a14161 --- /dev/null +++ b/tests/extensions/test_sandbox_e2b.py @@ -0,0 +1,487 @@ +from __future__ import annotations + +import base64 +import io +import tarfile +import uuid +from pathlib import Path + +import pytest +from pydantic import PrivateAttr + +from agents.extensions.sandbox.sandboxes.e2b import E2BSandboxSession, E2BSandboxSessionState +from agents.sandbox import Manifest +from agents.sandbox.entries import Dir, Mount +from agents.sandbox.errors import ( + WorkspaceArchiveReadError, + WorkspaceArchiveWriteError, + WorkspaceStartError, +) +from agents.sandbox.snapshot import NoopSnapshot + + +class _FakeE2BResult: + def __init__(self, *, stdout: str = "", stderr: str = "", exit_code: int = 0) -> None: + self.stdout = stdout + self.stderr = stderr + self.exit_code = exit_code + + +class _FakeE2BFiles: + def __init__(self) -> None: + self.make_dir_calls: list[tuple[str, float | None]] = [] + + def write( + self, + path: str, + data: bytes, + request_timeout: float | None = None, + ) -> None: + _ = (path, data, request_timeout) + + def remove(self, path: str, request_timeout: float | None = None) -> None: + _ = (path, request_timeout) + + def make_dir(self, path: str, request_timeout: float | None = None) -> bool: + self.make_dir_calls.append((path, request_timeout)) + return True + + def read(self, path: str, format: str = "bytes") -> bytes: + _ = (path, format) + return b"" + + +class _FakeE2BCommands: + def __init__(self) -> None: + self.exec_root_ready = False + self.calls: list[dict[str, object]] = [] + self.mkdir_result: _FakeE2BResult | None = None + self.next_result = _FakeE2BResult() + + def run( + self, + command: str, + timeout: float | None = None, + cwd: str | None = None, + envs: dict[str, str] | None = None, + user: str | None = None, + ) -> _FakeE2BResult: + self.calls.append( + { + "command": command, + "timeout": timeout, + "cwd": cwd, + "envs": envs, + "user": user, + } + ) + if command == "mkdir -p -- /workspace" and cwd == "/": + result = self.mkdir_result or _FakeE2BResult() + if result.exit_code == 0: + self.exec_root_ready = True + self.mkdir_result = None + return result + if cwd == "/workspace" and not self.exec_root_ready: + raise ValueError("cwd '/workspace' does not exist") + result = self.next_result + self.next_result = _FakeE2BResult() + return result + + +class _FakeE2BSandbox: + def __init__(self) -> None: + self.sandbox_id = "sb-123" + self.files = _FakeE2BFiles() + self.commands = _FakeE2BCommands() + + def beta_pause(self) -> None: + return + + def kill(self) -> None: + return + + def is_running(self, request_timeout: float | None = None) -> bool: + _ = request_timeout + return True + + +class _RecordingMount(Mount): + type: str = "recording_mount" + _mounted_paths: list[Path] = PrivateAttr(default_factory=list) + _unmounted_paths: list[Path] = PrivateAttr(default_factory=list) + _events: list[tuple[str, str]] = PrivateAttr(default_factory=list) + + def bind_events(self, events: list[tuple[str, str]]) -> _RecordingMount: + self._events = events + return self + + async def _mount(self, session: object, path: Path) -> None: + _ = session + self._events.append(("mount", str(path))) + self._mounted_paths.append(path) + + async def _unmount(self, session: object, path: Path) -> None: + _ = session + self._events.append(("unmount", str(path))) + self._unmounted_paths.append(path) + + +class _FailingUnmountMount(_RecordingMount): + type: str = "failing_unmount_mount" + + async def _unmount(self, session: object, path: Path) -> None: + _ = session + self._events.append(("unmount_fail", str(path))) + raise RuntimeError("boom while unmounting second mount") + + +class _FailingRemountMount(_RecordingMount): + type: str = "failing_remount_mount" + + async def _mount(self, session: object, path: Path) -> None: + _ = session + self._events.append(("mount_fail", str(path))) + raise RuntimeError("boom while remounting second mount") + + +def _session(*, workspace_root_ready: bool = False) -> tuple[E2BSandboxSession, _FakeE2BSandbox]: + sandbox = _FakeE2BSandbox() + state = E2BSandboxSessionState( + session_id=uuid.uuid4(), + manifest=Manifest(root="/workspace"), + snapshot=NoopSnapshot(id="snapshot"), + sandbox_id=sandbox.sandbox_id, + workspace_root_ready=workspace_root_ready, + ) + return E2BSandboxSession.from_state(state, sandbox=sandbox), sandbox + + +def _tar_bytes() -> bytes: + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w") as tar: + info = tarfile.TarInfo("note.txt") + payload = b"hello" + info.size = len(payload) + tar.addfile(info, io.BytesIO(payload)) + return buf.getvalue() + + +@pytest.mark.asyncio +async def test_e2b_exec_omits_cwd_until_workspace_ready() -> None: + session, sandbox = _session(workspace_root_ready=False) + + result = await session._exec_internal("find", ".", timeout=0.01) # noqa: SLF001 + + assert result.ok() + assert sandbox.commands.calls == [ + { + "command": "find .", + "timeout": 0.01, + "cwd": None, + "envs": {}, + "user": None, + } + ] + + +@pytest.mark.asyncio +async def test_e2b_exec_uses_manifest_root_after_workspace_ready() -> None: + session, sandbox = _session(workspace_root_ready=True) + sandbox.commands.exec_root_ready = True + + result = await session._exec_internal("find", ".", timeout=0.01) # noqa: SLF001 + + assert result.ok() + assert sandbox.commands.calls == [ + { + "command": "find .", + "timeout": 0.01, + "cwd": "/workspace", + "envs": {}, + "user": None, + } + ] + + +@pytest.mark.asyncio +async def test_e2b_start_prepares_workspace_root_for_command_cwd() -> None: + session, sandbox = _session(workspace_root_ready=False) + + await session.start() + result = await session._exec_internal("pwd", timeout=0.01) # noqa: SLF001 + + assert result.ok() + assert session.state.workspace_root_ready is True + assert session._workspace_root_ready is True # noqa: SLF001 + assert sandbox.files.make_dir_calls == [("/workspace", 10), ("/workspace", 10)] + assert sandbox.commands.calls == [ + { + "command": "mkdir -p -- /workspace", + "timeout": 10, + "cwd": "/", + "envs": {}, + "user": None, + }, + { + "command": "pwd", + "timeout": 0.01, + "cwd": "/workspace", + "envs": {}, + "user": None, + }, + ] + + +@pytest.mark.asyncio +async def test_e2b_start_raises_on_nonzero_workspace_root_setup_exit() -> None: + session, sandbox = _session(workspace_root_ready=False) + sandbox.commands.mkdir_result = _FakeE2BResult(stderr="mkdir failed", exit_code=2) + + with pytest.raises(WorkspaceStartError) as exc_info: + await session.start() + + assert exc_info.value.context["reason"] == "workspace_root_nonzero_exit" + assert exc_info.value.context["exit_code"] == 2 + assert session.state.workspace_root_ready is False + assert session._workspace_root_ready is False # noqa: SLF001 + + +@pytest.mark.asyncio +async def test_e2b_skip_start_still_prepares_workspace_root_for_resumed_exec_cwd() -> None: + session, sandbox = _session(workspace_root_ready=False) + session._skip_start = True # noqa: SLF001 + + await session.start() + result = await session._exec_internal("pwd", timeout=0.01) # noqa: SLF001 + + assert result.ok() + assert session.state.workspace_root_ready is True + assert session._workspace_root_ready is True # noqa: SLF001 + assert sandbox.commands.calls == [ + { + "command": "mkdir -p -- /workspace", + "timeout": 10, + "cwd": "/", + "envs": {}, + "user": None, + }, + { + "command": "pwd", + "timeout": 0.01, + "cwd": "/workspace", + "envs": {}, + "user": None, + }, + ] + + +@pytest.mark.asyncio +async def test_e2b_running_requires_workspace_root_ready() -> None: + session, _sandbox = _session(workspace_root_ready=False) + + assert await session.running() is False + + +@pytest.mark.asyncio +async def test_e2b_running_checks_remote_after_workspace_ready() -> None: + session, sandbox = _session(workspace_root_ready=True) + sandbox.commands.exec_root_ready = True + + assert await session.running() is True + + +@pytest.mark.asyncio +async def test_e2b_persist_workspace_raises_on_nonzero_snapshot_exit() -> None: + session, sandbox = _session(workspace_root_ready=True) + sandbox.commands.exec_root_ready = True + sandbox.commands.next_result = _FakeE2BResult(stderr="tar failed", exit_code=2) + + with pytest.raises(WorkspaceArchiveReadError) as exc_info: + await session.persist_workspace() + + assert exc_info.value.context["reason"] == "snapshot_nonzero_exit" + assert exc_info.value.context["exit_code"] == 2 + + +@pytest.mark.asyncio +async def test_e2b_persist_workspace_excludes_runtime_skip_paths() -> None: + session, sandbox = _session(workspace_root_ready=True) + sandbox.commands.exec_root_ready = True + session._register_persist_workspace_skip_relpath(Path("logs/events.jsonl")) # noqa: SLF001 + sandbox.commands.next_result = _FakeE2BResult( + stdout=base64.b64encode(b"fake-tar-bytes").decode("ascii") + ) + + archive = await session.persist_workspace() + + assert archive.read() == b"fake-tar-bytes" + expected_command = ( + "tar --exclude=logs/events.jsonl --exclude=./logs/events.jsonl " + "-C /workspace -cf - . | base64 -w0" + ) + assert sandbox.commands.calls == [ + { + "command": expected_command, + "timeout": session.state.timeouts.snapshot_tar_s, + "cwd": "/", + "envs": {}, + "user": None, + } + ] + + +@pytest.mark.asyncio +async def test_e2b_hydrate_workspace_raises_on_nonzero_extract_exit() -> None: + session, sandbox = _session(workspace_root_ready=False) + sandbox.commands.next_result = _FakeE2BResult(stderr="tar failed", exit_code=2) + + with pytest.raises(WorkspaceArchiveWriteError) as exc_info: + await session.hydrate_workspace(io.BytesIO(_tar_bytes())) + + assert exc_info.value.context["reason"] == "hydrate_nonzero_exit" + assert exc_info.value.context["exit_code"] == 2 + assert session.state.workspace_root_ready is False + assert session._workspace_root_ready is False # noqa: SLF001 + + +@pytest.mark.asyncio +async def test_e2b_persist_workspace_remounts_mounts_after_snapshot() -> None: + mount = _RecordingMount() + sandbox = _FakeE2BSandbox() + sandbox.commands.exec_root_ready = True + sandbox.commands.next_result = _FakeE2BResult( + stdout=base64.b64encode(b"fake-tar-bytes").decode("ascii") + ) + state = E2BSandboxSessionState( + session_id=uuid.uuid4(), + manifest=Manifest(root="/workspace", entries={"mount": mount}), + snapshot=NoopSnapshot(id="snapshot"), + sandbox_id=sandbox.sandbox_id, + workspace_root_ready=True, + ) + session = E2BSandboxSession.from_state(state, sandbox=sandbox) + + archive = await session.persist_workspace() + + assert archive.read() == b"fake-tar-bytes" + assert mount._unmounted_paths == [Path("/workspace/mount")] + assert mount._mounted_paths == [Path("/workspace/mount")] + + +@pytest.mark.asyncio +async def test_e2b_persist_workspace_uses_nested_mount_targets_and_resolved_excludes() -> None: + parent_mount = _RecordingMount(mount_path=Path("repo")) + child_mount = _RecordingMount(mount_path=Path("repo/sub")) + events: list[tuple[str, str]] = [] + sandbox = _FakeE2BSandbox() + sandbox.commands.exec_root_ready = True + sandbox.commands.next_result = _FakeE2BResult( + stdout=base64.b64encode(b"fake-tar-bytes").decode("ascii") + ) + state = E2BSandboxSessionState( + session_id=uuid.uuid4(), + manifest=Manifest( + root="/workspace", + entries={ + "parent": parent_mount.bind_events(events), + "nested": Dir(children={"child": child_mount.bind_events(events)}), + }, + ), + snapshot=NoopSnapshot(id="snapshot"), + sandbox_id=sandbox.sandbox_id, + workspace_root_ready=True, + ) + session = E2BSandboxSession.from_state(state, sandbox=sandbox) + + archive = await session.persist_workspace() + + assert archive.read() == b"fake-tar-bytes" + assert [path for kind, path in events if kind == "unmount"] == [ + "/workspace/repo/sub", + "/workspace/repo", + ] + assert [path for kind, path in events if kind == "mount"] == [ + "/workspace/repo", + "/workspace/repo/sub", + ] + tar_command = str(sandbox.commands.calls[-1]["command"]) + assert "--exclude=repo" in tar_command + assert "--exclude=./repo" in tar_command + assert "--exclude=repo/sub" in tar_command + assert "--exclude=./repo/sub" in tar_command + + +@pytest.mark.asyncio +async def test_e2b_persist_workspace_remounts_prior_mounts_after_unmount_failure() -> None: + events: list[tuple[str, str]] = [] + sandbox = _FakeE2BSandbox() + sandbox.commands.exec_root_ready = True + state = E2BSandboxSessionState( + session_id=uuid.uuid4(), + manifest=Manifest( + root="/workspace", + entries={ + "repo": Dir( + children={ + "mount1": _RecordingMount().bind_events(events), + "mount2": _FailingUnmountMount().bind_events(events), + } + ) + }, + ), + snapshot=NoopSnapshot(id="snapshot"), + sandbox_id=sandbox.sandbox_id, + workspace_root_ready=True, + ) + session = E2BSandboxSession.from_state(state, sandbox=sandbox) + + with pytest.raises(WorkspaceArchiveReadError): + await session.persist_workspace() + + assert [kind for kind, _path in events] == [ + "unmount", + "unmount_fail", + "mount", + ] + assert sandbox.commands.calls == [] + + +@pytest.mark.asyncio +async def test_e2b_persist_workspace_keeps_remounting_and_raises_remount_error_first() -> None: + events: list[tuple[str, str]] = [] + sandbox = _FakeE2BSandbox() + sandbox.commands.exec_root_ready = True + sandbox.commands.next_result = _FakeE2BResult(stderr="tar failed", exit_code=2) + state = E2BSandboxSessionState( + session_id=uuid.uuid4(), + manifest=Manifest( + root="/workspace", + entries={ + "repo": Dir( + children={ + "a": _RecordingMount().bind_events(events), + "b": _FailingRemountMount().bind_events(events), + } + ) + }, + ), + snapshot=NoopSnapshot(id="snapshot"), + sandbox_id=sandbox.sandbox_id, + workspace_root_ready=True, + ) + session = E2BSandboxSession.from_state(state, sandbox=sandbox) + + with pytest.raises(WorkspaceArchiveReadError) as exc_info: + await session.persist_workspace() + + assert isinstance(exc_info.value.cause, RuntimeError) + assert str(exc_info.value.cause) == "boom while remounting second mount" + assert exc_info.value.context["snapshot_error_before_remount_corruption"] == { + "message": "failed to read archive for path: /workspace", + } + assert [kind for kind, _path in events] == [ + "unmount", + "unmount", + "mount_fail", + "mount", + ] diff --git a/tests/extensions/test_sandbox_modal.py b/tests/extensions/test_sandbox_modal.py new file mode 100644 index 0000000000..ad63961384 --- /dev/null +++ b/tests/extensions/test_sandbox_modal.py @@ -0,0 +1,737 @@ +from __future__ import annotations + +import importlib +import io +import sys +import types +from collections.abc import Callable +from pathlib import Path +from typing import Any + +import pytest + +from agents.sandbox import Manifest +from agents.sandbox.entries import File, GCSMount +from agents.sandbox.errors import InvalidManifestPathError, WorkspaceArchiveReadError +from agents.sandbox.manifest import Environment +from agents.sandbox.types import ExecResult + + +def _load_modal_module( + monkeypatch: pytest.MonkeyPatch, +) -> tuple[Any, list[dict[str, object]], list[str]]: + create_calls: list[dict[str, object]] = [] + registry_tags: list[str] = [] + + class _FakeImage: + object_id = "im-123" + + @staticmethod + def from_registry(_tag: str) -> _FakeImage: + registry_tags.append(_tag) + return _FakeImage() + + @staticmethod + def from_id(_image_id: str) -> _FakeImage: + return _FakeImage() + + class _FakeSandboxInstance: + object_id = "sb-123" + + def __init__(self) -> None: + self.terminate_calls = 0 + + def terminate(self) -> None: + self.terminate_calls += 1 + + def poll(self) -> None: + return None + + class _FakeSandbox: + @staticmethod + def create(**kwargs: object) -> _FakeSandboxInstance: + create_calls.append(dict(kwargs)) + return _FakeSandboxInstance() + + @staticmethod + def from_id(_sandbox_id: str) -> _FakeSandboxInstance: + return _FakeSandboxInstance() + + class _FakeApp: + @staticmethod + def lookup(_name: str, *, create_if_missing: bool = False) -> object: + _ = create_if_missing + return object() + + fake_modal: Any = types.ModuleType("modal") + fake_modal.Image = _FakeImage + fake_modal.App = _FakeApp + fake_modal.Sandbox = _FakeSandbox + + fake_container_process: Any = types.ModuleType("modal.container_process") + fake_container_process.ContainerProcess = object + + monkeypatch.setitem(sys.modules, "modal", fake_modal) + monkeypatch.setitem(sys.modules, "modal.container_process", fake_container_process) + sys.modules.pop("agents.extensions.sandbox.sandboxes.modal", None) + + module: Any = importlib.import_module("agents.extensions.sandbox.sandboxes.modal") + return module, create_calls, registry_tags + + +@pytest.mark.asyncio +async def test_modal_sandbox_create_passes_manifest_environment( + monkeypatch: pytest.MonkeyPatch, +) -> None: + modal_module, create_calls, registry_tags = _load_modal_module(monkeypatch) + + client = modal_module.ModalSandboxClient() + session = await client.create( + manifest=Manifest(environment=Environment(value={"SANDBOX_FLAG": "enabled"})), + options=modal_module.ModalSandboxClientOptions(app_name="sandbox-tests"), + ) + + await session._inner._ensure_sandbox() # noqa: SLF001 + + assert create_calls + assert create_calls[0]["env"] == {"SANDBOX_FLAG": "enabled"} + assert registry_tags == ["python:3.11-slim"] + + +@pytest.mark.asyncio +async def test_modal_stop_is_persistence_only_and_shutdown_terminates( + monkeypatch: pytest.MonkeyPatch, +) -> None: + modal_module, _create_calls, _registry_tags = _load_modal_module(monkeypatch) + sandbox = sys.modules["modal"].Sandbox.create() + state = modal_module.ModalSandboxSessionState( + manifest=Manifest(root="/workspace"), + snapshot=modal_module.resolve_snapshot(None, "snapshot"), + app_name="sandbox-tests", + sandbox_id=sandbox.object_id, + ) + session = modal_module.ModalSandboxSession.from_state(state, sandbox=sandbox) + session._running = True + + await session.stop() + + assert sandbox.terminate_calls == 0 + assert session.state.sandbox_id == "sb-123" + assert await session.running() is True + + await session.shutdown() + + assert sandbox.terminate_calls == 1 + assert session.state.sandbox_id is None + assert await session.running() is False + + +@pytest.mark.asyncio +async def test_modal_tar_persist_respects_runtime_skip_paths( + monkeypatch: pytest.MonkeyPatch, +) -> None: + modal_module, _create_calls, _registry_tags = _load_modal_module(monkeypatch) + state = modal_module.ModalSandboxSessionState( + manifest=Manifest(root="/workspace"), + snapshot=modal_module.resolve_snapshot(None, "snapshot"), + app_name="sandbox-tests", + sandbox_id="sb-123", + ) + session = modal_module.ModalSandboxSession.from_state(state) + session._register_persist_workspace_skip_relpath(Path("logs/events.jsonl")) # noqa: SLF001 + + commands: list[list[str]] = [] + + async def _fake_exec( + *command: object, + timeout: float | None = None, + shell: bool | list[str] = True, + user: object | None = None, + ) -> ExecResult: + _ = (timeout, shell, user) + rendered = [str(part) for part in command] + commands.append(rendered) + return ExecResult(stdout=b"fake-tar-bytes", stderr=b"", exit_code=0) + + monkeypatch.setattr(session, "exec", _fake_exec) + + archive = await session.persist_workspace() + + assert archive.read() == b"fake-tar-bytes" + assert commands == [ + [ + "tar", + "cf", + "-", + "--exclude", + "./logs/events.jsonl", + "-C", + "/workspace", + ".", + ] + ] + + +@pytest.mark.asyncio +async def test_modal_snapshot_failure_restores_ephemeral_paths( + monkeypatch: pytest.MonkeyPatch, +) -> None: + modal_module, _create_calls, _registry_tags = _load_modal_module(monkeypatch) + + class _FakeRestoreProcess: + def __init__(self, owner: Any) -> None: + self._owner = owner + self.stderr = io.BytesIO(b"") + self.stdin = self._FakeStdin(owner) + + class _FakeStdin: + def __init__(self, owner: Any) -> None: + self._owner = owner + self._buffer = bytearray() + + def write(self, data: bytes) -> None: + self._buffer.extend(data) + + def write_eof(self) -> None: + return + + def drain(self) -> None: + return + + def wait(self) -> int: + self._owner.restore_payloads.append(bytes(self.stdin._buffer)) + return 0 + + class _FakeSnapshotSandbox: + object_id = "sb-123" + + def __init__(self) -> None: + self.restore_payloads: list[bytes] = [] + + def snapshot_filesystem(self) -> str: + raise RuntimeError("snapshot failed") + + def exec(self, *command: object, **kwargs: object) -> _FakeRestoreProcess: + _ = kwargs + assert command[:3] == ("tar", "xf", "-") + return _FakeRestoreProcess(self) + + sandbox = _FakeSnapshotSandbox() + state = modal_module.ModalSandboxSessionState( + manifest=Manifest( + root="/workspace", + entries={"tmp.txt": File(content=b"ephemeral", ephemeral=True)}, + ), + snapshot=modal_module.resolve_snapshot(None, "snapshot"), + app_name="sandbox-tests", + sandbox_id=sandbox.object_id, + workspace_persistence="snapshot_filesystem", + ) + session = modal_module.ModalSandboxSession.from_state(state, sandbox=sandbox) + + async def _fake_exec( + *command: object, + timeout: float | None = None, + shell: bool | list[str] = True, + user: object | None = None, + ) -> ExecResult: + _ = (timeout, shell, user) + rendered = [str(part) for part in command] + if rendered[:2] == ["sh", "-lc"]: + return ExecResult(stdout=b"ephemeral-backup", stderr=b"", exit_code=0) + if rendered[:3] == ["rm", "-rf", "--"]: + return ExecResult(stdout=b"", stderr=b"", exit_code=0) + raise AssertionError(f"unexpected command: {rendered!r}") + + async def _fake_call_modal( + fn: Callable[..., object], + *args: object, + call_timeout: float | None = None, + **kwargs: object, + ) -> object: + _ = call_timeout + return fn(*args, **kwargs) + + monkeypatch.setattr(session, "exec", _fake_exec) + monkeypatch.setattr(session, "_call_modal", _fake_call_modal) + + with pytest.raises(WorkspaceArchiveReadError) as exc_info: + await session.persist_workspace() + + assert exc_info.value.context["reason"] == "snapshot_filesystem_failed" + assert sandbox.restore_payloads == [b"ephemeral-backup"] + + +@pytest.mark.asyncio +async def test_modal_snapshot_cleanup_failure_raises_before_snapshot( + monkeypatch: pytest.MonkeyPatch, +) -> None: + modal_module, _create_calls, _registry_tags = _load_modal_module(monkeypatch) + + class _FakeRestoreProcess: + def __init__(self, owner: Any) -> None: + self._owner = owner + self.stderr = io.BytesIO(b"") + self.stdin = self._FakeStdin(owner) + + class _FakeStdin: + def __init__(self, owner: Any) -> None: + self._owner = owner + self._buffer = bytearray() + + def write(self, data: bytes) -> None: + self._buffer.extend(data) + + def write_eof(self) -> None: + return + + def drain(self) -> None: + return + + def wait(self) -> int: + self._owner.restore_payloads.append(bytes(self.stdin._buffer)) + return 0 + + class _FakeSnapshotSandbox: + object_id = "sb-123" + + def __init__(self) -> None: + self.restore_payloads: list[bytes] = [] + self.snapshot_calls = 0 + + def snapshot_filesystem(self) -> str: + self.snapshot_calls += 1 + return "snap-123" + + def exec(self, *command: object, **kwargs: object) -> _FakeRestoreProcess: + _ = kwargs + assert command[:3] == ("tar", "xf", "-") + return _FakeRestoreProcess(self) + + sandbox = _FakeSnapshotSandbox() + state = modal_module.ModalSandboxSessionState( + manifest=Manifest( + root="/workspace", + entries={"tmp.txt": File(content=b"ephemeral", ephemeral=True)}, + ), + snapshot=modal_module.resolve_snapshot(None, "snapshot"), + app_name="sandbox-tests", + sandbox_id=sandbox.object_id, + workspace_persistence="snapshot_filesystem", + ) + session = modal_module.ModalSandboxSession.from_state(state, sandbox=sandbox) + + async def _fake_exec( + *command: object, + timeout: float | None = None, + shell: bool | list[str] = True, + user: object | None = None, + ) -> ExecResult: + _ = (timeout, shell, user) + rendered = [str(part) for part in command] + if rendered[:2] == ["sh", "-lc"]: + return ExecResult(stdout=b"ephemeral-backup", stderr=b"", exit_code=0) + if rendered[:3] == ["rm", "-rf", "--"]: + return ExecResult(stdout=b"", stderr=b"rm failed", exit_code=1) + raise AssertionError(f"unexpected command: {rendered!r}") + + async def _fake_call_modal( + fn: Callable[..., object], + *args: object, + call_timeout: float | None = None, + **kwargs: object, + ) -> object: + _ = call_timeout + return fn(*args, **kwargs) + + monkeypatch.setattr(session, "exec", _fake_exec) + monkeypatch.setattr(session, "_call_modal", _fake_call_modal) + + with pytest.raises(WorkspaceArchiveReadError) as exc_info: + await session.persist_workspace() + + assert exc_info.value.context["reason"] == "snapshot_filesystem_ephemeral_remove_failed" + assert exc_info.value.context["exit_code"] == 1 + assert exc_info.value.context["stderr"] == "rm failed" + assert sandbox.snapshot_calls == 0 + assert sandbox.restore_payloads == [b"ephemeral-backup"] + + +@pytest.mark.asyncio +async def test_modal_snapshot_filesystem_uses_resolved_mount_paths_for_backup_and_removal( + monkeypatch: pytest.MonkeyPatch, +) -> None: + modal_module, _create_calls, _registry_tags = _load_modal_module(monkeypatch) + + class _FakeRestoreProcess: + def __init__(self) -> None: + self.stderr = io.BytesIO(b"") + self.stdin = self._FakeStdin() + + class _FakeStdin: + def write(self, data: bytes) -> None: + _ = data + + def write_eof(self) -> None: + return + + def drain(self) -> None: + return + + def wait(self) -> int: + return 0 + + class _FakeSnapshotSandbox: + object_id = "sb-123" + + def snapshot_filesystem(self) -> str: + return "snap-123" + + def exec(self, *command: object, **kwargs: object) -> _FakeRestoreProcess: + _ = kwargs + assert command[:3] == ("tar", "xf", "-") + return _FakeRestoreProcess() + + sandbox = _FakeSnapshotSandbox() + state = modal_module.ModalSandboxSessionState( + manifest=Manifest( + root="/workspace", + entries={"logical": GCSMount(bucket="bucket", mount_path=Path("actual"))}, + ), + snapshot=modal_module.resolve_snapshot(None, "snapshot"), + app_name="sandbox-tests", + sandbox_id=sandbox.object_id, + workspace_persistence="snapshot_filesystem", + ) + session = modal_module.ModalSandboxSession.from_state(state, sandbox=sandbox) + commands: list[list[str]] = [] + + async def _fake_exec( + *command: object, + timeout: float | None = None, + shell: bool | list[str] = True, + user: object | None = None, + ) -> ExecResult: + _ = (timeout, shell, user) + rendered = [str(part) for part in command] + commands.append(rendered) + if rendered[:2] == ["sh", "-lc"]: + return ExecResult(stdout=b"ephemeral-backup", stderr=b"", exit_code=0) + if rendered[:3] == ["rm", "-rf", "--"]: + return ExecResult(stdout=b"", stderr=b"", exit_code=0) + raise AssertionError(f"unexpected command: {rendered!r}") + + async def _fake_call_modal( + fn: Callable[..., object], + *args: object, + call_timeout: float | None = None, + **kwargs: object, + ) -> object: + _ = call_timeout + return fn(*args, **kwargs) + + monkeypatch.setattr(session, "exec", _fake_exec) + monkeypatch.setattr(session, "_call_modal", _fake_call_modal) + + archive = await session.persist_workspace() + + assert archive.read() == modal_module._encode_snapshot_filesystem_ref(snapshot_id="snap-123") + assert commands[0][0:2] == ["sh", "-lc"] + assert "actual" in commands[0][2] + assert "logical" in commands[0][2] + assert commands[1] == ["rm", "-rf", "--", "/workspace/actual", "/workspace/logical"] + + +@pytest.mark.asyncio +async def test_modal_tar_persist_uses_resolved_mount_paths_for_excludes( + monkeypatch: pytest.MonkeyPatch, +) -> None: + modal_module, _create_calls, _registry_tags = _load_modal_module(monkeypatch) + + state = modal_module.ModalSandboxSessionState( + manifest=Manifest( + root="/workspace", + entries={"logical": GCSMount(bucket="bucket", mount_path=Path("actual"))}, + ), + snapshot=modal_module.resolve_snapshot(None, "snapshot"), + app_name="sandbox-tests", + ) + session = modal_module.ModalSandboxSession.from_state(state, sandbox=None) + commands: list[list[str]] = [] + + async def _fake_exec( + *command: object, + timeout: float | None = None, + shell: bool | list[str] = True, + user: object | None = None, + ) -> ExecResult: + _ = (timeout, shell, user) + rendered = [str(part) for part in command] + commands.append(rendered) + return ExecResult(stdout=b"tar-bytes", stderr=b"", exit_code=0) + + monkeypatch.setattr(session, "exec", _fake_exec) + + archive = await session.persist_workspace() + + assert archive.read() == b"tar-bytes" + assert commands == [ + [ + "tar", + "cf", + "-", + "--exclude", + "./actual", + "--exclude", + "./logical", + "-C", + "/workspace", + ".", + ] + ] + + +@pytest.mark.asyncio +async def test_modal_snapshot_filesystem_rejects_escaping_mount_paths_before_exec( + monkeypatch: pytest.MonkeyPatch, +) -> None: + modal_module, _create_calls, _registry_tags = _load_modal_module(monkeypatch) + + class _FakeSnapshotSandbox: + object_id = "sb-123" + + def __init__(self) -> None: + self.snapshot_calls = 0 + + def snapshot_filesystem(self) -> str: + self.snapshot_calls += 1 + return "snap-123" + + sandbox = _FakeSnapshotSandbox() + state = modal_module.ModalSandboxSessionState( + manifest=Manifest( + root="/workspace", + entries={"logical": GCSMount(bucket="bucket", mount_path=Path("/workspace/../../tmp"))}, + ), + snapshot=modal_module.resolve_snapshot(None, "snapshot"), + app_name="sandbox-tests", + sandbox_id=sandbox.object_id, + workspace_persistence="snapshot_filesystem", + ) + session = modal_module.ModalSandboxSession.from_state(state, sandbox=sandbox) + commands: list[list[str]] = [] + + async def _fake_exec( + *command: object, + timeout: float | None = None, + shell: bool | list[str] = True, + user: object | None = None, + ) -> ExecResult: + _ = (timeout, shell, user) + commands.append([str(part) for part in command]) + raise AssertionError("exec() should not run for escaping mount paths") + + async def _fake_call_modal( + fn: Callable[..., object], + *args: object, + call_timeout: float | None = None, + **kwargs: object, + ) -> object: + _ = (fn, args, call_timeout, kwargs) + raise AssertionError("snapshot_filesystem() should not run for escaping mount paths") + + monkeypatch.setattr(session, "exec", _fake_exec) + monkeypatch.setattr(session, "_call_modal", _fake_call_modal) + + with pytest.raises(InvalidManifestPathError, match="must not escape root"): + await session.persist_workspace() + + assert commands == [] + assert sandbox.snapshot_calls == 0 + + +@pytest.mark.asyncio +async def test_modal_write_chunks_large_payload_before_draining( + monkeypatch: pytest.MonkeyPatch, +) -> None: + modal_module, _create_calls, _registry_tags = _load_modal_module(monkeypatch) + + class _FakeWaitResult: + def wait(self) -> int: + return 0 + + class _FakeStdin: + def __init__(self, *, limit: int) -> None: + self._limit = limit + self._buffer = bytearray() + self.chunks: list[bytes] = [] + self.write_eof_calls = 0 + self.drain_calls = 0 + + def write(self, data: bytes | bytearray | memoryview) -> None: + rendered = bytes(data) + if len(self._buffer) + len(rendered) > self._limit: + raise BufferError("Buffer size exceed limit. Call drain to flush the buffer.") + self._buffer.extend(rendered) + + def write_eof(self) -> None: + self.write_eof_calls += 1 + + def drain(self) -> None: + self.chunks.append(bytes(self._buffer)) + self._buffer.clear() + self.drain_calls += 1 + + class _FakeProcess: + def __init__(self, *, limit: int) -> None: + self.stdin = _FakeStdin(limit=limit) + self.stderr = io.BytesIO(b"") + + def wait(self) -> int: + return 0 + + class _FakeSandbox: + object_id = "sb-123" + + def __init__(self) -> None: + self.processes: list[_FakeProcess] = [] + self.commands: list[tuple[object, ...]] = [] + + def exec(self, *command: object, **kwargs: object) -> object: + _ = kwargs + self.commands.append(command) + if command[:3] == ("mkdir", "-p", "--"): + return _FakeWaitResult() + process = _FakeProcess(limit=5) + self.processes.append(process) + return process + + sandbox = _FakeSandbox() + state = modal_module.ModalSandboxSessionState( + manifest=Manifest(root="/workspace"), + snapshot=modal_module.resolve_snapshot(None, "snapshot"), + app_name="sandbox-tests", + sandbox_id=sandbox.object_id, + ) + session = modal_module.ModalSandboxSession.from_state(state, sandbox=sandbox) + monkeypatch.setattr(modal_module, "_MODAL_STDIN_CHUNK_SIZE", 5) + + async def _fake_call_modal( + fn: Callable[..., object], + *args: object, + call_timeout: float | None = None, + **kwargs: object, + ) -> object: + _ = call_timeout + return fn(*args, **kwargs) + + monkeypatch.setattr(session, "_call_modal", _fake_call_modal) + + payload = b"abcdefghijklm" + await session.write(Path("nested/file.bin"), io.BytesIO(payload)) + + assert sandbox.commands == [ + ("mkdir", "-p", "--", "/workspace/nested"), + ("sh", "-lc", "cat > /workspace/nested/file.bin"), + ] + assert len(sandbox.processes) == 1 + assert sandbox.processes[0].stdin.chunks == [b"abcde", b"fghij", b"klm", b""] + assert sandbox.processes[0].stdin.write_eof_calls == 1 + assert sandbox.processes[0].stdin.drain_calls == 4 + + +@pytest.mark.asyncio +async def test_modal_hydrate_tar_chunks_large_payload_before_draining( + monkeypatch: pytest.MonkeyPatch, +) -> None: + modal_module, _create_calls, _registry_tags = _load_modal_module(monkeypatch) + + class _FakeWaitResult: + def wait(self) -> int: + return 0 + + class _FakeStdin: + def __init__(self, *, limit: int) -> None: + self._limit = limit + self._buffer = bytearray() + self.chunks: list[bytes] = [] + self.write_eof_calls = 0 + self.drain_calls = 0 + + def write(self, data: bytes | bytearray | memoryview) -> None: + rendered = bytes(data) + if len(self._buffer) + len(rendered) > self._limit: + raise BufferError("Buffer size exceed limit. Call drain to flush the buffer.") + self._buffer.extend(rendered) + + def write_eof(self) -> None: + self.write_eof_calls += 1 + + def drain(self) -> None: + self.chunks.append(bytes(self._buffer)) + self._buffer.clear() + self.drain_calls += 1 + + class _FakeProcess: + def __init__(self, *, limit: int) -> None: + self.stdin = _FakeStdin(limit=limit) + self.stderr = io.BytesIO(b"") + + def wait(self) -> int: + return 0 + + class _FakeSandbox: + object_id = "sb-123" + + def __init__(self) -> None: + self.processes: list[_FakeProcess] = [] + self.commands: list[tuple[object, ...]] = [] + + def exec(self, *command: object, **kwargs: object) -> object: + _ = kwargs + self.commands.append(command) + if command[:3] == ("mkdir", "-p", "--"): + return _FakeWaitResult() + process = _FakeProcess(limit=7) + self.processes.append(process) + return process + + sandbox = _FakeSandbox() + state = modal_module.ModalSandboxSessionState( + manifest=Manifest(root="/workspace"), + snapshot=modal_module.resolve_snapshot(None, "snapshot"), + app_name="sandbox-tests", + sandbox_id=sandbox.object_id, + ) + session = modal_module.ModalSandboxSession.from_state(state, sandbox=sandbox) + monkeypatch.setattr(modal_module, "_MODAL_STDIN_CHUNK_SIZE", 7) + + async def _fake_call_modal( + fn: Callable[..., object], + *args: object, + call_timeout: float | None = None, + **kwargs: object, + ) -> object: + _ = call_timeout + return fn(*args, **kwargs) + + monkeypatch.setattr(session, "_call_modal", _fake_call_modal) + + tar_payload = io.BytesIO() + with modal_module.tarfile.open(fileobj=tar_payload, mode="w") as tar: + info = modal_module.tarfile.TarInfo(name="large.txt") + contents = b"abcdefghijklmno" + info.size = len(contents) + tar.addfile(info, io.BytesIO(contents)) + tar_payload.seek(0) + + await session.hydrate_workspace(tar_payload) + + assert sandbox.commands == [ + ("mkdir", "-p", "--", "/workspace"), + ("tar", "xf", "-", "-C", "/workspace"), + ] + assert len(sandbox.processes) == 1 + assert b"".join(sandbox.processes[0].stdin.chunks[:-1]) == tar_payload.getvalue() + assert sandbox.processes[0].stdin.write_eof_calls == 1 + assert sandbox.processes[0].stdin.drain_calls >= 2 diff --git a/tests/test_agent_runner.py b/tests/test_agent_runner.py index 8b07297167..2984dd34bd 100644 --- a/tests/test_agent_runner.py +++ b/tests/test_agent_runner.py @@ -54,6 +54,7 @@ from agents.lifecycle import RunHooks from agents.run import AgentRunner, get_default_agent_runner, set_default_agent_runner from agents.run_config import _default_trace_include_sensitive_data +from agents.run_internal.agent_bindings import bind_public_agent from agents.run_internal.items import ( drop_orphan_function_calls, ensure_input_item_format, @@ -2065,7 +2066,7 @@ async def test_conversation_lock_rewind_skips_when_no_snapshot() -> None: agent = Agent(name="test", model=model) result = await get_new_response( - agent=agent, + bindings=bind_public_agent(agent), system_prompt=None, input=[history_item, new_item], output_schema=None, @@ -2110,7 +2111,7 @@ async def test_get_new_response_uses_agent_retry_settings() -> None: ) result = await get_new_response( - agent=agent, + bindings=bind_public_agent(agent), system_prompt=None, input=[get_text_input_item("hello")], output_schema=None, diff --git a/tests/test_agent_runner_sync.py b/tests/test_agent_runner_sync.py index a570eea284..73906e7e93 100644 --- a/tests/test_agent_runner_sync.py +++ b/tests/test_agent_runner_sync.py @@ -1,6 +1,6 @@ import asyncio from collections.abc import Generator -from typing import Any +from typing import Any, Protocol import pytest @@ -8,10 +8,16 @@ from agents.run import AgentRunner +class _EventLoopPolicy(Protocol): + def get_event_loop(self) -> asyncio.AbstractEventLoop: ... + + def set_event_loop(self, loop: asyncio.AbstractEventLoop | None) -> None: ... + + @pytest.fixture -def fresh_event_loop_policy() -> Generator[asyncio.AbstractEventLoopPolicy, None, None]: +def fresh_event_loop_policy() -> Generator[_EventLoopPolicy, None, None]: policy_before = asyncio.get_event_loop_policy() - new_policy = asyncio.DefaultEventLoopPolicy() + new_policy = type(policy_before)() asyncio.set_event_loop_policy(new_policy) try: yield new_policy diff --git a/tests/test_computer_action.py b/tests/test_computer_action.py index bb6823942d..dd69e87537 100644 --- a/tests/test_computer_action.py +++ b/tests/test_computer_action.py @@ -571,7 +571,7 @@ def on_sc(data: ComputerToolSafetyCheckData) -> bool: ctx = RunContextWrapper(context=None) results = await run_loop.execute_computer_actions( - agent=agent, + public_agent=agent, actions=[run_action], hooks=RunHooks[Any](), context_wrapper=ctx, diff --git a/tests/test_example_workflows.py b/tests/test_example_workflows.py index dff1ef7910..bf39f78c61 100644 --- a/tests/test_example_workflows.py +++ b/tests/test_example_workflows.py @@ -28,6 +28,12 @@ from agents.agent import ToolsToFinalOutputResult from agents.items import TResponseInputItem from agents.tool import FunctionToolResult, function_tool +from examples.sandbox.basic import _stream_event_banner +from examples.sandbox.sandbox_agents_as_tools import ( + PricingPacketReview, + RolloutRiskReview, + _structured_tool_output_extractor, +) from .fake_model import FakeModel from .test_responses import ( @@ -487,6 +493,162 @@ async def fake_invoke(ctx, input: str) -> str: ) +@pytest.mark.asyncio +async def test_sandbox_agents_as_tools_example_serializes_structured_reviews() -> None: + pricing_model = FakeModel() + pricing_model.set_next_output( + [ + get_final_output_message( + json.dumps( + { + "requested_discount_percent": 15, + "requested_term_months": 24, + "pricing_risk": "medium", + "summary": "Discount ask is above target band.", + "recommended_next_step": "Trade discount for a stronger give-get.", + "evidence_files": ["pricing_summary.md", "commercial_notes.md"], + } + ) + ) + ] + ) + rollout_model = FakeModel() + rollout_model.set_next_output( + [ + get_final_output_message( + json.dumps( + { + "rollout_risk": "medium", + "summary": "Launch timing is compressed.", + "blockers": [ + "Regional admin training is incomplete.", + "SSO migration lands in week 2.", + ], + "recommended_next_step": "Require a phased rollout plan.", + "evidence_files": ["rollout_plan.md", "support_history.md"], + } + ) + ) + ] + ) + orchestrator_model = FakeModel() + orchestrator_model.add_multiple_turn_outputs( + [ + [ + get_function_tool_call( + "review_pricing_packet", + json.dumps({"input": "Review pricing"}), + call_id="outer_pricing", + ), + get_function_tool_call( + "review_rollout_risk", + json.dumps({"input": "Review rollout"}), + call_id="outer_rollout", + ), + get_function_tool_call( + "get_discount_approval_rule", + json.dumps({"discount_percent": 15}), + call_id="outer_approval", + ), + ], + [get_text_message("Recommendation complete")], + ] + ) + + @function_tool + def get_discount_approval_rule(discount_percent: int) -> str: + if discount_percent <= 10: + return "AE" + if discount_percent <= 15: + return "RSD" + return "Finance + RSD" + + pricing_agent = Agent( + name="pricing", + model=pricing_model, + output_type=PricingPacketReview, + ) + rollout_agent = Agent( + name="rollout", + model=rollout_model, + output_type=RolloutRiskReview, + ) + orchestrator = Agent( + name="orchestrator", + model=orchestrator_model, + tools=[ + pricing_agent.as_tool( + "review_pricing_packet", + "Pricing review", + custom_output_extractor=_structured_tool_output_extractor, + ), + rollout_agent.as_tool( + "review_rollout_risk", + "Rollout review", + custom_output_extractor=_structured_tool_output_extractor, + ), + get_discount_approval_rule, + ], + model_settings=ModelSettings(tool_choice="required"), + ) + + result = await Runner.run(orchestrator, "Review the renewal") + + assert result.final_output == "Recommendation complete" + outer_second_turn_input = cast( + list[dict[str, Any]], + orchestrator_model.last_turn_args["input"], + ) + outer_tool_outputs = [ + item for item in outer_second_turn_input if item.get("type") == "function_call_output" + ] + assert outer_tool_outputs == [ + { + "call_id": "outer_pricing", + "output": json.dumps( + { + "evidence_files": ["pricing_summary.md", "commercial_notes.md"], + "pricing_risk": "medium", + "recommended_next_step": "Trade discount for a stronger give-get.", + "requested_discount_percent": 15, + "requested_term_months": 24, + "summary": "Discount ask is above target band.", + }, + sort_keys=True, + ), + "type": "function_call_output", + }, + { + "call_id": "outer_rollout", + "output": json.dumps( + { + "blockers": [ + "Regional admin training is incomplete.", + "SSO migration lands in week 2.", + ], + "evidence_files": ["rollout_plan.md", "support_history.md"], + "recommended_next_step": "Require a phased rollout plan.", + "rollout_risk": "medium", + "summary": "Launch timing is compressed.", + }, + sort_keys=True, + ), + "type": "function_call_output", + }, + { + "call_id": "outer_approval", + "output": "RSD", + "type": "function_call_output", + }, + ] + + +def test_docker_runner_stream_event_banner_uses_stable_event_names() -> None: + assert _stream_event_banner("tool_called") == "[tool call] shell" + assert _stream_event_banner("tool_output") == "[tool output] shell" + assert _stream_event_banner("message_output_created") is None + + @pytest.mark.asyncio async def test_forcing_tool_use_behaviors_align_with_example() -> None: """Mimics forcing_tool_use example: default vs first_tool vs custom behaviors.""" diff --git a/tests/test_hitl_error_scenarios.py b/tests/test_hitl_error_scenarios.py index d0de312d69..845147aebe 100644 --- a/tests/test_hitl_error_scenarios.py +++ b/tests/test_hitl_error_scenarios.py @@ -26,6 +26,7 @@ function_tool, tool_namespace, ) +from agents._public_agent import set_public_agent from agents.computer import Computer, Environment from agents.exceptions import ModelBehaviorError, UserError from agents.items import ( @@ -39,10 +40,12 @@ from agents.lifecycle import RunHooks from agents.run import RunConfig from agents.run_internal import run_loop +from agents.run_internal.agent_bindings import bind_execution_agent, bind_public_agent from agents.run_internal.run_loop import ( NextStepInterruption, NextStepRunAgain, ProcessedResponse, + ToolRunApplyPatchCall, ToolRunComputerAction, ToolRunFunction, ToolRunMCPApprovalRequest, @@ -84,6 +87,20 @@ ) +def _bind_agent(agent: Agent[Any]): + public_agent = getattr(agent, "_agents_public_agent", None) + if isinstance(public_agent, Agent): + return bind_execution_agent(public_agent=public_agent, execution_agent=agent) + return bind_public_agent(agent) + + +async def _resolve_interrupted_turn(*, agent: Agent[Any], **kwargs: Any): + return await run_loop.resolve_interrupted_turn( + bindings=_bind_agent(agent), + **kwargs, + ) + + class TrackingComputer(Computer): """Minimal computer implementation that records method calls.""" @@ -705,7 +722,7 @@ class DummyMcpTool: interruptions=[], ) - result = await run_loop.resolve_interrupted_turn( + result = await _resolve_interrupted_turn( agent=agent, original_input="test", original_pre_step_items=[approval_item], @@ -745,7 +762,7 @@ async def test_shell_call_without_call_id_raises() -> None: ) with pytest.raises(ModelBehaviorError): - await run_loop.resolve_interrupted_turn( + await _resolve_interrupted_turn( agent=agent, original_input="test", original_pre_step_items=[], @@ -891,7 +908,7 @@ def bad_tool() -> str: ) with pytest.raises(UserError, match="needs_approval"): - await run_loop.resolve_interrupted_turn( + await _resolve_interrupted_turn( agent=agent, original_input="resume invalid", original_pre_step_items=[], @@ -1006,7 +1023,7 @@ def approve_me(reason: Optional[str] = None) -> str: # noqa: UP007 interruptions=[], ) - result = await run_loop.resolve_interrupted_turn( + result = await _resolve_interrupted_turn( agent=agent, original_input="resume approvals", original_pre_step_items=[], @@ -1078,7 +1095,7 @@ async def deferred_lookup_account(customer_id: str) -> str: interruptions=[], ) - result = await run_loop.resolve_interrupted_turn( + result = await _resolve_interrupted_turn( agent=agent, original_input="resume approvals", original_pre_step_items=[], @@ -1099,6 +1116,71 @@ async def deferred_lookup_account(customer_id: str) -> str: assert deferred_outputs == ["deferred:customer_1"] +@pytest.mark.asyncio +async def test_resume_does_not_rebuild_approved_calls_for_same_named_sibling_agent() -> None: + """Approved interruptions should match the current public agent, not any same-named sibling.""" + + first_calls: list[str] = [] + second_calls: list[str] = [] + + @function_tool(needs_approval=True, name_override="approval_tool") + async def first_approval_tool() -> str: + first_calls.append("first") + return "first" + + @function_tool(needs_approval=True, name_override="approval_tool") + async def second_approval_tool() -> str: + second_calls.append("second") + return "second" + + first = Agent(name="sandbox", tools=[first_approval_tool]) + second = Agent(name="sandbox", tools=[second_approval_tool]) + first.handoffs = [second] + second.handoffs = [first] + + approval_item = ToolApprovalItem( + agent=second, + raw_item=make_function_tool_call( + name="approval_tool", + call_id="call-sibling-approval", + arguments="{}", + ), + tool_name="approval_tool", + ) + context_wrapper = make_context_wrapper() + context_wrapper.approve_tool(approval_item) + run_state = make_state_with_interruptions(first, [approval_item]) + processed_response = ProcessedResponse( + new_items=[], + handoffs=[], + functions=[], + computer_actions=[], + local_shell_calls=[], + shell_calls=[], + apply_patch_calls=[], + tools_used=[], + mcp_approval_requests=[], + interruptions=[], + ) + + execution_agent = set_public_agent(first.clone(), first) + result = await _resolve_interrupted_turn( + agent=execution_agent, + original_input="resume approvals", + original_pre_step_items=[], + new_response=ModelResponse(output=[], usage=Usage(), response_id="resp"), + processed_response=processed_response, + hooks=RunHooks(), + context_wrapper=context_wrapper, + run_config=RunConfig(), + run_state=run_state, + ) + + assert first_calls == [] + assert second_calls == [] + assert not any(isinstance(item, ToolCallOutputItem) for item in result.new_step_items) + + @pytest.mark.asyncio async def test_resume_honors_permanent_namespaced_function_approval_with_new_call_id() -> None: @function_tool(needs_approval=True, name_override="lookup_account") @@ -1198,7 +1280,7 @@ def approve_me(reason: Optional[str] = None) -> str: # noqa: UP007 interruptions=[], ) - result = await run_loop.resolve_interrupted_turn( + result = await _resolve_interrupted_turn( agent=agent, original_input="resume approvals", original_pre_step_items=[], @@ -1252,7 +1334,7 @@ async def test_resume_rebuilds_local_mcp_function_runs_from_approvals() -> None: interruptions=[], ) - result = await run_loop.resolve_interrupted_turn( + result = await _resolve_interrupted_turn( agent=agent, original_input="resume approvals", original_pre_step_items=[], @@ -1317,7 +1399,7 @@ async def get_weather() -> str: interruptions=[], ) - result = await run_loop.resolve_interrupted_turn( + result = await _resolve_interrupted_turn( agent=agent, original_input="resume approvals", original_pre_step_items=[], @@ -1377,7 +1459,7 @@ def pending_me(text: str = "wait") -> str: interruptions=[], ) - result = await run_loop.resolve_interrupted_turn( + result = await _resolve_interrupted_turn( agent=agent, original_input="resume approvals", original_pre_step_items=[], @@ -1399,6 +1481,127 @@ def pending_me(text: str = "wait") -> str: assert rejection_outputs, "Rejected function call should emit rejection output" +@pytest.mark.asyncio +async def test_resume_function_rejection_outputs_use_public_agent() -> None: + @function_tool(needs_approval=True) + def reject_me(text: str = "nope") -> str: + return text + + _model, public_agent = make_model_and_agent(tools=[reject_me]) + execution_agent = public_agent.clone() + set_public_agent(execution_agent, public_agent) + context_wrapper = make_context_wrapper() + + rejected_call = make_function_tool_call(reject_me.name, call_id="obj-reject-public") + assert isinstance(rejected_call, ResponseFunctionToolCall) + rejected_item = ToolApprovalItem(agent=public_agent, raw_item=rejected_call) + context_wrapper.reject_tool(rejected_item) + + run_state = make_state_with_interruptions(public_agent, [rejected_item]) + processed_response = ProcessedResponse( + new_items=[], + handoffs=[], + functions=[], + computer_actions=[], + local_shell_calls=[], + shell_calls=[], + apply_patch_calls=[], + tools_used=[], + mcp_approval_requests=[], + interruptions=[], + ) + + result = await _resolve_interrupted_turn( + agent=execution_agent, + original_input="resume approvals", + original_pre_step_items=[], + new_response=ModelResponse(output=[], usage=Usage(), response_id="resp"), + processed_response=processed_response, + hooks=RunHooks(), + context_wrapper=context_wrapper, + run_config=RunConfig(), + run_state=run_state, + ) + + rejection_outputs = [ + item + for item in result.new_step_items + if isinstance(item, ToolCallOutputItem) and item.output == HITL_REJECTION_MSG + ] + assert rejection_outputs + assert all(item.agent is public_agent for item in rejection_outputs) + + +@pytest.mark.parametrize("tool_kind", ["shell", "apply_patch"]) +@pytest.mark.asyncio +async def test_resume_non_function_rejection_outputs_use_public_agent( + tool_kind: str, +) -> None: + context_wrapper = make_context_wrapper() + processed_response = ProcessedResponse( + new_items=[], + handoffs=[], + functions=[], + computer_actions=[], + local_shell_calls=[], + shell_calls=[], + apply_patch_calls=[], + tools_used=[], + mcp_approval_requests=[], + interruptions=[], + ) + + if tool_kind == "shell": + shell_tool = ShellTool(executor=lambda _req: "should_not_run", needs_approval=True) + _model, public_agent = make_model_and_agent(tools=[shell_tool]) + raw_item = cast( + dict[str, Any], + make_shell_call( + "call_reject_shell_public", + id_value="shell_reject_public", + commands=["echo test"], + status="in_progress", + ), + ) + processed_response.shell_calls = [ + ToolRunShellCall(tool_call=raw_item, shell_tool=shell_tool) + ] + tool_name = shell_tool.name + else: + apply_patch_tool = ApplyPatchTool(editor=RecordingEditor(), needs_approval=True) + _model, public_agent = make_model_and_agent(tools=[apply_patch_tool]) + raw_item = cast(Any, make_apply_patch_dict("call_apply_reject_public")) + processed_response.apply_patch_calls = [ + ToolRunApplyPatchCall(tool_call=raw_item, apply_patch_tool=apply_patch_tool) + ] + tool_name = apply_patch_tool.name + + execution_agent = public_agent.clone() + set_public_agent(execution_agent, public_agent) + approval_item = ToolApprovalItem(agent=public_agent, raw_item=raw_item, tool_name=tool_name) + context_wrapper.reject_tool(approval_item) + + result = await _resolve_interrupted_turn( + agent=execution_agent, + original_input="resume rejection", + original_pre_step_items=[], + new_response=ModelResponse(output=[], usage=Usage(), response_id="resp"), + processed_response=processed_response, + hooks=RunHooks(), + context_wrapper=context_wrapper, + run_config=RunConfig(), + run_state=make_state_with_interruptions(public_agent, [approval_item]), + ) + + rejection_outputs = [ + item + for item in result.new_step_items + if isinstance(item, ToolCallOutputItem) and item.output == HITL_REJECTION_MSG + ] + assert rejection_outputs + assert all(item.agent is public_agent for item in rejection_outputs) + + @pytest.mark.asyncio async def test_resume_keeps_unmatched_pending_approvals_with_function_runs() -> None: """Pending approvals should persist even when resume has other function runs.""" @@ -1437,7 +1640,7 @@ def inner_tool() -> str: interruptions=[], ) - result = await run_loop.resolve_interrupted_turn( + result = await _resolve_interrupted_turn( agent=agent, original_input="resume approvals", original_pre_step_items=[], @@ -1477,7 +1680,7 @@ def already_ran() -> str: interruptions=[], ) - result = await run_loop.resolve_interrupted_turn( + result = await _resolve_interrupted_turn( agent=agent, original_input="resume run", original_pre_step_items=[], @@ -1538,7 +1741,7 @@ def already_ran() -> str: ) ] - result = await run_loop.resolve_interrupted_turn( + result = await _resolve_interrupted_turn( agent=agent, original_input="resume run", original_pre_step_items=original_pre_step_items, @@ -1593,7 +1796,7 @@ async def test_resume_skips_shell_calls_with_existing_output() -> None: ) ] - result = await run_loop.resolve_interrupted_turn( + result = await _resolve_interrupted_turn( agent=agent, original_input="resume shell", original_pre_step_items=cast(list[RunItem], original_pre_step_items), @@ -1653,7 +1856,7 @@ def pending_tool() -> str: interruptions=[], ) - result = await run_loop.resolve_interrupted_turn( + result = await _resolve_interrupted_turn( agent=agent, original_input="resume shell with pending approval", original_pre_step_items=[], @@ -1709,7 +1912,7 @@ async def test_resume_executes_pending_computer_actions() -> None: interruptions=[], ) - result = await run_loop.resolve_interrupted_turn( + result = await _resolve_interrupted_turn( agent=agent, original_input="resume computer", original_pre_step_items=[], @@ -1777,7 +1980,7 @@ async def test_resume_skips_computer_actions_with_existing_output() -> None: ) ] - result = await run_loop.resolve_interrupted_turn( + result = await _resolve_interrupted_turn( agent=agent, original_input="resume computer existing", original_pre_step_items=cast(list[RunItem], original_pre_step_items), @@ -1840,7 +2043,7 @@ def pending_me(text: str = "wait") -> str: interruptions=[], ) - result = await run_loop.resolve_interrupted_turn( + result = await _resolve_interrupted_turn( agent=agent, original_input="resume approvals", original_pre_step_items=[], @@ -1910,7 +2113,7 @@ async def test_rebuild_preserves_unmatched_pending_approvals( interruptions=[], ) - result = await run_loop.resolve_interrupted_turn( + result = await _resolve_interrupted_turn( agent=agent, original_input="resume approvals", original_pre_step_items=[], @@ -1957,7 +2160,7 @@ async def test_rejected_shell_calls_emit_rejection_output() -> None: interruptions=[], ) - result = await run_loop.resolve_interrupted_turn( + result = await _resolve_interrupted_turn( agent=agent, original_input="resume shell rejection", original_pre_step_items=[], @@ -2041,7 +2244,7 @@ async def test_rejected_shell_calls_with_existing_output_are_not_duplicated() -> ) ] - result = await run_loop.resolve_interrupted_turn( + result = await _resolve_interrupted_turn( agent=agent, original_input="resume shell rejection existing", original_pre_step_items=cast(list[RunItem], original_pre_step_items), @@ -2101,7 +2304,7 @@ def __init__(self) -> None: interruptions=[], ) - result = await run_loop.resolve_interrupted_turn( + result = await _resolve_interrupted_turn( agent=agent, original_input="handle mcp", original_pre_step_items=[], diff --git a/tests/test_run_impl_resume_paths.py b/tests/test_run_impl_resume_paths.py index 542d1f3749..07ffbf97c1 100644 --- a/tests/test_run_impl_resume_paths.py +++ b/tests/test_run_impl_resume_paths.py @@ -1,5 +1,5 @@ import json -from typing import cast +from typing import Any, cast import pytest from openai.types.responses import ResponseFunctionToolCall, ResponseOutputMessage @@ -7,11 +7,18 @@ import agents.run as run_module from agents import Agent, Runner, function_tool from agents.agent import ToolsToFinalOutputResult -from agents.items import MessageOutputItem, ModelResponse, ToolCallItem, ToolCallOutputItem +from agents.items import ( + MessageOutputItem, + ModelResponse, + ToolApprovalItem, + ToolCallItem, + ToolCallOutputItem, +) from agents.lifecycle import RunHooks from agents.run import RunConfig from agents.run_context import RunContextWrapper from agents.run_internal import run_loop, turn_resolution +from agents.run_internal.agent_bindings import bind_public_agent from agents.run_internal.run_loop import ( NextStepFinalOutput, NextStepInterruption, @@ -84,7 +91,7 @@ async def fake_execute_final_output( ) result = await run_loop.resolve_interrupted_turn( - agent=agent, + bindings=bind_public_agent(agent), original_input="input", original_pre_step_items=[], new_response=ModelResponse(output=[], usage=Usage(), response_id="resp"), @@ -266,3 +273,110 @@ async def test_tool() -> str: assert call_count == 1 assert output_count == 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("schema_version", "expect_execution"), + [("1.6", True), ("1.7", False)], +) +async def test_resolve_interrupted_turn_only_uses_name_fallback_for_legacy_approval_agents( + schema_version: str, + expect_execution: bool, +) -> None: + calls: list[str] = [] + + @function_tool(name_override="needs_ok", needs_approval=True) + async def needs_ok(text: str) -> str: + calls.append(text) + return text + + base_duplicate = Agent(name="duplicate", instructions="alpha", tools=[needs_ok]) + resumed_duplicate = Agent(name="duplicate", instructions="zeta", tools=[needs_ok]) + root = Agent(name="triage", handoffs=[base_duplicate, resumed_duplicate]) + base_duplicate.handoffs = [root] + resumed_duplicate.handoffs = [root] + + state: RunState[dict[str, str], Agent[Any]] = RunState( + context=RunContextWrapper(context={}), + original_input="input", + starting_agent=root, + max_turns=2, + ) + state._current_agent = resumed_duplicate + state._current_step = NextStepInterruption( + interruptions=[ + ToolApprovalItem( + agent=resumed_duplicate, + raw_item=cast( + ResponseFunctionToolCall, + get_function_tool_call( + "needs_ok", + json.dumps({"text": "one"}), + call_id="legacy-call", + ), + ), + ) + ] + ) + state._last_processed_response = ProcessedResponse( + new_items=[], + handoffs=[], + functions=[], + computer_actions=[], + local_shell_calls=[], + shell_calls=[], + apply_patch_calls=[], + tools_used=[], + mcp_approval_requests=[], + interruptions=[], + ) + state._model_responses = [ModelResponse(output=[], usage=Usage(), response_id="resp")] + + json_data = state.to_json() + current_agent_data = cast(dict[str, str], json_data["current_agent"]) + assert current_agent_data["name"] == "duplicate" + assert "identity" in current_agent_data + + interruption_data = cast( + dict[str, object], + json_data["current_step"]["data"]["interruptions"][0], + ) + interruption_agent_data = cast(dict[str, str], interruption_data["agent"]) + assert interruption_agent_data["identity"] == current_agent_data["identity"] + interruption_agent_data.pop("identity") + json_data["$schemaVersion"] = schema_version + + restored = await RunState.from_json(root, json_data) + assert restored._schema_version == schema_version + assert restored._current_agent is resumed_duplicate + restored_approval = restored.get_interruptions()[0] + restored.approve(restored_approval) + assert restored._context is not None + assert restored._last_processed_response is not None + + result = await turn_resolution.resolve_interrupted_turn( + bindings=bind_public_agent(cast(Agent[dict[str, str]], restored._current_agent)), + original_input=restored._original_input, + original_pre_step_items=restored._generated_items, + new_response=restored._model_responses[-1], + processed_response=restored._last_processed_response, + hooks=RunHooks(), + context_wrapper=restored._context, + run_config=RunConfig(), + run_state=restored, + ) + + if expect_execution: + assert isinstance(result.next_step, NextStepRunAgain) + assert calls == ["one"] + assert any( + isinstance(item, ToolCallOutputItem) and item.output == "one" + for item in result.new_step_items + ) + else: + assert calls == [] + assert not any( + isinstance(item, ToolCallOutputItem) and item.output == "one" + for item in result.new_step_items + ) diff --git a/tests/test_run_state.py b/tests/test_run_state.py index 56cd61fab2..3302feb2d5 100644 --- a/tests/test_run_state.py +++ b/tests/test_run_state.py @@ -68,8 +68,10 @@ ) from agents.run_state import ( CURRENT_SCHEMA_VERSION, + SCHEMA_VERSION_SUMMARIES, SUPPORTED_SCHEMA_VERSIONS, RunState, + _build_agent_identity_map, _build_agent_map, _deserialize_items, _deserialize_processed_response, @@ -101,6 +103,7 @@ from .test_responses import ( get_final_output_message, get_function_tool_call, + get_handoff_tool_call, get_text_message, ) from .utils.factories import ( @@ -118,6 +121,9 @@ run_and_resume_with_mutation, ) +_CURRENT_SCHEMA_MAJOR, _CURRENT_SCHEMA_MINOR = CURRENT_SCHEMA_VERSION.split(".") +_NEXT_UNSUPPORTED_SCHEMA_VERSION = f"{_CURRENT_SCHEMA_MAJOR}.{int(_CURRENT_SCHEMA_MINOR) + 1}" + TContext = TypeVar("TContext") @@ -242,6 +248,304 @@ def test_to_json_and_to_string_produce_valid_json(self): assert isinstance(str_data, str) assert json.loads(str_data) == json_data + @pytest.mark.asyncio + async def test_from_json_restores_duplicate_name_current_agent_by_identity(self): + """Duplicate agent names should round-trip through the serialized identity key.""" + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + second = Agent(name="duplicate") + first = Agent(name="duplicate", handoffs=[second]) + second.handoffs = [first] + state = make_state(first, context=context, original_input="input1", max_turns=2) + state._current_agent = second + + json_data = state.to_json() + assert json_data["current_agent"] == {"name": "duplicate", "identity": "duplicate#2"} + + restored = await RunState.from_json(first, json_data) + assert restored._current_agent is second + + def test_build_agent_identity_map_avoids_literal_suffix_collisions(self) -> None: + """Literal `#` names should not collide with generated duplicate identities.""" + first = Agent(name="sandbox") + literal_suffix = Agent(name="sandbox#2") + second = Agent(name="sandbox") + first.handoffs = [literal_suffix, second] + literal_suffix.handoffs = [first, second] + second.handoffs = [first, literal_suffix] + + identity_map = _build_agent_identity_map(first) + + assert identity_map == { + "sandbox": first, + "sandbox#2": literal_suffix, + "sandbox#3": second, + } + + def test_build_agent_identity_map_is_stable_across_reordered_duplicate_agents(self) -> None: + """Duplicate-name identities should not change when reachable order changes.""" + + @function_tool(name_override="alpha_tool") + def alpha_tool() -> str: + return "alpha" + + @function_tool(name_override="beta_tool") + def beta_tool() -> str: + return "beta" + + def _identity_for( + identity_map: Mapping[str, Agent[Any]], + target: Agent[Any], + ) -> str: + return next(identity for identity, agent in identity_map.items() if agent is target) + + first_alpha = Agent(name="sandbox", instructions="Alpha", tools=[alpha_tool]) + first_beta = Agent(name="sandbox", instructions="Beta", tools=[beta_tool]) + first_root = Agent(name="triage", handoffs=[first_beta, first_alpha]) + first_alpha.handoffs = [first_root] + first_beta.handoffs = [first_root] + + second_alpha = Agent(name="sandbox", instructions="Alpha", tools=[alpha_tool]) + second_beta = Agent(name="sandbox", instructions="Beta", tools=[beta_tool]) + second_root = Agent(name="triage", handoffs=[second_alpha, second_beta]) + second_alpha.handoffs = [second_root] + second_beta.handoffs = [second_root] + + first_identity_map = _build_agent_identity_map(first_root) + second_identity_map = _build_agent_identity_map(second_root) + + assert _identity_for(first_identity_map, first_alpha) == _identity_for( + second_identity_map, second_alpha + ) + assert _identity_for(first_identity_map, first_beta) == _identity_for( + second_identity_map, second_beta + ) + + @pytest.mark.asyncio + async def test_from_json_restores_duplicate_name_current_agent_with_reordered_graph(self): + """Restore should keep the same logical duplicate agent after graph reordering.""" + + @function_tool(name_override="alpha_tool") + def alpha_tool() -> str: + return "alpha" + + @function_tool(name_override="beta_tool") + def beta_tool() -> str: + return "beta" + + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + first_alpha = Agent(name="sandbox", instructions="Alpha", tools=[alpha_tool]) + first_beta = Agent(name="sandbox", instructions="Beta", tools=[beta_tool]) + first_root = Agent(name="triage", handoffs=[first_beta, first_alpha]) + first_alpha.handoffs = [first_root] + first_beta.handoffs = [first_root] + + state = make_state(first_root, context=context, original_input="input1", max_turns=2) + state._current_agent = first_beta + json_data = state.to_json() + + restored_alpha = Agent(name="sandbox", instructions="Alpha", tools=[alpha_tool]) + restored_beta = Agent(name="sandbox", instructions="Beta", tools=[beta_tool]) + restored_root = Agent(name="triage", handoffs=[restored_alpha, restored_beta]) + restored_alpha.handoffs = [restored_root] + restored_beta.handoffs = [restored_root] + + restored = await RunState.from_json(restored_root, json_data) + assert restored._current_agent is restored_beta + + @pytest.mark.asyncio + async def test_from_json_restores_bare_duplicate_name_current_agent_via_identity_map(self): + """Bare duplicate names should resolve through the identity map, not traversal order.""" + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + first = Agent(name="duplicate", instructions="zeta") + second = Agent(name="duplicate", instructions="alpha") + root = Agent(name="triage", handoffs=[first, second]) + first.handoffs = [root] + second.handoffs = [root] + + state = make_state(root, context=context, original_input="input1", max_turns=2) + state._current_agent = second + + json_data = state.to_json() + assert json_data["current_agent"] == {"name": "duplicate"} + + restored = await RunState.from_json(root, json_data) + assert restored._current_agent is second + + def test_build_agent_identity_map_uses_tool_use_behavior_for_duplicate_names(self) -> None: + """Duplicate-name identities should stay stable when only tool_use_behavior differs.""" + + def _identity_for( + identity_map: Mapping[str, Agent[Any]], + target: Agent[Any], + ) -> str: + return next(identity for identity, agent in identity_map.items() if agent is target) + + first_default = Agent( + name="sandbox", + instructions="Shared instructions.", + tool_use_behavior="run_llm_again", + ) + first_stop = Agent( + name="sandbox", + instructions="Shared instructions.", + tool_use_behavior="stop_on_first_tool", + ) + first_root = Agent(name="triage", handoffs=[first_default, first_stop]) + first_default.handoffs = [first_root] + first_stop.handoffs = [first_root] + + second_default = Agent( + name="sandbox", + instructions="Shared instructions.", + tool_use_behavior="run_llm_again", + ) + second_stop = Agent( + name="sandbox", + instructions="Shared instructions.", + tool_use_behavior="stop_on_first_tool", + ) + second_root = Agent(name="triage", handoffs=[second_stop, second_default]) + second_default.handoffs = [second_root] + second_stop.handoffs = [second_root] + + first_identity_map = _build_agent_identity_map(first_root) + second_identity_map = _build_agent_identity_map(second_root) + + assert _identity_for(first_identity_map, first_default) == _identity_for( + second_identity_map, second_default + ) + assert _identity_for(first_identity_map, first_stop) == _identity_for( + second_identity_map, second_stop + ) + + @pytest.mark.asyncio + async def test_from_json_restores_duplicate_name_current_agent_when_tool_use_behavior_differs( + self, + ) -> None: + """Duplicate-name restore should stay stable when tool_use_behavior is the only delta.""" + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + first_default = Agent( + name="sandbox", + instructions="Shared instructions.", + tool_use_behavior="run_llm_again", + ) + first_stop = Agent( + name="sandbox", + instructions="Shared instructions.", + tool_use_behavior="stop_on_first_tool", + ) + first_root = Agent(name="triage", handoffs=[first_default, first_stop]) + first_default.handoffs = [first_root] + first_stop.handoffs = [first_root] + + state = make_state(first_root, context=context, original_input="input1", max_turns=2) + state._current_agent = first_stop + json_data = state.to_json() + + restored_default = Agent( + name="sandbox", + instructions="Shared instructions.", + tool_use_behavior="run_llm_again", + ) + restored_stop = Agent( + name="sandbox", + instructions="Shared instructions.", + tool_use_behavior="stop_on_first_tool", + ) + restored_root = Agent(name="triage", handoffs=[restored_stop, restored_default]) + restored_default.handoffs = [restored_root] + restored_stop.handoffs = [restored_root] + + restored = await RunState.from_json(restored_root, json_data) + assert restored._current_agent is restored_stop + + @pytest.mark.asyncio + async def test_from_json_rejects_missing_saved_duplicate_identity(self): + """Identity-aware snapshots should fail when the saved duplicate no longer exists.""" + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + second = Agent(name="duplicate", instructions="Second") + first = Agent(name="duplicate", instructions="First", handoffs=[second]) + second.handoffs = [first] + state = make_state(first, context=context, original_input="input1", max_turns=2) + state._current_agent = second + + json_data = state.to_json() + restored_root = Agent(name="duplicate", instructions="First") + + with pytest.raises(UserError, match="agent identity"): + await RunState.from_json(restored_root, json_data) + + @pytest.mark.asyncio + async def test_result_to_state_preserves_duplicate_name_root_and_owned_state(self): + """RunResult.to_state should keep the root graph while preserving the active duplicate.""" + + @function_tool(name_override="approval_tool", needs_approval=True) + def approval_tool() -> str: + return "approved" + + first_model = FakeModel() + second_model = FakeModel() + first = Agent(name="duplicate", model=first_model) + second = Agent( + name="duplicate", + model=second_model, + tools=[approval_tool], + model_settings=ModelSettings(tool_choice="required"), + ) + first.handoffs = [second] + second.handoffs = [first] + + first_model.add_multiple_turn_outputs([[get_handoff_tool_call(second)]]) + second_model.add_multiple_turn_outputs( + [[get_function_tool_call("approval_tool", json.dumps({}), call_id="call_approval")]] + ) + + result = await Runner.run(first, "start") + assert result.interruptions + + state = result.to_state() + assert state._starting_agent is first + assert state._current_agent is second + + json_data = state.to_json() + assert json_data["current_agent"] == {"name": "duplicate", "identity": "duplicate#2"} + assert json_data["tool_use_tracker"]["duplicate#2"] == ["approval_tool"] + assert json_data["current_step"] is not None + assert json_data["current_step"]["data"]["interruptions"][0]["agent"] == { + "name": "duplicate", + "identity": "duplicate#2", + } + + approval_tool_items = [ + item + for item in json_data["generated_items"] + if item["type"] == "tool_call_item" + and item["raw_item"].get("call_id") == "call_approval" + ] + assert len(approval_tool_items) == 1 + assert approval_tool_items[0]["agent"] == { + "name": "duplicate", + "identity": "duplicate#2", + } + assert approval_tool_items[0]["raw_item"] == { + "arguments": "{}", + "call_id": "call_approval", + "id": "1", + "name": "approval_tool", + "type": "function_call", + } + + restored = await RunState.from_json(first, json_data) + assert restored._starting_agent is first + assert restored._current_agent is second + assert restored.get_interruptions()[0].agent is second + assert any( + isinstance(item, ToolCallItem) + and item.agent is second + and getattr(item.raw_item, "call_id", None) == "call_approval" + for item in restored._generated_items + ) + async def test_reasoning_item_id_policy_survives_serialization(self): """RunState should preserve reasoning item input policy across serialization.""" context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) @@ -1917,6 +2221,64 @@ async def test_serialization_includes_handoff_fields(self): assert len(restored._generated_items) == 1 assert restored._generated_items[0].type == "handoff_output_item" + @pytest.mark.asyncio + async def test_serialization_uses_duplicate_identities_for_handoff_and_output_guardrails(self): + """Duplicate-name item ownership should round-trip with identity keys.""" + first = Agent(name="duplicate") + second = Agent(name="duplicate") + third = Agent(name="duplicate") + first.handoffs = [second, third] + second.handoffs = [third] + third.handoffs = [first] + + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + state = make_state(first, context=context, original_input="test handoff", max_turns=2) + state._current_agent = second + state._generated_items = [ + HandoffOutputItem( + agent=second, + raw_item={"type": "handoff_output", "status": "completed"}, # type: ignore[arg-type] + source_agent=second, + target_agent=third, + ) + ] + + output_guardrail = OutputGuardrail( + guardrail_function=lambda _ctx, _agent, _output: GuardrailFunctionOutput( + output_info={"guardrail": "ok"}, + tripwire_triggered=False, + ), + name="duplicate_output_guardrail", + ) + state._output_guardrail_results = [ + OutputGuardrailResult( + guardrail=output_guardrail, + agent_output="done", + agent=third, + output=GuardrailFunctionOutput( + output_info={"guardrail": "ok"}, + tripwire_triggered=False, + ), + ) + ] + + json_data = state.to_json() + item_data = json_data["generated_items"][0] + assert item_data["agent"] == {"name": "duplicate", "identity": "duplicate#2"} + assert item_data["source_agent"] == {"name": "duplicate", "identity": "duplicate#2"} + assert item_data["target_agent"] == {"name": "duplicate", "identity": "duplicate#3"} + assert json_data["output_guardrail_results"][0]["agent"] == { + "name": "duplicate", + "identity": "duplicate#3", + } + + restored = await RunState.from_json(first, json_data) + restored_item = cast(HandoffOutputItem, restored._generated_items[0]) + assert restored_item.agent is second + assert restored_item.source_agent is second + assert restored_item.target_agent is third + assert restored._output_guardrail_results[0].agent is third + async def test_model_response_serialization_roundtrip(self): """Test that model responses serialize and deserialize correctly.""" @@ -3969,7 +4331,7 @@ async def test_from_json_missing_schema_version(self): await RunState.from_json(agent, state_json) @pytest.mark.asyncio - @pytest.mark.parametrize("schema_version", ["1.7", "2.0"]) + @pytest.mark.parametrize("schema_version", [_NEXT_UNSUPPORTED_SCHEMA_VERSION, "2.0", "9.9"]) async def test_from_json_unsupported_schema_version(self, schema_version: str): """Test that from_json raises error when schema version is unsupported.""" agent = Agent(name="TestAgent") @@ -4021,9 +4383,42 @@ async def test_from_json_accepts_previous_schema_version(self): def test_supported_schema_versions_match_released_boundary(self): """The support set should include released versions plus the current unreleased writer.""" assert SUPPORTED_SCHEMA_VERSIONS == frozenset( - {"1.0", "1.1", "1.2", "1.3", "1.4", "1.5", CURRENT_SCHEMA_VERSION} + {"1.0", "1.1", "1.2", "1.3", "1.4", "1.5", "1.6", CURRENT_SCHEMA_VERSION} ) + def test_supported_schema_versions_have_non_empty_summaries(self): + """Every supported schema version should have a one-line historical summary.""" + assert frozenset(SCHEMA_VERSION_SUMMARIES) == SUPPORTED_SCHEMA_VERSIONS + assert CURRENT_SCHEMA_VERSION in SCHEMA_VERSION_SUMMARIES + assert all(summary.strip() for summary in SCHEMA_VERSION_SUMMARIES.values()) + + @pytest.mark.asyncio + async def test_from_json_accepts_schema_version_1_5_without_sandbox_payload(self): + """RunState snapshots written before sandbox resume support should still restore.""" + agent = Agent(name="TestAgent") + state_json = { + "$schemaVersion": "1.5", + "original_input": "test", + "current_agent": {"name": "TestAgent"}, + "context": { + "context": {"foo": "bar"}, + "usage": {"requests": 0, "input_tokens": 0, "output_tokens": 0, "total_tokens": 0}, + "approvals": {}, + }, + "max_turns": 3, + "current_turn": 0, + "model_responses": [], + "generated_items": [], + } + + restored = await RunState.from_json(agent, state_json) + + assert restored._current_agent is not None + assert restored._current_agent.name == "TestAgent" + assert restored._context is not None + assert restored._context.context == {"foo": "bar"} + assert restored._sandbox is None + @pytest.mark.asyncio async def test_from_json_agent_not_found(self): """Test that from_json raises error when agent is not found in agent map.""" diff --git a/tests/test_run_step_execution.py b/tests/test_run_step_execution.py index c8226903a8..4cb8fa7718 100644 --- a/tests/test_run_step_execution.py +++ b/tests/test_run_step_execution.py @@ -46,7 +46,9 @@ tool_output_guardrail, trace, ) -from agents.run_internal import run_loop +from agents._public_agent import set_public_agent +from agents.run_internal import run_loop, turn_resolution +from agents.run_internal.agent_bindings import bind_execution_agent, bind_public_agent from agents.run_internal.run_loop import ( NextStepFinalOutput, NextStepHandoff, @@ -106,6 +108,13 @@ def _function_span_names() -> list[str]: return names +def _bind_agent(agent: Agent[Any]): + public_agent = getattr(agent, "_agents_public_agent", None) + if isinstance(public_agent, Agent): + return bind_execution_agent(public_agent=public_agent, execution_agent=agent) + return bind_public_agent(agent) + + @pytest.mark.asyncio async def test_empty_response_is_final_output(): agent = Agent[None](name="test") @@ -1142,7 +1151,7 @@ def _failure_handler(_ctx: RunContextWrapper[Any], error: Exception) -> str: execution_task = asyncio.create_task( execute_function_tool_calls( - agent=agent, + bindings=bind_public_agent(agent), tool_runs=tool_runs, hooks=RecordingHooks(), context_wrapper=RunContextWrapper(None), @@ -1188,7 +1197,7 @@ async def _shipping_eta(tracking_number: str) -> str: with trace("test_execute_function_tool_calls_collapse_trace_name_for_top_level_deferred_tools"): await execute_function_tool_calls( - agent=Agent(name="test", tools=[tool]), + bindings=bind_public_agent(Agent(name="test", tools=[tool])), tool_runs=[tool_run], hooks=RunHooks(), context_wrapper=RunContextWrapper(None), @@ -1230,7 +1239,7 @@ async def _shipping_eta(tracking_number: str) -> str: with trace("test_execute_function_tool_calls_preserve_trace_name_for_explicit_namespace"): await execute_function_tool_calls( - agent=Agent(name="test", tools=[tool]), + bindings=bind_public_agent(Agent(name="test", tools=[tool])), tool_runs=[tool_run], hooks=RunHooks(), context_wrapper=RunContextWrapper(None), @@ -2556,7 +2565,7 @@ async def get_execute_result( handoffs=handoffs, ) return await run_loop.execute_tools_and_side_effects( - agent=agent, + bindings=_bind_agent(agent), original_input=original_input or "hello", new_response=response, pre_step_items=generated_items or [], @@ -2574,7 +2583,7 @@ async def run_execute_with_processed_response( """Execute tools for a pre-constructed ProcessedResponse.""" return await run_loop.execute_tools_and_side_effects( - agent=agent, + bindings=_bind_agent(agent), original_input="test", pre_step_items=[], new_response=ModelResponse(output=[], usage=Usage(), response_id="resp"), @@ -2759,6 +2768,58 @@ async def test_execute_tools_runs_hosted_mcp_callback_when_present(): assert not result.processed_response or not result.processed_response.interruptions +@pytest.mark.asyncio +async def test_execute_tools_uses_public_agent_for_hosted_mcp_callback_results(): + """Hosted MCP callback responses should expose the public agent when execution uses a clone.""" + + mcp_tool = HostedMCPTool( + tool_config={ + "type": "mcp", + "server_label": "test_mcp_server", + "server_url": "https://example.com", + "require_approval": "always", + }, + on_approval_request=lambda request: {"approve": True}, + ) + public_agent = make_agent(tools=[mcp_tool]) + execution_agent = public_agent.clone() + set_public_agent(execution_agent, public_agent) + request_item = McpApprovalRequest( + id="mcp-approval-callback-public-agent", + type="mcp_approval_request", + server_label="test_mcp_server", + arguments="{}", + name="list_repo_languages", + ) + processed_response = make_processed_response( + new_items=[MCPApprovalRequestItem(raw_item=request_item, agent=execution_agent)], + mcp_approval_requests=[ + ToolRunMCPApprovalRequest( + request_item=request_item, + mcp_tool=mcp_tool, + ) + ], + ) + + result = await run_loop.execute_tools_and_side_effects( + bindings=_bind_agent(execution_agent), + original_input="test", + pre_step_items=[], + new_response=ModelResponse(output=[], usage=Usage(), response_id="resp"), + processed_response=processed_response, + output_schema=None, + hooks=RunHooks(), + context_wrapper=make_context_wrapper(), + run_config=RunConfig(), + ) + + assert not isinstance(result.next_step, NextStepInterruption) + assert any( + isinstance(item, MCPApprovalResponseItem) and item.agent is public_agent + for item in result.new_step_items + ) + + @pytest.mark.asyncio async def test_execute_tools_surfaces_hosted_mcp_interruptions_without_callback(): """Hosted MCP approvals should surface as interruptions when no callback is provided.""" @@ -2802,6 +2863,150 @@ async def test_execute_tools_surfaces_hosted_mcp_interruptions_without_callback( ) +@pytest.mark.asyncio +async def test_execute_tools_uses_public_agent_for_hosted_mcp_interruptions(): + """Hosted MCP approval items should expose the public agent when execution uses a clone.""" + + mcp_tool = HostedMCPTool( + tool_config={ + "type": "mcp", + "server_label": "test_mcp_server", + "server_url": "https://example.com", + "require_approval": "always", + }, + on_approval_request=None, + ) + public_agent = make_agent(tools=[mcp_tool]) + execution_agent = public_agent.clone() + set_public_agent(execution_agent, public_agent) + request_item = McpApprovalRequest( + id="mcp-approval-public-agent", + type="mcp_approval_request", + server_label="test_mcp_server", + arguments="{}", + name="list_repo_languages", + ) + processed_response = make_processed_response( + new_items=[MCPApprovalRequestItem(raw_item=request_item, agent=execution_agent)], + mcp_approval_requests=[ + ToolRunMCPApprovalRequest( + request_item=request_item, + mcp_tool=mcp_tool, + ) + ], + ) + + result = await run_loop.execute_tools_and_side_effects( + bindings=_bind_agent(execution_agent), + original_input="test", + pre_step_items=[], + new_response=ModelResponse(output=[], usage=Usage(), response_id="resp"), + processed_response=processed_response, + output_schema=None, + hooks=RunHooks(), + context_wrapper=make_context_wrapper(), + run_config=RunConfig(), + ) + + assert isinstance(result.next_step, NextStepInterruption) + assert result.next_step.interruptions + assert all(item.agent is public_agent for item in result.next_step.interruptions) + assert any( + isinstance(item, ToolApprovalItem) + and getattr(item.raw_item, "id", None) == "mcp-approval-public-agent" + and item.agent is public_agent + for item in result.new_step_items + ) + + +@pytest.mark.asyncio +async def test_resolve_interrupted_turn_uses_public_agent_for_resumed_hosted_mcp_approvals(): + """Resumed hosted MCP approvals should keep the public agent on approval responses.""" + + mcp_tool = HostedMCPTool( + tool_config={ + "type": "mcp", + "server_label": "test_mcp_server", + "server_url": "https://example.com", + "require_approval": "always", + }, + on_approval_request=None, + ) + public_agent = make_agent(tools=[mcp_tool]) + execution_agent = public_agent.clone() + set_public_agent(execution_agent, public_agent) + request_item = McpApprovalRequest( + id="mcp-approval-resume-public-agent", + type="mcp_approval_request", + server_label="test_mcp_server", + arguments="{}", + name="list_repo_languages", + ) + approval_item = ToolApprovalItem( + agent=public_agent, + raw_item=request_item, + tool_name="list_repo_languages", + ) + context_wrapper = make_context_wrapper() + context_wrapper.approve_tool(approval_item) + processed_response = make_processed_response( + new_items=[MCPApprovalRequestItem(raw_item=request_item, agent=execution_agent)], + mcp_approval_requests=[ + ToolRunMCPApprovalRequest( + request_item=request_item, + mcp_tool=mcp_tool, + ) + ], + ) + + result = await turn_resolution.resolve_interrupted_turn( + bindings=_bind_agent(execution_agent), + original_input="test", + original_pre_step_items=[approval_item], + new_response=ModelResponse(output=[], usage=Usage(), response_id="resp"), + processed_response=processed_response, + hooks=RunHooks(), + context_wrapper=context_wrapper, + run_config=RunConfig(), + ) + + responses = [ + item + for item in result.new_step_items + if isinstance(item, MCPApprovalResponseItem) + and item.raw_item.get("approval_request_id") == "mcp-approval-resume-public-agent" + ] + assert responses + assert all(item.agent is public_agent for item in responses) + + +@pytest.mark.asyncio +async def test_execute_handoffs_uses_public_agent_for_ignored_extra_handoffs(): + """Ignored extra handoff outputs should stay owned by the public agent.""" + + first_target = Agent(name="alpha") + second_target = Agent(name="beta") + public_agent = Agent(name="triage", handoffs=[first_target, second_target]) + execution_agent = public_agent.clone() + set_public_agent(execution_agent, public_agent) + response = ModelResponse( + output=[get_handoff_tool_call(first_target), get_handoff_tool_call(second_target)], + usage=Usage(), + response_id="resp", + ) + + result = await get_execute_result(execution_agent, response) + + ignored_outputs = [ + item + for item in result.new_step_items + if isinstance(item, ToolCallOutputItem) + and item.output == "Multiple handoffs detected, ignoring this one." + ] + assert len(ignored_outputs) == 1 + assert ignored_outputs[0].agent is public_agent + + @pytest.mark.asyncio async def test_execute_tools_emits_hosted_mcp_rejection_response(): """Hosted MCP rejections without callbacks should emit approval responses.""" @@ -2836,7 +3041,7 @@ async def test_execute_tools_emits_hosted_mcp_rejection_response(): reject_tool_call(context_wrapper, agent, request_item, tool_name="list_repo_languages") result = await run_loop.execute_tools_and_side_effects( - agent=agent, + bindings=_bind_agent(agent), original_input="test", pre_step_items=[], new_response=ModelResponse(output=[], usage=Usage(), response_id="resp"), @@ -2897,7 +3102,7 @@ async def test_execute_tools_emits_hosted_mcp_rejection_reason_from_explicit_mes ) result = await run_loop.execute_tools_and_side_effects( - agent=agent, + bindings=_bind_agent(agent), original_input="test", pre_step_items=[], new_response=ModelResponse(output=[], usage=Usage(), response_id="resp"), diff --git a/tests/test_run_step_processing.py b/tests/test_run_step_processing.py index 2682ba647d..8d83193185 100644 --- a/tests/test_run_step_processing.py +++ b/tests/test_run_step_processing.py @@ -232,7 +232,7 @@ def fake_nest( monkeypatch.setattr("agents.run_internal.turn_resolution.nest_handoff_history", fake_nest) result = await run_loop.execute_handoffs( - agent=source_agent, + public_agent=source_agent, original_input=list(original_input), pre_step_items=pre_step_items, new_step_items=new_step_items, @@ -280,7 +280,7 @@ def fake_nest( monkeypatch.setattr("agents.run_internal.turn_resolution.nest_handoff_history", fake_nest) result = await run_loop.execute_handoffs( - agent=source_agent, + public_agent=source_agent, original_input=list(original_input), pre_step_items=pre_step_items, new_step_items=new_step_items, diff --git a/tests/test_sandbox_app_server_client.py b/tests/test_sandbox_app_server_client.py new file mode 100644 index 0000000000..5bbf59ef9e --- /dev/null +++ b/tests/test_sandbox_app_server_client.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +import json +from collections import deque +from typing import Any, Callable + +import pytest + +from agents.sandbox.app_server.client import ( + AppServerClient, + AppServerConfig, + AppServerTransportOps, +) +from agents.sandbox.app_server.models import InitializeResponse, UnknownNotification + + +class _FakeConnection: + def __init__(self, responses: list[str | Callable[[list[dict[str, object]]], str]]) -> None: + self._responses = deque(responses) + self.sent_payloads: list[dict[str, object]] = [] + self.closed = False + + def send(self, message: str, text: bool | None = None) -> None: + assert text is True + self.sent_payloads.append(json.loads(message)) + + def recv(self, timeout: float | None = None) -> str: + assert timeout is None or timeout >= 0 + response = self._responses.popleft() + if callable(response): + return response(self.sent_payloads) + return response + + def close(self) -> None: + self.closed = True + + +def _initialize_result(sent_payloads: list[dict[str, object]]) -> str: + request = next(payload for payload in sent_payloads if payload.get("method") == "initialize") + return json.dumps( + { + "id": request["id"], + "result": { + "serverInfo": {"name": "codex-app-server", "version": "2"}, + "platformOs": "linux", + }, + } + ) + + +def test_app_server_client_initializes_over_explicit_websocket_url() -> None: + fake_connection = _FakeConnection([_initialize_result]) + captured_connect: dict[str, object] = {} + + def _connect(url: str, **kwargs: object) -> Any: + captured_connect["url"] = url + captured_connect["kwargs"] = kwargs + return fake_connection + + client = AppServerClient( + AppServerConfig( + websocket_url="ws://sandbox.example.test:4500/", + websocket_headers={"Authorization": "Bearer test"}, + ), + transport_ops=AppServerTransportOps(ws_connect=_connect), + ) + + try: + client.start() + response = client.initialize() + finally: + client.close() + + assert captured_connect["url"] == "ws://sandbox.example.test:4500/" + assert captured_connect["kwargs"] == { + "additional_headers": {"Authorization": "Bearer test"}, + "open_timeout": 10.0, + "max_size": None, + } + assert isinstance(response, InitializeResponse) + assert response.serverInfo is not None + assert response.serverInfo.name == "codex-app-server" + assert response.platformOs == "linux" + assert fake_connection.sent_payloads[0]["method"] == "initialize" + assert fake_connection.sent_payloads[1] == {"method": "initialized", "params": {}} + assert fake_connection.closed is True + + +def test_app_server_client_handles_server_requests_and_queues_notifications() -> None: + def _initialize_with_notification(sent_payloads: list[dict[str, object]]) -> str: + request = next( + payload for payload in sent_payloads if payload.get("method") == "initialize" + ) + return json.dumps({"id": request["id"], "result": {"platformOs": "linux"}}) + + fake_connection = _FakeConnection( + [ + json.dumps( + { + "id": "server-approval-1", + "method": "item/commandExecution/requestApproval", + "params": {"command": "ls"}, + } + ), + json.dumps({"method": "custom/notice", "params": {"seen": True}}), + _initialize_with_notification, + ] + ) + + def _connect(url: str, **kwargs: object) -> Any: + return fake_connection + + client = AppServerClient( + AppServerConfig(websocket_url="ws://sandbox.example.test:4500/"), + transport_ops=AppServerTransportOps(ws_connect=_connect), + ) + + try: + client.start() + response = client.initialize() + notification = client.next_notification() + finally: + client.close() + + assert response.platformOs == "linux" + assert fake_connection.sent_payloads[1] == { + "id": "server-approval-1", + "result": {"decision": "accept"}, + } + assert notification.method == "custom/notice" + assert isinstance(notification.payload, UnknownNotification) + assert notification.payload.params == {"seen": True} + + +def test_app_server_client_requires_websocket_url() -> None: + def _unused_connect(url: str, **kwargs: object) -> Any: + raise AssertionError("ws_connect should not be called without a websocket_url") + + client = AppServerClient( + AppServerConfig(), + transport_ops=AppServerTransportOps(ws_connect=_unused_connect), + ) + + with pytest.raises(ValueError, match="websocket_url is required"): + client.start() diff --git a/tests/test_sandbox_dependencies.py b/tests/test_sandbox_dependencies.py new file mode 100644 index 0000000000..ed282cf3e1 --- /dev/null +++ b/tests/test_sandbox_dependencies.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +import pytest + +from agents.sandbox.session import ( + Dependencies, + DependenciesBindingError, + DependenciesMissingDependencyError, +) + + +class _AsyncClosable: + def __init__(self) -> None: + self.calls = 0 + + async def aclose(self) -> None: + self.calls += 1 + + +class _AsyncCloseMethod: + def __init__(self) -> None: + self.calls = 0 + + async def close(self) -> None: + self.calls += 1 + + +class _SyncClosable: + def __init__(self) -> None: + self.calls = 0 + + def close(self) -> None: + self.calls += 1 + + +@pytest.mark.asyncio +async def test_dependencies_with_values_binds_multiple_values() -> None: + key1 = "tests.with_values.str" + key2 = "tests.with_values.int" + dependencies = Dependencies.with_values({key1: "hello", key2: 123}) + + assert await dependencies.require(key1) == "hello" + assert await dependencies.require(key2) == 123 + + +@pytest.mark.asyncio +async def test_dependencies_bind_value_and_require() -> None: + dependencies = Dependencies() + key = "tests.value" + dependencies.bind_value(key, "hello") + + assert await dependencies.get(key) == "hello" + assert await dependencies.require(key, consumer="test") == "hello" + + +@pytest.mark.asyncio +async def test_dependencies_missing_dependency_includes_key_and_consumer() -> None: + dependencies = Dependencies() + key = "tests.missing" + + with pytest.raises(DependenciesMissingDependencyError, match="tests.missing"): + await dependencies.require(key, consumer="SedimentFile") + + +def test_dependencies_duplicate_binding_raises() -> None: + dependencies = Dependencies() + key = "tests.dup" + dependencies.bind_value(key, "a") + + with pytest.raises(DependenciesBindingError, match="already bound"): + dependencies.bind_value(key, "b") + + +def test_dependencies_empty_key_raises() -> None: + dependencies = Dependencies() + + with pytest.raises(ValueError, match="non-empty"): + dependencies.bind_value("", "x") + + with pytest.raises(ValueError, match="non-empty"): + dependencies.bind_factory("", lambda _dependencies: "x") + + +@pytest.mark.asyncio +async def test_dependencies_cached_factory_resolves_once() -> None: + dependencies = Dependencies() + key = "tests.cached_factory" + calls = 0 + + def _factory(_dependencies: Dependencies) -> str: + nonlocal calls + calls += 1 + return f"value-{calls}" + + dependencies.bind_factory(key, _factory, cache=True) + + assert await dependencies.require(key) == "value-1" + assert await dependencies.require(key) == "value-1" + assert calls == 1 + + +@pytest.mark.asyncio +async def test_dependencies_uncached_factory_resolves_every_time() -> None: + dependencies = Dependencies() + key = "tests.uncached_factory" + calls = 0 + + def _factory(_dependencies: Dependencies) -> str: + nonlocal calls + calls += 1 + return f"value-{calls}" + + dependencies.bind_factory(key, _factory, cache=False) + + assert await dependencies.require(key) == "value-1" + assert await dependencies.require(key) == "value-2" + assert calls == 2 + + +@pytest.mark.asyncio +async def test_dependencies_async_factory_supported() -> None: + dependencies = Dependencies() + key = "tests.async_factory" + + async def _factory(_dependencies: Dependencies) -> str: + return "async-value" + + dependencies.bind_factory(key, _factory) + assert await dependencies.require(key) == "async-value" + + +@pytest.mark.asyncio +async def test_dependencies_aclose_closes_owned_results_and_is_idempotent() -> None: + dependencies = Dependencies() + k1 = "tests.async_aclose" + k2 = "tests.async_close" + k3 = "tests.sync_close" + + dependencies.bind_factory(k1, lambda _deps: _AsyncClosable(), owns_result=True) + dependencies.bind_factory(k2, lambda _deps: _AsyncCloseMethod(), owns_result=True) + dependencies.bind_factory(k3, lambda _deps: _SyncClosable(), owns_result=True, cache=False) + + v1 = await dependencies.require(k1) + v2 = await dependencies.require(k2) + v3a = await dependencies.require(k3) + v3b = await dependencies.require(k3) + + assert v3a is not v3b + + await dependencies.aclose() + await dependencies.aclose() + + assert isinstance(v1, _AsyncClosable) and v1.calls == 1 + assert isinstance(v2, _AsyncCloseMethod) and v2.calls == 1 + assert isinstance(v3a, _SyncClosable) and v3a.calls == 1 + assert isinstance(v3b, _SyncClosable) and v3b.calls == 1 + + +@pytest.mark.asyncio +async def test_dependencies_bound_values_are_not_closed() -> None: + dependencies = Dependencies() + key = "tests.bound_value" + value = _SyncClosable() + dependencies.bind_value(key, value) + + _ = await dependencies.require(key) + await dependencies.aclose() + + assert value.calls == 0 diff --git a/tests/test_sandbox_docker.py b/tests/test_sandbox_docker.py new file mode 100644 index 0000000000..5708df879d --- /dev/null +++ b/tests/test_sandbox_docker.py @@ -0,0 +1,937 @@ +from __future__ import annotations + +import asyncio +import io +import shutil +import tarfile +import uuid +from collections.abc import Callable +from pathlib import Path +from typing import cast + +import docker.errors # type: ignore[import-untyped] +import pytest +from pydantic import PrivateAttr + +import agents.sandbox.sandboxes.docker as docker_sandbox +from agents.sandbox.entries import ( + AzureBlobMount, + Dir, + File, + FuseMountPattern, + Mount, + RcloneMountPattern, +) +from agents.sandbox.errors import ( + ExecTimeoutError, + InvalidManifestPathError, + WorkspaceArchiveReadError, +) +from agents.sandbox.manifest import Manifest +from agents.sandbox.sandboxes.docker import ( + DockerSandboxClient, + DockerSandboxSession, + DockerSandboxSessionState, + _manifest_requires_fuse, + _manifest_requires_sys_admin, +) +from agents.sandbox.snapshot import NoopSnapshot +from agents.sandbox.types import ExecResult + + +class _FakeDockerContainer: + def __init__(self, host_root: Path, *, archive_error: Exception | None = None) -> None: + self._host_root = host_root + self.status = "running" + self.archive_calls: list[str] = [] + self.archive_error = archive_error + + def reload(self) -> None: + return + + def get_archive(self, path: str) -> tuple[object, dict[str, object]]: + self.archive_calls.append(path) + if self.archive_error is not None: + raise self.archive_error + if path == "/workspace": + raise docker.errors.APIError("root archive unsupported") + + host_path = self._host_path(path) + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w") as tar: + tar.add(host_path, arcname=Path(path).name) + buf.seek(0) + return iter([buf.getvalue()]), {} + + def _host_path(self, path: str | Path) -> Path: + container_path = Path(path) + return self._host_root / container_path.relative_to("/") + + +class _PullRecorder: + def __init__(self) -> None: + self.calls: list[tuple[str, str | None, bool]] = [] + + def pull(self, repo: str, *, tag: str | None = None, all_tags: bool = False) -> None: + self.calls.append((repo, tag, all_tags)) + + +class _FakeDockerClient: + def __init__(self) -> None: + self.images = _PullRecorder() + + +class _HostBackedDockerSession(DockerSandboxSession): + def __init__( + self, + *, + host_root: Path, + manifest: Manifest, + event_log: list[tuple[str, str]] | None = None, + archive_error: Exception | None = None, + ) -> None: + container = _FakeDockerContainer(host_root, archive_error=archive_error) + state = DockerSandboxSessionState( + manifest=manifest, + snapshot=NoopSnapshot(id="snapshot"), + image="python:3.11-slim", + container_id="container", + ) + super().__init__( + docker_client=object(), + container=container, + state=state, + ) + self._host_root = host_root + self._fake_container = container + self._event_log = event_log if event_log is not None else [] + + async def _exec_internal( + self, + *command: str | Path, + timeout: float | None = None, + ) -> ExecResult: + _ = timeout + cmd = [str(part) for part in command] + if cmd[:2] == ["mkdir", "-p"]: + self._host_path(cmd[2]).mkdir(parents=True, exist_ok=True) + return ExecResult(stdout=b"", stderr=b"", exit_code=0) + if cmd[:3] == ["cp", "-R", "--"]: + self._event_log.append(("cp", cmd[3])) + src = self._host_path(cmd[3]) + dst = self._host_path(cmd[4]) + shutil.copytree(src, dst) + return ExecResult(stdout=b"", stderr=b"", exit_code=0) + if cmd[:2] == ["rm", "-rf"]: + target = self._host_path(cmd[3]) + if target.is_dir(): + shutil.rmtree(target, ignore_errors=True) + else: + try: + target.unlink() + except FileNotFoundError: + pass + return ExecResult(stdout=b"", stderr=b"", exit_code=0) + raise AssertionError(f"Unexpected command: {cmd!r}") + + def _host_path(self, path: str | Path) -> Path: + container_path = Path(path) + return self._host_root / container_path.relative_to("/") + + +class _RecordingMount(Mount): + type: str = f"recording_mount_{uuid.uuid4().hex}" + remove_on_unmount: bool = True + remount_marker: str | None = None + _events: list[tuple[str, str]] = PrivateAttr(default_factory=list) + + def bind_events(self, events: list[tuple[str, str]]) -> _RecordingMount: + self._events = events + return self + + async def _mount(self, session: object, path: Path) -> None: + host_path = cast(_HostBackedDockerSession, session)._host_path(path) + host_path.mkdir(parents=True, exist_ok=True) + self._events.append(("mount", str(path))) + if self.remount_marker is not None: + (host_path / self.remount_marker).write_text("remounted", encoding="utf-8") + + async def _unmount(self, session: object, path: Path) -> None: + host_path = cast(_HostBackedDockerSession, session)._host_path(path) + self._events.append(("unmount", str(path))) + if not self.remove_on_unmount: + return + shutil.rmtree(host_path, ignore_errors=True) + + +class _FailingUnmountMount(_RecordingMount): + type: str = f"failing_unmount_mount_{uuid.uuid4().hex}" + + async def _unmount(self, session: object, path: Path) -> None: + self._events.append(("unmount_fail", str(path))) + raise RuntimeError("boom while unmounting second mount") + + +class _FailingRemountMount(_RecordingMount): + type: str = f"failing_remount_mount_{uuid.uuid4().hex}" + + async def _mount(self, session: object, path: Path) -> None: + self._events.append(("mount_fail", str(path))) + raise RuntimeError("boom while remounting second mount") + + +class _OrderSensitiveMount(_RecordingMount): + type: str = f"order_sensitive_mount_{uuid.uuid4().hex}" + require_unmounted_before: str | None = None + require_mounted_before: str | None = None + + async def _mount(self, session: object, path: Path) -> None: + if ( + self.require_mounted_before is not None + and ( + "mount", + self.require_mounted_before, + ) + not in self._events + ): + self._events.append(("mount_fail", str(path))) + raise RuntimeError("parent mount missing") + await super()._mount(session, path) + + async def _unmount(self, session: object, path: Path) -> None: + if ( + self.require_unmounted_before is not None + and ( + "unmount", + self.require_unmounted_before, + ) + not in self._events + ): + self._events.append(("unmount_fail", str(path))) + raise RuntimeError("target is busy") + await super()._unmount(session, path) + + +def _archive_member_names(archive: io.IOBase) -> list[str]: + payload = archive.read() + if not isinstance(payload, bytes): + raise AssertionError(f"Expected bytes archive payload, got {type(payload)!r}") + with tarfile.open(fileobj=io.BytesIO(payload), mode="r:*") as tar: + return tar.getnames() + + +@pytest.mark.asyncio +async def test_docker_persist_workspace_stages_copy_before_get_archive( + tmp_path: Path, +) -> None: + host_root = tmp_path / "container" + workspace = host_root / "workspace" + workspace.mkdir(parents=True) + (workspace / "README.md").write_text("hello from workspace", encoding="utf-8") + + session = _HostBackedDockerSession( + host_root=host_root, + manifest=Manifest(root="/workspace"), + ) + + archive = await session.persist_workspace() + + names = _archive_member_names(archive) + + assert "/workspace" not in session._fake_container.archive_calls + assert any(name.endswith("workspace") for name in names) + assert any(name.endswith("workspace/README.md") for name in names) + + +@pytest.mark.asyncio +async def test_docker_persist_workspace_prunes_ephemeral_entries_from_staged_copy( + tmp_path: Path, +) -> None: + host_root = tmp_path / "container" + workspace = host_root / "workspace" + workspace.mkdir(parents=True) + (workspace / "keep.txt").write_text("keep", encoding="utf-8") + (workspace / "skip.txt").write_text("skip", encoding="utf-8") + + session = _HostBackedDockerSession( + host_root=host_root, + manifest=Manifest( + root="/workspace", + entries={ + "skip.txt": File(content=b"skip", ephemeral=True), + }, + ), + ) + + archive = await session.persist_workspace() + + names = _archive_member_names(archive) + + assert any(name.endswith("workspace/keep.txt") for name in names) + assert not any(name.endswith("workspace/skip.txt") for name in names) + + +@pytest.mark.asyncio +async def test_docker_persist_workspace_unmounts_nested_ephemeral_mounts_before_copy( + tmp_path: Path, +) -> None: + host_root = tmp_path / "container" + workspace = host_root / "workspace" + mount_dir = workspace / "repo" / "mount" + mount_dir.mkdir(parents=True) + (mount_dir / "remote.txt").write_text("remote", encoding="utf-8") + + events: list[tuple[str, str]] = [] + mount = _RecordingMount(remount_marker="remounted.txt").bind_events(events) + session = _HostBackedDockerSession( + host_root=host_root, + manifest=Manifest( + root="/workspace", + entries={ + "repo": Dir( + children={ + "mount": mount, + } + ) + }, + ), + event_log=events, + ) + + archive = await session.persist_workspace() + + names = _archive_member_names(archive) + + assert [kind for kind, _path in events if kind in {"unmount", "cp", "mount"}] == [ + "unmount", + "cp", + "mount", + ] + assert not any(name.endswith("workspace/repo/mount/remote.txt") for name in names) + assert (mount_dir / "remounted.txt").read_text(encoding="utf-8") == "remounted" + + +@pytest.mark.asyncio +async def test_docker_persist_workspace_prunes_runtime_only_skip_paths_from_staged_copy( + tmp_path: Path, +) -> None: + host_root = tmp_path / "container" + workspace = host_root / "workspace" + logs = workspace / "logs" + logs.mkdir(parents=True) + (logs / "keep.txt").write_text("keep", encoding="utf-8") + (logs / "events.jsonl").write_text("skip", encoding="utf-8") + + session = _HostBackedDockerSession( + host_root=host_root, + manifest=Manifest(root="/workspace"), + ) + session._register_persist_workspace_skip_relpath(Path("logs/events.jsonl")) # noqa: SLF001 + + archive = await session.persist_workspace() + + names = _archive_member_names(archive) + + assert any(name.endswith("workspace/logs/keep.txt") for name in names) + assert not any(name.endswith("workspace/logs/events.jsonl") for name in names) + + +@pytest.mark.asyncio +async def test_docker_persist_workspace_prunes_explicit_mount_path_from_staged_copy( + tmp_path: Path, +) -> None: + host_root = tmp_path / "container" + workspace = host_root / "workspace" + actual_mount_path = workspace / "actual" + actual_mount_path.mkdir(parents=True) + (actual_mount_path / "remote.txt").write_text("remote", encoding="utf-8") + + mount = _RecordingMount(mount_path=Path("actual"), remove_on_unmount=False) + session = _HostBackedDockerSession( + host_root=host_root, + manifest=Manifest( + root="/workspace", + entries={ + "logical": mount, + }, + ), + ) + + archive = await session.persist_workspace() + + names = _archive_member_names(archive) + + assert not any(name.endswith("workspace/actual/remote.txt") for name in names) + assert (actual_mount_path / "remote.txt").read_text(encoding="utf-8") == "remote" + + +@pytest.mark.asyncio +async def test_docker_persist_workspace_remounts_prior_mounts_after_unmount_failure( + tmp_path: Path, +) -> None: + host_root = tmp_path / "container" + workspace = host_root / "workspace" + first_mount_dir = workspace / "repo" / "mount1" + second_mount_dir = workspace / "repo" / "mount2" + first_mount_dir.mkdir(parents=True) + second_mount_dir.mkdir(parents=True) + (first_mount_dir / "remote1.txt").write_text("remote-1", encoding="utf-8") + (second_mount_dir / "remote2.txt").write_text("remote-2", encoding="utf-8") + + events: list[tuple[str, str]] = [] + session = _HostBackedDockerSession( + host_root=host_root, + manifest=Manifest( + root="/workspace", + entries={ + "repo": Dir( + children={ + "mount1": _RecordingMount(remount_marker="remounted.txt").bind_events( + events + ), + "mount2": _FailingUnmountMount().bind_events(events), + } + ) + }, + ), + event_log=events, + ) + + with pytest.raises(WorkspaceArchiveReadError): + await session.persist_workspace() + + assert [kind for kind, _path in events] == [ + "unmount", + "unmount_fail", + "mount", + ] + assert "cp" not in [kind for kind, _path in events] + assert (first_mount_dir / "remounted.txt").read_text(encoding="utf-8") == "remounted" + assert (second_mount_dir / "remote2.txt").read_text(encoding="utf-8") == "remote-2" + + +@pytest.mark.asyncio +async def test_docker_persist_workspace_unmounts_nested_mount_paths_deepest_first( + tmp_path: Path, +) -> None: + host_root = tmp_path / "container" + workspace = host_root / "workspace" + parent_mount_dir = workspace / "repo" + child_mount_dir = parent_mount_dir / "sub" + child_mount_dir.mkdir(parents=True) + (child_mount_dir / "remote.txt").write_text("remote", encoding="utf-8") + + events: list[tuple[str, str]] = [] + child_path = "/workspace/repo/sub" + parent_path = "/workspace/repo" + parent_mount = _OrderSensitiveMount( + remount_marker="parent-remounted.txt", + require_unmounted_before=child_path, + ).bind_events(events) + child_mount = _OrderSensitiveMount( + mount_path=Path("repo/sub"), + remount_marker="child-remounted.txt", + require_mounted_before=parent_path, + ).bind_events(events) + session = _HostBackedDockerSession( + host_root=host_root, + manifest=Manifest( + root="/workspace", + entries={ + "repo": parent_mount, + "child": child_mount, + }, + ), + event_log=events, + ) + + archive = await session.persist_workspace() + + names = _archive_member_names(archive) + + assert [kind for kind, _path in events] == [ + "unmount", + "unmount", + "cp", + "mount", + "mount", + ] + assert [path for kind, path in events if kind == "unmount"] == [ + child_path, + parent_path, + ] + assert [path for kind, path in events if kind == "mount"] == [ + parent_path, + child_path, + ] + assert not any(name.endswith("workspace/repo/remote.txt") for name in names) + assert not any(name.endswith("workspace/repo/sub/remote.txt") for name in names) + assert (parent_mount_dir / "parent-remounted.txt").read_text(encoding="utf-8") == "remounted" + assert (child_mount_dir / "child-remounted.txt").read_text(encoding="utf-8") == "remounted" + + +@pytest.mark.asyncio +async def test_docker_persist_workspace_keeps_remounting_and_raises_remount_error_first( + tmp_path: Path, +) -> None: + host_root = tmp_path / "container" + workspace = host_root / "workspace" + first_mount_dir = workspace / "repo" / "a" + second_mount_dir = workspace / "repo" / "b" + first_mount_dir.mkdir(parents=True) + second_mount_dir.mkdir(parents=True) + (first_mount_dir / "remote1.txt").write_text("remote-1", encoding="utf-8") + (second_mount_dir / "remote2.txt").write_text("remote-2", encoding="utf-8") + + events: list[tuple[str, str]] = [] + session = _HostBackedDockerSession( + host_root=host_root, + manifest=Manifest( + root="/workspace", + entries={ + "repo": Dir( + children={ + "a": _RecordingMount(remount_marker="a-remounted.txt").bind_events(events), + "b": _FailingRemountMount().bind_events(events), + } + ) + }, + ), + event_log=events, + archive_error=docker.errors.APIError("snapshot failed"), + ) + + with pytest.raises(WorkspaceArchiveReadError) as exc_info: + await session.persist_workspace() + + assert isinstance(exc_info.value.cause, RuntimeError) + assert str(exc_info.value.cause) == "boom while remounting second mount" + snapshot_error = cast( + dict[str, str], + exc_info.value.context["snapshot_error_before_remount_corruption"], + ) + assert snapshot_error == { + "message": "failed to read archive for path: /workspace", + "cause_type": "APIError", + "cause": "snapshot failed", + } + assert exc_info.value.context["snapshot_error_before_remount_corruption"] == snapshot_error + assert "earlier_unmount_error" not in exc_info.value.context + assert "additional_remount_errors" not in exc_info.value.context + assert snapshot_error["cause"] == "snapshot failed" + assert snapshot_error["cause_type"] == "APIError" + assert exc_info.value.context["path"] == "/workspace" + assert [kind for kind, _path in events] == [ + "unmount", + "unmount", + "cp", + "mount_fail", + "mount", + ] + assert (first_mount_dir / "a-remounted.txt").read_text(encoding="utf-8") == "remounted" + + +@pytest.mark.asyncio +async def test_docker_read_and_write_reject_paths_outside_workspace_root(tmp_path: Path) -> None: + host_root = tmp_path / "container" + workspace = host_root / "workspace" + workspace.mkdir(parents=True) + + session = _HostBackedDockerSession( + host_root=host_root, + manifest=Manifest(root="/workspace"), + ) + + with pytest.raises(InvalidManifestPathError, match="must not escape root"): + await session.read(Path("../secret.txt")) + with pytest.raises(InvalidManifestPathError, match="must not escape root"): + await session.write(Path("../secret.txt"), io.BytesIO(b"nope")) + + +def test_manifest_requires_fuse_detects_nested_mounts() -> None: + manifest = Manifest( + entries={ + "workspace": Dir( + children={ + "mount": AzureBlobMount( + account="account", + container="container", + mount_pattern=FuseMountPattern(), + ) + } + ) + } + ) + + assert _manifest_requires_fuse(manifest) is True + + +def test_manifest_requires_sys_admin_detects_nested_mounts() -> None: + manifest = Manifest( + entries={ + "workspace": Dir( + children={ + "mount": AzureBlobMount( + account="account", + container="container", + mount_pattern=RcloneMountPattern(mode="nfs"), + ) + } + ) + } + ) + + assert _manifest_requires_sys_admin(manifest) is True + + +@pytest.mark.asyncio +async def test_docker_create_container_parses_registry_port_image_refs( + monkeypatch: pytest.MonkeyPatch, +) -> None: + docker_client = _FakeDockerClient() + client = DockerSandboxClient(docker_client=cast(object, docker_client)) + + def _missing_image(_image: str) -> bool: + return False + + monkeypatch.setattr(client, "image_exists", _missing_image) + with pytest.raises(AssertionError): + await client._create_container("localhost:5000/myimg:latest") + + assert docker_client.images.calls == [("localhost:5000/myimg", "latest", False)] + + +class _ExecRunContainer: + def __init__(self, *, workspace_exists: bool = False) -> None: + self.exec_calls: list[dict[str, object]] = [] + self._workspace_exists = workspace_exists + + def exec_run( + self, + cmd: list[str], + demux: bool = True, + workdir: str | None = None, + ) -> object: + self.exec_calls.append({"cmd": cmd, "demux": demux, "workdir": workdir}) + exit_code = 0 + if cmd == ["test", "-d", "--", "/workspace"]: + exit_code = 0 if self._workspace_exists else 1 + return type( + "_ExecResult", + (), + {"output": (b"", b""), "exit_code": exit_code}, + )() + + +class _ResumeDockerClient: + def __init__(self, container: object) -> None: + self._container = container + self.containers = self + + def get(self, container_id: str) -> object: + _ = container_id + if isinstance(self._container, BaseException): + raise self._container + return self._container + + +class _PositionalOnlyMissingDockerClient: + def __init__(self) -> None: + self.containers = self + + def get(self, container_id: str, /) -> object: + _ = container_id + raise docker.errors.NotFound("missing") + + +class _ResumeContainer: + def __init__( + self, + *, + status: str, + container_id: str = "container", + workspace_exists: bool = False, + ) -> None: + self.status = status + self.id = container_id + self.exec_calls: list[dict[str, object]] = [] + self._workspace_exists = workspace_exists + + def reload(self) -> None: + return + + def exec_run( + self, + cmd: list[str], + demux: bool = True, + workdir: str | None = None, + ) -> object: + self.exec_calls.append({"cmd": cmd, "demux": demux, "workdir": workdir}) + exit_code = 0 + if cmd == ["test", "-d", "--", "/workspace"]: + exit_code = 0 if self._workspace_exists else 1 + return type( + "_ExecResult", + (), + {"output": (b"", b""), "exit_code": exit_code}, + )() + + +@pytest.mark.asyncio +async def test_docker_exec_timeout_uses_shared_executor(monkeypatch: pytest.MonkeyPatch) -> None: + container = _ExecRunContainer() + session = DockerSandboxSession( + docker_client=object(), + container=container, + state=DockerSandboxSessionState( + manifest=Manifest(root="/workspace"), + snapshot=NoopSnapshot(id="snapshot"), + image="python:3.11-slim", + container_id="container", + ), + ) + + submitted_executors: list[object] = [] + loop = asyncio.get_running_loop() + + def fake_run_in_executor(executor: object, func: object) -> asyncio.Future[object]: + _ = func + submitted_executors.append(executor) + return asyncio.Future() + + monkeypatch.setattr(loop, "run_in_executor", fake_run_in_executor) + + with pytest.raises(ExecTimeoutError): + await session._exec_internal("sleep", "10", timeout=0.01) + with pytest.raises(ExecTimeoutError): + await session._exec_internal("sleep", "20", timeout=0.01) + + assert submitted_executors == [ + docker_sandbox._DOCKER_EXECUTOR, + docker_sandbox._DOCKER_EXECUTOR, + ] + assert container.exec_calls == [ + { + "cmd": ["sh", "-lc", "pkill -f -- 'sleep 10' >/dev/null 2>&1 || true"], + "demux": True, + "workdir": None, + }, + { + "cmd": ["sh", "-lc", "pkill -f -- 'sleep 20' >/dev/null 2>&1 || true"], + "demux": True, + "workdir": None, + }, + ] + + +@pytest.mark.asyncio +async def test_docker_exec_omits_workdir_until_workspace_ready( + monkeypatch: pytest.MonkeyPatch, +) -> None: + container = _ExecRunContainer() + session = DockerSandboxSession( + docker_client=object(), + container=container, + state=DockerSandboxSessionState( + manifest=Manifest(root="/workspace"), + snapshot=NoopSnapshot(id="snapshot"), + image="python:3.11-slim", + container_id="container", + ), + ) + + loop = asyncio.get_running_loop() + + def fake_run_in_executor( + executor: object, func: Callable[[], object] + ) -> asyncio.Future[object]: + _ = executor + future: asyncio.Future[object] = asyncio.Future() + future.set_result(func()) + return future + + monkeypatch.setattr(loop, "run_in_executor", fake_run_in_executor) + + result = await session._exec_internal("find", ".", timeout=0.01) + + assert result.ok() + assert container.exec_calls == [ + { + "cmd": ["find", "."], + "demux": True, + "workdir": None, + } + ] + + +@pytest.mark.asyncio +async def test_docker_exec_uses_manifest_root_as_workdir_after_workspace_ready( + monkeypatch: pytest.MonkeyPatch, +) -> None: + container = _ExecRunContainer() + session = DockerSandboxSession( + docker_client=object(), + container=container, + state=DockerSandboxSessionState( + manifest=Manifest(root="/workspace"), + snapshot=NoopSnapshot(id="snapshot"), + image="python:3.11-slim", + container_id="container", + ), + ) + session._workspace_root_ready = True + + loop = asyncio.get_running_loop() + + def fake_run_in_executor( + executor: object, func: Callable[[], object] + ) -> asyncio.Future[object]: + _ = executor + future: asyncio.Future[object] = asyncio.Future() + future.set_result(func()) + return future + + monkeypatch.setattr(loop, "run_in_executor", fake_run_in_executor) + + result = await session._exec_internal("find", ".", timeout=0.01) + + assert result.ok() + assert container.exec_calls == [ + { + "cmd": ["find", "."], + "demux": True, + "workdir": "/workspace", + } + ] + + +@pytest.mark.asyncio +async def test_docker_resume_preserves_workspace_readiness_from_state() -> None: + client = DockerSandboxClient( + docker_client=_ResumeDockerClient(_ResumeContainer(status="running")) + ) + + ready_session = await client.resume( + DockerSandboxSessionState( + manifest=Manifest(root="/workspace"), + snapshot=NoopSnapshot(id="snapshot"), + image="python:3.11-slim", + container_id="container", + workspace_root_ready=True, + ) + ) + not_ready_session = await client.resume( + DockerSandboxSessionState( + manifest=Manifest(root="/workspace"), + snapshot=NoopSnapshot(id="snapshot"), + image="python:3.11-slim", + container_id="container", + workspace_root_ready=False, + ) + ) + + assert isinstance(ready_session._inner, DockerSandboxSession) + assert ready_session._inner._workspace_root_ready is True + assert ready_session._inner.should_provision_manifest_accounts_on_resume() is False + assert isinstance(not_ready_session._inner, DockerSandboxSession) + assert not_ready_session._inner._workspace_root_ready is False + assert not_ready_session._inner.should_provision_manifest_accounts_on_resume() is False + + +@pytest.mark.asyncio +async def test_docker_resume_resets_workspace_readiness_when_container_is_recreated( + monkeypatch: pytest.MonkeyPatch, +) -> None: + client = DockerSandboxClient( + docker_client=cast(object, _ResumeDockerClient(docker.errors.NotFound("missing"))) + ) + replacement = _ResumeContainer(status="created", container_id="replacement") + + async def _fake_create_container(image: str, *, manifest: Manifest | None = None) -> object: + _ = (image, manifest) + return replacement + + monkeypatch.setattr(client, "_create_container", _fake_create_container) + + resumed = await client.resume( + DockerSandboxSessionState( + manifest=Manifest(root="/workspace"), + snapshot=NoopSnapshot(id="snapshot"), + image="python:3.11-slim", + container_id="missing", + workspace_root_ready=True, + ) + ) + + assert isinstance(resumed._inner, DockerSandboxSession) + inner = resumed._inner + assert inner.state.container_id == "replacement" + assert inner.state.workspace_root_ready is False + assert inner._workspace_root_ready is False + assert inner.should_provision_manifest_accounts_on_resume() is True + + +@pytest.mark.asyncio +async def test_docker_resume_recovers_workspace_workdir_when_root_already_exists( + monkeypatch: pytest.MonkeyPatch, +) -> None: + container = _ResumeContainer(status="running", workspace_exists=True) + client = DockerSandboxClient(docker_client=_ResumeDockerClient(container)) + + payload = DockerSandboxSessionState( + manifest=Manifest(root="/workspace"), + snapshot=NoopSnapshot(id="snapshot"), + image="python:3.11-slim", + container_id="container", + workspace_root_ready=True, + ).model_dump(mode="json") + payload.pop("workspace_root_ready") + + resumed = await client.resume(client.deserialize_session_state(payload)) + assert isinstance(resumed._inner, DockerSandboxSession) + + loop = asyncio.get_running_loop() + + def fake_run_in_executor( + executor: object, func: Callable[[], object] + ) -> asyncio.Future[object]: + _ = executor + future: asyncio.Future[object] = asyncio.Future() + future.set_result(func()) + return future + + monkeypatch.setattr(loop, "run_in_executor", fake_run_in_executor) + + result = await resumed._inner._exec_internal("find", ".", timeout=0.01) + + assert result.ok() + assert resumed._inner.state.workspace_root_ready is True + assert resumed._inner._workspace_root_ready is True + assert container.exec_calls == [ + { + "cmd": ["test", "-d", "--", "/workspace"], + "demux": True, + "workdir": None, + }, + { + "cmd": ["find", "."], + "demux": True, + "workdir": "/workspace", + }, + ] + + +@pytest.mark.asyncio +async def test_docker_exists_returns_false_for_missing_container() -> None: + session = DockerSandboxSession( + docker_client=cast(object, _PositionalOnlyMissingDockerClient()), + container=_ResumeContainer(status="running"), + state=DockerSandboxSessionState( + manifest=Manifest(root="/workspace"), + snapshot=NoopSnapshot(id="snapshot"), + image="python:3.11-slim", + container_id="missing", + ), + ) + + assert await session.exists() is False diff --git a/tests/test_sandbox_entries.py b/tests/test_sandbox_entries.py new file mode 100644 index 0000000000..2defb52f67 --- /dev/null +++ b/tests/test_sandbox_entries.py @@ -0,0 +1,389 @@ +from __future__ import annotations + +import io +import tarfile +from pathlib import Path +from typing import Literal + +import pytest + +import agents.sandbox.entries.codex as codex_module +from agents.sandbox.entries import Codex, Dir, File, GitRepo, LocalFile +from agents.sandbox.errors import ExecNonZeroError +from agents.sandbox.manifest import Manifest +from agents.sandbox.session.base_sandbox_session import BaseSandboxSession +from agents.sandbox.session.sandbox_session_state import SandboxSessionState +from agents.sandbox.snapshot import NoopSnapshot, SnapshotBase +from agents.sandbox.types import ExecResult, User + + +class _RecordingSession(BaseSandboxSession): + def __init__(self, manifest: Manifest | None = None) -> None: + self.state = SandboxSessionState( + manifest=manifest or Manifest(), + snapshot=NoopSnapshot(id="noop"), + ) + self.exec_calls: list[tuple[str, ...]] = [] + self.writes: dict[Path, bytes] = {} + + async def _exec_internal( + self, + *command: str | Path, + timeout: float | None = None, + ) -> ExecResult: + _ = timeout + cmd = tuple(str(part) for part in command) + self.exec_calls.append(cmd) + return ExecResult(stdout=b"", stderr=b"", exit_code=0) + + async def read(self, path: Path) -> io.IOBase: + return io.BytesIO(self.writes[path]) + + async def write(self, path: Path, data: io.IOBase) -> None: + self.writes[path] = data.read() + + async def running(self) -> bool: + return True + + async def persist_workspace(self) -> io.IOBase: + return io.BytesIO() + + async def hydrate_workspace(self, data: io.IOBase) -> None: + _ = data + + async def shutdown(self) -> None: + return + + +class _GitRefSession(_RecordingSession): + async def _exec_internal( + self, + *command: str | Path, + timeout: float | None = None, + ) -> ExecResult: + _ = timeout + cmd = tuple(str(part) for part in command) + self.exec_calls.append(cmd) + if cmd == ("command -v git >/dev/null 2>&1",): + return ExecResult(stdout=b"/usr/bin/git\n", stderr=b"", exit_code=0) + if cmd[:2] == ("git", "clone"): + return ExecResult(stdout=b"", stderr=b"unexpected clone path", exit_code=1) + return ExecResult(stdout=b"", stderr=b"", exit_code=0) + + +class _MetadataFailureSession(_RecordingSession): + def __init__( + self, + manifest: Manifest | None = None, + *, + fail_commands: set[str], + ) -> None: + super().__init__(manifest) + self.fail_commands = fail_commands + + async def _exec_internal( + self, + *command: str | Path, + timeout: float | None = None, + ) -> ExecResult: + _ = timeout + cmd = tuple(str(part) for part in command) + self.exec_calls.append(cmd) + if cmd and cmd[0] in self.fail_commands: + return ExecResult(stdout=b"", stderr=b"metadata failed", exit_code=1) + return ExecResult(stdout=b"", stderr=b"", exit_code=0) + + +class _CodexSession(_RecordingSession): + def __init__(self, asset_name: str, *, resolved_binary_path: str) -> None: + super().__init__() + self.asset_name = asset_name + self.resolved_binary_path = resolved_binary_path + + async def _exec_internal( + self, + *command: str | Path, + timeout: float | None = None, + ) -> ExecResult: + _ = timeout + cmd = tuple(str(part) for part in command) + self.exec_calls.append(cmd) + if cmd[:2] == ("sh", "-lc") and "find " in cmd[2] and "head -n 1" in cmd[2]: + return ExecResult( + stdout=f"{self.resolved_binary_path}\n".encode(), + stderr=b"", + exit_code=0, + ) + return ExecResult(stdout=b"", stderr=b"", exit_code=0) + + async def resolve_codex_github_asset_name(self) -> str: + return self.asset_name + + +class _RestorableSnapshot(SnapshotBase): + __test__ = False + type: Literal["entry-restorable"] = "entry-restorable" + + async def persist(self, data: io.IOBase) -> None: + _ = data + + async def restore(self) -> io.IOBase: + return io.BytesIO(b"snapshot") + + async def restorable(self) -> bool: + return True + + +class _ResumeCodexSession(_CodexSession): + def __init__(self, asset_name: str, *, resolved_binary_path: str, codex_path: Path) -> None: + super().__init__(asset_name, resolved_binary_path=resolved_binary_path) + self.state.snapshot = _RestorableSnapshot(id="resume") + self.codex_path = codex_path + self.hydrated = False + self.existing_paths: set[str] = set() + + async def _exec_internal( + self, + *command: str | Path, + timeout: float | None = None, + ) -> ExecResult: + _ = timeout + cmd = tuple(str(part) for part in command) + self.exec_calls.append(cmd) + if cmd[:2] == ("test", "-e"): + return ExecResult( + stdout=b"", + stderr=b"", + exit_code=0 if cmd[2] in self.existing_paths else 1, + ) + if cmd[:2] == ("sh", "-lc") and "find " in cmd[2] and "head -n 1" in cmd[2]: + return ExecResult( + stdout=f"{self.resolved_binary_path}\n".encode(), + stderr=b"", + exit_code=0, + ) + if cmd[:1] == ("cp",): + self.existing_paths.add(cmd[2]) + return ExecResult(stdout=b"", stderr=b"", exit_code=0) + + async def hydrate_workspace(self, data: io.IOBase) -> None: + _ = data + self.hydrated = True + + +def _tar_gz_bytes(*, members: dict[str, bytes]) -> bytes: + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w:gz") as archive: + for name, payload in members.items(): + info = tarfile.TarInfo(name=name) + info.size = len(payload) + archive.addfile(info, io.BytesIO(payload)) + return buf.getvalue() + + +@pytest.mark.asyncio +async def test_base_sandbox_session_uses_current_working_directory_for_local_file_sources( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + source = tmp_path / "source.txt" + source.write_text("hello", encoding="utf-8") + monkeypatch.chdir(tmp_path) + session = _RecordingSession( + Manifest( + entries={"copied.txt": LocalFile(src=Path("source.txt"))}, + ), + ) + + result = await session.apply_manifest() + + assert result.files[0].path == Path("/workspace/copied.txt") + assert session.writes[Path("/workspace/copied.txt")] == b"hello" + + +@pytest.mark.asyncio +async def test_git_repo_uses_fetch_checkout_path_for_commit_refs() -> None: + session = _GitRefSession() + repo = GitRepo(repo="openai/example", ref="deadbeef") + + await repo.apply(session, Path("/workspace/repo"), Path("/ignored")) + + assert not any(call[:2] == ("git", "clone") for call in session.exec_calls) + assert any(call[:2] == ("git", "init") for call in session.exec_calls) + assert any( + len(call) >= 7 + and call[:2] == ("git", "-C") + and call[3:6] == ("remote", "add", "origin") + and call[6] == "https://github.com/openai/example.git" + for call in session.exec_calls + ) + assert any( + len(call) >= 9 + and call[:2] == ("git", "-C") + and call[3:7] == ("fetch", "--depth", "1", "--no-tags") + and call[-2:] == ("origin", "deadbeef") + for call in session.exec_calls + ) + assert any( + len(call) >= 6 + and call[:2] == ("git", "-C") + and call[3:5] == ("checkout", "--detach") + and call[-1] == "FETCH_HEAD" + for call in session.exec_calls + ) + + +@pytest.mark.asyncio +async def test_dir_metadata_strips_file_type_bits_before_chmod() -> None: + session = _RecordingSession() + + await Dir()._apply_metadata(session, Path("/workspace/dir")) + + assert ("chmod", "0755", "/workspace/dir") in session.exec_calls + + +@pytest.mark.asyncio +async def test_apply_manifest_raises_on_chmod_failure() -> None: + session = _MetadataFailureSession( + Manifest(entries={"copied.txt": File(content=b"hello")}), + fail_commands={"chmod"}, + ) + + with pytest.raises(ExecNonZeroError): + await session.apply_manifest() + + +@pytest.mark.asyncio +async def test_apply_manifest_raises_on_chgrp_failure() -> None: + session = _MetadataFailureSession( + Manifest( + entries={ + "copied.txt": File( + content=b"hello", + group=User(name="sandbox-user"), + ) + } + ), + fail_commands={"chgrp"}, + ) + + with pytest.raises(ExecNonZeroError): + await session.apply_manifest() + + assert ("chgrp", "sandbox-user", "/workspace/copied.txt") in session.exec_calls + assert not any(call[0] == "chmod" for call in session.exec_calls) + + +@pytest.mark.asyncio +async def test_codex_artifact_downloads_resolved_release_asset_inside_unix_box() -> None: + session = _CodexSession( + "codex-x86_64-unknown-linux-gnu.tar.gz", + resolved_binary_path="/workspace/.codex_bin/.codex-install-123/codex", + ) + entry = Codex(version="v1.2.3") + archive_bytes = _tar_gz_bytes(members={"codex": b"#!/bin/sh\n"}) + + class _FakeResponse: + headers = {"Content-Length": str(len(archive_bytes))} + + def raise_for_status(self) -> None: + return None + + def iter_bytes(self): + yield archive_bytes[:5] + yield archive_bytes[5:] + + class _FakeStreamContext: + def __enter__(self) -> _FakeResponse: + return _FakeResponse() + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def _fake_stream(url: str) -> _FakeStreamContext: + _ = url + return _FakeStreamContext() + + original_stream = codex_module._stream_release_asset + codex_module._stream_release_asset = _fake_stream + try: + result = await entry.apply(session, Path("/workspace/.codex_bin/codex"), Path("/ignored")) + finally: + codex_module._stream_release_asset = original_stream + + assert result == [] + assert any(path.name == "codex-x86_64-unknown-linux-gnu.tar.gz" for path in session.writes) + assert Path("/workspace/.codex_bin/codex") not in session.writes + assert any( + call[:2] == ("tar", "-xzf") + and call[2].endswith("/codex-x86_64-unknown-linux-gnu.tar.gz") + and call[3:5] == ("-C", call[4]) + and "/.codex-install-" in call[2] + for call in session.exec_calls + ) + assert ( + "cp", + "/workspace/.codex_bin/.codex-install-123/codex", + "/workspace/.codex_bin/codex", + ) in session.exec_calls + assert ("chmod", "0755", "/workspace/.codex_bin/codex") in session.exec_calls + + +@pytest.mark.asyncio +async def test_codex_artifact_rejects_windows_release_assets() -> None: + session = _CodexSession( + "codex-x86_64-pc-windows-msvc.exe.tar.gz", + resolved_binary_path="/workspace/.codex_bin/.codex-install-456/codex.exe", + ) + entry = Codex() + + with pytest.raises(RuntimeError, match="Windows Codex artifacts are not supported"): + await entry.apply(session, Path("/workspace/.codex_bin/codex.exe"), Path("/ignored")) + + +@pytest.mark.asyncio +async def test_base_session_reapplies_missing_codex_entry_after_snapshot_restore() -> None: + codex_path = Path("/workspace/.codex_bin/codex") + session = _ResumeCodexSession( + "codex-x86_64-unknown-linux-gnu.tar.gz", + resolved_binary_path="/workspace/.codex_bin/.codex-install-789/codex", + codex_path=codex_path, + ) + session.state.manifest = Manifest(entries={".codex_bin/codex": Codex(version="v1.2.3")}) + archive_bytes = _tar_gz_bytes(members={"codex": b"#!/bin/sh\n"}) + + class _FakeResponse: + headers = {"Content-Length": str(len(archive_bytes))} + + def raise_for_status(self) -> None: + return None + + def iter_bytes(self): + yield archive_bytes + + class _FakeStreamContext: + def __enter__(self) -> _FakeResponse: + return _FakeResponse() + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def _fake_stream(url: str) -> _FakeStreamContext: + _ = url + return _FakeStreamContext() + + original_stream = codex_module._stream_release_asset + codex_module._stream_release_asset = _fake_stream + try: + await session.start() + finally: + codex_module._stream_release_asset = original_stream + + assert session.hydrated is True + assert ("test", "-e", str(codex_path)) in session.exec_calls + assert ( + "cp", + "/workspace/.codex_bin/.codex-install-789/codex", + str(codex_path), + ) in session.exec_calls + assert str(codex_path) in session.existing_paths diff --git a/tests/test_sandbox_extract.py b/tests/test_sandbox_extract.py new file mode 100644 index 0000000000..93c6fca1b5 --- /dev/null +++ b/tests/test_sandbox_extract.py @@ -0,0 +1,371 @@ +from __future__ import annotations + +import io +import os +import tarfile +import zipfile +from pathlib import Path + +import pytest + +from agents.sandbox.entries import GCSMount +from agents.sandbox.errors import InvalidManifestPathError, WorkspaceArchiveWriteError +from agents.sandbox.files import EntryKind, FileEntry +from agents.sandbox.manifest import Manifest +from agents.sandbox.sandboxes.unix_local import ( + UnixLocalSandboxSession, + UnixLocalSandboxSessionState, +) +from agents.sandbox.session.archive_extraction import zipfile_compatible_stream +from agents.sandbox.session.base_sandbox_session import BaseSandboxSession +from agents.sandbox.snapshot import NoopSnapshot +from agents.sandbox.types import ExecResult, Permissions + + +def _build_session(tmp_path: Path) -> UnixLocalSandboxSession: + state = UnixLocalSandboxSessionState( + manifest=Manifest(root=str(tmp_path / "workspace")), + snapshot=NoopSnapshot(id="noop"), + ) + return UnixLocalSandboxSession.from_state(state) + + +class _CountingExtractSession(BaseSandboxSession): + def __init__(self, workspace_root: Path) -> None: + self.state = UnixLocalSandboxSessionState( + manifest=Manifest(root=str(workspace_root)), + snapshot=NoopSnapshot(id="noop"), + ) + self.ls_calls: list[Path] = [] + + async def _exec_internal( + self, + *command: str | Path, + timeout: float | None = None, + ) -> ExecResult: + _ = (command, timeout) + raise AssertionError("exec() should not be called in this test") + + async def read(self, path: Path) -> io.IOBase: + return self.normalize_path(path).open("rb") + + async def write(self, path: Path, data: io.IOBase) -> None: + workspace_path = self.normalize_path(path) + workspace_path.parent.mkdir(parents=True, exist_ok=True) + payload = data.read() + if isinstance(payload, str): + payload = payload.encode("utf-8") + workspace_path.write_bytes(payload) + + async def running(self) -> bool: + return True + + async def persist_workspace(self) -> io.IOBase: + return io.BytesIO() + + async def hydrate_workspace(self, data: io.IOBase) -> None: + _ = data + + async def shutdown(self) -> None: + return + + async def mkdir(self, path: Path | str, *, parents: bool = False) -> None: + self.normalize_path(path).mkdir(parents=parents, exist_ok=True) + + async def ls(self, path: Path | str) -> list[FileEntry]: + directory = self.normalize_path(path) + self.ls_calls.append(directory) + if not directory.exists(): + raise AssertionError(f"ls() called for missing directory: {directory}") + + entries: list[FileEntry] = [] + for child in directory.iterdir(): + if child.is_symlink(): + kind = EntryKind.SYMLINK + elif child.is_dir(): + kind = EntryKind.DIRECTORY + else: + kind = EntryKind.FILE + entries.append( + FileEntry( + path=str(child), + permissions=Permissions(), + owner="root", + group="root", + size=0, + kind=kind, + ) + ) + return entries + + +def _tar_bytes(*, members: dict[str, bytes]) -> io.BytesIO: + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w") as archive: + for name, payload in members.items(): + info = tarfile.TarInfo(name=name) + info.size = len(payload) + archive.addfile(info, io.BytesIO(payload)) + buf.seek(0) + return buf + + +def _zip_bytes(*, members: dict[str, bytes]) -> io.BytesIO: + buf = io.BytesIO() + with zipfile.ZipFile(buf, mode="w") as archive: + for name, payload in members.items(): + archive.writestr(name, payload) + buf.seek(0) + return buf + + +@pytest.mark.asyncio +async def test_extract_tar_writes_archive_and_unpacks_contents(tmp_path: Path) -> None: + session = _build_session(tmp_path) + await session.start() + try: + await session.extract( + "bundle.tar", + _tar_bytes(members={"nested/hello.txt": b"hello from tar"}), + ) + finally: + await session.shutdown() + + workspace = Path(session.state.manifest.root) + assert (workspace / "bundle.tar").is_file() + assert (workspace / "nested" / "hello.txt").read_text(encoding="utf-8") == "hello from tar" + + +@pytest.mark.asyncio +async def test_extract_zip_writes_archive_and_unpacks_contents(tmp_path: Path) -> None: + session = _build_session(tmp_path) + await session.start() + try: + await session.extract( + "bundle.zip", + _zip_bytes(members={"nested/hello.txt": b"hello from zip"}), + ) + finally: + await session.shutdown() + + workspace = Path(session.state.manifest.root) + assert (workspace / "bundle.zip").is_file() + assert (workspace / "nested" / "hello.txt").read_text(encoding="utf-8") == "hello from zip" + + +class _NoSeekableZipStream(io.IOBase): + def __init__(self, payload: bytes) -> None: + self._buffer = io.BytesIO(payload) + + def tell(self) -> int: + return self._buffer.tell() + + def seek(self, offset: int, whence: int = io.SEEK_SET) -> int: + return self._buffer.seek(offset, whence) + + def read(self, size: int = -1) -> bytes: + return self._buffer.read(size) + + +class _ChunkedBinaryStream(io.IOBase): + def __init__(self, chunks: list[bytes]) -> None: + self._chunks = list(chunks) + self.headers = {"Content-Length": str(sum(len(chunk) for chunk in chunks))} + + def read(self, size: int = -1) -> bytes: + if not self._chunks: + return b"" + if size < 0: + data = b"".join(self._chunks) + self._chunks.clear() + return data + + remaining = size + out = bytearray() + while remaining > 0 and self._chunks: + chunk = self._chunks[0] + if len(chunk) <= remaining: + out.extend(self._chunks.pop(0)) + remaining -= len(chunk) + continue + out.extend(chunk[:remaining]) + self._chunks[0] = chunk[remaining:] + remaining = 0 + return bytes(out) + + +class _SeekableFalseZipStream(io.IOBase): + def __init__(self, payload: bytes) -> None: + self._buffer = io.BytesIO(payload) + + def seekable(self) -> bool: + return False + + def read(self, size: int = -1) -> bytes: + return self._buffer.read(size) + + +def test_zipfile_compatible_stream_supports_streams_without_seekable() -> None: + raw_stream = _NoSeekableZipStream(_zip_bytes(members={"file.txt": b"hello"}).getvalue()) + + with zipfile_compatible_stream(raw_stream) as compatible: + assert compatible.seekable() is True + with zipfile.ZipFile(compatible) as archive: + assert archive.read("file.txt") == b"hello" + + +def test_zipfile_compatible_stream_buffers_streams_with_seekable_false() -> None: + raw_stream = _SeekableFalseZipStream(_zip_bytes(members={"file.txt": b"hello"}).getvalue()) + + with zipfile_compatible_stream(raw_stream) as compatible: + assert compatible.seekable() is True + with zipfile.ZipFile(compatible) as archive: + assert archive.read("file.txt") == b"hello" + + +@pytest.mark.asyncio +async def test_unix_local_write_accepts_chunked_non_seekable_binary_stream(tmp_path: Path) -> None: + session = _build_session(tmp_path) + await session.start() + try: + await session.write( + Path("streamed.bin"), + _ChunkedBinaryStream([b"hello ", b"from ", b"stream"]), + ) + finally: + await session.shutdown() + + workspace = Path(session.state.manifest.root) + assert (workspace / "streamed.bin").read_bytes() == b"hello from stream" + + +@pytest.mark.asyncio +async def test_extract_tar_rejects_symlinked_parent_paths(tmp_path: Path) -> None: + session = _build_session(tmp_path) + await session.start() + try: + workspace = Path(session.state.manifest.root) + outside = tmp_path / "outside" + outside.mkdir() + os.symlink(outside, workspace / "link", target_is_directory=True) + + with pytest.raises(WorkspaceArchiveWriteError) as exc_info: + await session.extract( + "bundle.tar", + _tar_bytes(members={"link/hello.txt": b"hello from tar"}), + ) + + assert exc_info.value.context["member"] == "link/hello.txt" + assert exc_info.value.context["reason"] == "symlink in parent path: link" + assert not (outside / "hello.txt").exists() + finally: + await session.shutdown() + + +@pytest.mark.asyncio +async def test_extract_zip_rejects_symlinked_parent_paths(tmp_path: Path) -> None: + session = _build_session(tmp_path) + await session.start() + try: + workspace = Path(session.state.manifest.root) + outside = tmp_path / "outside" + outside.mkdir() + os.symlink(outside, workspace / "link", target_is_directory=True) + + with pytest.raises(WorkspaceArchiveWriteError) as exc_info: + await session.extract( + "bundle.zip", + _zip_bytes(members={"link/hello.txt": b"hello from zip"}), + ) + + assert exc_info.value.context["member"] == "link/hello.txt" + assert exc_info.value.context["reason"] == "symlink in parent path: link" + assert not (outside / "hello.txt").exists() + finally: + await session.shutdown() + + +@pytest.mark.asyncio +async def test_unix_local_persist_workspace_excludes_resolved_mount_path(tmp_path: Path) -> None: + workspace_root = tmp_path / "workspace" + actual_mount_path = workspace_root / "actual" + actual_mount_path.mkdir(parents=True) + (actual_mount_path / "remote.txt").write_text("remote", encoding="utf-8") + (workspace_root / "keep.txt").write_text("keep", encoding="utf-8") + + state = UnixLocalSandboxSessionState( + manifest=Manifest( + root=str(workspace_root), + entries={"logical": GCSMount(bucket="bucket", mount_path=Path("actual"))}, + ), + snapshot=NoopSnapshot(id="noop"), + ) + session = UnixLocalSandboxSession.from_state(state) + + archive = await session.persist_workspace() + + with tarfile.open(fileobj=archive, mode="r:*") as tar: + names = set(tar.getnames()) + + assert "./keep.txt" in names + assert "./actual" not in names + assert "./actual/remote.txt" not in names + + +@pytest.mark.asyncio +async def test_extract_tar_reuses_directory_listings_during_symlink_checks(tmp_path: Path) -> None: + workspace = tmp_path / "workspace" + workspace.mkdir() + session = _CountingExtractSession(workspace) + + await session.extract( + "bundle.tar", + _tar_bytes( + members={ + "nested/one.txt": b"one", + "nested/two.txt": b"two", + } + ), + ) + + assert (workspace / "nested" / "one.txt").read_text(encoding="utf-8") == "one" + assert (workspace / "nested" / "two.txt").read_text(encoding="utf-8") == "two" + assert session.ls_calls == [ + workspace, + workspace / "nested", + ] + + +@pytest.mark.asyncio +async def test_unix_local_helpers_reject_paths_outside_workspace_root(tmp_path: Path) -> None: + session = _build_session(tmp_path) + await session.start() + try: + with pytest.raises(InvalidManifestPathError, match="must not escape root"): + await session.ls("../outside") + with pytest.raises(InvalidManifestPathError, match="must not escape root"): + await session.mkdir("../outside", parents=True) + with pytest.raises(InvalidManifestPathError, match="must not escape root"): + await session.rm("../outside") + with pytest.raises(InvalidManifestPathError, match="must be relative"): + await session.extract("/tmp/bundle.tar", _tar_bytes(members={"a.txt": b"a"})) + finally: + await session.shutdown() + + +@pytest.mark.asyncio +async def test_unix_local_helpers_reject_symlink_escape_paths(tmp_path: Path) -> None: + session = _build_session(tmp_path) + await session.start() + try: + workspace = Path(session.state.manifest.root) + outside = tmp_path / "outside" + outside.mkdir() + os.symlink(outside, workspace / "link", target_is_directory=True) + + with pytest.raises(InvalidManifestPathError, match="must not escape root"): + await session.mkdir("link/nested", parents=True) + with pytest.raises(InvalidManifestPathError, match="must not escape root"): + await session.ls("link") + finally: + await session.shutdown() diff --git a/tests/test_sandbox_manifest.py b/tests/test_sandbox_manifest.py new file mode 100644 index 0000000000..350b440564 --- /dev/null +++ b/tests/test_sandbox_manifest.py @@ -0,0 +1,232 @@ +from pathlib import Path + +import pytest + +from agents.sandbox.codex_config import ( + DEFAULT_CODEX_VERSION, + CodexConfig, + apply_codex_to_manifest, + manifest_has_codex_entry, +) +from agents.sandbox.entries import Codex, Dir, File, GCSMount +from agents.sandbox.errors import InvalidManifestPathError +from agents.sandbox.manifest import Manifest + + +def test_manifest_rejects_nested_child_paths_that_escape_workspace() -> None: + manifest = Manifest( + entries={ + "safe": Dir( + children={ + "../outside.txt": File(content=b"nope"), + } + ) + } + ) + + with pytest.raises(InvalidManifestPathError, match="must not escape root"): + manifest.validated_entries() + + +def test_manifest_rejects_nested_absolute_child_paths() -> None: + manifest = Manifest( + entries={ + "safe": Dir( + children={ + "/tmp/outside.txt": File(content=b"nope"), + } + ) + } + ) + + with pytest.raises(InvalidManifestPathError, match="must be relative"): + manifest.validated_entries() + + +def test_manifest_ephemeral_entry_paths_include_nested_children() -> None: + manifest = Manifest( + entries={ + "dir": Dir( + children={ + "keep.txt": File(content=b"keep"), + "tmp.txt": File(content=b"tmp", ephemeral=True), + } + ) + } + ) + + assert manifest.ephemeral_entry_paths() == {Path("dir/tmp.txt")} + + +def test_manifest_ephemeral_persistence_paths_include_resolved_mount_targets() -> None: + manifest = Manifest( + root="/workspace", + entries={ + "logical": GCSMount(bucket="bucket", mount_path=Path("actual")), + "dir": Dir( + children={ + "tmp.txt": File(content=b"tmp", ephemeral=True), + } + ), + }, + ) + + assert manifest.ephemeral_persistence_paths() == { + Path("logical"), + Path("actual"), + Path("dir/tmp.txt"), + } + + +def test_manifest_ephemeral_mount_targets_sort_by_resolved_depth() -> None: + parent = GCSMount(bucket="parent", mount_path=Path("repo")) + child = GCSMount(bucket="child", mount_path=Path("repo/sub")) + manifest = Manifest( + root="/workspace", + entries={ + "parent": parent, + "nested": Dir(children={"child": child}), + }, + ) + + assert manifest.ephemeral_mount_targets() == [ + (child, Path("/workspace/repo/sub")), + (parent, Path("/workspace/repo")), + ] + + +def test_manifest_ephemeral_mount_targets_normalize_non_escaping_mount_paths() -> None: + mount = GCSMount(bucket="bucket", mount_path=Path("/workspace/repo/../actual")) + manifest = Manifest(root="/workspace", entries={"logical": mount}) + + assert manifest.ephemeral_mount_targets() == [ + (mount, Path("/workspace/actual")), + ] + assert manifest.ephemeral_persistence_paths() == { + Path("logical"), + Path("actual"), + } + + +def test_manifest_ephemeral_mount_targets_reject_escaping_mount_paths() -> None: + manifest = Manifest( + root="/workspace", + entries={ + "logical": GCSMount(bucket="bucket", mount_path=Path("/workspace/../../tmp")), + }, + ) + + with pytest.raises(InvalidManifestPathError, match="must not escape root"): + manifest.ephemeral_mount_targets() + + with pytest.raises(InvalidManifestPathError, match="must not escape root"): + manifest.ephemeral_persistence_paths() + + +def test_manifest_describe_preserves_tree_rendering_after_renderer_extract() -> None: + manifest = Manifest( + root="/workspace", + entries={ + "repo": Dir( + description="project root", + children={ + "README.md": File(content=b"hi", description="overview"), + }, + ), + "data": GCSMount(bucket="bucket", description="shared data"), + }, + ) + + description = manifest.describe(depth=2) + + assert description.startswith("/workspace\n") + assert "data/" in description + assert "/workspace/data" in description + assert "repo/" in description + assert "/workspace/repo/README.md" in description + + +def test_apply_codex_to_manifest_adds_codex_entry_at_configured_path() -> None: + manifest = apply_codex_to_manifest( + Manifest(), + CodexConfig(path="tools/codex"), + ) + + validated = manifest.validated_entries() + + assert Path("tools/codex") in validated + entry = validated[Path("tools/codex")] + assert isinstance(entry, Codex) + assert entry.version == DEFAULT_CODEX_VERSION + assert entry.ephemeral is True + + +def test_apply_codex_to_manifest_uses_reserved_default_codex_path() -> None: + manifest = apply_codex_to_manifest(Manifest(), True) + + validated = manifest.validated_entries() + + assert Path(".codex_bin/codex") in validated + assert manifest.ephemeral_persistence_paths() == {Path(".codex_bin/codex")} + + +def test_apply_codex_to_manifest_treats_home_relative_path_as_workspace_relative() -> None: + manifest = apply_codex_to_manifest( + Manifest(), + CodexConfig(path="~/.codex/codex"), + ) + + validated = manifest.validated_entries() + + assert Path(".codex/codex") in validated + entry = validated[Path(".codex/codex")] + assert isinstance(entry, Codex) + + +def test_apply_codex_to_manifest_preserves_explicit_entry_at_configured_path() -> None: + explicit = File(content=b"custom") + manifest = apply_codex_to_manifest( + Manifest( + entries={"tools/codex": explicit}, + ), + CodexConfig(path="tools/codex"), + ) + + validated = manifest.validated_entries() + + preserved = validated["tools/codex"] + assert isinstance(preserved, File) + assert preserved == explicit + + +def test_apply_codex_to_manifest_accepts_absolute_path_within_manifest_root() -> None: + manifest = apply_codex_to_manifest( + Manifest(root="/workspace"), + CodexConfig(path="/workspace/tools/codex"), + ) + + validated = manifest.validated_entries() + + assert Path("tools/codex") in validated + entry = validated[Path("tools/codex")] + assert isinstance(entry, Codex) + + +def test_apply_codex_to_manifest_rejects_absolute_path_outside_manifest_root() -> None: + with pytest.raises(InvalidManifestPathError, match="must be relative"): + apply_codex_to_manifest( + Manifest(root="/workspace"), + CodexConfig(path="/tmp/codex"), + ) + + +def test_manifest_has_codex_entry_accepts_absolute_default_root_path_after_root_rewrite() -> None: + manifest = Manifest( + root="/tmp/session-root", + entries={"tools/codex": Codex(version=DEFAULT_CODEX_VERSION)}, + ) + + assert manifest_has_codex_entry( + manifest, + CodexConfig(path="/workspace/tools/codex"), + ) diff --git a/tests/test_sandbox_manifest_application.py b/tests/test_sandbox_manifest_application.py new file mode 100644 index 0000000000..22680d091d --- /dev/null +++ b/tests/test_sandbox_manifest_application.py @@ -0,0 +1,293 @@ +from __future__ import annotations + +import asyncio +from pathlib import Path + +import pytest + +from agents.sandbox.entries import Dir, File, GCSMount +from agents.sandbox.manifest import Manifest +from agents.sandbox.materialization import MaterializedFile +from agents.sandbox.session.manifest_application import ManifestApplier +from agents.sandbox.types import ExecResult, Group, User + + +def _materialized(dest: Path) -> list[MaterializedFile]: + return [MaterializedFile(path=dest, sha256=dest.as_posix())] + + +@pytest.mark.asyncio +async def test_manifest_applier_only_applies_ephemeral_entries_without_account_provisioning() -> ( + None +): + mkdir_calls: list[Path] = [] + exec_calls: list[tuple[str, ...]] = [] + apply_calls: list[tuple[str, Path, Path]] = [] + + async def mkdir(path: Path) -> None: + mkdir_calls.append(path) + + async def exec_checked_nonzero(*command: str) -> ExecResult: + exec_calls.append(command) + return ExecResult(stdout=b"", stderr=b"", exit_code=0) + + async def apply_entry(entry: object, dest: Path, base_dir: Path) -> list[MaterializedFile]: + apply_calls.append((type(entry).__name__, dest, base_dir)) + return _materialized(dest) + + applier = ManifestApplier( + mkdir=mkdir, + exec_checked_nonzero=exec_checked_nonzero, + apply_entry=apply_entry, + ) + manifest = Manifest( + root="/workspace", + entries={ + "keep.txt": File(content=b"keep"), + "tmp.txt": File(content=b"tmp", ephemeral=True), + }, + users=[User(name="alice")], + groups=[Group(name="dev", users=[User(name="alice")])], + ) + + result = await applier.apply_manifest(manifest, only_ephemeral=True) + + assert mkdir_calls == [Path("/workspace")] + assert exec_calls == [] + assert apply_calls == [("File", Path("/workspace/tmp.txt"), Path("/"))] + assert result.files == _materialized(Path("/workspace/tmp.txt")) + + +@pytest.mark.asyncio +async def test_manifest_applier_only_ephemeral_reapplies_nested_ephemeral_children() -> None: + apply_calls: list[tuple[str, Path, Path]] = [] + + async def mkdir(_path: Path) -> None: + return None + + async def exec_checked_nonzero(*_command: str) -> ExecResult: + return ExecResult(stdout=b"", stderr=b"", exit_code=0) + + async def apply_entry(entry: object, dest: Path, base_dir: Path) -> list[MaterializedFile]: + apply_calls.append((type(entry).__name__, dest, base_dir)) + return _materialized(dest) + + applier = ManifestApplier( + mkdir=mkdir, + exec_checked_nonzero=exec_checked_nonzero, + apply_entry=apply_entry, + ) + manifest = Manifest( + root="/workspace", + entries={ + "dir": Dir( + children={ + "keep.txt": File(content=b"keep"), + "tmp.txt": File(content=b"tmp", ephemeral=True), + } + ) + }, + ) + + result = await applier.apply_manifest(manifest, only_ephemeral=True) + + assert apply_calls == [("File", Path("/workspace/dir/tmp.txt"), Path("/"))] + assert result.files == _materialized(Path("/workspace/dir/tmp.txt")) + + +@pytest.mark.asyncio +async def test_manifest_applier_only_ephemeral_reapplies_full_ephemeral_directories() -> None: + applied_entries: list[tuple[object, Path, Path]] = [] + + async def mkdir(_path: Path) -> None: + return None + + async def exec_checked_nonzero(*_command: str) -> ExecResult: + return ExecResult(stdout=b"", stderr=b"", exit_code=0) + + async def apply_entry(entry: object, dest: Path, base_dir: Path) -> list[MaterializedFile]: + applied_entries.append((entry, dest, base_dir)) + return _materialized(dest) + + applier = ManifestApplier( + mkdir=mkdir, + exec_checked_nonzero=exec_checked_nonzero, + apply_entry=apply_entry, + ) + manifest = Manifest( + root="/workspace", + entries={ + "tmp": Dir( + ephemeral=True, + children={ + "keep.txt": File(content=b"keep"), + "nested": Dir(children={"child.txt": File(content=b"child")}), + "tmp.txt": File(content=b"tmp", ephemeral=True), + }, + ) + }, + ) + + result = await applier.apply_manifest(manifest, only_ephemeral=True) + + assert len(applied_entries) == 1 + entry, dest, base_dir = applied_entries[0] + assert isinstance(entry, Dir) + assert dest == Path("/workspace/tmp") + assert base_dir == Path("/") + assert set(entry.children) == {"keep.txt", "nested", "tmp.txt"} + assert result.files == _materialized(Path("/workspace/tmp")) + + +@pytest.mark.asyncio +async def test_manifest_applier_respects_explicit_base_dir() -> None: + apply_calls: list[tuple[str, Path, Path]] = [] + + async def mkdir(_path: Path) -> None: + return None + + async def exec_checked_nonzero(*_command: str) -> ExecResult: + return ExecResult(stdout=b"", stderr=b"", exit_code=0) + + async def apply_entry(entry: object, dest: Path, base_dir: Path) -> list[MaterializedFile]: + apply_calls.append((type(entry).__name__, dest, base_dir)) + return _materialized(dest) + + applier = ManifestApplier( + mkdir=mkdir, + exec_checked_nonzero=exec_checked_nonzero, + apply_entry=apply_entry, + ) + manifest = Manifest(entries={"file.txt": File(content=b"hello")}) + + result = await applier.apply_manifest(manifest, base_dir=Path("/tmp/project")) + + assert apply_calls == [("File", Path("/workspace/file.txt"), Path("/tmp/project"))] + assert result.files == _materialized(Path("/workspace/file.txt")) + + +@pytest.mark.asyncio +async def test_manifest_applier_provisions_groups_and_unique_users_before_entries() -> None: + exec_calls: list[tuple[str, ...]] = [] + + async def mkdir(_path: Path) -> None: + return None + + async def exec_checked_nonzero(*command: str) -> ExecResult: + exec_calls.append(command) + return ExecResult(stdout=b"", stderr=b"", exit_code=0) + + async def apply_entry(_entry: object, _dest: Path, _base_dir: Path) -> list[MaterializedFile]: + return [] + + applier = ManifestApplier( + mkdir=mkdir, + exec_checked_nonzero=exec_checked_nonzero, + apply_entry=apply_entry, + ) + manifest = Manifest( + users=[User(name="alice")], + groups=[Group(name="dev", users=[User(name="alice"), User(name="bob")])], + ) + + result = await applier.apply_manifest(manifest) + + assert result.files == [] + assert exec_calls[0] == ("groupadd", "dev") + assert exec_calls.count(("groupadd", "alice")) == 0 + assert exec_calls.count(("groupadd", "bob")) == 0 + assert ("useradd", "-U", "-M", "-s", "/usr/sbin/nologin", "alice") in exec_calls + assert ("useradd", "-U", "-M", "-s", "/usr/sbin/nologin", "bob") in exec_calls + assert ("usermod", "-aG", "dev", "alice") in exec_calls + assert ("usermod", "-aG", "dev", "bob") in exec_calls + + +@pytest.mark.asyncio +async def test_apply_entry_batch_flushes_parallel_work_before_overlapping_paths() -> None: + events: list[tuple[str, Path]] = [] + + async def mkdir(_path: Path) -> None: + return None + + async def exec_checked_nonzero(*_command: str) -> ExecResult: + return ExecResult(stdout=b"", stderr=b"", exit_code=0) + + async def apply_entry(_entry: object, dest: Path, _base_dir: Path) -> list[MaterializedFile]: + events.append(("start", dest)) + await asyncio.sleep(0) + events.append(("end", dest)) + return _materialized(dest) + + applier = ManifestApplier( + mkdir=mkdir, + exec_checked_nonzero=exec_checked_nonzero, + apply_entry=apply_entry, + ) + destinations = [ + Path("/workspace/alpha.txt"), + Path("/workspace/beta.txt"), + Path("/workspace/nested"), + Path("/workspace/nested/child.txt"), + ] + + files = await applier._apply_entry_batch( + [ + (destinations[0], File(content=b"a")), + (destinations[1], File(content=b"b")), + (destinations[2], Dir()), + (destinations[3], File(content=b"c")), + ], + base_dir=Path("/"), + ) + + assert [file.path for file in files] == destinations + child_start = events.index(("start", destinations[3])) + assert events.index(("end", destinations[0])) < child_start + assert events.index(("end", destinations[1])) < child_start + assert events.index(("end", destinations[2])) < child_start + + +@pytest.mark.asyncio +async def test_apply_entry_batch_flushes_before_and_after_mount_entries() -> None: + events: list[tuple[str, Path]] = [] + + async def mkdir(_path: Path) -> None: + return None + + async def exec_checked_nonzero(*_command: str) -> ExecResult: + return ExecResult(stdout=b"", stderr=b"", exit_code=0) + + async def apply_entry(_entry: object, dest: Path, _base_dir: Path) -> list[MaterializedFile]: + events.append(("start", dest)) + await asyncio.sleep(0) + events.append(("end", dest)) + return _materialized(dest) + + applier = ManifestApplier( + mkdir=mkdir, + exec_checked_nonzero=exec_checked_nonzero, + apply_entry=apply_entry, + ) + destinations = [ + Path("/workspace/alpha.txt"), + Path("/workspace/beta.txt"), + Path("/workspace/mount"), + Path("/workspace/gamma.txt"), + ] + + files = await applier._apply_entry_batch( + [ + (destinations[0], File(content=b"a")), + (destinations[1], File(content=b"b")), + (destinations[2], GCSMount(bucket="sandbox-bucket")), + (destinations[3], File(content=b"c")), + ], + base_dir=Path("/"), + ) + + assert [file.path for file in files] == destinations + mount_start = events.index(("start", destinations[2])) + gamma_start = events.index(("start", destinations[3])) + assert events.index(("end", destinations[0])) < mount_start + assert events.index(("end", destinations[1])) < mount_start + assert events.index(("end", destinations[2])) < gamma_start diff --git a/tests/test_sandbox_mounts.py b/tests/test_sandbox_mounts.py new file mode 100644 index 0000000000..737ca2fb31 --- /dev/null +++ b/tests/test_sandbox_mounts.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import io +import uuid +from pathlib import Path + +import pytest + +from agents.sandbox import Manifest +from agents.sandbox.entries import ( + AzureBlobMount, + GCSMount, + MountpointMountPattern, + RcloneMountPattern, +) +from agents.sandbox.entries.mounts.patterns import MountpointMountConfig, RcloneMountConfig +from agents.sandbox.session.base_sandbox_session import BaseSandboxSession +from agents.sandbox.session.sandbox_session_state import SandboxSessionState +from agents.sandbox.snapshot import NoopSnapshot +from agents.sandbox.types import ExecResult + + +class _MountConfigSession(BaseSandboxSession): + def __init__(self, *, session_id: uuid.UUID | None = None, config_text: str = "") -> None: + self.state = SandboxSessionState( + session_id=session_id or uuid.uuid4(), + manifest=Manifest(root="/workspace"), + snapshot=NoopSnapshot(id=str(uuid.uuid4())), + ) + self._config_text = config_text + + async def read(self, path: Path) -> io.BytesIO: + _ = path + return io.BytesIO(self._config_text.encode("utf-8")) + + async def shutdown(self) -> None: + return None + + async def write(self, path: Path, data: io.IOBase) -> None: + _ = (path, data) + raise AssertionError("write() should not be called in these tests") + + async def running(self) -> bool: + return True + + async def _exec_internal( + self, + *command: str | Path, + timeout: float | None = None, + ) -> ExecResult: + _ = (command, timeout) + raise AssertionError("exec() should not be called in these tests") + + async def persist_workspace(self) -> io.IOBase: + raise AssertionError("persist_workspace() should not be called in these tests") + + async def hydrate_workspace(self, data: io.IOBase) -> None: + _ = data + raise AssertionError("hydrate_workspace() should not be called in these tests") + + +@pytest.mark.asyncio +async def test_azure_blob_mount_builds_rclone_runtime_config_without_hidden_pattern_state() -> None: + session_id = uuid.uuid4() + pattern = RcloneMountPattern(config_file_path=Path("rclone.conf")) + remote_name = pattern.resolve_remote_name( + session_id=session_id.hex, + remote_kind="azureblob", + mount_type="azure_blob_mount", + ) + session = _MountConfigSession( + session_id=session_id, + config_text=f"[{remote_name}]\ntype = azureblob\n", + ) + mount = AzureBlobMount( + account="acct", + container="container", + mount_pattern=pattern, + ) + + apply_config = await mount._build_mount_config(session, pattern, include_config_text=True) + unmount_config = await mount._build_mount_config(session, pattern, include_config_text=False) + + assert isinstance(apply_config, RcloneMountConfig) + assert apply_config.remote_name == remote_name + assert apply_config.remote_path == "container" + assert apply_config.config_text is not None + assert "account = acct" in apply_config.config_text + assert isinstance(unmount_config, RcloneMountConfig) + assert unmount_config.remote_name == remote_name + assert unmount_config.config_text is None + + +@pytest.mark.asyncio +async def test_gcs_mount_uses_runtime_endpoint_override_without_mutating_pattern_options() -> None: + pattern = MountpointMountPattern() + mount = GCSMount(bucket="bucket", mount_pattern=pattern) + + config = await mount._build_mount_config( + _MountConfigSession(), + pattern, + include_config_text=False, + ) + + assert isinstance(config, MountpointMountConfig) + assert config.endpoint_url == "https://storage.googleapis.com" + assert pattern.options.endpoint_url is None diff --git a/tests/test_sandbox_parse_utils.py b/tests/test_sandbox_parse_utils.py new file mode 100644 index 0000000000..9a9b81eef4 --- /dev/null +++ b/tests/test_sandbox_parse_utils.py @@ -0,0 +1,26 @@ +from agents.sandbox.files import EntryKind +from agents.sandbox.util.parse_utils import parse_ls_la + + +def test_parse_ls_la_preserves_absolute_file_paths() -> None: + output = "-rwxr-xr-x 1 root root 48915747 Jan 1 00:00 /workspace/.codex_bin/codex\n" + + entries = parse_ls_la(output, base="/workspace/.codex_bin/codex") + + assert len(entries) == 1 + assert entries[0].path == "/workspace/.codex_bin/codex" + assert entries[0].kind == EntryKind.FILE + + +def test_parse_ls_la_prefixes_directory_entries_with_base() -> None: + output = ( + "drwxr-xr-x 2 root root 4096 Jan 1 00:00 .\n" + "drwxr-xr-x 3 root root 4096 Jan 1 00:00 ..\n" + "-rw-r--r-- 1 root root 123 Jan 1 00:00 notes.md\n" + ) + + entries = parse_ls_la(output, base="/workspace/docs") + + assert len(entries) == 1 + assert entries[0].path == "/workspace/docs/notes.md" + assert entries[0].kind == EntryKind.FILE diff --git a/tests/test_sandbox_retry.py b/tests/test_sandbox_retry.py new file mode 100644 index 0000000000..de43f3e98a --- /dev/null +++ b/tests/test_sandbox_retry.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +import asyncio +from typing import cast + +import pytest + +from agents.sandbox.util.retry import ( + BackoffStrategy, + exception_chain_contains_type, + exception_chain_has_status_code, + iter_exception_chain, + retry_async, +) + + +class _ErrorWithHttpMetadata(Exception): + def __init__( + self, + message: str, + *, + status_code: int | None = None, + http_code: int | None = None, + response_status_code: int | None = None, + ) -> None: + super().__init__(message) + self.status_code = status_code + self.http_code = http_code + if response_status_code is not None: + self.response = type("_Response", (), {"status_code": response_status_code})() + + +def test_iter_exception_chain_supports_context_and_stops_on_cycles() -> None: + outer = RuntimeError("outer") + inner = ValueError("inner") + outer.__context__ = inner + + assert list(iter_exception_chain(outer)) == [outer, inner] + + cyclical_outer = RuntimeError("cyclical-outer") + cyclical_inner = ValueError("cyclical-inner") + cyclical_outer.__cause__ = cyclical_inner + cyclical_inner.__cause__ = cyclical_outer + + assert list(iter_exception_chain(cyclical_outer)) == [cyclical_outer, cyclical_inner] + + +def test_exception_chain_helpers_detect_types_and_status_codes() -> None: + outer = RuntimeError("outer") + inner = _ErrorWithHttpMetadata("inner", response_status_code=504) + outer.__cause__ = inner + + assert exception_chain_contains_type(outer, ()) is False + assert exception_chain_contains_type(outer, (_ErrorWithHttpMetadata,)) is True + assert exception_chain_contains_type(outer, (LookupError,)) is False + + assert exception_chain_has_status_code( + _ErrorWithHttpMetadata("status", status_code=500), + {500}, + ) + assert exception_chain_has_status_code( + _ErrorWithHttpMetadata("http", http_code=502), + {502}, + ) + assert exception_chain_has_status_code(outer, {504}) + assert exception_chain_has_status_code(outer, {503}) is False + + +def test_retry_async_validates_configuration() -> None: + with pytest.raises(ValueError, match="max_attempt must be >= 1"): + retry_async(max_attempt=0, retry_if=lambda _exc: True) + + with pytest.raises(ValueError, match="interval must be >= 0"): + retry_async(interval=-1, retry_if=lambda _exc: True) + + with pytest.raises(ValueError, match="backoff must be"): + retry_async( + backoff=cast(BackoffStrategy, "quadratic"), + retry_if=lambda _exc: True, + ) + + +@pytest.mark.parametrize( + ("backoff", "expected_delays"), + [ + (BackoffStrategy.FIXED, [0.5, 0.5]), + (BackoffStrategy.LINEAR, [0.5, 1.0]), + (BackoffStrategy.EXPONENTIAL, [0.5, 1.0]), + ], +) +@pytest.mark.asyncio +async def test_retry_async_retries_with_expected_backoff_and_async_hook( + monkeypatch: pytest.MonkeyPatch, + backoff: BackoffStrategy, + expected_delays: list[float], +) -> None: + sleep_delays: list[float] = [] + hook_calls: list[tuple[int, int, float]] = [] + attempts = 0 + + async def fake_sleep(delay: float) -> None: + sleep_delays.append(delay) + + async def on_retry( + _exc: Exception, + attempt: int, + max_attempt: int, + delay_s: float, + *_args: object, + **_kwargs: object, + ) -> None: + hook_calls.append((attempt, max_attempt, delay_s)) + + monkeypatch.setattr(asyncio, "sleep", fake_sleep) + + @retry_async( + interval=0.5, + max_attempt=3, + backoff=backoff, + retry_if=lambda exc, *_args, **_kwargs: isinstance(exc, RuntimeError), + on_retry=on_retry, + ) + async def flaky(label: str) -> str: + nonlocal attempts + attempts += 1 + if attempts < 3: + raise RuntimeError(label) + return f"ok:{label}" + + result = await flaky("sandbox") + + assert result == "ok:sandbox" + assert attempts == 3 + assert sleep_delays == expected_delays + assert hook_calls == [(1, 3, expected_delays[0]), (2, 3, expected_delays[1])] + assert str(backoff) == backoff.value + + +@pytest.mark.asyncio +async def test_retry_async_stops_without_sleep_when_retry_is_rejected( + monkeypatch: pytest.MonkeyPatch, +) -> None: + attempts = 0 + + async def fail_sleep(_delay: float) -> None: + raise AssertionError("sleep should not be called") + + monkeypatch.setattr(asyncio, "sleep", fail_sleep) + + @retry_async( + interval=0.5, + max_attempt=3, + backoff=BackoffStrategy.EXPONENTIAL, + retry_if=lambda _exc, *_args, **_kwargs: False, + on_retry=lambda *_args, **_kwargs: None, + ) + async def always_fail() -> None: + nonlocal attempts + attempts += 1 + raise RuntimeError("stop") + + with pytest.raises(RuntimeError, match="stop"): + await always_fail() + + assert attempts == 1 diff --git a/tests/test_sandbox_runtime.py b/tests/test_sandbox_runtime.py new file mode 100644 index 0000000000..72141c4747 --- /dev/null +++ b/tests/test_sandbox_runtime.py @@ -0,0 +1,3363 @@ +from __future__ import annotations + +import asyncio +import io +import json +import os +import shutil +import sys +import tempfile +import uuid +from collections.abc import Sequence +from pathlib import Path +from typing import Any, Literal, TypedDict, cast + +import pytest +from openai.types.responses.response_output_item import LocalShellCall, LocalShellCallAction +from openai.types.responses.response_reasoning_item import ResponseReasoningItem, Summary + +from agents import Agent, AgentHooks, LocalShellTool, RunHooks, Runner, function_tool +from agents.exceptions import InputGuardrailTripwireTriggered, UserError +from agents.guardrail import GuardrailFunctionOutput, InputGuardrail, OutputGuardrail +from agents.items import ModelResponse, ToolCallOutputItem, TResponseInputItem +from agents.model_settings import ModelSettings +from agents.prompts import GenerateDynamicPromptData, Prompt +from agents.run import CallModelData, ModelInputData, RunConfig +from agents.run_context import AgentHookContext, RunContextWrapper +from agents.run_state import RunState, _build_agent_identity_map +from agents.sandbox import Manifest, SandboxAgent, SandboxRunConfig +from agents.sandbox.capabilities import Capability +from agents.sandbox.codex_config import ( + CodexConfig, + apply_codex_to_manifest, + apply_codex_to_session_state, +) +from agents.sandbox.entries import BaseEntry, File +from agents.sandbox.errors import ExecNonZeroError, ExecTransportError, InvalidManifestPathError +from agents.sandbox.materialization import MaterializedFile +from agents.sandbox.runtime import SandboxRuntime +from agents.sandbox.runtime_session_manager import SandboxRuntimeSessionManager +from agents.sandbox.sandboxes import unix_local as unix_local_module +from agents.sandbox.sandboxes.unix_local import ( + UnixLocalSandboxClient, + UnixLocalSandboxSession, + UnixLocalSandboxSessionState, +) +from agents.sandbox.session.base_sandbox_session import BaseSandboxSession +from agents.sandbox.session.sandbox_client import BaseSandboxClient +from agents.sandbox.session.sandbox_session import SandboxSession +from agents.sandbox.session.sandbox_session_state import SandboxSessionState +from agents.sandbox.snapshot import LocalSnapshotSpec, NoopSnapshot, SnapshotBase +from agents.sandbox.types import ExecResult, User +from agents.stream_events import RunItemStreamEvent +from agents.tool import Tool +from tests.fake_model import FakeModel +from tests.test_responses import ( + get_final_output_message, + get_function_tool, + get_function_tool_call, + get_handoff_tool_call, +) +from tests.utils.simple_session import SimpleListSession + + +class _FakeSession(BaseSandboxSession): + def __init__( + self, + manifest: Manifest, + *, + start_gate: asyncio.Event | None = None, + ) -> None: + self.state = SandboxSessionState( + manifest=manifest, + snapshot=NoopSnapshot(id=str(uuid.uuid4())), + ) + self._start_gate = start_gate + self._running = False + self.start_calls = 0 + self.stop_calls = 0 + self.shutdown_calls = 0 + self.close_dependency_calls = 0 + + async def start(self) -> None: + self.start_calls += 1 + if self._start_gate is not None: + await self._start_gate.wait() + self._running = True + + async def stop(self) -> None: + self.stop_calls += 1 + self._running = False + + async def shutdown(self) -> None: + self.shutdown_calls += 1 + + async def running(self) -> bool: + return self._running + + async def read(self, path: Path) -> io.BytesIO: + _ = path + raise AssertionError("read() should not be called in these tests") + + async def write(self, path: Path, data: io.IOBase) -> None: + _ = (path, data) + raise AssertionError("write() should not be called in these tests") + + async def _exec_internal( + self, + *command: str | Path, + timeout: float | None = None, + ) -> ExecResult: + _ = (command, timeout) + raise AssertionError("exec() should not be called in these tests") + + async def persist_workspace(self) -> io.IOBase: + return io.BytesIO() + + async def hydrate_workspace(self, data: io.IOBase) -> None: + _ = data + + async def _aclose_dependencies(self) -> None: + self.close_dependency_calls += 1 + await super()._aclose_dependencies() + + +class _FailingStopSession(_FakeSession): + async def stop(self) -> None: + await super().stop() + raise RuntimeError("stop failed") + + +class _LiveSessionDeltaRecorder(_FakeSession): + def __init__(self, manifest: Manifest, *, fail_entry_batch_times: int = 0) -> None: + super().__init__(manifest) + self.apply_manifest_calls = 0 + self.applied_entry_batches: list[list[tuple[Path, BaseEntry]]] = [] + self._fail_entry_batch_times = fail_entry_batch_times + + async def apply_manifest(self, *, only_ephemeral: bool = False): + _ = only_ephemeral + self.apply_manifest_calls += 1 + raise AssertionError("apply_manifest() should not be used for running injected sessions") + + async def _apply_entry_batch( + self, + entries: Sequence[tuple[Path, BaseEntry]], + *, + base_dir: Path, + ) -> list[MaterializedFile]: + _ = base_dir + self.applied_entry_batches.append( + [(dest, artifact.model_copy(deep=True)) for dest, artifact in entries] + ) + if self._fail_entry_batch_times > 0: + self._fail_entry_batch_times -= 1 + raise RuntimeError("delta apply failed") + return [] + + +class _BlockingStopSession(_FakeSession): + def __init__(self, manifest: Manifest, stop_gate: asyncio.Event) -> None: + super().__init__(manifest) + self._stop_gate = stop_gate + + async def stop(self) -> None: + await super().stop() + await self._stop_gate.wait() + + +class _MarkerSnapshot(SnapshotBase): + __test__ = False + type: Literal["marker"] = "marker" + marker: str = "initial" + + async def persist(self, data: io.IOBase) -> None: + _ = data + + async def restore(self) -> io.IOBase: + return io.BytesIO() + + async def restorable(self) -> bool: + return False + + +class _PersistingStopSession(_BlockingStopSession): + def __init__(self, manifest: Manifest, stop_gate: asyncio.Event) -> None: + super().__init__(manifest, stop_gate) + self.state.snapshot = _MarkerSnapshot(id="marker") + + async def stop(self) -> None: + self.stop_calls += 1 + self._running = False + await self._stop_gate.wait() + snapshot = cast(_MarkerSnapshot, self.state.snapshot) + snapshot.marker = "persisted" + + +class _ProvisioningFailureSession(_FakeSession): + async def _exec_internal( + self, + *command: str | Path, + timeout: float | None = None, + ) -> ExecResult: + _ = timeout + cmd = [str(part) for part in command] + if cmd[:2] == ["mkdir", "-p"]: + return ExecResult(stdout=b"", stderr=b"", exit_code=0) + if cmd and cmd[0] in {"groupadd", "useradd"}: + return ExecResult(stdout=b"", stderr=f"missing {cmd[0]}".encode(), exit_code=1) + return ExecResult(stdout=b"", stderr=b"", exit_code=0) + + +class _RestorableSnapshot(SnapshotBase): + __test__ = False + type: Literal["restorable"] = "restorable" + + async def persist(self, data: io.IOBase) -> None: + _ = data + + async def restore(self) -> io.IOBase: + return io.BytesIO(b"snapshot") + + async def restorable(self) -> bool: + return True + + +class _RestorableProvisioningFailureSession(_ProvisioningFailureSession): + def __init__(self, manifest: Manifest, *, provision_on_resume: bool = True) -> None: + super().__init__(manifest) + self.state.snapshot = _RestorableSnapshot(id="resume") + self.cleared_workspace_root = False + self.hydrate_calls = 0 + self._provision_on_resume = provision_on_resume + + async def start(self) -> None: + self.start_calls += 1 + self._running = True + await BaseSandboxSession.start(self) + + async def hydrate_workspace(self, data: io.IOBase) -> None: + _ = data + self.hydrate_calls += 1 + + async def _clear_workspace_root_on_resume(self) -> None: + self.cleared_workspace_root = True + + def should_provision_manifest_accounts_on_resume(self) -> bool: + return self._provision_on_resume + + +@pytest.mark.asyncio +async def test_sandbox_session_aclose_runs_public_cleanup_lifecycle() -> None: + inner = _FakeSession(Manifest()) + session = SandboxSession(inner) + + await session.aclose() + + assert inner.stop_calls == 1 + assert inner.shutdown_calls == 1 + assert inner.close_dependency_calls == 1 + + +@pytest.mark.asyncio +async def test_sandbox_session_aclose_closes_dependencies_when_stop_fails() -> None: + inner = _FailingStopSession(Manifest()) + session = SandboxSession(inner) + + with pytest.raises(RuntimeError, match="stop failed"): + await session.aclose() + + assert inner.stop_calls == 1 + assert inner.shutdown_calls == 0 + assert inner.close_dependency_calls == 1 + + +def _extract_user_text(item: dict[str, object]) -> str: + content = item["content"] + if isinstance(content, str): + return content + if isinstance(content, list): + first = content[0] + if isinstance(first, dict): + return str(first.get("text", "")) + raise AssertionError(f"Unexpected content payload: {content!r}") + + +def _tripwire_input_guardrail( + _context: RunContextWrapper[Any], + _agent: Agent[Any], + _input: str | list[TResponseInputItem], +) -> GuardrailFunctionOutput: + return GuardrailFunctionOutput(output_info=None, tripwire_triggered=True) + + +def _get_reasoning_item() -> ResponseReasoningItem: + return ResponseReasoningItem( + id="rid", + type="reasoning", + summary=[Summary(text="thinking", type="summary_text")], + ) + + +class _CreateKwargs(TypedDict): + snapshot: object | None + manifest: Manifest | None + codex: bool | CodexConfig + options: dict[str, str] + + +class _FakeClient(BaseSandboxClient[dict[str, str]]): + backend_id = "fake" + + def __init__(self, session: _FakeSession) -> None: + self.inner_session = session + self.session = self._wrap_session(session) + self.create_kwargs: _CreateKwargs | None = None + self.resume_state: SandboxSessionState | None = None + self.delete_calls = 0 + + async def create( + self, + *, + snapshot: object | None = None, + manifest: Manifest | None = None, + codex: bool | CodexConfig = False, + options: dict[str, str], + ) -> SandboxSession: + base_manifest = manifest if manifest is not None else self.inner_session.state.manifest + self.create_kwargs = { + "snapshot": snapshot, + "manifest": apply_codex_to_manifest(base_manifest, codex), + "codex": codex, + "options": options, + } + if self.create_kwargs["manifest"] is not None: + self.inner_session.state.manifest = self.create_kwargs["manifest"] + return self.session + + async def delete(self, session: SandboxSession) -> SandboxSession: + self.delete_calls += 1 + return session + + async def resume( + self, + state: SandboxSessionState, + *, + codex: bool | CodexConfig = False, + ) -> SandboxSession: + self.resume_state = apply_codex_to_session_state(state, codex) + self.inner_session.state = self.resume_state + return self.session + + def deserialize_session_state(self, payload: dict[str, object]) -> SandboxSessionState: + return SandboxSessionState.model_validate(payload) + + +class _ManifestSessionClient(BaseSandboxClient[None]): + backend_id = "manifest" + supports_default_options = True + + def __init__(self) -> None: + self.created_manifests: list[Manifest | None] = [] + + async def create( + self, + *, + snapshot: object | None = None, + manifest: Manifest | None = None, + codex: bool | CodexConfig = False, + options: None = None, + ) -> SandboxSession: + _ = (snapshot, options) + manifest = apply_codex_to_manifest(manifest, codex) + self.created_manifests.append(manifest) + assert manifest is not None + session = _FakeSession(manifest) + return self._wrap_session(session) + + async def delete(self, session: SandboxSession) -> SandboxSession: + return session + + async def resume( + self, + state: SandboxSessionState, + *, + codex: bool | CodexConfig = False, + ) -> SandboxSession: + resumed_state = apply_codex_to_session_state(state, codex) + return self._wrap_session(_FakeSession(resumed_state.manifest)) + + def deserialize_session_state(self, payload: dict[str, object]) -> SandboxSessionState: + return SandboxSessionState.model_validate(payload) + + +class _RecordingCapability(Capability): + type: str = "recording" + bound_session: BaseSandboxSession | None = None + instruction_text: str | None = None + provided_tools: list[Tool] + + def __init__( + self, + *, + instruction_text: str | None = None, + provided_tools: list[Tool] | None = None, + ) -> None: + super().__init__(type="recording") + self.bound_session = None + self.instruction_text = instruction_text + self.provided_tools = list(provided_tools or []) + + def bind(self, session: BaseSandboxSession) -> None: + self.bound_session = session + + def tools(self) -> list[Tool]: + return list(self.provided_tools) + + async def instructions(self, manifest: Manifest) -> str | None: + _ = manifest + return self.instruction_text + + +class _NestedStateCapability(Capability): + type: str = "nested-state" + state: dict[str, list[str]] + + def __init__(self) -> None: + super().__init__(type="nested-state") + self.state = {"seen": []} + + +class _NestedObjectState: + def __init__(self) -> None: + self.seen: list[str] = [] + + +class _NestedObjectCapability(Capability): + type: str = "nested-object-state" + state: _NestedObjectState + + def __init__(self) -> None: + super().__init__(type="nested-object-state") + self.state = _NestedObjectState() + + +class _AwaitableSessionCapability(Capability): + type: str = "awaitable-session" + bound_session: BaseSandboxSession | None = None + release_gate: asyncio.Event + first_instruction_started: asyncio.Event + second_instruction_started: asyncio.Event + + def __init__( + self, + *, + release_gate: asyncio.Event, + first_instruction_started: asyncio.Event, + second_instruction_started: asyncio.Event, + ) -> None: + super().__init__(type="awaitable-session") + self.bound_session = None + self.release_gate = release_gate + self.first_instruction_started = first_instruction_started + self.second_instruction_started = second_instruction_started + + def bind(self, session: BaseSandboxSession) -> None: + self.bound_session = session + + async def instructions(self, manifest: Manifest) -> str | None: + _ = manifest + assert self.bound_session is not None + readme = self.bound_session.state.manifest.entries["README.md"] + assert isinstance(readme, File) + readme_text = readme.content.decode() + if readme_text == "Session one instructions.": + self.first_instruction_started.set() + elif readme_text == "Session two instructions.": + self.second_instruction_started.set() + await self.release_gate.wait() + return readme_text + + +class _ManifestInstructionsCapability(Capability): + type: str = "manifest-instructions" + bound_session: BaseSandboxSession | None = None + + def __init__(self) -> None: + super().__init__(type="manifest-instructions") + self.bound_session = None + + def bind(self, session: BaseSandboxSession) -> None: + self.bound_session = session + + async def instructions(self, manifest: Manifest) -> str | None: + _ = manifest + assert self.bound_session is not None + readme = self.bound_session.state.manifest.entries["README.md"] + assert isinstance(readme, File) + return readme.content.decode() + + +class _ManifestMutationCapability(Capability): + type: str = "manifest-mutation" + + def __init__(self, *, rel_path: str = "cap.txt", content: bytes = b"capability") -> None: + super().__init__(type="manifest-mutation") + self.rel_path = rel_path + self.content = content + + def process_manifest(self, manifest: Manifest) -> Manifest: + manifest.entries[self.rel_path] = File(content=self.content) + return manifest + + +class _ManifestUsersCapability(Capability): + type: str = "manifest-users" + + def __init__(self) -> None: + super().__init__(type="manifest-users") + + def process_manifest(self, manifest: Manifest) -> Manifest: + manifest.users.append(User(name="sandbox-user")) + return manifest + + +class _SessionFileCapability(Capability): + type: str = "session-files" + bound_session: BaseSandboxSession | None = None + + def __init__(self) -> None: + super().__init__(type="session-files") + self.bound_session = None + + def bind(self, session: BaseSandboxSession) -> None: + self.bound_session = session + + def tools(self) -> list[Tool]: + @function_tool(name_override="write_file") + async def write_file(path: str, content: str) -> str: + assert self.bound_session is not None + await self.bound_session.write(Path(path), io.BytesIO(content.encode("utf-8"))) + return "wrote" + + @function_tool(name_override="read_file") + async def read_file(path: str) -> str: + assert self.bound_session is not None + data = await self.bound_session.read(Path(path)) + return cast(bytes, data.read()).decode("utf-8") + + return [write_file, read_file] + + +class _RecordingRunHooks(RunHooks[None]): + def __init__(self) -> None: + self.started_agents: list[Agent[None]] = [] + self.ended_agents: list[Agent[None]] = [] + self.llm_started_agents: list[Agent[None]] = [] + self.llm_ended_agents: list[Agent[None]] = [] + + async def on_agent_start(self, context: AgentHookContext[None], agent: Agent[None]) -> None: + _ = context + self.started_agents.append(agent) + + async def on_llm_start( + self, + context: RunContextWrapper[None], + agent: Agent[None], + system_prompt: str | None, + input_items: list[TResponseInputItem], + ) -> None: + _ = (context, system_prompt, input_items) + self.llm_started_agents.append(agent) + + async def on_llm_end( + self, + context: RunContextWrapper[None], + agent: Agent[None], + response: ModelResponse, + ) -> None: + _ = (context, response) + self.llm_ended_agents.append(agent) + + async def on_agent_end( + self, + context: AgentHookContext[None], + agent: Agent[None], + output: object, + ) -> None: + _ = (context, output) + self.ended_agents.append(agent) + + +class _RecordingAgentHooks(AgentHooks[None]): + def __init__(self) -> None: + self.started_agents: list[Agent[None]] = [] + self.ended_agents: list[Agent[None]] = [] + self.llm_started_agents: list[Agent[None]] = [] + self.llm_ended_agents: list[Agent[None]] = [] + + async def on_start(self, context: AgentHookContext[None], agent: Agent[None]) -> None: + _ = context + self.started_agents.append(agent) + + async def on_llm_start( + self, + context: RunContextWrapper[None], + agent: Agent[None], + system_prompt: str | None, + input_items: list[TResponseInputItem], + ) -> None: + _ = (context, system_prompt, input_items) + self.llm_started_agents.append(agent) + + async def on_llm_end( + self, + context: RunContextWrapper[None], + agent: Agent[None], + response: ModelResponse, + ) -> None: + _ = (context, response) + self.llm_ended_agents.append(agent) + + async def on_end( + self, + context: AgentHookContext[None], + agent: Agent[None], + output: object, + ) -> None: + _ = (context, output) + self.ended_agents.append(agent) + + +def _sandbox_run_config(client: _FakeClient | None = None) -> RunConfig: + return RunConfig( + sandbox=SandboxRunConfig( + client=client, + options={"image": "sandbox"} if client is not None else None, + ) + ) + + +def _unix_local_manifest(**kwargs: Any) -> Manifest: + return Manifest(**kwargs) + + +def _unix_local_run_config( + *, + client: UnixLocalSandboxClient | None = None, + session_state: SandboxSessionState | None = None, + manifest: Manifest | None = None, +) -> RunConfig: + sandbox_kwargs: dict[str, Any] = { + "client": client or UnixLocalSandboxClient(), + } + if session_state is not None: + sandbox_kwargs["session_state"] = session_state + else: + sandbox_kwargs["manifest"] = manifest or _unix_local_manifest() + return RunConfig(sandbox=SandboxRunConfig(**sandbox_kwargs)) + + +@pytest.mark.asyncio +async def test_runner_merges_sandbox_instructions_and_tools() -> None: + model = FakeModel(initial_output=[get_final_output_message("done")]) + capability_tool = get_function_tool("capability_tool", "ok") + capability = _RecordingCapability( + instruction_text="Capability instructions.", + provided_tools=[capability_tool], + ) + manifest = Manifest(entries={"README.md": File(content=b"Follow the repo contract.")}) + session = _FakeSession(manifest) + client = _FakeClient(session) + agent = SandboxAgent( + name="sandbox", + model=model, + instructions="Base instructions.", + developer_instructions="Developer instructions.", + default_manifest=manifest, + capabilities=[capability], + ) + + result = await Runner.run( + agent, + "hello", + run_config=_sandbox_run_config(client), + ) + + assert result.final_output == "done" + assert capability.bound_session is None + assert session.start_calls == 1 + assert session.stop_calls == 1 + assert session.shutdown_calls == 1 + assert session.close_dependency_calls == 1 + assert client.delete_calls == 1 + + state = result.to_state() + assert state._sandbox is not None + assert state._sandbox["backend_id"] == "fake" + assert state._sandbox["current_agent_name"] == agent.name + assert state._sandbox["current_agent_key"] == agent.name + sessions_by_agent = state._sandbox["sessions_by_agent"] + assert isinstance(sessions_by_agent, dict) + assert sessions_by_agent[agent.name] == { + "agent_name": agent.name, + "session_state": state._sandbox["session_state"], + } + + assert client.create_kwargs is not None + assert client.create_kwargs["manifest"] is not manifest + assert client.create_kwargs["options"] == {"image": "sandbox"} + assert isinstance(client.create_kwargs["snapshot"], LocalSnapshotSpec) + + assert model.first_turn_args is not None + assert model.first_turn_args["system_instructions"] == ( + "Base instructions.\n\nDeveloper instructions.\n\nCapability instructions." + ) + assert [tool.name for tool in model.first_turn_args["tools"]] == ["capability_tool"] + + input_items = model.first_turn_args["input"] + assert isinstance(input_items, list) + assert _extract_user_text(input_items[0]) == "hello" + + +@pytest.mark.asyncio +async def test_runner_requires_sandbox_config_for_sandbox_agent() -> None: + agent = SandboxAgent( + name="sandbox", + model=FakeModel(initial_output=[get_final_output_message("done")]), + instructions="Base instructions.", + ) + + with pytest.raises(UserError, match="RunConfig\\(sandbox=.*\\)"): + await Runner.run(agent, "hello") + + +@pytest.mark.asyncio +async def test_runner_streamed_cleans_runner_owned_session() -> None: + model = FakeModel(initial_output=[get_final_output_message("done")]) + session = _FakeSession(Manifest()) + client = _FakeClient(session) + agent = SandboxAgent( + name="sandbox", + model=model, + instructions="Base instructions.", + ) + + result = Runner.run_streamed( + agent, + "hello", + run_config=_sandbox_run_config(client), + ) + events = [event async for event in result.stream_events()] + + assert events + assert result.final_output == "done" + assert session.start_calls == 1 + assert session.stop_calls == 1 + assert session.shutdown_calls == 1 + assert session.close_dependency_calls == 1 + assert client.delete_calls == 1 + + state = result.to_state() + assert state._sandbox is not None + assert state._sandbox["backend_id"] == "fake" + assert state._sandbox["current_agent_name"] == agent.name + assert state._sandbox["current_agent_key"] == agent.name + sessions_by_agent = state._sandbox["sessions_by_agent"] + assert isinstance(sessions_by_agent, dict) + assert sessions_by_agent[agent.name] == { + "agent_name": agent.name, + "session_state": state._sandbox["session_state"], + } + + +@pytest.mark.asyncio +async def test_runner_streamed_guardrail_trip_blocks_runner_owned_sandbox_creation() -> None: + session = _FakeSession(Manifest()) + client = _FakeClient(session) + agent = SandboxAgent( + name="sandbox", + model=FakeModel(initial_output=[get_final_output_message("done")]), + instructions="Base instructions.", + input_guardrails=[ + InputGuardrail( + guardrail_function=_tripwire_input_guardrail, + run_in_parallel=False, + ) + ], + ) + + with pytest.raises(InputGuardrailTripwireTriggered): + result = Runner.run_streamed(agent, "hello", run_config=_sandbox_run_config(client)) + async for _ in result.stream_events(): + pass + + assert client.create_kwargs is None + assert session.start_calls == 0 + assert session.stop_calls == 0 + assert session.shutdown_calls == 0 + assert session.close_dependency_calls == 0 + + +@pytest.mark.asyncio +async def test_runner_does_not_close_injected_sandbox_session() -> None: + model = FakeModel(initial_output=[get_final_output_message("done")]) + default_manifest = Manifest(entries={"default.txt": File(content=b"default")}) + session_manifest = Manifest(entries={"session.txt": File(content=b"session")}) + injected_session = _FakeSession(session_manifest) + agent = SandboxAgent( + name="sandbox", + model=model, + instructions="Base instructions.", + default_manifest=default_manifest, + codex=False, + ) + + result = await Runner.run( + agent, + "hello", + run_config=RunConfig( + sandbox=SandboxRunConfig( + session=injected_session, + manifest=Manifest(entries={"override.txt": File(content=b"override")}), + ) + ), + ) + + assert result.final_output == "done" + assert injected_session.start_calls == 1 + assert injected_session.stop_calls == 0 + assert injected_session.shutdown_calls == 0 + assert injected_session.close_dependency_calls == 0 + + assert model.first_turn_args is not None + input_items = model.first_turn_args["input"] + assert isinstance(input_items, str) or isinstance(input_items, list) + assert injected_session.state.manifest.entries == session_manifest.entries + + +@pytest.mark.asyncio +async def test_runner_does_not_restart_running_injected_sandbox_session() -> None: + model = FakeModel(initial_output=[get_final_output_message("done")]) + injected_session = _FakeSession(Manifest(entries={"session.txt": File(content=b"session")})) + injected_session._running = True + agent = SandboxAgent( + name="sandbox", + model=model, + instructions="Base instructions.", + codex=False, + ) + + result = await Runner.run( + agent, + "hello", + run_config=RunConfig(sandbox=SandboxRunConfig(session=injected_session)), + ) + + assert result.final_output == "done" + assert injected_session.start_calls == 0 + assert injected_session.stop_calls == 0 + assert injected_session.shutdown_calls == 0 + + +@pytest.mark.asyncio +async def test_runner_passes_codex_requirement_to_client_created_sessions() -> None: + session = _FakeSession(Manifest()) + client = _FakeClient(session) + agent = SandboxAgent( + name="sandbox", + model=FakeModel(initial_output=[get_final_output_message("done")]), + instructions="Base instructions.", + ) + + result = await Runner.run(agent, "hello", run_config=_sandbox_run_config(client)) + + assert result.final_output == "done" + assert client.create_kwargs is not None + assert client.create_kwargs["codex"] is True + manifest = client.create_kwargs["manifest"] + assert manifest is not None + manifest_paths = {manifest._coerce_rel_path(path) for path in manifest.entries} + assert Path(".codex_bin/codex") in manifest_paths + + +@pytest.mark.asyncio +async def test_runner_rejects_injected_session_missing_required_codex() -> None: + injected_session = _FakeSession(Manifest(entries={"session.txt": File(content=b"session")})) + agent = SandboxAgent( + name="sandbox", + model=FakeModel(initial_output=[get_final_output_message("done")]), + instructions="Base instructions.", + ) + + with pytest.raises(UserError, match="missing Codex"): + await Runner.run( + agent, + "hello", + run_config=RunConfig(sandbox=SandboxRunConfig(session=injected_session)), + ) + + +@pytest.mark.asyncio +async def test_runner_guardrail_trip_blocks_runner_owned_sandbox_creation() -> None: + session = _FakeSession(Manifest()) + client = _FakeClient(session) + agent = SandboxAgent( + name="sandbox", + model=FakeModel(initial_output=[get_final_output_message("done")]), + instructions="Base instructions.", + input_guardrails=[ + InputGuardrail( + guardrail_function=_tripwire_input_guardrail, + run_in_parallel=False, + ) + ], + ) + + with pytest.raises(InputGuardrailTripwireTriggered): + await Runner.run(agent, "hello", run_config=_sandbox_run_config(client)) + + assert client.create_kwargs is None + assert session.start_calls == 0 + assert session.stop_calls == 0 + assert session.shutdown_calls == 0 + assert session.close_dependency_calls == 0 + + +@pytest.mark.asyncio +async def test_runner_guardrail_trip_blocks_running_injected_session_mutation() -> None: + live_session = _LiveSessionDeltaRecorder(Manifest()) + live_session._running = True + agent = SandboxAgent( + name="sandbox", + model=FakeModel(initial_output=[get_final_output_message("done")]), + instructions="Base instructions.", + capabilities=[_ManifestMutationCapability()], + input_guardrails=[ + InputGuardrail( + guardrail_function=_tripwire_input_guardrail, + run_in_parallel=False, + ) + ], + ) + + with pytest.raises(InputGuardrailTripwireTriggered): + await Runner.run( + agent, + "hello", + run_config=RunConfig(sandbox=SandboxRunConfig(session=live_session)), + ) + + assert "cap.txt" not in live_session.state.manifest.entries + assert live_session.start_calls == 0 + assert live_session.applied_entry_batches == [] + assert live_session.stop_calls == 0 + assert live_session.shutdown_calls == 0 + + +@pytest.mark.asyncio +async def test_runner_streamed_guardrail_trip_blocks_running_injected_session_mutation() -> None: + live_session = _LiveSessionDeltaRecorder(Manifest()) + live_session._running = True + agent = SandboxAgent( + name="sandbox", + model=FakeModel(initial_output=[get_final_output_message("done")]), + instructions="Base instructions.", + capabilities=[_ManifestMutationCapability()], + input_guardrails=[ + InputGuardrail( + guardrail_function=_tripwire_input_guardrail, + run_in_parallel=False, + ) + ], + ) + + with pytest.raises(InputGuardrailTripwireTriggered): + result = Runner.run_streamed( + agent, + "hello", + run_config=RunConfig(sandbox=SandboxRunConfig(session=live_session)), + ) + async for _ in result.stream_events(): + pass + + assert "cap.txt" not in live_session.state.manifest.entries + assert live_session.start_calls == 0 + assert live_session.applied_entry_batches == [] + assert live_session.stop_calls == 0 + assert live_session.shutdown_calls == 0 + + +@pytest.mark.asyncio +async def test_runner_uses_public_sandbox_agent_for_dynamic_instructions() -> None: + model = FakeModel(initial_output=[get_final_output_message("done")]) + session = _FakeSession(Manifest()) + client = _FakeClient(session) + seen_agents: list[Agent[Any]] = [] + + def dynamic_instructions(_ctx: RunContextWrapper[Any], current_agent: Agent[Any]) -> str: + seen_agents.append(current_agent) + return "Saw public agent." if current_agent is agent else "Saw execution clone." + + agent = SandboxAgent( + name="sandbox", + model=model, + instructions=dynamic_instructions, + capabilities=[ + _RecordingCapability( + instruction_text="Capability instructions.", + provided_tools=[get_function_tool("capability_tool", "ok")], + ) + ], + ) + + result = await Runner.run(agent, "hello", run_config=_sandbox_run_config(client)) + + assert result.final_output == "done" + assert seen_agents == [agent] + assert model.first_turn_args is not None + assert model.first_turn_args["system_instructions"] == ( + "Saw public agent.\n\nCapability instructions." + ) + + +@pytest.mark.asyncio +async def test_runner_uses_public_sandbox_agent_for_dynamic_prompts() -> None: + seen_agents: list[Agent[Any]] = [] + + def dynamic_prompt(data: GenerateDynamicPromptData) -> Prompt: + seen_agents.append(data.agent) + return {"id": "prompt_test", "variables": {"agent_name": data.agent.name}} + + agent = SandboxAgent( + name="sandbox", + model=FakeModel(initial_output=[get_final_output_message("done")]), + instructions="Base instructions.", + prompt=dynamic_prompt, + capabilities=[_RecordingCapability(instruction_text="Capability instructions.")], + ) + + result = await Runner.run( + agent, "hello", run_config=_sandbox_run_config(_FakeClient(_FakeSession(Manifest()))) + ) + + assert result.final_output == "done" + assert seen_agents == [agent] + + streamed_agent = SandboxAgent( + name="streamed-sandbox", + model=FakeModel(initial_output=[get_final_output_message("streamed done")]), + instructions="Base instructions.", + prompt=dynamic_prompt, + capabilities=[_RecordingCapability(instruction_text="Capability instructions.")], + ) + streamed = Runner.run_streamed( + streamed_agent, + "hello", + run_config=_sandbox_run_config(_FakeClient(_FakeSession(Manifest()))), + ) + async for _ in streamed.stream_events(): + pass + + assert streamed.final_output == "streamed done" + assert seen_agents == [agent, streamed_agent] + + +@pytest.mark.asyncio +async def test_runner_uses_public_agent_for_call_model_input_filter() -> None: + seen_agents: list[Agent[Any]] = [] + + def capture_model_input(data: CallModelData[Any]) -> ModelInputData: + seen_agents.append(data.agent) + return data.model_data + + agent = SandboxAgent( + name="sandbox", + model=FakeModel(initial_output=[get_final_output_message("done")]), + instructions="Base instructions.", + capabilities=[_RecordingCapability(instruction_text="Capability instructions.")], + ) + + result = await Runner.run( + agent, + "hello", + run_config=RunConfig( + sandbox=SandboxRunConfig( + client=_FakeClient(_FakeSession(Manifest())), + options={"image": "sandbox"}, + ), + call_model_input_filter=capture_model_input, + ), + ) + + assert result.final_output == "done" + assert seen_agents == [agent] + + +@pytest.mark.asyncio +async def test_runner_streamed_uses_public_agent_for_call_model_input_filter() -> None: + seen_agents: list[Agent[Any]] = [] + + def capture_model_input(data: CallModelData[Any]) -> ModelInputData: + seen_agents.append(data.agent) + return data.model_data + + agent = SandboxAgent( + name="sandbox", + model=FakeModel(initial_output=[get_final_output_message("done")]), + instructions="Base instructions.", + capabilities=[_RecordingCapability(instruction_text="Capability instructions.")], + ) + + result = Runner.run_streamed( + agent, + "hello", + run_config=RunConfig( + sandbox=SandboxRunConfig( + client=_FakeClient(_FakeSession(Manifest())), + options={"image": "sandbox"}, + ), + call_model_input_filter=capture_model_input, + ), + ) + events = [event async for event in result.stream_events()] + + assert events + assert result.final_output == "done" + assert seen_agents == [agent] + + +@pytest.mark.asyncio +async def test_runner_reuses_prepared_sandbox_agent_across_turns_for_tool_choice_reset() -> None: + model = FakeModel() + tool = get_function_tool("capability_tool", "ok") + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("capability_tool", json.dumps({}))], + [get_final_output_message("done")], + ] + ) + session = _FakeSession(Manifest()) + client = _FakeClient(session) + agent = SandboxAgent( + name="sandbox", + model=model, + instructions="Base instructions.", + tools=[tool], + model_settings=ModelSettings(tool_choice="required"), + ) + + result = await Runner.run(agent, "hello", run_config=_sandbox_run_config(client)) + + assert result.final_output == "done" + assert model.first_turn_args is not None + assert model.first_turn_args["model_settings"].tool_choice == "required" + assert model.last_turn_args["model_settings"].tool_choice is None + + +@pytest.mark.asyncio +async def test_runner_rebuilds_sandbox_resources_for_handoff_target_agent() -> None: + triage_model = FakeModel() + worker_model = FakeModel(initial_output=[get_final_output_message("done")]) + client = _ManifestSessionClient() + triage_manifest = Manifest(entries={"README.md": File(content=b"Triage workspace")}) + worker_manifest = Manifest(entries={"README.md": File(content=b"Worker workspace")}) + worker = SandboxAgent( + name="worker", + model=worker_model, + instructions="Worker instructions.", + default_manifest=worker_manifest, + capabilities=[_ManifestInstructionsCapability()], + ) + triage = SandboxAgent( + name="triage", + model=triage_model, + instructions="Triage instructions.", + default_manifest=triage_manifest, + capabilities=[_ManifestInstructionsCapability()], + handoffs=[worker], + ) + triage_model.turn_outputs = [[get_handoff_tool_call(worker)]] + + result = await Runner.run( + triage, + "route this", + run_config=RunConfig(sandbox=SandboxRunConfig(client=client)), + ) + + assert result.final_output == "done" + assert len(client.created_manifests) == 2 + assert client.created_manifests[0] is not None + assert client.created_manifests[1] is not None + assert ( + client.created_manifests[0].entries["README.md"] + != client.created_manifests[1].entries["README.md"] + ) + assert worker_model.first_turn_args is not None + assert worker_model.first_turn_args["system_instructions"] == ( + "Worker instructions.\n\nWorker workspace" + ) + + +@pytest.mark.asyncio +async def test_runner_resumed_handoff_materializes_manifest_for_new_sandbox_agent() -> None: + triage_model = FakeModel() + worker_model = FakeModel(initial_output=[get_final_output_message("done")]) + client = _ManifestSessionClient() + + @function_tool(name_override="approval_tool", needs_approval=True) + def approval_tool() -> str: + return "approved" + + triage_manifest = Manifest(entries={"README.md": File(content=b"Triage workspace")}) + worker_manifest = Manifest(entries={"README.md": File(content=b"Worker workspace")}) + worker = SandboxAgent( + name="worker", + model=worker_model, + instructions="Worker instructions.", + default_manifest=worker_manifest, + capabilities=[_ManifestInstructionsCapability()], + ) + triage = SandboxAgent( + name="triage", + model=triage_model, + instructions="Triage instructions.", + default_manifest=triage_manifest, + tools=[approval_tool], + capabilities=[_ManifestInstructionsCapability()], + handoffs=[worker], + ) + triage_model.add_multiple_turn_outputs( + [ + [get_function_tool_call("approval_tool", json.dumps({}), call_id="call_resume")], + [get_handoff_tool_call(worker)], + ] + ) + + first_run = await Runner.run( + triage, + "route this", + run_config=RunConfig(sandbox=SandboxRunConfig(client=client)), + ) + + assert len(first_run.interruptions) == 1 + state = first_run.to_state() + state.approve(first_run.interruptions[0]) + + resumed = await Runner.run( + triage, + state, + run_config=RunConfig(sandbox=SandboxRunConfig(client=client)), + ) + + assert resumed.final_output == "done" + assert len(client.created_manifests) == 2 + assert client.created_manifests[1] is not None + assert worker_model.first_turn_args is not None + assert worker_model.first_turn_args["system_instructions"] == ( + "Worker instructions.\n\nWorker workspace" + ) + + +@pytest.mark.asyncio +async def test_unix_local_client_rewrites_default_manifest_root_to_temp_workspace() -> None: + client = UnixLocalSandboxClient() + manifest = _unix_local_manifest(entries={"default.txt": File(content=b"default")}) + + session = await client.create(manifest=manifest, options=None) + workspace_root = Path(session.state.manifest.root) + try: + session_manifest = session.state.manifest + session_state = cast(UnixLocalSandboxSessionState, session.state) + + assert session_manifest is not manifest + assert session_manifest.entries == manifest.entries + assert session_manifest.root != manifest.root + assert workspace_root.is_absolute() + assert workspace_root.name.startswith("uc-local-") + assert session_state.workspace_root_owned is True + assert manifest.root == "/workspace" + finally: + await client.delete(session) + assert not workspace_root.exists() + + +@pytest.mark.asyncio +async def test_runner_allows_fresh_unix_local_sessions_without_options() -> None: + agent = SandboxAgent( + name="sandbox", + model=FakeModel(initial_output=[get_final_output_message("done")]), + instructions="Base instructions.", + codex=False, + ) + + result = await Runner.run( + agent, + "hello", + run_config=_unix_local_run_config(), + ) + + assert result.final_output == "done" + + +@pytest.mark.asyncio +async def test_unix_local_client_delete_preserves_caller_owned_workspace_root() -> None: + client = UnixLocalSandboxClient() + workspace_root = Path(tempfile.mkdtemp(prefix="caller-owned-")) + manifest = _unix_local_manifest(root=str(workspace_root)) + + session = await client.create(manifest=manifest, options=None) + assert cast(UnixLocalSandboxSessionState, session.state).workspace_root_owned is False + + await client.delete(session) + + assert workspace_root.exists() + shutil.rmtree(workspace_root) + + +@pytest.mark.asyncio +async def test_unix_local_runner_cleanup_preserves_resumed_caller_owned_workspace_root() -> None: + workspace_root = Path(tempfile.mkdtemp(prefix="resumed-owned-")) + state = UnixLocalSandboxSessionState( + session_id=uuid.uuid4(), + manifest=_unix_local_manifest(root=str(workspace_root)), + snapshot=NoopSnapshot(id=str(uuid.uuid4())), + ) + agent = SandboxAgent( + name="sandbox", + model=FakeModel(initial_output=[get_final_output_message("done")]), + instructions="Base instructions.", + codex=False, + ) + + try: + result = await Runner.run( + agent, + "hello", + run_config=_unix_local_run_config(session_state=state), + ) + finally: + assert workspace_root.exists() + shutil.rmtree(workspace_root) + + assert result.final_output == "done" + + +@pytest.mark.asyncio +async def test_unix_local_read_and_write_reject_paths_outside_workspace_root() -> None: + client = UnixLocalSandboxClient() + workspace_root = Path(tempfile.mkdtemp(prefix="workspace-root-")) + session = await client.create( + manifest=_unix_local_manifest(root=str(workspace_root)), + options=None, + ) + + try: + with pytest.raises(InvalidManifestPathError, match="must not escape root"): + await session.write(Path("../secret.txt"), io.BytesIO(b"nope")) + with pytest.raises(InvalidManifestPathError, match="must not escape root"): + await session.read(Path("../secret.txt")) + finally: + await client.delete(session) + shutil.rmtree(workspace_root) + + +@pytest.mark.asyncio +async def test_unix_local_rm_recursive_ignores_missing_paths() -> None: + client = UnixLocalSandboxClient() + workspace_root = Path(tempfile.mkdtemp(prefix="workspace-root-")) + session = await client.create( + manifest=_unix_local_manifest(root=str(workspace_root)), + options=None, + ) + + try: + await session.rm("missing-dir", recursive=True) + finally: + await client.delete(session) + shutil.rmtree(workspace_root) + + +@pytest.mark.asyncio +async def test_unix_local_rm_non_recursive_still_errors_for_missing_paths() -> None: + client = UnixLocalSandboxClient() + workspace_root = Path(tempfile.mkdtemp(prefix="workspace-root-")) + session = await client.create( + manifest=_unix_local_manifest(root=str(workspace_root)), + options=None, + ) + + try: + with pytest.raises(ExecNonZeroError): + await session.rm("missing-dir") + finally: + await client.delete(session) + shutil.rmtree(workspace_root) + + +@pytest.mark.asyncio +async def test_runner_streamed_ignores_sandbox_cleanup_failures_after_success() -> None: + session = _FailingStopSession(Manifest()) + client = _FakeClient(session) + agent = SandboxAgent( + name="sandbox", + model=FakeModel(initial_output=[get_final_output_message("done")]), + instructions="Base instructions.", + ) + + result = Runner.run_streamed(agent, "hello", run_config=_sandbox_run_config(client)) + events = [event async for event in result.stream_events()] + + assert events + assert result.final_output == "done" + assert result._sandbox_session is None + + +@pytest.mark.asyncio +async def test_runner_omits_sandbox_resume_state_when_cleanup_fails() -> None: + session = _FailingStopSession(Manifest()) + client = _FakeClient(session) + agent = SandboxAgent( + name="sandbox", + model=FakeModel(initial_output=[get_final_output_message("done")]), + instructions="Base instructions.", + ) + + result = await Runner.run(agent, "hello", run_config=_sandbox_run_config(client)) + state = result.to_state() + + assert result.final_output == "done" + assert result._sandbox_resume_state is None + assert result._sandbox_session is None + assert state._sandbox is None + + +@pytest.mark.asyncio +async def test_runner_clears_sandbox_session_from_non_streamed_results_after_cleanup() -> None: + session = _FakeSession(Manifest()) + client = _FakeClient(session) + agent = SandboxAgent( + name="sandbox", + model=FakeModel(initial_output=[get_final_output_message("done")]), + instructions="Base instructions.", + ) + + result = await Runner.run(agent, "hello", run_config=_sandbox_run_config(client)) + + assert result.final_output == "done" + assert result._sandbox_session is None + + +@pytest.mark.asyncio +async def test_runner_streamed_cleans_sandbox_once_after_stream_completion() -> None: + session = _FakeSession(Manifest()) + client = _FakeClient(session) + agent = SandboxAgent( + name="sandbox", + model=FakeModel(initial_output=[get_final_output_message("done")]), + instructions="Base instructions.", + ) + + result = Runner.run_streamed(agent, "hello", run_config=_sandbox_run_config(client)) + events = [event async for event in result.stream_events()] + await asyncio.sleep(0) + + assert events + assert result.final_output == "done" + assert result._sandbox_session is None + assert session.stop_calls == 1 + assert session.shutdown_calls == 1 + assert session.close_dependency_calls == 1 + assert client.delete_calls == 1 + + +@pytest.mark.asyncio +async def test_runner_uses_public_agent_for_non_streaming_output_guardrails() -> None: + seen_agents: list[Agent[None]] = [] + + async def output_guardrail( + _context: RunContextWrapper[None], + guardrail_agent: Agent[None], + _output: object, + ) -> GuardrailFunctionOutput: + seen_agents.append(guardrail_agent) + return GuardrailFunctionOutput(output_info=None, tripwire_triggered=False) + + agent = SandboxAgent( + name="sandbox", + model=FakeModel(initial_output=[get_final_output_message("done")]), + instructions="Base instructions.", + capabilities=[_RecordingCapability(instruction_text="Capability instructions.")], + output_guardrails=[OutputGuardrail(guardrail_function=output_guardrail)], + ) + + result = await Runner.run( + agent, "hello", run_config=_sandbox_run_config(_FakeClient(_FakeSession(Manifest()))) + ) + + assert result.final_output == "done" + assert seen_agents == [agent] + + +@pytest.mark.asyncio +async def test_runner_streamed_immediate_cancel_skips_waiting_for_sandbox_cleanup() -> None: + stop_gate = asyncio.Event() + session = _BlockingStopSession(Manifest(), stop_gate) + client = _FakeClient(session) + agent = SandboxAgent( + name="sandbox", + model=FakeModel(initial_output=[get_final_output_message("done")]), + instructions="Base instructions.", + ) + + result = Runner.run_streamed(agent, "hello", run_config=_sandbox_run_config(client)) + + async def consume_with_cancel() -> None: + async for _event in result.stream_events(): + result.cancel(mode="immediate") + break + + try: + await asyncio.wait_for(consume_with_cancel(), timeout=0.2) + finally: + stop_gate.set() + await asyncio.sleep(0) + + +@pytest.mark.asyncio +async def test_runner_streamed_run_loop_task_waits_for_sandbox_cleanup_and_persisted_state() -> ( + None +): + stop_gate = asyncio.Event() + session = _PersistingStopSession(Manifest(), stop_gate) + client = _FakeClient(session) + model = FakeModel() + model.add_multiple_turn_outputs( + [ + [get_final_output_message("done")], + [get_final_output_message("again")], + ] + ) + agent = SandboxAgent( + name="sandbox", + model=model, + instructions="Base instructions.", + ) + run_config = _sandbox_run_config(client) + + result = Runner.run_streamed(agent, "hello", run_config=run_config) + assert result.run_loop_task is not None + + while session.stop_calls == 0: + await asyncio.sleep(0) + + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(asyncio.shield(result.run_loop_task), timeout=0.05) + + stop_gate.set() + await result.run_loop_task + + state = result.to_state() + assert state._sandbox is not None + session_state = state._sandbox["session_state"] + assert isinstance(session_state, dict) + snapshot = session_state["snapshot"] + assert isinstance(snapshot, dict) + assert snapshot["marker"] == "persisted" + + second = await Runner.run(agent, "again", run_config=run_config) + + assert second.final_output == "again" + + +@pytest.mark.asyncio +async def test_runner_rejects_unix_local_manifest_user_and_group_provisioning() -> None: + workspace_root = Path(tempfile.mkdtemp(prefix="unix-local-users-")) + session = await UnixLocalSandboxClient().create( + manifest=_unix_local_manifest( + root=str(workspace_root), + users=[User(name="sandbox-user")], + ), + options=None, + ) + + try: + with pytest.raises(ValueError, match="does not support manifest users or groups"): + await session.start() + finally: + shutil.rmtree(workspace_root) + + +@pytest.mark.asyncio +async def test_runner_persists_workspace_and_tool_choice_state_across_sandbox_resume() -> None: + client = UnixLocalSandboxClient() + file_capability = _SessionFileCapability() + + @function_tool(name_override="approval_tool", needs_approval=True) + def approval_tool() -> str: + return "approved" + + model = FakeModel() + model.add_multiple_turn_outputs( + [ + [ + get_function_tool_call( + "write_file", + json.dumps({"path": "note.txt", "content": "persist me"}), + call_id="call_write", + ) + ], + [ + get_function_tool_call( + "approval_tool", + json.dumps({}), + call_id="call_approval", + ) + ], + ] + ) + agent = SandboxAgent( + name="sandbox", + model=model, + instructions="Base instructions.", + tools=[approval_tool], + capabilities=[file_capability], + model_settings=ModelSettings(tool_choice="required"), + codex=False, + ) + + first_run = await Runner.run( + agent, + "hello", + run_config=_unix_local_run_config(client=client), + ) + + assert len(first_run.interruptions) == 1 + state = first_run.to_state() + assert state._sandbox is not None + assert state._sandbox["backend_id"] == "unix_local" + session_state = state._sandbox["session_state"] + assert isinstance(session_state, dict) + snapshot_payload = session_state.get("snapshot") + assert isinstance(snapshot_payload, dict) + assert snapshot_payload.get("type") == "local" + sessions_by_agent = state._sandbox["sessions_by_agent"] + assert isinstance(sessions_by_agent, dict) + assert sessions_by_agent[agent.name] == { + "agent_name": agent.name, + "session_state": session_state, + } + + state_json = state.to_json() + resumed_model = FakeModel() + resumed_model.add_multiple_turn_outputs( + [ + [ + get_function_tool_call( + "read_file", + json.dumps({"path": "note.txt"}), + call_id="call_read", + ) + ], + [get_final_output_message("done")], + ] + ) + resumed_agent = SandboxAgent( + name="sandbox", + model=resumed_model, + instructions="Base instructions.", + tools=[approval_tool], + capabilities=[_SessionFileCapability()], + model_settings=ModelSettings(tool_choice="required"), + codex=False, + ) + + restored_state = await RunState.from_json(resumed_agent, state_json) + restored_state.approve(restored_state.get_interruptions()[0]) + resumed = await Runner.run( + resumed_agent, + restored_state, + run_config=_unix_local_run_config(client=client), + ) + + assert resumed.final_output == "done" + assert resumed_model.last_turn_args["model_settings"].tool_choice is None + assert any( + isinstance(item, ToolCallOutputItem) + and item.output == "persist me" + and item.agent is resumed_agent + for item in resumed.new_items + ) + + +@pytest.mark.asyncio +async def test_runner_restores_all_sandbox_agents_from_run_state_across_handoffs() -> None: + client = UnixLocalSandboxClient() + file_capability = _SessionFileCapability() + + @function_tool(name_override="approval_tool", needs_approval=True) + def approval_tool() -> str: + return "approved" + + triage_model = FakeModel() + worker_model = FakeModel() + worker = SandboxAgent( + name="worker", + model=worker_model, + instructions="Worker instructions.", + tools=[approval_tool], + codex=False, + ) + triage = SandboxAgent( + name="triage", + model=triage_model, + instructions="Triage instructions.", + capabilities=[file_capability], + handoffs=[worker], + codex=False, + ) + worker.handoffs = [triage] + triage_model.add_multiple_turn_outputs( + [ + [ + get_function_tool_call( + "write_file", + json.dumps({"path": "note.txt", "content": "persist triage"}), + call_id="call_write", + ) + ], + [get_handoff_tool_call(worker)], + ] + ) + worker_model.add_multiple_turn_outputs( + [ + [get_function_tool_call("approval_tool", json.dumps({}), call_id="call_approval")], + ] + ) + + first_run = await Runner.run( + triage, + "hello", + run_config=_unix_local_run_config(client=client), + ) + + assert len(first_run.interruptions) == 1 + state = first_run.to_state() + assert state._sandbox is not None + assert state._sandbox["backend_id"] == "unix_local" + assert state._sandbox["current_agent_name"] == worker.name + sessions_by_agent = state._sandbox["sessions_by_agent"] + assert isinstance(sessions_by_agent, dict) + assert set(sessions_by_agent) == {triage.name, worker.name} + + state_json = state.to_json() + resumed_triage_model = FakeModel() + resumed_worker_model = FakeModel() + resumed_worker = SandboxAgent( + name="worker", + model=resumed_worker_model, + instructions="Worker instructions.", + tools=[approval_tool], + codex=False, + ) + resumed_triage = SandboxAgent( + name="triage", + model=resumed_triage_model, + instructions="Triage instructions.", + capabilities=[_SessionFileCapability()], + handoffs=[resumed_worker], + codex=False, + ) + resumed_worker.handoffs = [resumed_triage] + resumed_worker_model.add_multiple_turn_outputs([[get_handoff_tool_call(resumed_triage)]]) + resumed_triage_model.add_multiple_turn_outputs( + [ + [ + get_function_tool_call( + "read_file", + json.dumps({"path": "note.txt"}), + call_id="call_read", + ) + ], + [get_final_output_message("done")], + ] + ) + + restored_state = await RunState.from_json(resumed_triage, state_json) + restored_state.approve(restored_state.get_interruptions()[0]) + resumed = await Runner.run( + resumed_triage, + restored_state, + run_config=_unix_local_run_config(client=client), + ) + + assert resumed.final_output == "done" + assert any( + isinstance(item, ToolCallOutputItem) + and item.output == "persist triage" + and item.agent is resumed_triage + for item in resumed.new_items + ) + + +@pytest.mark.asyncio +async def test_runner_serializes_unique_sandbox_resume_keys_for_duplicate_agent_names() -> None: + client = UnixLocalSandboxClient() + file_capability = _SessionFileCapability() + + @function_tool(name_override="approval_tool", needs_approval=True) + def approval_tool() -> str: + return "approved" + + first_model = FakeModel() + second_model = FakeModel() + first = SandboxAgent( + name="sandbox", + model=first_model, + instructions="First instructions.", + capabilities=[file_capability], + codex=False, + ) + second = SandboxAgent( + name="sandbox", + model=second_model, + instructions="Second instructions.", + tools=[approval_tool], + codex=False, + ) + first.handoffs = [second] + second.handoffs = [first] + first_model.add_multiple_turn_outputs( + [ + [ + get_function_tool_call( + "write_file", + json.dumps({"path": "note.txt", "content": "first"}), + call_id="call_write", + ) + ], + [get_handoff_tool_call(second)], + [ + get_function_tool_call( + "read_file", + json.dumps({"path": "note.txt"}), + call_id="call_read", + ) + ], + [get_final_output_message("done")], + ] + ) + second_model.add_multiple_turn_outputs( + [ + [get_function_tool_call("approval_tool", json.dumps({}), call_id="call_approval")], + [get_handoff_tool_call(first)], + ] + ) + + first_run = await Runner.run( + first, + "hello", + run_config=_unix_local_run_config(client=client), + ) + + state = first_run.to_state() + assert state._sandbox is not None + sessions_by_agent = cast(dict[str, dict[str, object]], state._sandbox["sessions_by_agent"]) + assert len(sessions_by_agent) == 2 + assert state._sandbox["current_agent_key"] in sessions_by_agent + + state.approve(first_run.interruptions[0]) + resumed = await Runner.run( + first, + state, + run_config=_unix_local_run_config(client=client), + ) + + assert resumed.final_output == "done" + assert any( + isinstance(item, ToolCallOutputItem) and item.output == "first" and item.agent is first + for item in resumed.new_items + ) + + +def test_duplicate_name_sandbox_identity_map_uses_capability_and_manifest_config() -> None: + """Duplicate-name sandbox identities should stay stable when only sandbox config differs.""" + + def _make_agent(readme: bytes, capability_text: str) -> SandboxAgent[None]: + return SandboxAgent( + name="sandbox", + model=FakeModel(), + instructions="Base instructions.", + default_manifest=Manifest(entries={"README.md": File(content=readme)}), + capabilities=[_RecordingCapability(instruction_text=capability_text)], + ) + + def _identity_for(identity_map: dict[str, Agent[Any]], target: Agent[Any]) -> str: + return next(identity for identity, agent in identity_map.items() if agent is target) + + first_alpha = _make_agent(b"alpha", "Alpha capability.") + first_beta = _make_agent(b"beta", "Beta capability.") + first_root = Agent(name="triage", handoffs=[first_beta, first_alpha]) + first_alpha.handoffs = [first_root] + first_beta.handoffs = [first_root] + + second_alpha = _make_agent(b"alpha", "Alpha capability.") + second_beta = _make_agent(b"beta", "Beta capability.") + second_root = Agent(name="triage", handoffs=[second_alpha, second_beta]) + second_alpha.handoffs = [second_root] + second_beta.handoffs = [second_root] + + first_identity_map = _build_agent_identity_map(first_root) + second_identity_map = _build_agent_identity_map(second_root) + + assert _identity_for(first_identity_map, first_alpha) == _identity_for( + second_identity_map, second_alpha + ) + assert _identity_for(first_identity_map, first_beta) == _identity_for( + second_identity_map, second_beta + ) + + +@pytest.mark.asyncio +async def test_session_manager_reserves_current_duplicate_resume_key_for_current_agent() -> None: + manifest = Manifest(entries={"README.md": File(content=b"duplicate resume")}) + client = _FakeClient(_FakeSession(manifest)) + first = SandboxAgent(name="sandbox", model=FakeModel(), instructions="First.") + second = SandboxAgent(name="sandbox", model=FakeModel(), instructions="Second.") + first.handoffs = [second] + second.handoffs = [first] + first_session_state = client.serialize_session_state( + SandboxSessionState(manifest=Manifest(), snapshot=NoopSnapshot(id="first")) + ) + second_session_state = client.serialize_session_state( + SandboxSessionState(manifest=Manifest(), snapshot=NoopSnapshot(id="second")) + ) + run_state: RunState[Any, Agent[Any]] = cast( + RunState[Any, Agent[Any]], + RunState( + context=RunContextWrapper(context={}), + original_input="hello", + starting_agent=first, + ), + ) + run_state._current_agent = second + run_state._sandbox = { + "backend_id": "fake", + "current_agent_key": "sandbox#2", + "current_agent_name": second.name, + "session_state": second_session_state, + "sessions_by_agent": { + "sandbox": {"agent_name": first.name, "session_state": first_session_state}, + "sandbox#2": {"agent_name": second.name, "session_state": second_session_state}, + }, + } + manager = SandboxRuntimeSessionManager( + starting_agent=first, + sandbox_config=SandboxRunConfig(client=client, options={"image": "sandbox"}), + run_state=run_state, + ) + + assert ( + manager._resume_state_payload_for_agent(client=client, agent=first, agent_id=id(first)) + == first_session_state + ) + assert ( + manager._resume_state_payload_for_agent(client=client, agent=second, agent_id=id(second)) + == second_session_state + ) + + +def test_session_manager_generates_collision_free_resume_keys_for_literal_suffix_names() -> None: + client = _FakeClient(_FakeSession(Manifest())) + first = SandboxAgent(name="sandbox", model=FakeModel(), instructions="First.") + literal_suffix = SandboxAgent(name="sandbox#2", model=FakeModel(), instructions="Literal.") + second = SandboxAgent(name="sandbox", model=FakeModel(), instructions="Second.") + first.handoffs = [literal_suffix, second] + literal_suffix.handoffs = [first, second] + second.handoffs = [first, literal_suffix] + manager = SandboxRuntimeSessionManager( + starting_agent=first, + sandbox_config=SandboxRunConfig(client=client, options={"image": "sandbox"}), + run_state=None, + ) + + manager.acquire_agent(first) + manager.acquire_agent(literal_suffix) + manager.acquire_agent(second) + + assert manager._ensure_resume_key(first) == "sandbox" + assert manager._ensure_resume_key(literal_suffix) == "sandbox#2" + assert manager._ensure_resume_key(second) == "sandbox#3" + + +@pytest.mark.asyncio +async def test_session_manager_preserves_untouched_run_state_sessions_on_cleanup() -> None: + manifest = Manifest(entries={"README.md": File(content=b"duplicate resume")}) + client = _FakeClient(_FakeSession(manifest)) + triage = SandboxAgent(name="triage", model=FakeModel(), instructions="Triage.", codex=False) + worker = SandboxAgent(name="worker", model=FakeModel(), instructions="Worker.", codex=False) + triage.handoffs = [worker] + worker.handoffs = [triage] + triage_session_state = client.serialize_session_state( + SandboxSessionState(manifest=Manifest(), snapshot=NoopSnapshot(id="triage")) + ) + worker_session_state = client.serialize_session_state( + SandboxSessionState(manifest=Manifest(), snapshot=NoopSnapshot(id="worker")) + ) + run_state: RunState[Any, Agent[Any]] = cast( + RunState[Any, Agent[Any]], + RunState( + context=RunContextWrapper(context={}), + original_input="hello", + starting_agent=triage, + ), + ) + run_state._current_agent = worker + run_state._sandbox = { + "backend_id": "fake", + "current_agent_key": worker.name, + "current_agent_name": worker.name, + "session_state": worker_session_state, + "sessions_by_agent": { + triage.name: {"agent_name": triage.name, "session_state": triage_session_state}, + worker.name: {"agent_name": worker.name, "session_state": worker_session_state}, + }, + } + manager = SandboxRuntimeSessionManager( + starting_agent=triage, + sandbox_config=SandboxRunConfig(client=client, options={"image": "sandbox"}), + run_state=run_state, + ) + + manager.acquire_agent(worker) + await manager.ensure_session(agent=worker, capabilities=[], is_resumed_state=True) + payload = await manager.cleanup() + + assert payload is not None + sessions_by_agent = cast(dict[str, dict[str, object]], payload["sessions_by_agent"]) + assert set(sessions_by_agent) == {triage.name, worker.name} + assert sessions_by_agent[triage.name] == { + "agent_name": triage.name, + "session_state": triage_session_state, + } + assert sessions_by_agent[worker.name] == { + "agent_name": worker.name, + "session_state": worker_session_state, + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize("resume_source", ["run_state", "session_state"]) +async def test_session_manager_reapplies_capability_manifest_mutations_on_resume( + resume_source: str, +) -> None: + client = _FakeClient(_FakeSession(Manifest())) + capability = _ManifestMutationCapability() + agent = SandboxAgent(name="worker", model=FakeModel(), instructions="Worker.") + session_state = SandboxSessionState( + manifest=Manifest(), + snapshot=NoopSnapshot(id="resume"), + ) + + run_state: RunState[Any, Agent[Any]] | None = None + if resume_source == "run_state": + run_state = cast( + RunState[Any, Agent[Any]], + RunState( + context=RunContextWrapper(context={}), + original_input="hello", + starting_agent=agent, + ), + ) + run_state._current_agent = agent + serialized_state = client.serialize_session_state(session_state) + run_state._sandbox = { + "backend_id": client.backend_id, + "current_agent_key": agent.name, + "current_agent_name": agent.name, + "session_state": serialized_state, + "sessions_by_agent": { + agent.name: { + "agent_name": agent.name, + "session_state": serialized_state, + } + }, + } + sandbox_config = SandboxRunConfig(client=client, options={"image": "sandbox"}) + else: + sandbox_config = SandboxRunConfig( + client=client, + session_state=session_state, + options={"image": "sandbox"}, + ) + + manager = SandboxRuntimeSessionManager( + starting_agent=agent, + sandbox_config=sandbox_config, + run_state=run_state, + ) + + manager.acquire_agent(agent) + session = await manager.ensure_session( + agent=agent, + capabilities=[capability], + is_resumed_state=True, + ) + + assert session.state.manifest.entries["cap.txt"] == File(content=b"capability") + assert client.resume_state is not None + assert client.resume_state.manifest.entries["cap.txt"] == File(content=b"capability") + + +@pytest.mark.asyncio +@pytest.mark.parametrize("source", ["live_session", "session_state", "create"]) +async def test_session_manager_applies_capability_manifest_mutations_with_session_parity( + source: str, +) -> None: + capability = _ManifestMutationCapability() + agent = SandboxAgent(name="worker", model=FakeModel(), instructions="Worker.", codex=False) + run_state: RunState[Any, Agent[Any]] | None = None + + if source == "live_session": + live_session = _FakeSession(Manifest()) + sandbox_config = SandboxRunConfig(session=live_session) + else: + client = _FakeClient(_FakeSession(Manifest())) + if source == "session_state": + sandbox_config = SandboxRunConfig( + client=client, + session_state=SandboxSessionState( + manifest=Manifest(), + snapshot=NoopSnapshot(id="resume"), + ), + options={"image": "sandbox"}, + ) + else: + sandbox_config = SandboxRunConfig( + client=client, + manifest=Manifest(), + options={"image": "sandbox"}, + ) + + manager = SandboxRuntimeSessionManager( + starting_agent=agent, + sandbox_config=sandbox_config, + run_state=run_state, + ) + + manager.acquire_agent(agent) + session = await manager.ensure_session( + agent=agent, + capabilities=[capability], + is_resumed_state=False, + ) + + assert session.state.manifest.entries["cap.txt"] == File(content=b"capability") + if source == "session_state": + assert client.resume_state is not None + assert client.resume_state.manifest.entries["cap.txt"] == File(content=b"capability") + if source == "create": + assert client.create_kwargs is not None + manifest = client.create_kwargs["manifest"] + assert manifest is not None + assert manifest.entries["cap.txt"] == File(content=b"capability") + + +@pytest.mark.asyncio +async def test_session_manager_starts_stopped_injected_session_with_manifest_mutation() -> None: + live_session = _LiveSessionDeltaRecorder(Manifest()) + capability = _ManifestMutationCapability() + agent = SandboxAgent(name="worker", model=FakeModel(), instructions="Worker.", codex=False) + manager = SandboxRuntimeSessionManager( + starting_agent=agent, + sandbox_config=SandboxRunConfig(session=live_session), + run_state=None, + ) + + manager.acquire_agent(agent) + session = await manager.ensure_session( + agent=agent, + capabilities=[capability], + is_resumed_state=False, + ) + payload = await manager.cleanup() + + assert session is live_session + assert live_session.start_calls == 1 + assert live_session.apply_manifest_calls == 0 + assert live_session.stop_calls == 0 + assert live_session.shutdown_calls == 0 + assert session.state.manifest.entries["cap.txt"] == File(content=b"capability") + assert payload is None + + +@pytest.mark.asyncio +async def test_session_manager_materializes_running_injected_session_manifest_mutation() -> None: + live_session = _LiveSessionDeltaRecorder(Manifest()) + live_session._running = True + capability = _ManifestMutationCapability() + agent = SandboxAgent(name="worker", model=FakeModel(), instructions="Worker.", codex=False) + manager = SandboxRuntimeSessionManager( + starting_agent=agent, + sandbox_config=SandboxRunConfig(session=live_session), + run_state=None, + ) + + manager.acquire_agent(agent) + session = await manager.ensure_session( + agent=agent, + capabilities=[capability], + is_resumed_state=False, + ) + payload = await manager.cleanup() + + assert session is live_session + assert live_session.start_calls == 0 + assert live_session.apply_manifest_calls == 0 + assert live_session.applied_entry_batches == [ + [(Path("/workspace/cap.txt"), File(content=b"capability"))] + ] + assert session.state.manifest.entries["cap.txt"] == File(content=b"capability") + assert live_session.stop_calls == 0 + assert live_session.shutdown_calls == 0 + assert payload is None + + +@pytest.mark.asyncio +async def test_session_manager_retries_running_injected_session_delta_apply_after_failure() -> None: + live_session = _LiveSessionDeltaRecorder(Manifest(), fail_entry_batch_times=1) + live_session._running = True + capability = _ManifestMutationCapability() + agent = SandboxAgent(name="worker", model=FakeModel(), instructions="Worker.", codex=False) + manager = SandboxRuntimeSessionManager( + starting_agent=agent, + sandbox_config=SandboxRunConfig(session=live_session), + run_state=None, + ) + + manager.acquire_agent(agent) + with pytest.raises(RuntimeError, match="delta apply failed"): + await manager.ensure_session( + agent=agent, + capabilities=[capability], + is_resumed_state=False, + ) + + assert live_session.state.manifest.entries == {} + assert live_session.applied_entry_batches == [ + [(Path("/workspace/cap.txt"), File(content=b"capability"))] + ] + + session = await manager.ensure_session( + agent=agent, + capabilities=[capability], + is_resumed_state=False, + ) + payload = await manager.cleanup() + + assert session is live_session + assert live_session.state.manifest.entries["cap.txt"] == File(content=b"capability") + assert live_session.applied_entry_batches == [ + [(Path("/workspace/cap.txt"), File(content=b"capability"))], + [(Path("/workspace/cap.txt"), File(content=b"capability"))], + ] + assert payload is None + + +@pytest.mark.asyncio +async def test_session_manager_skips_rematerialization_for_unchanged_running_session() -> None: + live_session = _LiveSessionDeltaRecorder(Manifest()) + live_session._running = True + agent = SandboxAgent(name="worker", model=FakeModel(), instructions="Worker.", codex=False) + manager = SandboxRuntimeSessionManager( + starting_agent=agent, + sandbox_config=SandboxRunConfig(session=live_session), + run_state=None, + ) + + manager.acquire_agent(agent) + session = await manager.ensure_session( + agent=agent, + capabilities=[Capability(type="noop")], + is_resumed_state=False, + ) + payload = await manager.cleanup() + + assert session is live_session + assert live_session.start_calls == 0 + assert live_session.apply_manifest_calls == 0 + assert live_session.applied_entry_batches == [] + assert session.state.manifest.entries == {} + assert live_session.stop_calls == 0 + assert live_session.shutdown_calls == 0 + assert payload is None + + +@pytest.mark.asyncio +async def test_session_manager_rejects_running_injected_session_account_mutation() -> None: + live_session = _LiveSessionDeltaRecorder(Manifest()) + live_session._running = True + agent = SandboxAgent(name="worker", model=FakeModel(), instructions="Worker.", codex=False) + manager = SandboxRuntimeSessionManager( + starting_agent=agent, + sandbox_config=SandboxRunConfig(session=live_session), + run_state=None, + ) + + manager.acquire_agent(agent) + with pytest.raises(ValueError, match="manifest.users` or `manifest.groups"): + await manager.ensure_session( + agent=agent, + capabilities=[_ManifestUsersCapability()], + is_resumed_state=False, + ) + + assert live_session.apply_manifest_calls == 0 + assert live_session.applied_entry_batches == [] + assert live_session.state.manifest.users == [] + + +@pytest.mark.asyncio +async def test_session_manager_preserves_existing_payload_when_no_sandbox_session_is_used() -> None: + client = _FakeClient(_FakeSession(Manifest())) + agent = SandboxAgent(name="sandbox", model=FakeModel(), instructions="Base instructions.") + run_state: RunState[Any, Agent[Any]] = cast( + RunState[Any, Agent[Any]], + RunState( + context=RunContextWrapper(context={}), + original_input="hello", + starting_agent=agent, + ), + ) + existing_payload = { + "backend_id": "fake", + "current_agent_key": agent.name, + "current_agent_name": agent.name, + "session_state": {"snapshot": {"id": "persisted"}}, + "sessions_by_agent": { + agent.name: { + "agent_name": agent.name, + "session_state": {"snapshot": {"id": "persisted"}}, + } + }, + } + run_state._sandbox = existing_payload + manager = SandboxRuntimeSessionManager( + starting_agent=agent, + sandbox_config=SandboxRunConfig(client=client, options={"image": "sandbox"}), + run_state=run_state, + ) + + payload = await manager.cleanup() + + assert payload == existing_payload + assert payload is not existing_payload + + +@pytest.mark.asyncio +async def test_session_manager_uses_run_state_starting_agent_for_duplicate_resume_keys() -> None: + manifest = Manifest(entries={"README.md": File(content=b"duplicate resume")}) + client = _FakeClient(_FakeSession(manifest)) + first = SandboxAgent(name="sandbox", model=FakeModel(), instructions="First.") + second = SandboxAgent(name="sandbox", model=FakeModel(), instructions="Second.") + approver = Agent(name="approver", model=FakeModel(), instructions="Approve.", handoffs=[]) + approver.handoffs = [second, first] + first.handoffs = [second] + second.handoffs = [approver] + first_session_state = client.serialize_session_state( + SandboxSessionState(manifest=Manifest(), snapshot=NoopSnapshot(id="first")) + ) + second_session_state = client.serialize_session_state( + SandboxSessionState(manifest=Manifest(), snapshot=NoopSnapshot(id="second")) + ) + run_state: RunState[Any, Agent[Any]] = cast( + RunState[Any, Agent[Any]], + RunState( + context=RunContextWrapper(context={}), + original_input="hello", + starting_agent=first, + ), + ) + run_state._current_agent = approver + run_state._starting_agent = first + run_state._sandbox = { + "backend_id": "fake", + "current_agent_key": "sandbox#2", + "current_agent_name": second.name, + "session_state": second_session_state, + "sessions_by_agent": { + "sandbox": {"agent_name": first.name, "session_state": first_session_state}, + "sandbox#2": {"agent_name": second.name, "session_state": second_session_state}, + }, + } + manager = SandboxRuntimeSessionManager( + starting_agent=approver, + sandbox_config=SandboxRunConfig(client=client, options={"image": "sandbox"}), + run_state=run_state, + ) + + assert ( + manager._resume_state_payload_for_agent(client=client, agent=first, agent_id=id(first)) + == first_session_state + ) + assert ( + manager._resume_state_payload_for_agent(client=client, agent=second, agent_id=id(second)) + == second_session_state + ) + + +@pytest.mark.asyncio +async def test_session_manager_restores_duplicate_name_sessions_when_only_sandbox_config_differs(): + client = _FakeClient(_FakeSession(Manifest())) + + def _make_agent(readme: bytes, capability_text: str) -> SandboxAgent[None]: + return SandboxAgent( + name="sandbox", + model=FakeModel(), + instructions="Base instructions.", + default_manifest=Manifest(entries={"README.md": File(content=readme)}), + capabilities=[_RecordingCapability(instruction_text=capability_text)], + ) + + first = _make_agent(b"first", "First capability.") + second = _make_agent(b"second", "Second capability.") + root = Agent(name="triage", handoffs=[second, first]) + first.handoffs = [root] + second.handoffs = [root] + + first_session_state = client.serialize_session_state( + SandboxSessionState(manifest=Manifest(), snapshot=NoopSnapshot(id="first")) + ) + second_session_state = client.serialize_session_state( + SandboxSessionState(manifest=Manifest(), snapshot=NoopSnapshot(id="second")) + ) + + state: RunState[Any, Agent[Any]] = cast( + RunState[Any, Agent[Any]], + RunState( + context=RunContextWrapper(context={}), + original_input="hello", + starting_agent=root, + ), + ) + state._current_agent = second + state._sandbox = { + "backend_id": "fake", + "current_agent_key": "sandbox#2", + "current_agent_name": second.name, + "session_state": second_session_state, + "sessions_by_agent": { + "sandbox": {"agent_name": first.name, "session_state": first_session_state}, + "sandbox#2": {"agent_name": second.name, "session_state": second_session_state}, + }, + } + + restored_first = _make_agent(b"first", "First capability.") + restored_second = _make_agent(b"second", "Second capability.") + restored_root = Agent(name="triage", handoffs=[restored_first, restored_second]) + restored_first.handoffs = [restored_root] + restored_second.handoffs = [restored_root] + + restored_state = await RunState.from_json(restored_root, state.to_json()) + assert restored_state._current_agent is restored_second + + manager = SandboxRuntimeSessionManager( + starting_agent=restored_root, + sandbox_config=SandboxRunConfig(client=client, options={"image": "sandbox"}), + run_state=restored_state, + ) + + assert ( + manager._resume_state_payload_for_agent( + client=client, + agent=restored_first, + agent_id=id(restored_first), + ) + == first_session_state + ) + assert ( + manager._resume_state_payload_for_agent( + client=client, + agent=restored_second, + agent_id=id(restored_second), + ) + == second_session_state + ) + + +@pytest.mark.asyncio +async def test_runner_restores_duplicate_name_sandbox_sessions_after_json_roundtrip() -> None: + client = UnixLocalSandboxClient() + file_capability = _SessionFileCapability() + + @function_tool(name_override="approval_tool", needs_approval=True) + def approval_tool() -> str: + return "approved" + + first_model = FakeModel() + second_model = FakeModel() + first = SandboxAgent( + name="sandbox", + model=first_model, + instructions="First instructions.", + capabilities=[file_capability], + codex=False, + ) + second = SandboxAgent( + name="sandbox", + model=second_model, + instructions="Second instructions.", + tools=[approval_tool], + codex=False, + ) + first.handoffs = [second] + second.handoffs = [first] + first_model.add_multiple_turn_outputs( + [ + [ + get_function_tool_call( + "write_file", + json.dumps({"path": "note.txt", "content": "first"}), + call_id="call_write", + ) + ], + [get_handoff_tool_call(second)], + ] + ) + second_model.add_multiple_turn_outputs( + [[get_function_tool_call("approval_tool", json.dumps({}), call_id="call_approval")]] + ) + + first_run = await Runner.run( + first, + "hello", + run_config=_unix_local_run_config(client=client), + ) + + state = first_run.to_state() + state_json = state.to_json() + + resumed_first_model = FakeModel() + resumed_second_model = FakeModel() + resumed_first = SandboxAgent( + name="sandbox", + model=resumed_first_model, + instructions="First instructions.", + capabilities=[_SessionFileCapability()], + codex=False, + ) + resumed_second = SandboxAgent( + name="sandbox", + model=resumed_second_model, + instructions="Second instructions.", + tools=[approval_tool], + codex=False, + ) + resumed_first.handoffs = [resumed_second] + resumed_second.handoffs = [resumed_first] + resumed_second_model.add_multiple_turn_outputs([[get_handoff_tool_call(resumed_first)]]) + resumed_first_model.add_multiple_turn_outputs( + [ + [ + get_function_tool_call( + "read_file", + json.dumps({"path": "note.txt"}), + call_id="call_read", + ) + ], + [get_final_output_message("done")], + ] + ) + + restored_state = await RunState.from_json(resumed_first, state_json) + restored_state.approve(restored_state.get_interruptions()[0]) + resumed = await Runner.run( + resumed_first, + restored_state, + run_config=_unix_local_run_config(client=client), + ) + + assert resumed.final_output == "done" + assert any( + isinstance(item, ToolCallOutputItem) + and item.output == "first" + and item.agent is resumed_first + for item in resumed.new_items + ) + + +@pytest.mark.asyncio +async def test_runner_restores_legacy_current_sandbox_payload_after_json_roundtrip() -> None: + client = UnixLocalSandboxClient() + + @function_tool(name_override="approval_tool", needs_approval=True) + def approval_tool() -> str: + return "approved" + + initial_model = FakeModel() + initial_model.add_multiple_turn_outputs( + [ + [ + get_function_tool_call( + "write_file", json.dumps({"path": "note.txt", "content": "legacy"}) + ) + ], + [get_function_tool_call("approval_tool", json.dumps({}), call_id="call_approval")], + ] + ) + agent = SandboxAgent( + name="sandbox", + model=initial_model, + instructions="Base instructions.", + tools=[approval_tool], + capabilities=[_SessionFileCapability()], + codex=False, + ) + + first_run = await Runner.run( + agent, + "hello", + run_config=_unix_local_run_config(client=client), + ) + state = first_run.to_state() + assert state._sandbox is not None + session_state = cast(dict[str, object], state._sandbox["session_state"]) + state._sandbox = { + "backend_id": "unix_local", + "current_agent_id": id(agent), + "session_state": session_state, + "sessions_by_agent": {str(id(agent)): session_state}, + } + + resumed_model = FakeModel() + resumed_model.add_multiple_turn_outputs( + [ + [ + get_function_tool_call( + "read_file", json.dumps({"path": "note.txt"}), call_id="call_read" + ) + ], + [get_final_output_message("done")], + ] + ) + resumed_agent = SandboxAgent( + name="sandbox", + model=resumed_model, + instructions="Base instructions.", + tools=[approval_tool], + capabilities=[_SessionFileCapability()], + codex=False, + ) + + restored_state = await RunState.from_json(resumed_agent, state.to_json()) + restored_state.approve(restored_state.get_interruptions()[0]) + resumed = await Runner.run( + resumed_agent, + restored_state, + run_config=_unix_local_run_config(client=client), + ) + + assert resumed.final_output == "done" + assert any( + isinstance(item, ToolCallOutputItem) + and item.output == "legacy" + and item.agent is resumed_agent + for item in resumed.new_items + ) + + +@pytest.mark.asyncio +@pytest.mark.skipif( + sys.platform != "darwin" or shutil.which("sandbox-exec") is None, + reason="sandbox-exec is only available on macOS when installed", +) +async def test_unix_local_exec_confines_commands_to_workspace_root() -> None: + workspace_root = Path(tempfile.mkdtemp(prefix="unix-local-exec-")) + session = await UnixLocalSandboxClient().create( + manifest=_unix_local_manifest(root=str(workspace_root)), + options=None, + ) + + try: + async with session: + result = await session.exec("echo hi > note.txt && cat note.txt") + assert result.ok() + assert result.stdout.decode("utf-8", errors="replace").strip().endswith("hi") + + forbidden = await session.exec("cat /etc/passwd >/dev/null") + assert not forbidden.ok() + + outside_write = await session.exec("echo nope > /usr/local/test-codex-sandbox") + assert not outside_write.ok() + + sibling = workspace_root.parent / "escape.txt" + sibling.unlink(missing_ok=True) + escaped = await session.exec("echo nope > ../escape.txt") + assert not escaped.ok() + assert not sibling.exists() + finally: + shutil.rmtree(workspace_root, ignore_errors=True) + + +@pytest.mark.asyncio +async def test_unix_local_exec_rejects_when_confinement_is_unavailable( + monkeypatch: pytest.MonkeyPatch, +) -> None: + workspace_root = Path(tempfile.mkdtemp(prefix="unix-local-exec-")) + session = await UnixLocalSandboxClient().create( + manifest=_unix_local_manifest(root=str(workspace_root)), + options=None, + ) + unix_local = cast(Any, unix_local_module) + monkeypatch.setattr(unix_local.sys, "platform", "darwin") + monkeypatch.setattr(unix_local.shutil, "which", lambda _name: None) + + try: + with pytest.raises(ExecTransportError) as exc_info: + await session.exec("pwd") + finally: + shutil.rmtree(workspace_root, ignore_errors=True) + + assert exc_info.value.context["reason"] == "unix_local_confinement_unavailable" + + +@pytest.mark.asyncio +async def test_unix_local_exec_runs_without_wrapper_on_linux( + monkeypatch: pytest.MonkeyPatch, +) -> None: + workspace_root = Path(tempfile.mkdtemp(prefix="unix-local-exec-")) + session = await UnixLocalSandboxClient().create( + manifest=_unix_local_manifest(root=str(workspace_root)), + options=None, + ) + unix_local = cast(Any, unix_local_module) + monkeypatch.setattr(unix_local.sys, "platform", "linux") + + try: + async with session: + result = await session.exec("pwd") + finally: + shutil.rmtree(workspace_root, ignore_errors=True) + + assert result.ok() + assert result.stdout.decode("utf-8", errors="replace").strip() == str(workspace_root.resolve()) + + +def test_unix_local_confined_exec_command_allows_common_darwin_interpreter_roots( + monkeypatch: pytest.MonkeyPatch, +) -> None: + workspace_root = Path(tempfile.mkdtemp(prefix="unix-local-exec-")) + session = UnixLocalSandboxSession.from_state( + UnixLocalSandboxSessionState( + session_id=uuid.uuid4(), + manifest=_unix_local_manifest(root=str(workspace_root)), + snapshot=NoopSnapshot(id="darwin"), + workspace_root_owned=False, + ) + ) + unix_local = cast(Any, unix_local_module) + host_home = Path.home() + path_env = os.pathsep.join( + [ + "/opt/homebrew/bin", + "/usr/local/bin", + str(host_home / ".local" / "bin"), + ] + ) + + def _fake_which(name: str, path: str | None = None) -> str | None: + if name == "sandbox-exec": + return "/usr/bin/sandbox-exec" + if name == "python3": + assert path == path_env + return "/opt/homebrew/bin/python3" + return None + + monkeypatch.setattr(unix_local.sys, "platform", "darwin") + monkeypatch.setattr(unix_local.shutil, "which", _fake_which) + + command = session._confined_exec_command( + command_parts=["python3", "-V"], + workspace_root=workspace_root, + env={"PATH": path_env}, + ) + profile = command[2] + + assert command[:2] == ["/usr/bin/sandbox-exec", "-p"] + assert '(allow file-read-data file-read-metadata (subpath "/opt/homebrew"))' in profile + assert '(allow file-read-data file-read-metadata (subpath "/usr/local"))' in profile + assert ( + f'(allow file-read-data file-read-metadata (subpath "{host_home / ".local"}"))' in profile + ) + assert '(deny file-write* (subpath "/opt"))' in profile + assert '(allow file-write* (subpath "/opt/homebrew"))' not in profile + + +@pytest.mark.asyncio +async def test_sandbox_run_persists_only_new_session_input_items() -> None: + session = SimpleListSession( + history=[ + { + "role": "user", + "content": "old", + } + ] + ) + model = FakeModel(initial_output=[get_final_output_message("done")]) + agent = SandboxAgent( + name="sandbox", + model=model, + instructions="Base instructions.", + ) + + result = await Runner.run( + agent, + "new", + session=session, + run_config=_sandbox_run_config(_FakeClient(_FakeSession(Manifest()))), + ) + + assert result.final_output == "done" + saved_user_items = [ + item + for item in await session.get_items() + if isinstance(item, dict) and item.get("role") == "user" + ] + assert saved_user_items == [ + {"role": "user", "content": "old"}, + {"role": "user", "content": "new"}, + ] + + +@pytest.mark.asyncio +async def test_runner_streamed_emits_public_agent_for_tool_and_reasoning_events() -> None: + model = FakeModel() + model.add_multiple_turn_outputs( + [ + [ + _get_reasoning_item(), + get_function_tool_call("tool1", json.dumps({}), call_id="call_tool"), + ], + [get_final_output_message("done")], + ] + ) + agent = SandboxAgent( + name="sandbox", + model=model, + instructions="Base instructions.", + tools=[get_function_tool("tool1", "tool result")], + ) + + result = Runner.run_streamed( + agent, + "hello", + run_config=_sandbox_run_config(_FakeClient(_FakeSession(Manifest()))), + ) + events = [event async for event in result.stream_events()] + relevant_events = [ + event + for event in events + if isinstance(event, RunItemStreamEvent) + and event.name in {"reasoning_item_created", "tool_called", "tool_output"} + ] + + assert relevant_events + assert all(event.item.agent is agent for event in relevant_events) + + +def test_capability_clone_deep_copies_nested_mutable_state() -> None: + capability = _NestedStateCapability() + + cloned = cast(_NestedStateCapability, capability.clone()) + cloned.state["seen"].append("turn-1") + + assert capability.state == {"seen": []} + assert cloned.state == {"seen": ["turn-1"]} + + +def test_capability_clone_deep_copies_nested_object_state() -> None: + capability = _NestedObjectCapability() + + cloned = cast(_NestedObjectCapability, capability.clone()) + cloned.state.seen.append("turn-1") + + assert capability.state.seen == [] + assert cloned.state.seen == ["turn-1"] + + +@pytest.mark.asyncio +async def test_apply_manifest_raises_on_account_provisioning_failures() -> None: + session = _ProvisioningFailureSession( + Manifest(users=[User(name="sandbox-user")]), + ) + + with pytest.raises(ExecNonZeroError): + await session.apply_manifest() + + +@pytest.mark.asyncio +async def test_apply_manifest_only_ephemeral_skips_account_provisioning_failures() -> None: + session = _ProvisioningFailureSession( + Manifest(users=[User(name="sandbox-user")]), + ) + + result = await session.apply_manifest(only_ephemeral=True) + + assert result.files == [] + + +@pytest.mark.asyncio +async def test_resume_reprovisions_manifest_accounts_before_reapplying_ephemeral_entries() -> None: + session = _RestorableProvisioningFailureSession( + Manifest(users=[User(name="sandbox-user")]), + ) + + with pytest.raises(ExecNonZeroError): + await session.start() + + assert session.cleared_workspace_root is True + assert session.hydrate_calls == 1 + + +@pytest.mark.asyncio +async def test_resume_can_skip_manifest_account_reprovisioning_when_os_state_is_preserved() -> None: + session = _RestorableProvisioningFailureSession( + Manifest(users=[User(name="sandbox-user")]), + provision_on_resume=False, + ) + + await session.start() + + assert session.cleared_workspace_root is True + assert session.hydrate_calls == 1 + + +@pytest.mark.asyncio +async def test_prepare_agent_rechecks_session_liveness_before_reusing_cached_agent() -> None: + session = _FakeSession(Manifest()) + client = _FakeClient(session) + agent = SandboxAgent( + name="sandbox", + model=FakeModel(initial_output=[get_final_output_message("done")]), + instructions="Base instructions.", + ) + runtime = SandboxRuntime( + starting_agent=agent, + run_config=_sandbox_run_config(client), + run_state=None, + ) + context_wrapper = RunContextWrapper(context=None) + + first_prepared = await runtime.prepare_agent( + current_agent=agent, + current_input="hello", + context_wrapper=context_wrapper, + is_resumed_state=False, + ) + assert session.start_calls == 1 + + session._running = False + + second_prepared = await runtime.prepare_agent( + current_agent=agent, + current_input="hello again", + context_wrapper=context_wrapper, + is_resumed_state=False, + ) + + assert second_prepared.bindings.execution_agent is first_prepared.bindings.execution_agent + assert session.start_calls == 2 + + +@pytest.mark.asyncio +async def test_prepare_agent_starts_new_live_session_even_when_backend_reports_running() -> None: + session = _FakeSession(Manifest()) + session._running = True + client = _FakeClient(session) + agent = SandboxAgent( + name="sandbox", + model=FakeModel(initial_output=[get_final_output_message("done")]), + instructions="Base instructions.", + ) + runtime = SandboxRuntime( + starting_agent=agent, + run_config=_sandbox_run_config(client), + run_state=None, + ) + + await runtime.prepare_agent( + current_agent=agent, + current_input="hello", + context_wrapper=RunContextWrapper(context=None), + is_resumed_state=False, + ) + + assert session.start_calls == 1 + + +@pytest.mark.asyncio +async def test_runner_uses_public_agent_for_non_function_tool_outputs() -> None: + tool = LocalShellTool(executor=lambda _request: "shell result") + action = LocalShellCallAction( + command=["bash", "-lc", "echo sandbox"], + env={}, + type="exec", + timeout_ms=1000, + working_directory="/workspace", + ) + local_shell_call = LocalShellCall( + id="lsh_sandbox", + action=action, + call_id="call_local_shell", + status="completed", + type="local_shell_call", + ) + + model = FakeModel() + model.add_multiple_turn_outputs( + [ + [local_shell_call], + [get_final_output_message("done")], + ] + ) + + agent = SandboxAgent( + name="sandbox", + model=model, + instructions="Base instructions.", + tools=[tool], + ) + + result = await Runner.run( + agent, + "hello", + run_config=_sandbox_run_config(_FakeClient(_FakeSession(Manifest()))), + ) + + output_items = [ + item + for item in result.new_items + if isinstance(item, ToolCallOutputItem) + and isinstance(item.raw_item, dict) + and item.raw_item.get("type") == "local_shell_call_output" + ] + + assert output_items + assert all(item.agent is agent for item in output_items) + + +@pytest.mark.asyncio +async def test_sandbox_agent_as_tool_uses_runner_sandbox_prep() -> None: + child_model = FakeModel(initial_output=[get_final_output_message("child done")]) + parent_model = FakeModel( + initial_output=[ + get_function_tool_call("delegate_to_child", json.dumps({"input": "check sandbox"})) + ] + ) + parent_model.set_next_output([get_final_output_message("parent done")]) + + capability = _RecordingCapability(instruction_text="Use the sandbox carefully.") + manifest = Manifest(entries={"README.md": File(content=b"Use repo-safe commands only.")}) + session = _FakeSession(manifest) + client = _FakeClient(session) + + child = SandboxAgent( + name="child", + model=child_model, + instructions="Child base instructions.", + default_manifest=manifest, + capabilities=[capability], + ) + parent = Agent( + name="parent", + model=parent_model, + instructions="Parent instructions.", + tools=[child.as_tool("delegate_to_child", "Delegate to the sandbox child.")], + ) + + result = await Runner.run( + parent, + "hello", + run_config=_sandbox_run_config(client), + ) + + assert result.final_output == "parent done" + assert capability.bound_session is None + assert child_model.first_turn_args is not None + child_input = child_model.first_turn_args["input"] + assert isinstance(child_input, list) + assert _extract_user_text(child_input[0]) == "check sandbox" + + +@pytest.mark.asyncio +async def test_runner_reapplies_sandbox_prep_on_handoff() -> None: + triage_model = FakeModel() + worker_model = FakeModel(initial_output=[get_final_output_message("done")]) + manifest = Manifest(entries={"README.md": File(content=b"Shared repo instructions.")}) + session = _FakeSession(manifest) + client = _FakeClient(session) + + capability_one = _RecordingCapability(instruction_text="Triage capability.") + capability_two = _RecordingCapability(instruction_text="Worker capability.") + worker = SandboxAgent( + name="worker", + model=worker_model, + instructions="Worker instructions.", + default_manifest=manifest, + capabilities=[capability_two], + ) + triage = SandboxAgent( + name="triage", + model=triage_model, + instructions="Triage instructions.", + default_manifest=manifest, + capabilities=[capability_one], + handoffs=[worker], + ) + triage_model.turn_outputs = [[get_handoff_tool_call(worker)]] + + result = await Runner.run( + triage, + "route this", + run_config=_sandbox_run_config(client), + ) + + assert result.final_output == "done" + assert capability_one.bound_session is None + assert capability_two.bound_session is None + assert worker_model.first_turn_args is not None + assert worker_model.first_turn_args["system_instructions"] == ( + "Worker instructions.\n\nWorker capability." + ) + + +@pytest.mark.asyncio +async def test_runner_restores_sandbox_from_run_state() -> None: + model = FakeModel() + + @function_tool(name_override="approval_tool", needs_approval=True) + def approval_tool() -> str: + return "approved" + + manifest = Manifest(entries={"README.md": File(content=b"Resume with sandbox state.")}) + session = _FakeSession(manifest) + client = _FakeClient(session) + agent = SandboxAgent( + name="sandbox", + model=model, + instructions="Base instructions.", + tools=[approval_tool], + default_manifest=manifest, + ) + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("approval_tool", json.dumps({}), call_id="call_resume")], + [get_final_output_message("done")], + ] + ) + + first_run = await Runner.run( + agent, + "hello", + run_config=_sandbox_run_config(client), + ) + + assert len(first_run.interruptions) == 1 + state = first_run.to_state() + assert state._sandbox is not None + state.approve(first_run.interruptions[0]) + + resumed = await Runner.run( + agent, + state, + run_config=_sandbox_run_config(client), + ) + + assert resumed.final_output == "done" + assert client.resume_state is not None + + +@pytest.mark.asyncio +async def test_runner_rejects_concurrent_reuse_of_same_sandbox_agent() -> None: + model = FakeModel(initial_output=[get_final_output_message("done")]) + start_gate = asyncio.Event() + session = _FakeSession(Manifest(), start_gate=start_gate) + client = _FakeClient(session) + agent = SandboxAgent( + name="sandbox", + model=model, + instructions="Base instructions.", + ) + run_config = _sandbox_run_config(client) + + first_run = asyncio.create_task(Runner.run(agent, "hello", run_config=run_config)) + while session.start_calls == 0: + await asyncio.sleep(0) + + with pytest.raises(RuntimeError, match="cannot be reused concurrently"): + await Runner.run(agent, "again", run_config=run_config) + + start_gate.set() + result = await first_run + assert result.final_output == "done" + + +@pytest.mark.asyncio +async def test_runner_isolates_shared_capabilities_per_run() -> None: + release_gate = asyncio.Event() + first_instruction_started = asyncio.Event() + second_instruction_started = asyncio.Event() + shared_capability = _AwaitableSessionCapability( + release_gate=release_gate, + first_instruction_started=first_instruction_started, + second_instruction_started=second_instruction_started, + ) + + session_one = _FakeSession( + Manifest(entries={"README.md": File(content=b"Session one instructions.")}) + ) + session_two = _FakeSession( + Manifest(entries={"README.md": File(content=b"Session two instructions.")}) + ) + client_one = _FakeClient(session_one) + client_two = _FakeClient(session_two) + model_one = FakeModel(initial_output=[get_final_output_message("done one")]) + model_two = FakeModel(initial_output=[get_final_output_message("done two")]) + agent_one = SandboxAgent( + name="sandbox-one", + model=model_one, + instructions="Base instructions.", + capabilities=[shared_capability], + ) + agent_two = SandboxAgent( + name="sandbox-two", + model=model_two, + instructions="Base instructions.", + capabilities=[shared_capability], + ) + + first_run = asyncio.create_task( + Runner.run(agent_one, "hello one", run_config=_sandbox_run_config(client_one)) + ) + await first_instruction_started.wait() + + second_run = asyncio.create_task( + Runner.run(agent_two, "hello two", run_config=_sandbox_run_config(client_two)) + ) + await second_instruction_started.wait() + + release_gate.set() + first_result, second_result = await asyncio.gather(first_run, second_run) + + assert first_result.final_output == "done one" + assert second_result.final_output == "done two" + assert model_one.first_turn_args is not None + assert model_two.first_turn_args is not None + assert ( + model_one.first_turn_args["system_instructions"] + == "Base instructions.\n\nSession one instructions." + ) + assert ( + model_two.first_turn_args["system_instructions"] + == "Base instructions.\n\nSession two instructions." + ) + assert shared_capability.bound_session is None + + +@pytest.mark.asyncio +async def test_runner_deep_clones_capability_runtime_state() -> None: + model = FakeModel(initial_output=[get_final_output_message("done")]) + session = _FakeSession(Manifest(entries={"README.md": File(content=b"hello")})) + client = _FakeClient(session) + + class _MutableCapability(Capability): + def __init__(self) -> None: + super().__init__(type="mutable") + self.bound_labels: list[str] = [] + + def bind(self, session: BaseSandboxSession) -> None: + readme = session.state.manifest.entries["README.md"] + assert isinstance(readme, File) + self.bound_labels.append(readme.content.decode()) + + capability = _MutableCapability() + agent = SandboxAgent( + name="sandbox", + model=model, + instructions="Base instructions.", + capabilities=[capability], + ) + + result = await Runner.run(agent, "hello", run_config=_sandbox_run_config(client)) + + assert result.final_output == "done" + assert capability.bound_labels == [] + + +@pytest.mark.asyncio +async def test_runner_keeps_public_agent_identity_for_hooks_and_streaming() -> None: + model = FakeModel(initial_output=[get_final_output_message("done")]) + session = _FakeSession(Manifest()) + client = _FakeClient(session) + run_hooks = _RecordingRunHooks() + agent_hooks = _RecordingAgentHooks() + agent = SandboxAgent( + name="sandbox", + model=model, + instructions="Base instructions.", + hooks=agent_hooks, + capabilities=[_RecordingCapability(instruction_text="Capability instructions.")], + ) + + result = await Runner.run( + agent, + "hello", + run_config=_sandbox_run_config(client), + hooks=run_hooks, + ) + + assert result.last_agent is agent + assert run_hooks.started_agents == [agent] + assert run_hooks.ended_agents == [agent] + assert run_hooks.llm_started_agents == [agent] + assert run_hooks.llm_ended_agents == [agent] + assert agent_hooks.started_agents == [agent] + assert agent_hooks.ended_agents == [agent] + assert agent_hooks.llm_started_agents == [agent] + assert agent_hooks.llm_ended_agents == [agent] + assert all(item.agent is agent for item in result.new_items) + + streamed_model = FakeModel(initial_output=[get_final_output_message("streamed done")]) + streamed_session = _FakeSession(Manifest()) + streamed_client = _FakeClient(streamed_session) + streamed_run_hooks = _RecordingRunHooks() + streamed_agent_hooks = _RecordingAgentHooks() + streamed_agent = SandboxAgent( + name="streamed-sandbox", + model=streamed_model, + instructions="Base instructions.", + hooks=streamed_agent_hooks, + capabilities=[_RecordingCapability(instruction_text="Capability instructions.")], + ) + + streamed_result = Runner.run_streamed( + streamed_agent, + "hello", + run_config=_sandbox_run_config(streamed_client), + hooks=streamed_run_hooks, + ) + streamed_events = [event async for event in streamed_result.stream_events()] + run_item_events = [event for event in streamed_events if isinstance(event, RunItemStreamEvent)] + + assert streamed_result.current_agent is streamed_agent + assert streamed_run_hooks.started_agents == [streamed_agent] + assert streamed_run_hooks.ended_agents == [streamed_agent] + assert streamed_run_hooks.llm_started_agents == [streamed_agent] + assert streamed_run_hooks.llm_ended_agents == [streamed_agent] + assert streamed_agent_hooks.started_agents == [streamed_agent] + assert streamed_agent_hooks.ended_agents == [streamed_agent] + assert streamed_agent_hooks.llm_started_agents == [streamed_agent] + assert streamed_agent_hooks.llm_ended_agents == [streamed_agent] + assert all(item.agent is streamed_agent for item in streamed_result.new_items) + assert run_item_events + assert all(event.item.agent is streamed_agent for event in run_item_events) diff --git a/tests/test_sandbox_session_manager.py b/tests/test_sandbox_session_manager.py new file mode 100644 index 0000000000..c7fc424a12 --- /dev/null +++ b/tests/test_sandbox_session_manager.py @@ -0,0 +1,231 @@ +from __future__ import annotations + +import asyncio +import uuid +from pathlib import Path + +import pytest + +from agents.sandbox.manifest import Manifest +from agents.sandbox.runtime_session_manager import SandboxRuntimeSessionManager +from agents.sandbox.sandboxes.unix_local import ( + UnixLocalSandboxSession, + UnixLocalSandboxSessionState, +) +from agents.sandbox.session import ( + CallbackSink, + EventPayloadPolicy, + Instrumentation, + UCEvent, + UCFinishEvent, +) +from agents.sandbox.session.sinks import ChainedSink, EventSink +from agents.sandbox.snapshot import LocalSnapshot, LocalSnapshotSpec, NoopSnapshotSpec + + +class _EventSink(EventSink): + def __init__(self, *, mode: str, on_error: str = "raise") -> None: + self.mode = mode # type: ignore[assignment] + self.on_error = on_error # type: ignore[assignment] + self.payload_policy = None + + async def handle(self, event: UCEvent) -> None: # pragma: no cover + _ = event + raise NotImplementedError + + +def _build_session(tmp_path: Path) -> UnixLocalSandboxSession: + state = UnixLocalSandboxSessionState( + manifest=Manifest(root=str(tmp_path / "workspace")), + snapshot=LocalSnapshot(id="x", base_path=tmp_path), + ) + return UnixLocalSandboxSession.from_state(state) + + +@pytest.mark.asyncio +async def test_instrumentation_per_op_policy_overrides_default(tmp_path: Path) -> None: + events: list[UCEvent] = [] + session = _build_session(tmp_path) + sink = CallbackSink(lambda event, _session: events.append(event), mode="sync") + sink.bind(session) + instrumentation = Instrumentation( + sinks=[sink], + payload_policy=EventPayloadPolicy(include_exec_output=False), + payload_policy_by_op={"exec": EventPayloadPolicy(include_exec_output=True)}, + ) + + event = UCFinishEvent( + session_id=uuid.uuid4(), + seq=1, + op="exec", + span_id=uuid.uuid4(), + ok=True, + duration_ms=0.0, + ) + event.stdout_bytes = b"hello" + event.stderr_bytes = b"" + + await instrumentation.emit(event) + + assert isinstance(events[0], UCFinishEvent) + assert events[0].stdout == "hello" + + +@pytest.mark.asyncio +async def test_instrumentation_per_sink_policy_overrides_per_op(tmp_path: Path) -> None: + first: list[UCEvent] = [] + second: list[UCEvent] = [] + session = _build_session(tmp_path) + sink_a = CallbackSink(lambda event, _session: first.append(event), mode="sync") + sink_b = CallbackSink( + lambda event, _session: second.append(event), + mode="sync", + payload_policy=EventPayloadPolicy(include_exec_output=True), + ) + sink_a.bind(session) + sink_b.bind(session) + + instrumentation = Instrumentation( + sinks=[sink_a, sink_b], + payload_policy=EventPayloadPolicy(include_exec_output=False), + payload_policy_by_op={"exec": EventPayloadPolicy(include_exec_output=False)}, + ) + + event = UCFinishEvent( + session_id=uuid.uuid4(), + seq=1, + op="exec", + span_id=uuid.uuid4(), + ok=True, + duration_ms=0.0, + ) + event.stdout_bytes = b"hello" + event.stderr_bytes = b"" + + await instrumentation.emit(event) + + assert isinstance(first[0], UCFinishEvent) + assert isinstance(second[0], UCFinishEvent) + assert first[0].stdout is None + assert second[0].stdout == "hello" + + +@pytest.mark.asyncio +async def test_instrumentation_redacts_raw_exec_bytes_when_output_disabled( + tmp_path: Path, +) -> None: + events: list[UCEvent] = [] + session = _build_session(tmp_path) + sink = CallbackSink(lambda event, _session: events.append(event), mode="sync") + sink.bind(session) + instrumentation = Instrumentation( + sinks=[sink], + payload_policy=EventPayloadPolicy(include_exec_output=False), + ) + + event = UCFinishEvent( + session_id=uuid.uuid4(), + seq=1, + op="exec", + span_id=uuid.uuid4(), + ok=True, + duration_ms=0.0, + ) + event.stdout_bytes = b"secret" + event.stderr_bytes = b"secret2" + + await instrumentation.emit(event) + + assert isinstance(events[0], UCFinishEvent) + assert events[0].stdout_bytes is None + assert events[0].stderr_bytes is None + + +@pytest.mark.asyncio +async def test_chained_sink_preserves_completion_order_across_modes() -> None: + completed = asyncio.Event() + + class SlowBestEffortSink(_EventSink): + async def handle(self, event: UCEvent) -> None: + _ = event + await asyncio.sleep(0) + completed.set() + + class AssertAfterSink(_EventSink): + async def handle(self, event: UCEvent) -> None: + _ = event + assert completed.is_set(), "later sink ran before earlier sink completed" + + sink_a = SlowBestEffortSink(mode="best_effort", on_error="raise") + sink_b = AssertAfterSink(mode="sync", on_error="raise") + instrumentation = Instrumentation(sinks=[ChainedSink(sink_a, sink_b)]) + + event = UCFinishEvent( + session_id=uuid.uuid4(), + seq=1, + op="running", + span_id=uuid.uuid4(), + ok=True, + duration_ms=0.0, + ) + await instrumentation.emit(event) + + +@pytest.mark.asyncio +async def test_async_sink_raise_propagates_to_emit() -> None: + class _FailingAsyncSink(_EventSink): + async def handle(self, event: UCEvent) -> None: + _ = event + await asyncio.sleep(0) + raise RuntimeError("boom") + + instrumentation = Instrumentation(sinks=[_FailingAsyncSink(mode="async", on_error="raise")]) + event = UCFinishEvent( + session_id=uuid.uuid4(), + seq=1, + op="running", + span_id=uuid.uuid4(), + ok=True, + duration_ms=0.0, + ) + + with pytest.raises(RuntimeError, match="boom"): + await instrumentation.emit(event) + + +def test_session_manager_uses_custom_snapshot_spec_without_resolving_default( + monkeypatch: pytest.MonkeyPatch, +) -> None: + called = False + + def _unexpected_default_resolution() -> LocalSnapshotSpec: + nonlocal called + called = True + raise AssertionError("default snapshot resolution should not run") + + monkeypatch.setattr( + "agents.sandbox.runtime_session_manager.resolve_default_local_snapshot_spec", + _unexpected_default_resolution, + ) + + custom = LocalSnapshotSpec(base_path=Path("/tmp/custom-sandbox-snapshots")) + resolved = SandboxRuntimeSessionManager._resolve_snapshot_spec(custom) + + assert resolved is custom + assert called is False + + +def test_session_manager_falls_back_to_noop_when_default_snapshot_resolution_fails( + monkeypatch: pytest.MonkeyPatch, +) -> None: + def _raise_os_error() -> LocalSnapshotSpec: + raise OSError("read-only home") + + monkeypatch.setattr( + "agents.sandbox.runtime_session_manager.resolve_default_local_snapshot_spec", + _raise_os_error, + ) + + resolved = SandboxRuntimeSessionManager._resolve_snapshot_spec(None) + + assert isinstance(resolved, NoopSnapshotSpec) diff --git a/tests/test_sandbox_session_sinks.py b/tests/test_sandbox_session_sinks.py new file mode 100644 index 0000000000..ef74877235 --- /dev/null +++ b/tests/test_sandbox_session_sinks.py @@ -0,0 +1,376 @@ +from __future__ import annotations + +import asyncio +import io +import json +import tarfile +import uuid +from pathlib import Path + +import pytest + +from agents.sandbox.entries import Dir, File +from agents.sandbox.manifest import Manifest +from agents.sandbox.sandboxes.unix_local import ( + UnixLocalSandboxSession, + UnixLocalSandboxSessionState, +) +from agents.sandbox.session import ( + CallbackSink, + ChainedSink, + EventPayloadPolicy, + Instrumentation, + JsonlOutboxSink, + SandboxSession, + UCEvent, + UCFinishEvent, + UCStartEvent, + WorkspaceJsonlSink, +) +from agents.sandbox.session.base_sandbox_session import BaseSandboxSession +from agents.sandbox.snapshot import LocalSnapshot + + +def _build_unix_local_session( + tmp_path: Path, + *, + manifest: Manifest | None = None, +) -> UnixLocalSandboxSession: + workspace = tmp_path / "workspace" + snapshot = LocalSnapshot(id=str(uuid.uuid4()), base_path=tmp_path) + session_manifest = ( + manifest.model_copy(update={"root": str(workspace)}, deep=True) + if manifest is not None + else Manifest(root=str(workspace)) + ) + state = UnixLocalSandboxSessionState( + manifest=session_manifest, + snapshot=snapshot, + ) + return UnixLocalSandboxSession.from_state(state) + + +@pytest.mark.asyncio +async def test_sandbox_session_exec_emits_stdout_when_enabled(tmp_path: Path) -> None: + events: list[UCEvent] = [] + instrumentation = Instrumentation( + sinks=[CallbackSink(lambda e, _sess: events.append(e), mode="sync")], + payload_policy=EventPayloadPolicy(include_exec_output=True), + ) + + inner = _build_unix_local_session(tmp_path) + async with SandboxSession(inner, instrumentation=instrumentation) as session: + result = await session.exec("echo hi") + assert result.ok() + + exec_finish = [event for event in events if event.op == "exec" and event.phase == "finish"][0] + assert isinstance(exec_finish, UCFinishEvent) + assert exec_finish.stdout is not None + assert "hi" in exec_finish.stdout + + +@pytest.mark.asyncio +async def test_sandbox_session_write_does_not_include_bytes_when_disabled( + tmp_path: Path, +) -> None: + events: list[UCEvent] = [] + instrumentation = Instrumentation( + sinks=[CallbackSink(lambda e, _sess: events.append(e), mode="sync")], + payload_policy=EventPayloadPolicy(include_write_len=False), + ) + + inner = _build_unix_local_session(tmp_path) + async with SandboxSession(inner, instrumentation=instrumentation) as session: + await session.write(Path("x.txt"), io.BytesIO(b"hello")) + + write_start = [event for event in events if event.op == "write" and event.phase == "start"][0] + assert "bytes" not in write_start.data + + +@pytest.mark.asyncio +async def test_jsonl_outbox_sink_appends_one_line_per_event(tmp_path: Path) -> None: + outbox = tmp_path / "events.jsonl" + sink = JsonlOutboxSink(outbox, mode="sync", on_error="raise") + + start_event = UCStartEvent( + session_id=uuid.uuid4(), + seq=1, + op="write", + span_id=uuid.uuid4(), + ) + finish_event = UCFinishEvent( + session_id=start_event.session_id, + seq=2, + op="write", + span_id=start_event.span_id, + ok=True, + duration_ms=0.0, + ) + + await sink.handle(start_event) + await sink.handle(finish_event) + + lines = outbox.read_text(encoding="utf-8").splitlines() + assert len(lines) == 2 + assert json.loads(lines[0])["phase"] == "start" + assert json.loads(lines[1])["phase"] == "finish" + + +@pytest.mark.asyncio +async def test_chained_sink_runs_in_order(tmp_path: Path) -> None: + outbox = tmp_path / "events.jsonl" + seen: list[int] = [] + + def _callback(_event: UCEvent, _session: BaseSandboxSession) -> None: + seen.append(len(outbox.read_text(encoding="utf-8").splitlines())) + + inner = _build_unix_local_session(tmp_path) + callback_sink = CallbackSink(_callback, mode="sync") + callback_sink.bind(inner) + + instrumentation = Instrumentation( + sinks=[ + ChainedSink( + JsonlOutboxSink(outbox, mode="sync", on_error="raise"), + callback_sink, + ) + ] + ) + + start_event = UCStartEvent( + session_id=uuid.uuid4(), + seq=1, + op="write", + span_id=uuid.uuid4(), + ) + finish_event = UCFinishEvent( + session_id=start_event.session_id, + seq=2, + op="write", + span_id=start_event.span_id, + ok=True, + duration_ms=0.0, + ) + + await instrumentation.emit(start_event) + await instrumentation.emit(finish_event) + + assert seen == [1, 2] + + +@pytest.mark.asyncio +async def test_workspace_jsonl_sink_writes_into_workspace_and_persists(tmp_path: Path) -> None: + inner = _build_unix_local_session(tmp_path) + instrumentation = Instrumentation( + sinks=[WorkspaceJsonlSink(mode="sync", on_error="raise", ephemeral=False)] + ) + wrapped = SandboxSession(inner, instrumentation=instrumentation) + + async with wrapped as session: + await session.exec("echo hi") + + outbox_stream = await inner.read(Path(f"logs/events-{inner.state.session_id}.jsonl")) + lines = outbox_stream.read().decode("utf-8").splitlines() + assert any(json.loads(line)["op"] == "exec" for line in lines) + + snapshot_path = tmp_path / f"{inner.state.snapshot.id}.tar" + with tarfile.open(snapshot_path, mode="r:*") as tar: + names = [member.name for member in tar.getmembers()] + assert any(f"logs/events-{inner.state.session_id}.jsonl" in name for name in names) + + +@pytest.mark.asyncio +async def test_workspace_jsonl_sink_supports_session_id_template(tmp_path: Path) -> None: + inner = _build_unix_local_session(tmp_path) + relpath = Path("logs/events-{session_id}.jsonl") + instrumentation = Instrumentation( + sinks=[ + WorkspaceJsonlSink( + mode="sync", + on_error="raise", + ephemeral=False, + workspace_relpath=relpath, + ) + ] + ) + wrapped = SandboxSession(inner, instrumentation=instrumentation) + + async with wrapped as session: + await session.exec("echo hi") + + expected_path = Path(f"logs/events-{inner.state.session_id}.jsonl") + outbox_stream = await inner.read(expected_path) + lines = outbox_stream.read().decode("utf-8").splitlines() + assert any(json.loads(line)["op"] == "exec" for line in lines) + + +@pytest.mark.asyncio +async def test_workspace_jsonl_sink_preserves_preexisting_outbox_contents(tmp_path: Path) -> None: + inner = _build_unix_local_session(tmp_path) + relpath = Path(f"logs/events-{inner.state.session_id}.jsonl") + old_line = b'{"old":true}\n' + + async with inner: + await inner.write(relpath, io.BytesIO(old_line)) + sink = WorkspaceJsonlSink(mode="sync", on_error="raise", ephemeral=False) + sink.bind(inner) + + start = UCStartEvent( + session_id=inner.state.session_id, + seq=1, + op="write", + span_id=uuid.uuid4(), + ) + finish = UCFinishEvent( + session_id=inner.state.session_id, + seq=2, + op="write", + span_id=start.span_id, + ok=True, + duration_ms=0.0, + ) + + await sink.handle(start) + await sink.handle(finish) + + outbox_stream = await inner.read(relpath) + lines = outbox_stream.read().decode("utf-8").splitlines() + + assert len(lines) == 3 + assert json.loads(lines[0]) == {"old": True} + assert json.loads(lines[1])["seq"] == 1 + assert json.loads(lines[2])["seq"] == 2 + + +@pytest.mark.asyncio +async def test_workspace_jsonl_sink_does_not_duplicate_lines_across_flushes( + tmp_path: Path, +) -> None: + inner = _build_unix_local_session(tmp_path) + relpath = Path(f"logs/events-{inner.state.session_id}.jsonl") + + async with inner: + sink = WorkspaceJsonlSink(mode="sync", on_error="raise", ephemeral=False, flush_every=1) + sink.bind(inner) + + for seq in (1, 2, 3): + await sink.handle( + UCStartEvent( + session_id=inner.state.session_id, + seq=seq, + op="write", + span_id=uuid.uuid4(), + ) + ) + + outbox_stream = await inner.read(relpath) + lines = outbox_stream.read().decode("utf-8").splitlines() + + assert [json.loads(line)["seq"] for line in lines] == [1, 2, 3] + + +@pytest.mark.asyncio +async def test_workspace_jsonl_sink_ephemeral_excludes_runtime_outbox_with_existing_parent( + tmp_path: Path, +) -> None: + inner = _build_unix_local_session( + tmp_path, + manifest=Manifest( + entries={ + "logs": Dir( + children={ + "keep.txt": File(content=b"keep"), + } + ) + } + ), + ) + instrumentation = Instrumentation( + sinks=[WorkspaceJsonlSink(mode="sync", on_error="raise", ephemeral=True)] + ) + wrapped = SandboxSession(inner, instrumentation=instrumentation) + + async with wrapped as session: + await session.exec("echo hi") + relpath = Path(f"logs/events-{inner.state.session_id}.jsonl") + outbox_stream = await inner.read(relpath) + assert outbox_stream.read() + + logs_entry = inner.state.manifest.entries["logs"] + assert isinstance(logs_entry, Dir) + assert {str(child) for child in logs_entry.children.keys()} == {"keep.txt"} + + snapshot_path = tmp_path / f"{inner.state.snapshot.id}.tar" + with tarfile.open(snapshot_path, mode="r:*") as tar: + names = [member.name for member in tar.getmembers()] + assert any(name.endswith("logs/keep.txt") for name in names) + assert not any(f"logs/events-{inner.state.session_id}.jsonl" in name for name in names) + + +@pytest.mark.asyncio +async def test_workspace_jsonl_sink_flushes_on_stop_when_flush_every_gt_one( + tmp_path: Path, +) -> None: + inner = _build_unix_local_session(tmp_path) + instrumentation = Instrumentation( + sinks=[ + WorkspaceJsonlSink( + mode="sync", + on_error="raise", + ephemeral=False, + flush_every=10, + ) + ] + ) + wrapped = SandboxSession(inner, instrumentation=instrumentation) + + async with wrapped as session: + await session.exec("echo hi") + + outbox_stream = await inner.read(Path(f"logs/events-{inner.state.session_id}.jsonl")) + lines = outbox_stream.read().decode("utf-8").splitlines() + assert lines + + snapshot_path = tmp_path / f"{inner.state.snapshot.id}.tar" + with tarfile.open(snapshot_path, mode="r:*") as tar: + names = [member.name for member in tar.getmembers()] + assert any(f"logs/events-{inner.state.session_id}.jsonl" in name for name in names) + + +@pytest.mark.asyncio +async def test_callback_sink_receives_bound_inner_session(tmp_path: Path) -> None: + inner = _build_unix_local_session(tmp_path) + seen: list[tuple[str, BaseSandboxSession]] = [] + + def _callback(event: UCEvent, session: BaseSandboxSession) -> None: + seen.append((event.op, session)) + + instrumentation = Instrumentation(sinks=[CallbackSink(_callback, mode="sync")]) + wrapped = SandboxSession(inner, instrumentation=instrumentation) + + async with wrapped as session: + await session.exec("echo hi") + + assert seen + assert all(session is inner for _op, session in seen) + + +@pytest.mark.asyncio +async def test_sandbox_session_aclose_flushes_best_effort_sink_tasks(tmp_path: Path) -> None: + inner = _build_unix_local_session(tmp_path) + seen: list[tuple[str, str]] = [] + + async def _callback(event: UCEvent, _session: BaseSandboxSession) -> None: + await asyncio.sleep(0) + seen.append((event.op, event.phase)) + + instrumentation = Instrumentation( + sinks=[CallbackSink(_callback, mode="best_effort", on_error="log")] + ) + wrapped = SandboxSession(inner, instrumentation=instrumentation) + + await wrapped.start() + await wrapped.aclose() + + assert ("stop", "finish") in seen + assert ("shutdown", "finish") in seen diff --git a/tests/test_sandbox_session_utils.py b/tests/test_sandbox_session_utils.py new file mode 100644 index 0000000000..dc67c81b86 --- /dev/null +++ b/tests/test_sandbox_session_utils.py @@ -0,0 +1,406 @@ +from __future__ import annotations + +import io +import shlex +import uuid +from pathlib import Path + +import pytest + +from agents.sandbox.entries.codex import resolve_codex_target_triple_for_target +from agents.sandbox.errors import UnsupportedCodexTargetError +from agents.sandbox.files import EntryKind, FileEntry +from agents.sandbox.manifest import Manifest +from agents.sandbox.session import UCStartEvent +from agents.sandbox.session.base_sandbox_session import BaseSandboxSession +from agents.sandbox.session.events import UCFinishEvent +from agents.sandbox.session.sandbox_session_state import SandboxSessionState +from agents.sandbox.session.utils import ( + _best_effort_stream_len, + _safe_decode, + event_to_json_line, +) +from agents.sandbox.snapshot import NoopSnapshot +from agents.sandbox.types import ExecResult, Permissions + + +class _CaptureExecSession(BaseSandboxSession): + def __init__(self) -> None: + self.state = SandboxSessionState( + manifest=Manifest(), + snapshot=NoopSnapshot(id="noop"), + ) + self.last_command: tuple[str, ...] | None = None + + async def _exec_internal( + self, + *command: str | Path, + timeout: float | None = None, + ) -> ExecResult: + _ = timeout + self.last_command = tuple(str(part) for part in command) + return ExecResult(stdout=b"", stderr=b"", exit_code=0) + + async def read(self, path: Path) -> io.IOBase: + _ = path + raise AssertionError("read() should not be called in this test") + + async def write(self, path: Path, data: io.IOBase) -> None: + _ = (path, data) + raise AssertionError("write() should not be called in this test") + + async def running(self) -> bool: + return True + + async def persist_workspace(self) -> io.IOBase: + return io.BytesIO() + + async def hydrate_workspace(self, data: io.IOBase) -> None: + _ = data + + async def shutdown(self) -> None: + return + + +class _ScriptedExecSession(BaseSandboxSession): + def __init__(self, responses: dict[tuple[str, ...], list[ExecResult] | ExecResult]) -> None: + self.state = SandboxSessionState( + manifest=Manifest(), + snapshot=NoopSnapshot(id="noop"), + ) + self.responses: dict[tuple[str, ...], list[ExecResult]] = {} + for command, response in responses.items(): + if isinstance(response, ExecResult): + self.responses[command] = [response] + else: + self.responses[command] = list(response) + + async def _exec_internal( + self, + *command: str | Path, + timeout: float | None = None, + ) -> ExecResult: + _ = timeout + key = tuple(str(part) for part in command) + if key not in self.responses or not self.responses[key]: + return ExecResult(stdout=b"", stderr=b"", exit_code=1) + return self.responses[key].pop(0) + + async def read(self, path: Path) -> io.IOBase: + _ = path + raise AssertionError("read() should not be called in this test") + + async def write(self, path: Path, data: io.IOBase) -> None: + _ = (path, data) + raise AssertionError("write() should not be called in this test") + + async def running(self) -> bool: + return True + + async def persist_workspace(self) -> io.IOBase: + return io.BytesIO() + + async def hydrate_workspace(self, data: io.IOBase) -> None: + _ = data + + async def shutdown(self) -> None: + return + + +def test_safe_decode_truncates_and_appends_ellipsis() -> None: + assert _safe_decode(b"abcdef", max_chars=3) == "abc…" + + +def test_best_effort_stream_len_tracks_remaining_bytes_for_seekable_streams() -> None: + buffer = io.BytesIO(b"hello") + assert _best_effort_stream_len(buffer) == 5 + assert buffer.read(1) == b"h" + assert _best_effort_stream_len(buffer) == 4 + + +class _NoSeekableMethodStream(io.IOBase): + def __init__(self, payload: bytes) -> None: + self._buffer = io.BytesIO(payload) + + def tell(self) -> int: + return self._buffer.tell() + + def seek(self, offset: int, whence: int = io.SEEK_SET) -> int: + return self._buffer.seek(offset, whence) + + +def test_best_effort_stream_len_handles_streams_without_seekable_method() -> None: + stream = _NoSeekableMethodStream(b"hello") + + assert _best_effort_stream_len(stream) == 5 + stream.seek(2) + assert _best_effort_stream_len(stream) == 3 + + +def test_event_to_json_line_is_single_line() -> None: + event = UCStartEvent( + session_id=uuid.uuid4(), + seq=1, + op="write", + span_id=uuid.uuid4(), + data={"x": 1}, + ) + + line = event_to_json_line(event) + assert line.endswith("\n") + assert "\n" not in line[:-1] + + +def test_uc_finish_event_excludes_raw_bytes_from_json_dump() -> None: + event = UCFinishEvent( + session_id=uuid.uuid4(), + seq=1, + op="exec", + span_id=uuid.uuid4(), + ok=True, + duration_ms=0.0, + ) + event.stdout_bytes = b"secret" + event.stderr_bytes = b"secret2" + + dumped = event.model_dump(mode="json") + assert "stdout_bytes" not in dumped + assert "stderr_bytes" not in dumped + + +def test_file_entry_is_dir_uses_kind() -> None: + directory_entry = FileEntry( + path="/workspace/dir", + permissions=Permissions.from_str("drwxr-xr-x"), + owner="root", + group="root", + size=0, + kind=EntryKind.DIRECTORY, + ) + file_entry = FileEntry( + path="/workspace/file.txt", + permissions=Permissions.from_str("-rw-r--r--"), + owner="root", + group="root", + size=3, + kind=EntryKind.FILE, + ) + + assert directory_entry.is_dir() is True + assert file_entry.is_dir() is False + + +@pytest.mark.asyncio +async def test_exec_shell_true_quotes_multi_arg_commands() -> None: + session = _CaptureExecSession() + + await session.exec("printf", "%s\n", "hello world", "$(whoami)", "semi;colon", shell=True) + + assert session.last_command == ( + "sh", + "-lc", + shlex.join(["printf", "%s\n", "hello world", "$(whoami)", "semi;colon"]), + ) + + +@pytest.mark.asyncio +async def test_exec_shell_true_preserves_single_shell_snippet() -> None: + session = _CaptureExecSession() + + await session.exec("echo hello && echo goodbye", shell=True) + + assert session.last_command == ("sh", "-lc", "echo hello && echo goodbye") + + +@pytest.mark.asyncio +async def test_resolve_codex_github_asset_name_linux_gnu() -> None: + session = _ScriptedExecSession( + { + ("uname", "-s"): ExecResult(stdout=b"Linux\n", stderr=b"", exit_code=0), + ("uname", "-m"): ExecResult(stdout=b"x86_64\n", stderr=b"", exit_code=0), + ("getconf", "GNU_LIBC_VERSION"): ExecResult( + stdout=b"glibc 2.39\n", + stderr=b"", + exit_code=0, + ), + } + ) + + assert ( + await session.resolve_codex_github_asset_name() == "codex-x86_64-unknown-linux-gnu.tar.gz" + ) + + +@pytest.mark.asyncio +async def test_resolve_codex_github_asset_name_linux_musl() -> None: + session = _ScriptedExecSession( + { + ("uname", "-s"): ExecResult(stdout=b"Linux\n", stderr=b"", exit_code=0), + ("uname", "-m"): ExecResult(stdout=b"amd64\n", stderr=b"", exit_code=0), + ("getconf", "GNU_LIBC_VERSION"): ExecResult(stdout=b"", stderr=b"", exit_code=1), + ("ldd", "--version"): ExecResult( + stdout=b"", + stderr=b"musl libc (x86_64)\n", + exit_code=1, + ), + } + ) + + assert ( + await session.resolve_codex_github_asset_name() == "codex-x86_64-unknown-linux-musl.tar.gz" + ) + + +@pytest.mark.asyncio +async def test_resolve_codex_github_asset_name_linux_aarch64_gnu() -> None: + session = _ScriptedExecSession( + { + ("uname", "-s"): ExecResult(stdout=b"Linux\n", stderr=b"", exit_code=0), + ("uname", "-m"): ExecResult(stdout=b"aarch64\n", stderr=b"", exit_code=0), + ("getconf", "GNU_LIBC_VERSION"): ExecResult( + stdout=b"glibc 2.39\n", + stderr=b"", + exit_code=0, + ), + } + ) + + assert ( + await session.resolve_codex_github_asset_name() == "codex-aarch64-unknown-linux-gnu.tar.gz" + ) + + +@pytest.mark.asyncio +async def test_resolve_codex_github_asset_name_darwin() -> None: + session = _ScriptedExecSession( + { + ("uname", "-s"): ExecResult(stdout=b"Darwin\n", stderr=b"", exit_code=0), + ("uname", "-m"): ExecResult(stdout=b"x86_64\n", stderr=b"", exit_code=0), + } + ) + + assert await session.resolve_codex_github_asset_name() == "codex-x86_64-apple-darwin.tar.gz" + + +@pytest.mark.asyncio +async def test_resolve_codex_github_asset_name_darwin_arm64() -> None: + session = _ScriptedExecSession( + { + ("uname", "-s"): ExecResult(stdout=b"Darwin\n", stderr=b"", exit_code=0), + ("uname", "-m"): ExecResult(stdout=b"arm64\n", stderr=b"", exit_code=0), + } + ) + + assert await session.resolve_codex_github_asset_name() == "codex-aarch64-apple-darwin.tar.gz" + + +@pytest.mark.asyncio +async def test_resolve_codex_github_asset_name_windows() -> None: + session = _ScriptedExecSession( + { + ("uname", "-s"): ExecResult(stdout=b"", stderr=b"", exit_code=1), + ("cmd", "/c", "echo", "%OS%"): ExecResult( + stdout=b"Windows_NT\r\n", + stderr=b"", + exit_code=0, + ), + ("cmd", "/c", "echo", "%PROCESSOR_ARCHITECTURE%"): ExecResult( + stdout=b"AMD64\r\n", + stderr=b"", + exit_code=0, + ), + } + ) + + assert ( + await session.resolve_codex_github_asset_name() == "codex-x86_64-pc-windows-msvc.exe.tar.gz" + ) + + +@pytest.mark.asyncio +async def test_resolve_codex_github_asset_name_windows_arm64() -> None: + session = _ScriptedExecSession( + { + ("uname", "-s"): ExecResult(stdout=b"", stderr=b"", exit_code=1), + ("cmd", "/c", "echo", "%OS%"): ExecResult( + stdout=b"Windows_NT\r\n", + stderr=b"", + exit_code=0, + ), + ("cmd", "/c", "echo", "%PROCESSOR_ARCHITECTURE%"): ExecResult( + stdout=b"ARM64\r\n", + stderr=b"", + exit_code=0, + ), + } + ) + + assert ( + await session.resolve_codex_github_asset_name() + == "codex-aarch64-pc-windows-msvc.exe.tar.gz" + ) + + +def test_resolve_codex_target_triple_reports_unsupported_os() -> None: + with pytest.raises( + UnsupportedCodexTargetError, + match=( + "Unsupported Codex target operating system: freebsd. " + "Available operating systems: linux, darwin, windows." + ), + ) as exc_info: + resolve_codex_target_triple_for_target( + target_os="freebsd", + target_arch="x86_64", + ) + + assert exc_info.value.reason == "operating_system" + assert exc_info.value.target_os == "freebsd" + assert exc_info.value.supported_operating_systems == ("linux", "darwin", "windows") + + +def test_resolve_codex_target_triple_reports_unsupported_architecture() -> None: + with pytest.raises( + UnsupportedCodexTargetError, + match=( + "Unsupported Codex target architecture for darwin: ppc64le. " + "Available architectures: x86_64, aarch64." + ), + ) as exc_info: + resolve_codex_target_triple_for_target( + target_os="darwin", + target_arch="ppc64le", + ) + + assert exc_info.value.reason == "architecture" + assert exc_info.value.target_arch == "ppc64le" + assert exc_info.value.supported_architectures == ("x86_64", "aarch64") + + +def test_resolve_codex_target_triple_normalizes_arm_aliases() -> None: + assert ( + resolve_codex_target_triple_for_target( + target_os="darwin", + target_arch="arm64", + ) + == "aarch64-apple-darwin" + ) + + +def test_resolve_codex_target_triple_reports_unsupported_linux_libc() -> None: + with pytest.raises( + UnsupportedCodexTargetError, + match=( + "Unsupported Linux libc variant for Codex target resolution: uclibc. " + "Available libc variants: gnu, musl." + ), + ) as exc_info: + resolve_codex_target_triple_for_target( + target_os="linux", + target_arch="x86_64", + linux_libc="uclibc", + ) + + assert exc_info.value.reason == "linux_libc" + assert exc_info.value.linux_libc == "uclibc" + assert exc_info.value.supported_linux_libc_variants == ("gnu", "musl") diff --git a/tests/test_sandbox_skills_capability.py b/tests/test_sandbox_skills_capability.py new file mode 100644 index 0000000000..3809dc7bfa --- /dev/null +++ b/tests/test_sandbox_skills_capability.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from agents.sandbox import Manifest +from agents.sandbox.capabilities import Skill, Skills +from agents.sandbox.entries import Dir, File +from agents.sandbox.errors import SkillsConfigError + + +def _children_keys(entry: Dir) -> set[str]: + return {str(key if isinstance(key, Path) else Path(key)) for key in entry.children} + + +class TestSkillValidation: + def test_rejects_directory_content_artifact(self) -> None: + with pytest.raises(SkillsConfigError): + Skill(name="my-skill", description="desc", content=Dir()) + + +class TestSkillsValidation: + def test_requires_at_least_one_source(self) -> None: + with pytest.raises(SkillsConfigError): + Skills() + + def test_rejects_non_directory_from_artifact(self) -> None: + with pytest.raises(SkillsConfigError): + Skills(from_=File(content=b"not-a-dir")) + + def test_rejects_duplicate_skill_names(self) -> None: + with pytest.raises(SkillsConfigError): + Skills( + skills=[ + Skill(name="dup", description="first", content="a"), + Skill(name="dup", description="second", content="b"), + ] + ) + + def test_rejects_combining_literal_and_from_sources(self) -> None: + with pytest.raises(SkillsConfigError): + Skills( + from_=Dir( + children={"my-skill": Dir(children={"SKILL.md": File(content=b"imported")})} + ), + skills=[Skill(name="my-skill", description="desc", content="literal")], + ) + + +class TestSkillsManifest: + def test_literals_materialize_full_skill_structure(self) -> None: + capability = Skills( + skills=[ + Skill( + name="my-skill", + description="desc", + content="Use this skill.", + scripts={"run.sh": File(content=b"echo run")}, + references={"docs/readme.md": File(content=b"ref")}, + assets={"images/icon.txt": File(content=b"asset")}, + ) + ] + ) + + processed = capability.process_manifest(Manifest(root="/workspace")) + skill_entry = processed.entries[Path(".agents/skills/my-skill")] + assert isinstance(skill_entry, Dir) + assert _children_keys(skill_entry) == {"SKILL.md", "assets", "references", "scripts"} + + scripts = skill_entry.children["scripts"] + assert isinstance(scripts, Dir) + assert _children_keys(scripts) == {"run.sh"} + + references = skill_entry.children["references"] + assert isinstance(references, Dir) + assert _children_keys(references) == {"docs/readme.md"} + + assets = skill_entry.children["assets"] + assert isinstance(assets, Dir) + assert _children_keys(assets) == {"images/icon.txt"} + + def test_from_source_is_mapped_to_skills_root(self) -> None: + source = Dir(children={"imported": Dir(children={"SKILL.md": File(content=b"imported")})}) + capability = Skills(from_=source) + + processed = capability.process_manifest(Manifest(root="/workspace")) + assert processed.entries[Path(".agents/skills")] is source + + def test_literal_skills_are_idempotent_when_manifest_already_contains_same_skill(self) -> None: + capability = Skills( + skills=[ + Skill( + name="my-skill", + description="desc", + content="Use this skill.", + scripts={"run.sh": File(content=b"echo run")}, + ) + ] + ) + rendered_skill = capability.skills[0].as_dir_entry() + manifest = Manifest( + root="/workspace", + entries={".agents/skills/my-skill": rendered_skill}, + ) + + processed = capability.process_manifest(manifest) + assert processed.entries[".agents/skills/my-skill"] == rendered_skill + + def test_process_manifest_rejects_exact_path_collision(self) -> None: + capability = Skills(skills=[Skill(name="my-skill", description="desc", content="literal")]) + manifest = Manifest(root="/workspace", entries={Path(".agents/skills/my-skill"): Dir()}) + + with pytest.raises(SkillsConfigError): + capability.process_manifest(manifest) + + +class TestSkillsInstructions: + @pytest.mark.asyncio + async def test_instructions_return_none(self) -> None: + capability = Skills( + skills=[ + Skill(name="z-skill", description="z description", content="z"), + Skill(name="a-skill", description="a description", content="a"), + ] + ) + + instructions = await capability.instructions(Manifest(root="/workspace")) + assert instructions is None diff --git a/tests/test_sandbox_snapshot.py b/tests/test_sandbox_snapshot.py new file mode 100644 index 0000000000..629bebad93 --- /dev/null +++ b/tests/test_sandbox_snapshot.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +import io +from pathlib import Path +from typing import Literal + +import pytest + +from agents.sandbox.manifest import Manifest +from agents.sandbox.session.base_sandbox_session import BaseSandboxSession +from agents.sandbox.session.sandbox_session_state import SandboxSessionState +from agents.sandbox.snapshot import NoopSnapshot, SnapshotBase +from agents.sandbox.types import ExecResult + + +class TestNoopSnapshot(SnapshotBase): + __test__ = False + type: Literal["test-noop"] = "test-noop" + + async def persist(self, data: io.IOBase) -> None: + _ = data + + async def restore(self) -> io.IOBase: + raise FileNotFoundError(Path("")) + + async def restorable(self) -> bool: + return False + + +def test_sandbox_session_state_roundtrip_preserves_custom_snapshot_type() -> None: + state = SandboxSessionState( + manifest=Manifest(), + snapshot=TestNoopSnapshot(id="custom-snapshot"), + ) + + payload = state.model_dump_json() + restored = SandboxSessionState.model_validate_json(payload) + + assert isinstance(restored.snapshot, TestNoopSnapshot) + assert restored.snapshot.id == "custom-snapshot" + + +def test_snapshot_parse_uses_registered_custom_snapshot_type() -> None: + parsed = SnapshotBase.parse({"type": "test-noop", "id": "registered"}) + + assert isinstance(parsed, TestNoopSnapshot) + assert parsed.id == "registered" + + +def test_duplicate_snapshot_type_registration_raises() -> None: + class TestDuplicateSnapshotA(SnapshotBase): + __test__ = False + type: Literal["test-duplicate"] = "test-duplicate" + + async def persist(self, data: io.IOBase) -> None: + _ = data + + async def restore(self) -> io.IOBase: + raise FileNotFoundError(Path("")) + + async def restorable(self) -> bool: + return False + + _ = TestDuplicateSnapshotA + + with pytest.raises(TypeError, match="already registered"): + + class TestDuplicateSnapshotB(SnapshotBase): + __test__ = False + type: Literal["test-duplicate"] = "test-duplicate" + + async def persist(self, data: io.IOBase) -> None: + _ = data + + async def restore(self) -> io.IOBase: + raise FileNotFoundError(Path("")) + + async def restorable(self) -> bool: + return False + + +def test_snapshot_subclasses_require_type_discriminator_default() -> None: + with pytest.raises(TypeError, match="must define a non-empty string default for `type`"): + + class TestMissingTypeSnapshot(SnapshotBase): + __test__ = False + + async def persist(self, data: io.IOBase) -> None: + _ = data + + async def restore(self) -> io.IOBase: + raise FileNotFoundError(Path("")) + + async def restorable(self) -> bool: + return False + + +class _PersistTrackingSession(BaseSandboxSession): + def __init__(self, snapshot: SnapshotBase) -> None: + self.state = SandboxSessionState( + manifest=Manifest(), + snapshot=snapshot, + ) + self.persist_workspace_calls = 0 + self.persist_payload = b"tracked" + + async def _exec_internal( + self, + *command: str | Path, + timeout: float | None = None, + ) -> ExecResult: + _ = (command, timeout) + raise AssertionError("_exec_internal() should not be called in this test") + + async def read(self, path: Path) -> io.IOBase: + _ = path + raise AssertionError("read() should not be called in this test") + + async def write(self, path: Path, data: io.IOBase) -> None: + _ = (path, data) + raise AssertionError("write() should not be called in this test") + + async def running(self) -> bool: + return True + + async def persist_workspace(self) -> io.IOBase: + self.persist_workspace_calls += 1 + return io.BytesIO(self.persist_payload) + + async def hydrate_workspace(self, data: io.IOBase) -> None: + _ = data + + async def shutdown(self) -> None: + return + + +@pytest.mark.asyncio +async def test_noop_snapshot_stop_skips_workspace_persist() -> None: + session = _PersistTrackingSession(NoopSnapshot(id="noop")) + + await session.stop() + + assert session.persist_workspace_calls == 0 + + +@pytest.mark.asyncio +async def test_non_noop_snapshot_stop_persists_workspace() -> None: + snapshot = TestNoopSnapshot(id="custom-snapshot") + session = _PersistTrackingSession(snapshot) + + await session.stop() + + assert session.persist_workspace_calls == 1 diff --git a/tests/test_sandbox_snapshot_defaults.py b/tests/test_sandbox_snapshot_defaults.py new file mode 100644 index 0000000000..f752b4f859 --- /dev/null +++ b/tests/test_sandbox_snapshot_defaults.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import os +from pathlib import Path + +from agents.sandbox.snapshot import LocalSnapshotSpec +from agents.sandbox.snapshot_defaults import ( + _DEFAULT_LOCAL_SNAPSHOT_TTL_SECONDS, + cleanup_stale_default_local_snapshots, + default_local_snapshot_base_dir, + resolve_default_local_snapshot_spec, +) + + +def test_default_local_snapshot_base_dir_uses_xdg_state_home(tmp_path: Path) -> None: + state_home = tmp_path / "state" + result = default_local_snapshot_base_dir( + home=tmp_path / "home", + env={"XDG_STATE_HOME": str(state_home)}, + platform="linux", + os_name="posix", + ) + + assert result == state_home / "openai-agents-python" / "sandbox" / "snapshots" + + +def test_default_local_snapshot_base_dir_uses_macos_application_support(tmp_path: Path) -> None: + home = tmp_path / "home" + result = default_local_snapshot_base_dir( + home=home, + env={}, + platform="darwin", + os_name="posix", + ) + + assert ( + result + == home + / "Library" + / "Application Support" + / "openai-agents-python" + / "sandbox" + / "snapshots" + ) + + +def test_default_local_snapshot_base_dir_uses_localappdata_on_windows(tmp_path: Path) -> None: + local_app_data = tmp_path / "LocalAppData" + result = default_local_snapshot_base_dir( + home=tmp_path / "home", + env={"LOCALAPPDATA": str(local_app_data)}, + platform="win32", + os_name="nt", + ) + + assert result == local_app_data / "openai-agents-python" / "sandbox" / "snapshots" + + +def test_cleanup_stale_default_local_snapshots_removes_only_old_tar_files(tmp_path: Path) -> None: + managed_dir = tmp_path / "snapshots" + managed_dir.mkdir() + stale = managed_dir / "stale.tar" + fresh = managed_dir / "fresh.tar" + keep = managed_dir / "keep.txt" + stale.write_bytes(b"stale") + fresh.write_bytes(b"fresh") + keep.write_text("keep") + + now = 2_000_000_000.0 + stale_mtime = now - (_DEFAULT_LOCAL_SNAPSHOT_TTL_SECONDS + 60) + fresh_mtime = now - 60 + os.utime(stale, (stale_mtime, stale_mtime)) + os.utime(fresh, (fresh_mtime, fresh_mtime)) + + cleanup_stale_default_local_snapshots(managed_dir, now=now) + + assert not stale.exists() + assert fresh.exists() + assert keep.exists() + + +def test_resolve_default_local_snapshot_spec_keeps_existing_stale_files( + tmp_path: Path, +) -> None: + state_home = tmp_path / "state" + managed_dir = state_home / "openai-agents-python" / "sandbox" / "snapshots" + managed_dir.mkdir(parents=True) + stale = managed_dir / "stale.tar" + stale.write_bytes(b"stale") + now = 2_000_000_000.0 + stale_mtime = now - (_DEFAULT_LOCAL_SNAPSHOT_TTL_SECONDS + 60) + os.utime(stale, (stale_mtime, stale_mtime)) + + spec = resolve_default_local_snapshot_spec( + home=tmp_path / "home", + env={"XDG_STATE_HOME": str(state_home)}, + platform="linux", + os_name="posix", + now=now, + ) + + assert isinstance(spec, LocalSnapshotSpec) + assert spec.base_path == managed_dir + assert managed_dir.exists() + assert stale.exists() diff --git a/tests/test_server_conversation_tracker.py b/tests/test_server_conversation_tracker.py index baafac6fda..4ebbda7edd 100644 --- a/tests/test_server_conversation_tracker.py +++ b/tests/test_server_conversation_tracker.py @@ -10,6 +10,7 @@ from agents.result import RunResultStreaming from agents.run_config import ModelInputData, RunConfig from agents.run_context import RunContextWrapper +from agents.run_internal.agent_bindings import bind_public_agent from agents.run_internal.oai_conversation import OpenAIServerConversationTracker from agents.run_internal.run_loop import get_new_response, run_single_turn_streamed from agents.run_internal.tool_use_tracker import AgentToolUseTracker @@ -646,7 +647,7 @@ def _filter_input(payload: Any) -> ModelInputData: run_config = RunConfig(call_model_input_filter=_filter_input) await get_new_response( - agent, + bind_public_agent(agent), None, [item_1, item_2], None, @@ -705,7 +706,7 @@ def _filter_input(payload: Any) -> ModelInputData: await run_single_turn_streamed( streamed_result, - agent, + bind_public_agent(agent), RunHooks(), context_wrapper, run_config, @@ -780,7 +781,7 @@ def _filter_input(payload: Any) -> ModelInputData: await run_single_turn_streamed( streamed_result, - agent, + bind_public_agent(agent), RunHooks(), context_wrapper, run_config, diff --git a/tests/test_shell_tool.py b/tests/test_shell_tool.py index b513388d37..8a6a6ff857 100644 --- a/tests/test_shell_tool.py +++ b/tests/test_shell_tool.py @@ -204,7 +204,7 @@ async def test_execute_shell_calls_surfaces_missing_local_executor() -> None: context_wrapper: RunContextWrapper[Any] = RunContextWrapper(context=None) result = await execute_shell_calls( - agent=agent, + public_agent=agent, calls=[tool_run], context_wrapper=context_wrapper, hooks=RunHooks[Any](), diff --git a/tests/test_tool_use_tracker.py b/tests/test_tool_use_tracker.py index d2276c852d..9e6cf4c850 100644 --- a/tests/test_tool_use_tracker.py +++ b/tests/test_tool_use_tracker.py @@ -39,6 +39,59 @@ def test_tool_use_tracker_from_and_serialize_snapshots() -> None: assert serialize_tool_use_tracker(runtime_tracker) == {"serialize-agent": ["one", "two"]} +def test_serialize_and_hydrate_tool_use_tracker_preserves_duplicate_agent_identity() -> None: + second = Agent(name="duplicate") + first = Agent(name="duplicate", handoffs=[second]) + second.handoffs = [first] + + tracker = AgentToolUseTracker() + tracker.add_tool_use(second, ["approval_tool"]) + + snapshot = serialize_tool_use_tracker(tracker, starting_agent=first) + assert snapshot == {"duplicate#2": ["approval_tool"]} + + class _RunState: + def get_tool_use_tracker_snapshot(self) -> dict[str, list[str]]: + return snapshot + + hydrated = AgentToolUseTracker() + hydrate_tool_use_tracker( + tool_use_tracker=hydrated, + run_state=_RunState(), + starting_agent=first, + ) + + assert hydrated.agent_to_tools == [(second, ["approval_tool"])] + + +def test_tool_use_tracker_handles_literal_suffix_names_without_collision() -> None: + literal_suffix = Agent(name="sandbox#2") + first = Agent(name="sandbox", handoffs=[literal_suffix]) + second = Agent(name="sandbox") + literal_suffix.handoffs = [first, second] + first.handoffs = [literal_suffix, second] + second.handoffs = [first, literal_suffix] + + tracker = AgentToolUseTracker() + tracker.add_tool_use(second, ["approval_tool"]) + + snapshot = serialize_tool_use_tracker(tracker, starting_agent=first) + assert snapshot == {"sandbox#3": ["approval_tool"]} + + class _RunState: + def get_tool_use_tracker_snapshot(self) -> dict[str, list[str]]: + return snapshot + + hydrated = AgentToolUseTracker() + hydrate_tool_use_tracker( + tool_use_tracker=hydrated, + run_state=_RunState(), + starting_agent=first, + ) + + assert hydrated.agent_to_tools == [(second, ["approval_tool"])] + + def test_record_used_tools_uses_trace_names_for_namespaced_and_deferred_functions() -> None: agent = Agent(name="tracked-agent") tracker = AgentToolUseTracker() diff --git a/uv.lock b/uv.lock index 78531dc890..7c587907e9 100644 --- a/uv.lock +++ b/uv.lock @@ -127,6 +127,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f5/10/6c25ed6de94c49f88a91fa5018cb4c0f3625f31d5be9f771ebe5cc7cd506/aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0", size = 15792, upload-time = "2025-02-03T07:30:13.6Z" }, ] +[[package]] +name = "annotated-doc" +version = "0.0.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/57/ba/046ceea27344560984e26a590f90bc7f4a75b06701f653222458922b558c/annotated_doc-0.0.4.tar.gz", hash = "sha256:fbcda96e87e9c92ad167c2e53839e57503ecfda18804ea28102353485033faa4", size = 7288, upload-time = "2025-11-10T22:07:42.062Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl", hash = "sha256:571ac1dc6991c450b25a9c2d84a3705e2ae7a53467b5d111c24fa8baabbed320", size = 5303, upload-time = "2025-11-10T22:07:40.673Z" }, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -253,6 +262,59 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/41/ff/392bff89415399a979be4a65357a41d92729ae8580a66073d8ec8d810f98/backrefs-5.9-py39-none-any.whl", hash = "sha256:f48ee18f6252b8f5777a22a00a09a85de0ca931658f1dd96d4406a34f3748c60", size = 380265, upload-time = "2025-06-22T19:34:12.405Z" }, ] +[[package]] +name = "bracex" +version = "2.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/63/9a/fec38644694abfaaeca2798b58e276a8e61de49e2e37494ace423395febc/bracex-2.6.tar.gz", hash = "sha256:98f1347cd77e22ee8d967a30ad4e310b233f7754dbf31ff3fceb76145ba47dc7", size = 26642, upload-time = "2025-06-22T19:12:31.254Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9d/2a/9186535ce58db529927f6cf5990a849aa9e052eea3e2cfefe20b9e1802da/bracex-2.6-py3-none-any.whl", hash = "sha256:0b0049264e7340b3ec782b5cb99beb325f36c3782a32e36e876452fd49a09952", size = 11508, upload-time = "2025-06-22T19:12:29.781Z" }, +] + +[[package]] +name = "cbor2" +version = "5.8.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d9/8e/8b4fdde28e42ffcd741a37f4ffa9fb59cd4fe01625b544dfcfd9ccb54f01/cbor2-5.8.0.tar.gz", hash = "sha256:b19c35fcae9688ac01ef75bad5db27300c2537eb4ee00ed07e05d8456a0d4931", size = 107825, upload-time = "2025-12-30T18:44:22.455Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/05/486166d9e998d65d70810e63eeacc8c5f13d167d8797cf2d73a588beb335/cbor2-5.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2263c0c892194f10012ced24c322d025d9d7b11b41da1c357f3b3fe06676e6b7", size = 69882, upload-time = "2025-12-30T18:43:25.365Z" }, + { url = "https://files.pythonhosted.org/packages/4e/d0/ee976eaaf21c211eef651e1a921c109c3c3a3785d98307d74a70d142f341/cbor2-5.8.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6ffe4ca079f6f8ed393f5c71a8de22651cb27bd50e74e2bcd6bc9c8f853a732b", size = 260696, upload-time = "2025-12-30T18:43:27.784Z" }, + { url = "https://files.pythonhosted.org/packages/66/7f/81cabd3aee6cc54b101a5214d5c3e541d275d7c05647c7dfc266c6aacf6f/cbor2-5.8.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0427bd166230fe4c4b72965c6f2b6273bf29016d97cf08b258fa48db851ea598", size = 252135, upload-time = "2025-12-30T18:43:29.418Z" }, + { url = "https://files.pythonhosted.org/packages/c2/0b/f38e8c579e7e2d88d446549bce35bde7d845199300bc456b4123d6e6f0af/cbor2-5.8.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:c23a04947c37964d70028ca44ea2a8709f09b8adc0090f9b5710fa957e9bc545", size = 255342, upload-time = "2025-12-30T18:43:30.966Z" }, + { url = "https://files.pythonhosted.org/packages/5d/02/8413f1bd42c8f665fb85374151599cb4957848f0f307d08334a08dee544c/cbor2-5.8.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:218d5c7d2e8d13c7eded01a1b3fe2a9a1e51a7a843cefb8d38cb4bbbc6ad9bf7", size = 247191, upload-time = "2025-12-30T18:43:32.555Z" }, + { url = "https://files.pythonhosted.org/packages/e5/b8/edeffcad06b83d3661827973a8e6f5d51a9f5842e1ee9d191fdef60388ad/cbor2-5.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:4ce7d907a25448af7c13415281d739634edfd417228b274309b243ca52ad71f9", size = 69254, upload-time = "2025-12-30T18:43:33.717Z" }, + { url = "https://files.pythonhosted.org/packages/ce/1a/dde6537d8d1c2b3157ea6487ea417a5ad0157687d0e9a3ff806bf23c8cb1/cbor2-5.8.0-cp310-cp310-win_arm64.whl", hash = "sha256:628d0ea850aa040921a0e50a08180e7d20cf691432cec3eabc193f643eccfbde", size = 64946, upload-time = "2025-12-30T18:43:34.849Z" }, + { url = "https://files.pythonhosted.org/packages/88/4b/623435ef9b98e86b6956a41863d39ff4fe4d67983948b5834f55499681dd/cbor2-5.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:18ac191640093e6c7fbcb174c006ffec4106c3d8ab788e70272c1c4d933cbe11", size = 69875, upload-time = "2025-12-30T18:43:35.888Z" }, + { url = "https://files.pythonhosted.org/packages/58/17/f664201080b2a7d0f57c16c8e9e5922013b92f202e294863ec7e75b7ff7f/cbor2-5.8.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fddee9103a17d7bed5753f0c7fc6663faa506eb953e50d8287804eccf7b048e6", size = 268316, upload-time = "2025-12-30T18:43:37.161Z" }, + { url = "https://files.pythonhosted.org/packages/d0/e1/072745b4ff01afe9df2cd627f8fc51a1acedb5d3d1253765625d2929db91/cbor2-5.8.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8d2ea26fad620aba5e88d7541be8b10c5034a55db9a23809b7cb49f36803f05b", size = 258874, upload-time = "2025-12-30T18:43:38.878Z" }, + { url = "https://files.pythonhosted.org/packages/a7/10/61c262b886d22b62c56e8aac6d10fa06d0953c997879ab882a31a624952b/cbor2-5.8.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:de68b4b310b072b082d317adc4c5e6910173a6d9455412e6183d72c778d1f54c", size = 261971, upload-time = "2025-12-30T18:43:40.401Z" }, + { url = "https://files.pythonhosted.org/packages/7e/42/b7862f5e64364b10ad120ea53e87ec7e891fb268cb99c572348e647cf7e9/cbor2-5.8.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:418d2cf0e03e90160fa1474c05a40fe228bbb4a92d1628bdbbd13a48527cb34d", size = 254151, upload-time = "2025-12-30T18:43:41.938Z" }, + { url = "https://files.pythonhosted.org/packages/16/6a/8d3636cf75466c18615e7cfac0d345ee3c030f6c79535faed0c2c02b1839/cbor2-5.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:453200ffa1c285ea46ab5745736a015526d41f22da09cb45594624581d959770", size = 69169, upload-time = "2025-12-30T18:43:43.424Z" }, + { url = "https://files.pythonhosted.org/packages/9b/88/79b205bf869558b39a11de70750cb13679b27ba5654a43bed3f2aee7d1b4/cbor2-5.8.0-cp311-cp311-win_arm64.whl", hash = "sha256:f6615412fca973a8b472b3efc4dab01df71cc13f15d8b2c0a1cffac44500f12d", size = 64955, upload-time = "2025-12-30T18:43:44.7Z" }, + { url = "https://files.pythonhosted.org/packages/2f/4f/3a16e3e8fd7e5fd86751a4f1aad218a8d19a96e75ec3989c3e95a8fe1d8f/cbor2-5.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4b3f91fa699a5ce22470e973601c62dd9d55dc3ca20ee446516ac075fcab27c9", size = 70270, upload-time = "2025-12-30T18:43:46.005Z" }, + { url = "https://files.pythonhosted.org/packages/38/81/0d0cf0796fe8081492a61c45278f03def21a929535a492dd97c8438f5dbe/cbor2-5.8.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:518c118a5e00001854adb51f3164e647aa99b6a9877d2a733a28cb5c0a4d6857", size = 286242, upload-time = "2025-12-30T18:43:47.026Z" }, + { url = "https://files.pythonhosted.org/packages/7b/a9/fdab6c10190cfb8d639e01f2b168f2406fc847a2a6bc00e7de78c3381d0a/cbor2-5.8.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:cff2a1999e49cd51c23d1b6786a012127fd8f722c5946e82bd7ab3eb307443f3", size = 285412, upload-time = "2025-12-30T18:43:48.563Z" }, + { url = "https://files.pythonhosted.org/packages/31/59/746a8e630996217a3afd523f583fcf7e3d16640d63f9a03f0f4e4f74b5b1/cbor2-5.8.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4c4492160212374973cdc14e46f0565f2462721ef922b40f7ea11e7d613dfb2a", size = 278041, upload-time = "2025-12-30T18:43:49.92Z" }, + { url = "https://files.pythonhosted.org/packages/0f/a3/f3bbeb6dedd45c6e0cddd627ea790dea295eaf82c83f0e2159b733365ebd/cbor2-5.8.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:546c7c7c4c6bcdc54a59242e0e82cea8f332b17b4465ae628718fef1fce401ca", size = 278185, upload-time = "2025-12-30T18:43:51.192Z" }, + { url = "https://files.pythonhosted.org/packages/67/e5/9013d6b857ceb6cdb2851ffb5a887f53f2bab934a528c9d6fa73d9989d84/cbor2-5.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:074f0fa7535dd7fdee247c2c99f679d94f3aa058ccb1ccf4126cc72d6d89cbae", size = 69817, upload-time = "2025-12-30T18:43:52.352Z" }, + { url = "https://files.pythonhosted.org/packages/a8/ab/7aa94ba3d44ecbc3a97bdb2fb6a8298063fe2e0b611e539a6fe41e36da20/cbor2-5.8.0-cp312-cp312-win_arm64.whl", hash = "sha256:f95fed480b2a0d843f294d2a1ef4cc0f6a83c7922927f9f558e1f5a8dc54b7ca", size = 64923, upload-time = "2025-12-30T18:43:53.719Z" }, + { url = "https://files.pythonhosted.org/packages/a6/0d/5a3f20bafaefeb2c1903d961416f051c0950f0d09e7297a3aa6941596b29/cbor2-5.8.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6d8d104480845e2f28c6165b4c961bbe58d08cb5638f368375cfcae051c28015", size = 70332, upload-time = "2025-12-30T18:43:54.694Z" }, + { url = "https://files.pythonhosted.org/packages/57/66/177a3f089e69db69c987453ab4934086408c3338551e4984734597be9f80/cbor2-5.8.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:43efee947e5ab67d406d6e0dc61b5dee9d2f5e89ae176f90677a3741a20ca2e7", size = 285985, upload-time = "2025-12-30T18:43:55.733Z" }, + { url = "https://files.pythonhosted.org/packages/b7/8e/9e17b8e4ed80a2ce97e2dfa5915c169dbb31599409ddb830f514b57f96cc/cbor2-5.8.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:be7ae582f50be539e09c134966d0fd63723fc4789b8dff1f6c2e3f24ae3eaf32", size = 285173, upload-time = "2025-12-30T18:43:57.321Z" }, + { url = "https://files.pythonhosted.org/packages/cc/33/9f92e107d78f88ac22723ac15d0259d220ba98c1d855e51796317f4c4114/cbor2-5.8.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:50f5c709561a71ea7970b4cd2bf9eda4eccacc0aac212577080fdfe64183e7f5", size = 278395, upload-time = "2025-12-30T18:43:58.497Z" }, + { url = "https://files.pythonhosted.org/packages/2f/3f/46b80050a4a35ce5cf7903693864a9fdea7213567dc8faa6e25cb375c182/cbor2-5.8.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a6790ecc73aa93e76d2d9076fc42bf91a9e69f2295e5fa702e776dbe986465bd", size = 278330, upload-time = "2025-12-30T18:43:59.656Z" }, + { url = "https://files.pythonhosted.org/packages/eb/d2/d41f8c04c783a4d204e364be2d38043d4f732a3bed6f4c732e321cf34c7b/cbor2-5.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:c114af8099fa65a19a514db87ce7a06e942d8fea2730afd49be39f8e16e7f5e0", size = 69841, upload-time = "2025-12-30T18:44:01.159Z" }, + { url = "https://files.pythonhosted.org/packages/1b/8c/0397a82f6e67665009951453c83058e4c77ba54b9a9017ede56d6870306c/cbor2-5.8.0-cp313-cp313-win_arm64.whl", hash = "sha256:ab3ba00494ad8669a459b12a558448d309c271fa4f89b116ad496ee35db38fea", size = 64982, upload-time = "2025-12-30T18:44:02.138Z" }, + { url = "https://files.pythonhosted.org/packages/4b/0c/0654233d7543ac8a50f4785f172430ddc97538ba418eb305d6e529d1a120/cbor2-5.8.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:ad72381477133046ce217617d839ea4e9454f8b77d9a6351b229e214102daeb7", size = 70710, upload-time = "2025-12-30T18:44:03.209Z" }, + { url = "https://files.pythonhosted.org/packages/84/62/4671d24e557d7f5a74a01b422c538925140c0495e57decde7e566f91d029/cbor2-5.8.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6da25190fad3434ce99876b11d4ca6b8828df6ca232cf7344cd14ae1166fb718", size = 285005, upload-time = "2025-12-30T18:44:05.109Z" }, + { url = "https://files.pythonhosted.org/packages/87/85/0c67d763a08e848c9a80d7e4723ba497cce676f41bc7ca1828ae90a0a872/cbor2-5.8.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c13919e3a24c5a6d286551fa288848a4cedc3e507c58a722ccd134e461217d99", size = 282435, upload-time = "2025-12-30T18:44:06.465Z" }, + { url = "https://files.pythonhosted.org/packages/b2/01/0650972b4dbfbebcfbe37cbba7fc3cd9019a8da6397ab3446e07175e342b/cbor2-5.8.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:f8c40d32e5972047a777f9bf730870828f3cf1c43b3eb96fd0429c57a1d3b9e6", size = 277493, upload-time = "2025-12-30T18:44:07.609Z" }, + { url = "https://files.pythonhosted.org/packages/b3/6c/7704a4f32adc7f10f3b41ec067f500a4458f7606397af5e4cf2d368fd288/cbor2-5.8.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:7627894bc0b3d5d0807f31e3107e11b996205470c4429dc2bb4ef8bfe7f64e1e", size = 276085, upload-time = "2025-12-30T18:44:09.021Z" }, + { url = "https://files.pythonhosted.org/packages/88/6d/e43452347630efe8133f5304127539100d937c138c0996d27ec63963ec2c/cbor2-5.8.0-cp314-cp314-win_amd64.whl", hash = "sha256:b51c5e59becae746ca4de2bbaa8a2f5c64a68fec05cea62941b1a84a8335f7d1", size = 71657, upload-time = "2025-12-30T18:44:10.162Z" }, + { url = "https://files.pythonhosted.org/packages/8b/66/9a780ef34ab10a0437666232e885378cdd5f60197b1b5e61a62499e5a10a/cbor2-5.8.0-cp314-cp314-win_arm64.whl", hash = "sha256:53b630f4db4b9f477ad84077283dd17ecf9894738aa17ef4938c369958e02a71", size = 67171, upload-time = "2025-12-30T18:44:11.619Z" }, + { url = "https://files.pythonhosted.org/packages/d6/4f/101071f880b4da05771128c0b89f41e334cff044dee05fb013c8f4be661c/cbor2-5.8.0-py3-none-any.whl", hash = "sha256:3727d80f539567b03a7aa11890e57798c67092c38df9e6c23abb059e0f65069c", size = 24374, upload-time = "2025-12-30T18:44:21.476Z" }, +] + [[package]] name = "certifi" version = "2025.8.3" @@ -576,6 +638,50 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e3/26/57c6fb270950d476074c087527a558ccb6f4436657314bfb6cdf484114c4/docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0", size = 147774, upload-time = "2024-05-23T11:13:55.01Z" }, ] +[[package]] +name = "dockerfile-parse" +version = "2.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/92/df/929ee0b5d2c8bd8d713c45e71b94ab57c7e11e322130724d54f469b2cd48/dockerfile-parse-2.0.1.tar.gz", hash = "sha256:3184ccdc513221983e503ac00e1aa504a2aa8f84e5de673c46b0b6eee99ec7bc", size = 24556, upload-time = "2023-07-18T13:36:07.897Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/6c/79cd5bc1b880d8c1a9a5550aa8dacd57353fa3bb2457227e1fb47383eb49/dockerfile_parse-2.0.1-py2.py3-none-any.whl", hash = "sha256:bdffd126d2eb26acf1066acb54cb2e336682e1d72b974a40894fac76a4df17f6", size = 14845, upload-time = "2023-07-18T13:36:06.052Z" }, +] + +[[package]] +name = "e2b" +version = "2.15.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "dockerfile-parse" }, + { name = "httpcore" }, + { name = "httpx" }, + { name = "packaging" }, + { name = "protobuf" }, + { name = "python-dateutil" }, + { name = "rich" }, + { name = "typing-extensions" }, + { name = "wcmatch" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c2/d9/46cbb319dc8bd8a0428cffce61bdc1c13fbe15390f84cdb6472412e17f3a/e2b-2.15.2.tar.gz", hash = "sha256:414379d2421d6827eeb2eb50a4d6b3fdb7d691b39ff73b5ea05ca4b532819831", size = 139751, upload-time = "2026-03-09T22:22:44.288Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ac/ae/859e500f2a0a419b0d7486d6f18e7e47f81cabd34c2f2e252dd21027ed85/e2b-2.15.2-py3-none-any.whl", hash = "sha256:19a56fbdea25974dc81426ed48337eae6cea91d404f5bcf8861a5a2c6e4d982a", size = 257033, upload-time = "2026-03-09T22:22:42.921Z" }, +] + +[[package]] +name = "e2b-code-interpreter" +version = "2.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "e2b" }, + { name = "httpx" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1e/eb/db6e51edd9f3402fd68d026572579b9b1bd833b10d990376a1e4c05d5b8d/e2b_code_interpreter-2.4.1.tar.gz", hash = "sha256:4b15014ee0d0dfcdc3072e1f409cbb87ca48f48d53d75629b7257e5513b9e7dd", size = 10700, upload-time = "2025-11-26T18:12:38.086Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1b/e7/09b9106ead227f7be14bd97c3181391ee498bb38933b1a9c566b72c8567a/e2b_code_interpreter-2.4.1-py3-none-any.whl", hash = "sha256:15d35f025b4a15033e119f2e12e7ac65657ad2b5a013fa9149e74581fbee778a", size = 13719, upload-time = "2025-11-26T18:12:36.7Z" }, +] + [[package]] name = "eval-type-backport" version = "0.2.2" @@ -1005,6 +1111,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/05/18/56999a1da3577d8ccc8698a575d6638e15fe25650cc88b2ce0a087f180b9/grpcio_status-1.67.1-py3-none-any.whl", hash = "sha256:16e6c085950bdacac97c779e6a502ea671232385e6e37f258884d6883392c2bd", size = 14427, upload-time = "2024-10-29T06:27:38.228Z" }, ] +[[package]] +name = "grpclib" +version = "0.4.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "h2" }, + { name = "multidict" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5b/28/5a2c299ec82a876a252c5919aa895a6f1d1d35c96417c5ce4a4660dc3a80/grpclib-0.4.9.tar.gz", hash = "sha256:cc589c330fa81004c6400a52a566407574498cb5b055fa927013361e21466c46", size = 84798, upload-time = "2025-12-14T22:23:14.349Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/90/b0cbbd9efcc82816c58f31a34963071aa19fb792a212a5d9caf8e0fc3097/grpclib-0.4.9-py3-none-any.whl", hash = "sha256:7762ec1c8ed94dfad597475152dd35cbd11aecaaca2f243e29702435ca24cf0e", size = 77063, upload-time = "2025-12-14T22:23:13.224Z" }, +] + [[package]] name = "h11" version = "0.16.0" @@ -1014,6 +1133,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, ] +[[package]] +name = "h2" +version = "4.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "hpack" }, + { name = "hyperframe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1d/17/afa56379f94ad0fe8defd37d6eb3f89a25404ffc71d4d848893d270325fc/h2-4.3.0.tar.gz", hash = "sha256:6c59efe4323fa18b47a632221a1888bd7fde6249819beda254aeca909f221bf1", size = 2152026, upload-time = "2025-08-23T18:12:19.778Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/69/b2/119f6e6dcbd96f9069ce9a2665e0146588dc9f88f29549711853645e736a/h2-4.3.0-py3-none-any.whl", hash = "sha256:c438f029a25f7945c69e0ccf0fb951dc3f73a5f6412981daee861431b70e2bdd", size = 61779, upload-time = "2025-08-23T18:12:17.779Z" }, +] + [[package]] name = "hf-xet" version = "1.1.7" @@ -1029,6 +1161,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a3/73/e354eae84ceff117ec3560141224724794828927fcc013c5b449bf0b8745/hf_xet-1.1.7-cp37-abi3-win_amd64.whl", hash = "sha256:2e356da7d284479ae0f1dea3cf5a2f74fdf925d6dca84ac4341930d892c7cb34", size = 2820008, upload-time = "2025-08-06T00:30:57.056Z" }, ] +[[package]] +name = "hpack" +version = "4.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2c/48/71de9ed269fdae9c8057e5a4c0aa7402e8bb16f2c6e90b3aa53327b113f8/hpack-4.1.0.tar.gz", hash = "sha256:ec5eca154f7056aa06f196a557655c5b009b382873ac8d1e66e79e87535f1dca", size = 51276, upload-time = "2025-01-22T21:44:58.347Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/07/c6/80c95b1b2b94682a72cbdbfb85b81ae2daffa4291fbfa1b1464502ede10d/hpack-4.1.0-py3-none-any.whl", hash = "sha256:157ac792668d995c657d93111f46b4535ed114f0c9c8d672271bbec7eae1b496", size = 34357, upload-time = "2025-01-22T21:44:56.92Z" }, +] + [[package]] name = "httpcore" version = "1.0.9" @@ -1085,6 +1226,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/39/7b/bb06b061991107cd8783f300adff3e7b7f284e330fd82f507f2a1417b11d/huggingface_hub-0.34.4-py3-none-any.whl", hash = "sha256:9b365d781739c93ff90c359844221beef048403f1bc1f1c123c191257c3c890a", size = 561452, upload-time = "2025-08-08T09:14:50.159Z" }, ] +[[package]] +name = "hyperframe" +version = "6.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/02/e7/94f8232d4a74cc99514c13a9f995811485a6903d48e5d952771ef6322e30/hyperframe-6.1.0.tar.gz", hash = "sha256:f630908a00854a7adeabd6382b43923a4c4cd4b821fcb527e6ab9e15382a3b08", size = 26566, upload-time = "2025-01-22T21:41:49.302Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/48/30/47d0bf6072f7252e6521f3447ccfa40b421b6824517f82854703d0f5a98b/hyperframe-6.1.0-py3-none-any.whl", hash = "sha256:b03380493a519fce58ea5af42e4a42317bf9bd425596f7a0835ffce80f1a42e5", size = 13007, upload-time = "2025-01-22T21:41:47.295Z" }, +] + [[package]] name = "idna" version = "3.10" @@ -1552,6 +1702,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3b/dd/a24ee3de56954bfafb6ede7cd63c2413bb842cc48eb45e41c43a05a33074/mkdocstrings_python-1.16.12-py3-none-any.whl", hash = "sha256:22ded3a63b3d823d57457a70ff9860d5a4de9e8b1e482876fc9baabaf6f5f374", size = 124287, upload-time = "2025-06-03T12:52:47.819Z" }, ] +[[package]] +name = "modal" +version = "1.3.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, + { name = "cbor2" }, + { name = "certifi" }, + { name = "click" }, + { name = "grpclib" }, + { name = "protobuf" }, + { name = "rich" }, + { name = "synchronicity" }, + { name = "toml" }, + { name = "typer" }, + { name = "types-certifi" }, + { name = "types-toml" }, + { name = "typing-extensions" }, + { name = "watchfiles" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/32/fd/f4a684209dab54d7dc9d92f48d779b30d04aa8b4c6dd1395d6c61967ee34/modal-1.3.5.tar.gz", hash = "sha256:2e320e7dbc8995ce0769796a9027248a8b976b519469cc4599d6855a1a53a123", size = 655193, upload-time = "2026-03-03T18:13:06.22Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/10/39/aa5c773a4dddef833f1c846bb4204b442588b99a1d15ab7818157e66b32c/modal-1.3.5-py3-none-any.whl", hash = "sha256:67e5d3635c2c355d63b3e30f9012dd2bc9c38d5747349335c7ba9da65edca1cb", size = 755272, upload-time = "2026-03-03T18:13:03.323Z" }, +] + [[package]] name = "multidict" version = "6.6.4" @@ -1898,6 +2073,7 @@ dependencies = [ { name = "requests" }, { name = "types-requests" }, { name = "typing-extensions" }, + { name = "websockets" }, ] [package.optional-dependencies] @@ -1905,12 +2081,22 @@ dapr = [ { name = "dapr" }, { name = "grpcio" }, ] +docker = [ + { name = "docker" }, +] +e2b = [ + { name = "e2b" }, + { name = "e2b-code-interpreter" }, +] encrypt = [ { name = "cryptography" }, ] litellm = [ { name = "litellm" }, ] +modal = [ + { name = "modal" }, +] realtime = [ { name = "websockets" }, ] @@ -1968,11 +2154,15 @@ requires-dist = [ { name = "asyncpg", marker = "extra == 'sqlalchemy'", specifier = ">=0.29.0" }, { name = "cryptography", marker = "extra == 'encrypt'", specifier = ">=45.0,<46" }, { name = "dapr", marker = "extra == 'dapr'", specifier = ">=1.16.0" }, + { name = "docker", marker = "extra == 'docker'", specifier = ">=6.1" }, + { name = "e2b", marker = "extra == 'e2b'", specifier = ">=2.12.1" }, + { name = "e2b-code-interpreter", marker = "extra == 'e2b'", specifier = ">=1.0" }, { name = "graphviz", marker = "extra == 'viz'", specifier = ">=0.17" }, { name = "griffe", specifier = ">=1.5.6,<2" }, { name = "grpcio", marker = "extra == 'dapr'", specifier = ">=1.60.0" }, { name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.81.0,<2" }, { name = "mcp", marker = "python_full_version >= '3.10'", specifier = ">=1.19.0,<2" }, + { name = "modal", marker = "extra == 'modal'", specifier = ">=1.3.1" }, { name = "numpy", marker = "python_full_version >= '3.10' and extra == 'voice'", specifier = ">=2.2.0,<3" }, { name = "openai", specifier = ">=2.26.0,<3" }, { name = "pydantic", specifier = ">=2.12.2,<3" }, @@ -1981,10 +2171,11 @@ requires-dist = [ { name = "sqlalchemy", marker = "extra == 'sqlalchemy'", specifier = ">=2.0" }, { name = "types-requests", specifier = ">=2.0,<3" }, { name = "typing-extensions", specifier = ">=4.12.2,<5" }, + { name = "websockets", specifier = ">=15.0,<16" }, { name = "websockets", marker = "extra == 'realtime'", specifier = ">=15.0,<16" }, { name = "websockets", marker = "extra == 'voice'", specifier = ">=15.0,<16" }, ] -provides-extras = ["voice", "viz", "litellm", "realtime", "sqlalchemy", "encrypt", "redis", "dapr"] +provides-extras = ["voice", "viz", "litellm", "realtime", "sqlalchemy", "encrypt", "redis", "dapr", "docker", "modal", "e2b"] [package.metadata.requires-dev] dev = [ @@ -2011,7 +2202,7 @@ dev = [ { name = "pytest-asyncio" }, { name = "pytest-mock", specifier = ">=3.14.0" }, { name = "pytest-xdist" }, - { name = "rich", specifier = ">=13.1.0,<14" }, + { name = "rich", specifier = ">=13.1.0,<15" }, { name = "ruff", specifier = "==0.9.2" }, { name = "sounddevice" }, { name = "testcontainers", specifier = "==4.12.0" }, @@ -2806,16 +2997,15 @@ wheels = [ [[package]] name = "rich" -version = "13.9.4" +version = "14.3.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "markdown-it-py" }, { name = "pygments" }, - { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ab/3a/0316b28d0761c6734d6bc14e770d85506c986c85ffb239e688eeaab2c2bc/rich-13.9.4.tar.gz", hash = "sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098", size = 223149, upload-time = "2024-11-01T16:43:57.873Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b3/c6/f3b320c27991c46f43ee9d856302c70dc2d0fb2dba4842ff739d5f46b393/rich-14.3.3.tar.gz", hash = "sha256:b8daa0b9e4eef54dd8cf7c86c03713f53241884e814f4e2f5fb342fe520f639b", size = 230582, upload-time = "2026-02-19T17:23:12.474Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/19/71/39c7c0d87f8d4e6c020a393182060eaefeeae6c01dab6a84ec346f2567df/rich-13.9.4-py3-none-any.whl", hash = "sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90", size = 242424, upload-time = "2024-11-01T16:43:55.817Z" }, + { url = "https://files.pythonhosted.org/packages/14/25/b208c5683343959b670dc001595f2f3737e051da617f66c31f7c4fa93abc/rich-14.3.3-py3-none-any.whl", hash = "sha256:793431c1f8619afa7d3b52b2cdec859562b950ea0d4b6b505397612db8d5362d", size = 310458, upload-time = "2026-02-19T17:23:13.732Z" }, ] [[package]] @@ -2978,6 +3168,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0e/4e/33df635528292bd2d18404e4daabcd74ca8a9853b2e1df85ed3d32d24362/ruff-0.9.2-py3-none-win_arm64.whl", hash = "sha256:a1b63fa24149918f8b37cef2ee6fff81f24f0d74b6f0bdc37bc3e1f2143e41c6", size = 10001738, upload-time = "2025-01-16T13:22:18.121Z" }, ] +[[package]] +name = "shellingham" +version = "1.5.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/58/15/8b3609fd3830ef7b27b655beb4b4e9c62313a4e8da8c676e142cc210d58e/shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de", size = 10310, upload-time = "2023-10-24T04:13:40.426Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755, upload-time = "2023-10-24T04:13:38.866Z" }, +] + [[package]] name = "six" version = "1.17.0" @@ -3090,6 +3289,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f7/1f/b876b1f83aef204198a42dc101613fefccb32258e5428b5f9259677864b4/starlette-0.47.2-py3-none-any.whl", hash = "sha256:c5847e96134e5c5371ee9fac6fdf1a67336d5815e09eb2a01fdb57a351ef915b", size = 72984, upload-time = "2025-07-20T17:31:56.738Z" }, ] +[[package]] +name = "synchronicity" +version = "0.11.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/26/8874d34755691994266d4a844ba8d53d10c2690ec67f246ca4d6b6f34cbb/synchronicity-0.11.1.tar.gz", hash = "sha256:3628df9ab34bd7be89b729104114841c62612c5d5ec43b76f4b7b243185ec1a8", size = 58131, upload-time = "2025-12-19T18:28:42.291Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3f/b9/71153db12f4ad029cfe9b7fbf9792ef3fc9ade4485d31a13470b52954e62/synchronicity-0.11.1-py3-none-any.whl", hash = "sha256:53959c7f8b9b852fb5ea4d3d290a47a04310ede483a4cf0f8452cb4b5fa09db2", size = 40399, upload-time = "2025-12-19T18:28:40.972Z" }, +] + [[package]] name = "testcontainers" version = "4.12.0" @@ -3208,6 +3419,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/41/f2/fd673d979185f5dcbac4be7d09461cbb99751554ffb6718d0013af8604cb/tokenizers-0.21.4-cp39-abi3-win_amd64.whl", hash = "sha256:475d807a5c3eb72c59ad9b5fcdb254f6e17f53dfcbb9903233b0dfa9c943b597", size = 2507568, upload-time = "2025-07-28T15:48:55.456Z" }, ] +[[package]] +name = "toml" +version = "0.10.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/be/ba/1f744cdc819428fc6b5084ec34d9b30660f6f9daaf70eead706e3203ec3c/toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f", size = 22253, upload-time = "2020-11-01T01:40:22.204Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b", size = 16588, upload-time = "2020-11-01T01:40:20.672Z" }, +] + [[package]] name = "tomli" version = "2.2.1" @@ -3259,6 +3479,30 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540, upload-time = "2024-11-24T20:12:19.698Z" }, ] +[[package]] +name = "typer" +version = "0.24.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-doc" }, + { name = "click" }, + { name = "rich" }, + { name = "shellingham" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f5/24/cb09efec5cc954f7f9b930bf8279447d24618bb6758d4f6adf2574c41780/typer-0.24.1.tar.gz", hash = "sha256:e39b4732d65fbdcde189ae76cf7cd48aeae72919dea1fdfc16593be016256b45", size = 118613, upload-time = "2026-02-21T16:54:40.609Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4a/91/48db081e7a63bb37284f9fbcefda7c44c277b18b0e13fbc36ea2335b71e6/typer-0.24.1-py3-none-any.whl", hash = "sha256:112c1f0ce578bfb4cab9ffdabc68f031416ebcc216536611ba21f04e9aa84c9e", size = 56085, upload-time = "2026-02-21T16:54:41.616Z" }, +] + +[[package]] +name = "types-certifi" +version = "2021.10.8.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/52/68/943c3aeaf14624712a0357c4a67814dba5cea36d194f5c764dad7959a00c/types-certifi-2021.10.8.3.tar.gz", hash = "sha256:72cf7798d165bc0b76e1c10dd1ea3097c7063c42c21d664523b928e88b554a4f", size = 2095, upload-time = "2022-06-09T15:19:05.244Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b5/63/2463d89481e811f007b0e1cd0a91e52e141b47f9de724d20db7b861dcfec/types_certifi-2021.10.8.3-py3-none-any.whl", hash = "sha256:b2d1e325e69f71f7c78e5943d410e650b4707bb0ef32e4ddf3da37f54176e88a", size = 2136, upload-time = "2022-06-09T15:19:03.127Z" }, +] + [[package]] name = "types-pynput" version = "1.8.1.20250809" @@ -3280,6 +3524,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2b/6f/ec0012be842b1d888d46884ac5558fd62aeae1f0ec4f7a581433d890d4b5/types_requests-2.32.4.20250809-py3-none-any.whl", hash = "sha256:f73d1832fb519ece02c85b1f09d5f0dd3108938e7d47e7f94bbfa18a6782b163", size = 20644, upload-time = "2025-08-09T03:17:09.716Z" }, ] +[[package]] +name = "types-toml" +version = "0.10.8.20240310" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/86/47/3e4c75042792bff8e90d7991aa5c51812cc668828cc6cce711e97f63a607/types-toml-0.10.8.20240310.tar.gz", hash = "sha256:3d41501302972436a6b8b239c850b26689657e25281b48ff0ec06345b8830331", size = 4392, upload-time = "2024-03-10T02:18:37.518Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/da/a2/d32ab58c0b216912638b140ab2170ee4b8644067c293b170e19fba340ccc/types_toml-0.10.8.20240310-py3-none-any.whl", hash = "sha256:627b47775d25fa29977d9c70dc0cbab3f314f32c8d8d0c012f2ef5de7aaec05d", size = 4777, upload-time = "2024-03-10T02:18:36.568Z" }, +] + [[package]] name = "typing-extensions" version = "4.14.1" @@ -3365,6 +3618,121 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/e8/e40370e6d74ddba47f002a32919d91310d6074130fe4e17dabcafc15cbf1/watchdog-6.0.0-py3-none-win_ia64.whl", hash = "sha256:a1914259fa9e1454315171103c6a30961236f508b9b623eae470268bbcc6a22f", size = 79067, upload-time = "2024-11-01T14:07:11.845Z" }, ] +[[package]] +name = "watchfiles" +version = "1.1.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c2/c9/8869df9b2a2d6c59d79220a4db37679e74f807c559ffe5265e08b227a210/watchfiles-1.1.1.tar.gz", hash = "sha256:a173cb5c16c4f40ab19cecf48a534c409f7ea983ab8fed0741304a1c0a31b3f2", size = 94440, upload-time = "2025-10-14T15:06:21.08Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/1a/206e8cf2dd86fddf939165a57b4df61607a1e0add2785f170a3f616b7d9f/watchfiles-1.1.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:eef58232d32daf2ac67f42dea51a2c80f0d03379075d44a587051e63cc2e368c", size = 407318, upload-time = "2025-10-14T15:04:18.753Z" }, + { url = "https://files.pythonhosted.org/packages/b3/0f/abaf5262b9c496b5dad4ed3c0e799cbecb1f8ea512ecb6ddd46646a9fca3/watchfiles-1.1.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:03fa0f5237118a0c5e496185cafa92878568b652a2e9a9382a5151b1a0380a43", size = 394478, upload-time = "2025-10-14T15:04:20.297Z" }, + { url = "https://files.pythonhosted.org/packages/b1/04/9cc0ba88697b34b755371f5ace8d3a4d9a15719c07bdc7bd13d7d8c6a341/watchfiles-1.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8ca65483439f9c791897f7db49202301deb6e15fe9f8fe2fed555bf986d10c31", size = 449894, upload-time = "2025-10-14T15:04:21.527Z" }, + { url = "https://files.pythonhosted.org/packages/d2/9c/eda4615863cd8621e89aed4df680d8c3ec3da6a4cf1da113c17decd87c7f/watchfiles-1.1.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f0ab1c1af0cb38e3f598244c17919fb1a84d1629cc08355b0074b6d7f53138ac", size = 459065, upload-time = "2025-10-14T15:04:22.795Z" }, + { url = "https://files.pythonhosted.org/packages/84/13/f28b3f340157d03cbc8197629bc109d1098764abe1e60874622a0be5c112/watchfiles-1.1.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3bc570d6c01c206c46deb6e935a260be44f186a2f05179f52f7fcd2be086a94d", size = 488377, upload-time = "2025-10-14T15:04:24.138Z" }, + { url = "https://files.pythonhosted.org/packages/86/93/cfa597fa9389e122488f7ffdbd6db505b3b915ca7435ecd7542e855898c2/watchfiles-1.1.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e84087b432b6ac94778de547e08611266f1f8ffad28c0ee4c82e028b0fc5966d", size = 595837, upload-time = "2025-10-14T15:04:25.057Z" }, + { url = "https://files.pythonhosted.org/packages/57/1e/68c1ed5652b48d89fc24d6af905d88ee4f82fa8bc491e2666004e307ded1/watchfiles-1.1.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:620bae625f4cb18427b1bb1a2d9426dc0dd5a5ba74c7c2cdb9de405f7b129863", size = 473456, upload-time = "2025-10-14T15:04:26.497Z" }, + { url = "https://files.pythonhosted.org/packages/d5/dc/1a680b7458ffa3b14bb64878112aefc8f2e4f73c5af763cbf0bd43100658/watchfiles-1.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:544364b2b51a9b0c7000a4b4b02f90e9423d97fbbf7e06689236443ebcad81ab", size = 455614, upload-time = "2025-10-14T15:04:27.539Z" }, + { url = "https://files.pythonhosted.org/packages/61/a5/3d782a666512e01eaa6541a72ebac1d3aae191ff4a31274a66b8dd85760c/watchfiles-1.1.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:bbe1ef33d45bc71cf21364df962af171f96ecaeca06bd9e3d0b583efb12aec82", size = 630690, upload-time = "2025-10-14T15:04:28.495Z" }, + { url = "https://files.pythonhosted.org/packages/9b/73/bb5f38590e34687b2a9c47a244aa4dd50c56a825969c92c9c5fc7387cea1/watchfiles-1.1.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:1a0bb430adb19ef49389e1ad368450193a90038b5b752f4ac089ec6942c4dff4", size = 622459, upload-time = "2025-10-14T15:04:29.491Z" }, + { url = "https://files.pythonhosted.org/packages/f1/ac/c9bb0ec696e07a20bd58af5399aeadaef195fb2c73d26baf55180fe4a942/watchfiles-1.1.1-cp310-cp310-win32.whl", hash = "sha256:3f6d37644155fb5beca5378feb8c1708d5783145f2a0f1c4d5a061a210254844", size = 272663, upload-time = "2025-10-14T15:04:30.435Z" }, + { url = "https://files.pythonhosted.org/packages/11/a0/a60c5a7c2ec59fa062d9a9c61d02e3b6abd94d32aac2d8344c4bdd033326/watchfiles-1.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:a36d8efe0f290835fd0f33da35042a1bb5dc0e83cbc092dcf69bce442579e88e", size = 287453, upload-time = "2025-10-14T15:04:31.53Z" }, + { url = "https://files.pythonhosted.org/packages/1f/f8/2c5f479fb531ce2f0564eda479faecf253d886b1ab3630a39b7bf7362d46/watchfiles-1.1.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:f57b396167a2565a4e8b5e56a5a1c537571733992b226f4f1197d79e94cf0ae5", size = 406529, upload-time = "2025-10-14T15:04:32.899Z" }, + { url = "https://files.pythonhosted.org/packages/fe/cd/f515660b1f32f65df671ddf6f85bfaca621aee177712874dc30a97397977/watchfiles-1.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:421e29339983e1bebc281fab40d812742268ad057db4aee8c4d2bce0af43b741", size = 394384, upload-time = "2025-10-14T15:04:33.761Z" }, + { url = "https://files.pythonhosted.org/packages/7b/c3/28b7dc99733eab43fca2d10f55c86e03bd6ab11ca31b802abac26b23d161/watchfiles-1.1.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6e43d39a741e972bab5d8100b5cdacf69db64e34eb19b6e9af162bccf63c5cc6", size = 448789, upload-time = "2025-10-14T15:04:34.679Z" }, + { url = "https://files.pythonhosted.org/packages/4a/24/33e71113b320030011c8e4316ccca04194bf0cbbaeee207f00cbc7d6b9f5/watchfiles-1.1.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f537afb3276d12814082a2e9b242bdcf416c2e8fd9f799a737990a1dbe906e5b", size = 460521, upload-time = "2025-10-14T15:04:35.963Z" }, + { url = "https://files.pythonhosted.org/packages/f4/c3/3c9a55f255aa57b91579ae9e98c88704955fa9dac3e5614fb378291155df/watchfiles-1.1.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b2cd9e04277e756a2e2d2543d65d1e2166d6fd4c9b183f8808634fda23f17b14", size = 488722, upload-time = "2025-10-14T15:04:37.091Z" }, + { url = "https://files.pythonhosted.org/packages/49/36/506447b73eb46c120169dc1717fe2eff07c234bb3232a7200b5f5bd816e9/watchfiles-1.1.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5f3f58818dc0b07f7d9aa7fe9eb1037aecb9700e63e1f6acfed13e9fef648f5d", size = 596088, upload-time = "2025-10-14T15:04:38.39Z" }, + { url = "https://files.pythonhosted.org/packages/82/ab/5f39e752a9838ec4d52e9b87c1e80f1ee3ccdbe92e183c15b6577ab9de16/watchfiles-1.1.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9bb9f66367023ae783551042d31b1d7fd422e8289eedd91f26754a66f44d5cff", size = 472923, upload-time = "2025-10-14T15:04:39.666Z" }, + { url = "https://files.pythonhosted.org/packages/af/b9/a419292f05e302dea372fa7e6fda5178a92998411f8581b9830d28fb9edb/watchfiles-1.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aebfd0861a83e6c3d1110b78ad54704486555246e542be3e2bb94195eabb2606", size = 456080, upload-time = "2025-10-14T15:04:40.643Z" }, + { url = "https://files.pythonhosted.org/packages/b0/c3/d5932fd62bde1a30c36e10c409dc5d54506726f08cb3e1d8d0ba5e2bc8db/watchfiles-1.1.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:5fac835b4ab3c6487b5dbad78c4b3724e26bcc468e886f8ba8cc4306f68f6701", size = 629432, upload-time = "2025-10-14T15:04:41.789Z" }, + { url = "https://files.pythonhosted.org/packages/f7/77/16bddd9779fafb795f1a94319dc965209c5641db5bf1edbbccace6d1b3c0/watchfiles-1.1.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:399600947b170270e80134ac854e21b3ccdefa11a9529a3decc1327088180f10", size = 623046, upload-time = "2025-10-14T15:04:42.718Z" }, + { url = "https://files.pythonhosted.org/packages/46/ef/f2ecb9a0f342b4bfad13a2787155c6ee7ce792140eac63a34676a2feeef2/watchfiles-1.1.1-cp311-cp311-win32.whl", hash = "sha256:de6da501c883f58ad50db3a32ad397b09ad29865b5f26f64c24d3e3281685849", size = 271473, upload-time = "2025-10-14T15:04:43.624Z" }, + { url = "https://files.pythonhosted.org/packages/94/bc/f42d71125f19731ea435c3948cad148d31a64fccde3867e5ba4edee901f9/watchfiles-1.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:35c53bd62a0b885bf653ebf6b700d1bf05debb78ad9292cf2a942b23513dc4c4", size = 287598, upload-time = "2025-10-14T15:04:44.516Z" }, + { url = "https://files.pythonhosted.org/packages/57/c9/a30f897351f95bbbfb6abcadafbaca711ce1162f4db95fc908c98a9165f3/watchfiles-1.1.1-cp311-cp311-win_arm64.whl", hash = "sha256:57ca5281a8b5e27593cb7d82c2ac927ad88a96ed406aa446f6344e4328208e9e", size = 277210, upload-time = "2025-10-14T15:04:45.883Z" }, + { url = "https://files.pythonhosted.org/packages/74/d5/f039e7e3c639d9b1d09b07ea412a6806d38123f0508e5f9b48a87b0a76cc/watchfiles-1.1.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:8c89f9f2f740a6b7dcc753140dd5e1ab9215966f7a3530d0c0705c83b401bd7d", size = 404745, upload-time = "2025-10-14T15:04:46.731Z" }, + { url = "https://files.pythonhosted.org/packages/a5/96/a881a13aa1349827490dab2d363c8039527060cfcc2c92cc6d13d1b1049e/watchfiles-1.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bd404be08018c37350f0d6e34676bd1e2889990117a2b90070b3007f172d0610", size = 391769, upload-time = "2025-10-14T15:04:48.003Z" }, + { url = "https://files.pythonhosted.org/packages/4b/5b/d3b460364aeb8da471c1989238ea0e56bec24b6042a68046adf3d9ddb01c/watchfiles-1.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8526e8f916bb5b9a0a777c8317c23ce65de259422bba5b31325a6fa6029d33af", size = 449374, upload-time = "2025-10-14T15:04:49.179Z" }, + { url = "https://files.pythonhosted.org/packages/b9/44/5769cb62d4ed055cb17417c0a109a92f007114a4e07f30812a73a4efdb11/watchfiles-1.1.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2edc3553362b1c38d9f06242416a5d8e9fe235c204a4072e988ce2e5bb1f69f6", size = 459485, upload-time = "2025-10-14T15:04:50.155Z" }, + { url = "https://files.pythonhosted.org/packages/19/0c/286b6301ded2eccd4ffd0041a1b726afda999926cf720aab63adb68a1e36/watchfiles-1.1.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:30f7da3fb3f2844259cba4720c3fc7138eb0f7b659c38f3bfa65084c7fc7abce", size = 488813, upload-time = "2025-10-14T15:04:51.059Z" }, + { url = "https://files.pythonhosted.org/packages/c7/2b/8530ed41112dd4a22f4dcfdb5ccf6a1baad1ff6eed8dc5a5f09e7e8c41c7/watchfiles-1.1.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f8979280bdafff686ba5e4d8f97840f929a87ed9cdf133cbbd42f7766774d2aa", size = 594816, upload-time = "2025-10-14T15:04:52.031Z" }, + { url = "https://files.pythonhosted.org/packages/ce/d2/f5f9fb49489f184f18470d4f99f4e862a4b3e9ac2865688eb2099e3d837a/watchfiles-1.1.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dcc5c24523771db3a294c77d94771abcfcb82a0e0ee8efd910c37c59ec1b31bb", size = 475186, upload-time = "2025-10-14T15:04:53.064Z" }, + { url = "https://files.pythonhosted.org/packages/cf/68/5707da262a119fb06fbe214d82dd1fe4a6f4af32d2d14de368d0349eb52a/watchfiles-1.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1db5d7ae38ff20153d542460752ff397fcf5c96090c1230803713cf3147a6803", size = 456812, upload-time = "2025-10-14T15:04:55.174Z" }, + { url = "https://files.pythonhosted.org/packages/66/ab/3cbb8756323e8f9b6f9acb9ef4ec26d42b2109bce830cc1f3468df20511d/watchfiles-1.1.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:28475ddbde92df1874b6c5c8aaeb24ad5be47a11f87cde5a28ef3835932e3e94", size = 630196, upload-time = "2025-10-14T15:04:56.22Z" }, + { url = "https://files.pythonhosted.org/packages/78/46/7152ec29b8335f80167928944a94955015a345440f524d2dfe63fc2f437b/watchfiles-1.1.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:36193ed342f5b9842edd3532729a2ad55c4160ffcfa3700e0d54be496b70dd43", size = 622657, upload-time = "2025-10-14T15:04:57.521Z" }, + { url = "https://files.pythonhosted.org/packages/0a/bf/95895e78dd75efe9a7f31733607f384b42eb5feb54bd2eb6ed57cc2e94f4/watchfiles-1.1.1-cp312-cp312-win32.whl", hash = "sha256:859e43a1951717cc8de7f4c77674a6d389b106361585951d9e69572823f311d9", size = 272042, upload-time = "2025-10-14T15:04:59.046Z" }, + { url = "https://files.pythonhosted.org/packages/87/0a/90eb755f568de2688cb220171c4191df932232c20946966c27a59c400850/watchfiles-1.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:91d4c9a823a8c987cce8fa2690923b069966dabb196dd8d137ea2cede885fde9", size = 288410, upload-time = "2025-10-14T15:05:00.081Z" }, + { url = "https://files.pythonhosted.org/packages/36/76/f322701530586922fbd6723c4f91ace21364924822a8772c549483abed13/watchfiles-1.1.1-cp312-cp312-win_arm64.whl", hash = "sha256:a625815d4a2bdca61953dbba5a39d60164451ef34c88d751f6c368c3ea73d404", size = 278209, upload-time = "2025-10-14T15:05:01.168Z" }, + { url = "https://files.pythonhosted.org/packages/bb/f4/f750b29225fe77139f7ae5de89d4949f5a99f934c65a1f1c0b248f26f747/watchfiles-1.1.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:130e4876309e8686a5e37dba7d5e9bc77e6ed908266996ca26572437a5271e18", size = 404321, upload-time = "2025-10-14T15:05:02.063Z" }, + { url = "https://files.pythonhosted.org/packages/2b/f9/f07a295cde762644aa4c4bb0f88921d2d141af45e735b965fb2e87858328/watchfiles-1.1.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5f3bde70f157f84ece3765b42b4a52c6ac1a50334903c6eaf765362f6ccca88a", size = 391783, upload-time = "2025-10-14T15:05:03.052Z" }, + { url = "https://files.pythonhosted.org/packages/bc/11/fc2502457e0bea39a5c958d86d2cb69e407a4d00b85735ca724bfa6e0d1a/watchfiles-1.1.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:14e0b1fe858430fc0251737ef3824c54027bedb8c37c38114488b8e131cf8219", size = 449279, upload-time = "2025-10-14T15:05:04.004Z" }, + { url = "https://files.pythonhosted.org/packages/e3/1f/d66bc15ea0b728df3ed96a539c777acfcad0eb78555ad9efcaa1274688f0/watchfiles-1.1.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f27db948078f3823a6bb3b465180db8ebecf26dd5dae6f6180bd87383b6b4428", size = 459405, upload-time = "2025-10-14T15:05:04.942Z" }, + { url = "https://files.pythonhosted.org/packages/be/90/9f4a65c0aec3ccf032703e6db02d89a157462fbb2cf20dd415128251cac0/watchfiles-1.1.1-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:059098c3a429f62fc98e8ec62b982230ef2c8df68c79e826e37b895bc359a9c0", size = 488976, upload-time = "2025-10-14T15:05:05.905Z" }, + { url = "https://files.pythonhosted.org/packages/37/57/ee347af605d867f712be7029bb94c8c071732a4b44792e3176fa3c612d39/watchfiles-1.1.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bfb5862016acc9b869bb57284e6cb35fdf8e22fe59f7548858e2f971d045f150", size = 595506, upload-time = "2025-10-14T15:05:06.906Z" }, + { url = "https://files.pythonhosted.org/packages/a8/78/cc5ab0b86c122047f75e8fc471c67a04dee395daf847d3e59381996c8707/watchfiles-1.1.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:319b27255aacd9923b8a276bb14d21a5f7ff82564c744235fc5eae58d95422ae", size = 474936, upload-time = "2025-10-14T15:05:07.906Z" }, + { url = "https://files.pythonhosted.org/packages/62/da/def65b170a3815af7bd40a3e7010bf6ab53089ef1b75d05dd5385b87cf08/watchfiles-1.1.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c755367e51db90e75b19454b680903631d41f9e3607fbd941d296a020c2d752d", size = 456147, upload-time = "2025-10-14T15:05:09.138Z" }, + { url = "https://files.pythonhosted.org/packages/57/99/da6573ba71166e82d288d4df0839128004c67d2778d3b566c138695f5c0b/watchfiles-1.1.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:c22c776292a23bfc7237a98f791b9ad3144b02116ff10d820829ce62dff46d0b", size = 630007, upload-time = "2025-10-14T15:05:10.117Z" }, + { url = "https://files.pythonhosted.org/packages/a8/51/7439c4dd39511368849eb1e53279cd3454b4a4dbace80bab88feeb83c6b5/watchfiles-1.1.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:3a476189be23c3686bc2f4321dd501cb329c0a0469e77b7b534ee10129ae6374", size = 622280, upload-time = "2025-10-14T15:05:11.146Z" }, + { url = "https://files.pythonhosted.org/packages/95/9c/8ed97d4bba5db6fdcdb2b298d3898f2dd5c20f6b73aee04eabe56c59677e/watchfiles-1.1.1-cp313-cp313-win32.whl", hash = "sha256:bf0a91bfb5574a2f7fc223cf95eeea79abfefa404bf1ea5e339c0c1560ae99a0", size = 272056, upload-time = "2025-10-14T15:05:12.156Z" }, + { url = "https://files.pythonhosted.org/packages/1f/f3/c14e28429f744a260d8ceae18bf58c1d5fa56b50d006a7a9f80e1882cb0d/watchfiles-1.1.1-cp313-cp313-win_amd64.whl", hash = "sha256:52e06553899e11e8074503c8e716d574adeeb7e68913115c4b3653c53f9bae42", size = 288162, upload-time = "2025-10-14T15:05:13.208Z" }, + { url = "https://files.pythonhosted.org/packages/dc/61/fe0e56c40d5cd29523e398d31153218718c5786b5e636d9ae8ae79453d27/watchfiles-1.1.1-cp313-cp313-win_arm64.whl", hash = "sha256:ac3cc5759570cd02662b15fbcd9d917f7ecd47efe0d6b40474eafd246f91ea18", size = 277909, upload-time = "2025-10-14T15:05:14.49Z" }, + { url = "https://files.pythonhosted.org/packages/79/42/e0a7d749626f1e28c7108a99fb9bf524b501bbbeb9b261ceecde644d5a07/watchfiles-1.1.1-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:563b116874a9a7ce6f96f87cd0b94f7faf92d08d0021e837796f0a14318ef8da", size = 403389, upload-time = "2025-10-14T15:05:15.777Z" }, + { url = "https://files.pythonhosted.org/packages/15/49/08732f90ce0fbbc13913f9f215c689cfc9ced345fb1bcd8829a50007cc8d/watchfiles-1.1.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:3ad9fe1dae4ab4212d8c91e80b832425e24f421703b5a42ef2e4a1e215aff051", size = 389964, upload-time = "2025-10-14T15:05:16.85Z" }, + { url = "https://files.pythonhosted.org/packages/27/0d/7c315d4bd5f2538910491a0393c56bf70d333d51bc5b34bee8e68e8cea19/watchfiles-1.1.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce70f96a46b894b36eba678f153f052967a0d06d5b5a19b336ab0dbbd029f73e", size = 448114, upload-time = "2025-10-14T15:05:17.876Z" }, + { url = "https://files.pythonhosted.org/packages/c3/24/9e096de47a4d11bc4df41e9d1e61776393eac4cb6eb11b3e23315b78b2cc/watchfiles-1.1.1-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cb467c999c2eff23a6417e58d75e5828716f42ed8289fe6b77a7e5a91036ca70", size = 460264, upload-time = "2025-10-14T15:05:18.962Z" }, + { url = "https://files.pythonhosted.org/packages/cc/0f/e8dea6375f1d3ba5fcb0b3583e2b493e77379834c74fd5a22d66d85d6540/watchfiles-1.1.1-cp313-cp313t-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:836398932192dae4146c8f6f737d74baeac8b70ce14831a239bdb1ca882fc261", size = 487877, upload-time = "2025-10-14T15:05:20.094Z" }, + { url = "https://files.pythonhosted.org/packages/ac/5b/df24cfc6424a12deb41503b64d42fbea6b8cb357ec62ca84a5a3476f654a/watchfiles-1.1.1-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:743185e7372b7bc7c389e1badcc606931a827112fbbd37f14c537320fca08620", size = 595176, upload-time = "2025-10-14T15:05:21.134Z" }, + { url = "https://files.pythonhosted.org/packages/8f/b5/853b6757f7347de4e9b37e8cc3289283fb983cba1ab4d2d7144694871d9c/watchfiles-1.1.1-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:afaeff7696e0ad9f02cbb8f56365ff4686ab205fcf9c4c5b6fdfaaa16549dd04", size = 473577, upload-time = "2025-10-14T15:05:22.306Z" }, + { url = "https://files.pythonhosted.org/packages/e1/f7/0a4467be0a56e80447c8529c9fce5b38eab4f513cb3d9bf82e7392a5696b/watchfiles-1.1.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f7eb7da0eb23aa2ba036d4f616d46906013a68caf61b7fdbe42fc8b25132e77", size = 455425, upload-time = "2025-10-14T15:05:23.348Z" }, + { url = "https://files.pythonhosted.org/packages/8e/e0/82583485ea00137ddf69bc84a2db88bd92ab4a6e3c405e5fb878ead8d0e7/watchfiles-1.1.1-cp313-cp313t-musllinux_1_1_aarch64.whl", hash = "sha256:831a62658609f0e5c64178211c942ace999517f5770fe9436be4c2faeba0c0ef", size = 628826, upload-time = "2025-10-14T15:05:24.398Z" }, + { url = "https://files.pythonhosted.org/packages/28/9a/a785356fccf9fae84c0cc90570f11702ae9571036fb25932f1242c82191c/watchfiles-1.1.1-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:f9a2ae5c91cecc9edd47e041a930490c31c3afb1f5e6d71de3dc671bfaca02bf", size = 622208, upload-time = "2025-10-14T15:05:25.45Z" }, + { url = "https://files.pythonhosted.org/packages/c3/f4/0872229324ef69b2c3edec35e84bd57a1289e7d3fe74588048ed8947a323/watchfiles-1.1.1-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:d1715143123baeeaeadec0528bb7441103979a1d5f6fd0e1f915383fea7ea6d5", size = 404315, upload-time = "2025-10-14T15:05:26.501Z" }, + { url = "https://files.pythonhosted.org/packages/7b/22/16d5331eaed1cb107b873f6ae1b69e9ced582fcf0c59a50cd84f403b1c32/watchfiles-1.1.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:39574d6370c4579d7f5d0ad940ce5b20db0e4117444e39b6d8f99db5676c52fd", size = 390869, upload-time = "2025-10-14T15:05:27.649Z" }, + { url = "https://files.pythonhosted.org/packages/b2/7e/5643bfff5acb6539b18483128fdc0ef2cccc94a5b8fbda130c823e8ed636/watchfiles-1.1.1-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7365b92c2e69ee952902e8f70f3ba6360d0d596d9299d55d7d386df84b6941fb", size = 449919, upload-time = "2025-10-14T15:05:28.701Z" }, + { url = "https://files.pythonhosted.org/packages/51/2e/c410993ba5025a9f9357c376f48976ef0e1b1aefb73b97a5ae01a5972755/watchfiles-1.1.1-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bfff9740c69c0e4ed32416f013f3c45e2ae42ccedd1167ef2d805c000b6c71a5", size = 460845, upload-time = "2025-10-14T15:05:30.064Z" }, + { url = "https://files.pythonhosted.org/packages/8e/a4/2df3b404469122e8680f0fcd06079317e48db58a2da2950fb45020947734/watchfiles-1.1.1-cp314-cp314-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b27cf2eb1dda37b2089e3907d8ea92922b673c0c427886d4edc6b94d8dfe5db3", size = 489027, upload-time = "2025-10-14T15:05:31.064Z" }, + { url = "https://files.pythonhosted.org/packages/ea/84/4587ba5b1f267167ee715b7f66e6382cca6938e0a4b870adad93e44747e6/watchfiles-1.1.1-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:526e86aced14a65a5b0ec50827c745597c782ff46b571dbfe46192ab9e0b3c33", size = 595615, upload-time = "2025-10-14T15:05:32.074Z" }, + { url = "https://files.pythonhosted.org/packages/6a/0f/c6988c91d06e93cd0bb3d4a808bcf32375ca1904609835c3031799e3ecae/watchfiles-1.1.1-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:04e78dd0b6352db95507fd8cb46f39d185cf8c74e4cf1e4fbad1d3df96faf510", size = 474836, upload-time = "2025-10-14T15:05:33.209Z" }, + { url = "https://files.pythonhosted.org/packages/b4/36/ded8aebea91919485b7bbabbd14f5f359326cb5ec218cd67074d1e426d74/watchfiles-1.1.1-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c85794a4cfa094714fb9c08d4a218375b2b95b8ed1666e8677c349906246c05", size = 455099, upload-time = "2025-10-14T15:05:34.189Z" }, + { url = "https://files.pythonhosted.org/packages/98/e0/8c9bdba88af756a2fce230dd365fab2baf927ba42cd47521ee7498fd5211/watchfiles-1.1.1-cp314-cp314-musllinux_1_1_aarch64.whl", hash = "sha256:74d5012b7630714b66be7b7b7a78855ef7ad58e8650c73afc4c076a1f480a8d6", size = 630626, upload-time = "2025-10-14T15:05:35.216Z" }, + { url = "https://files.pythonhosted.org/packages/2a/84/a95db05354bf2d19e438520d92a8ca475e578c647f78f53197f5a2f17aaf/watchfiles-1.1.1-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:8fbe85cb3201c7d380d3d0b90e63d520f15d6afe217165d7f98c9c649654db81", size = 622519, upload-time = "2025-10-14T15:05:36.259Z" }, + { url = "https://files.pythonhosted.org/packages/1d/ce/d8acdc8de545de995c339be67711e474c77d643555a9bb74a9334252bd55/watchfiles-1.1.1-cp314-cp314-win32.whl", hash = "sha256:3fa0b59c92278b5a7800d3ee7733da9d096d4aabcfabb9a928918bd276ef9b9b", size = 272078, upload-time = "2025-10-14T15:05:37.63Z" }, + { url = "https://files.pythonhosted.org/packages/c4/c9/a74487f72d0451524be827e8edec251da0cc1fcf111646a511ae752e1a3d/watchfiles-1.1.1-cp314-cp314-win_amd64.whl", hash = "sha256:c2047d0b6cea13b3316bdbafbfa0c4228ae593d995030fda39089d36e64fc03a", size = 287664, upload-time = "2025-10-14T15:05:38.95Z" }, + { url = "https://files.pythonhosted.org/packages/df/b8/8ac000702cdd496cdce998c6f4ee0ca1f15977bba51bdf07d872ebdfc34c/watchfiles-1.1.1-cp314-cp314-win_arm64.whl", hash = "sha256:842178b126593addc05acf6fce960d28bc5fae7afbaa2c6c1b3a7b9460e5be02", size = 277154, upload-time = "2025-10-14T15:05:39.954Z" }, + { url = "https://files.pythonhosted.org/packages/47/a8/e3af2184707c29f0f14b1963c0aace6529f9d1b8582d5b99f31bbf42f59e/watchfiles-1.1.1-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:88863fbbc1a7312972f1c511f202eb30866370ebb8493aef2812b9ff28156a21", size = 403820, upload-time = "2025-10-14T15:05:40.932Z" }, + { url = "https://files.pythonhosted.org/packages/c0/ec/e47e307c2f4bd75f9f9e8afbe3876679b18e1bcec449beca132a1c5ffb2d/watchfiles-1.1.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:55c7475190662e202c08c6c0f4d9e345a29367438cf8e8037f3155e10a88d5a5", size = 390510, upload-time = "2025-10-14T15:05:41.945Z" }, + { url = "https://files.pythonhosted.org/packages/d5/a0/ad235642118090f66e7b2f18fd5c42082418404a79205cdfca50b6309c13/watchfiles-1.1.1-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f53fa183d53a1d7a8852277c92b967ae99c2d4dcee2bfacff8868e6e30b15f7", size = 448408, upload-time = "2025-10-14T15:05:43.385Z" }, + { url = "https://files.pythonhosted.org/packages/df/85/97fa10fd5ff3332ae17e7e40e20784e419e28521549780869f1413742e9d/watchfiles-1.1.1-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6aae418a8b323732fa89721d86f39ec8f092fc2af67f4217a2b07fd3e93c6101", size = 458968, upload-time = "2025-10-14T15:05:44.404Z" }, + { url = "https://files.pythonhosted.org/packages/47/c2/9059c2e8966ea5ce678166617a7f75ecba6164375f3b288e50a40dc6d489/watchfiles-1.1.1-cp314-cp314t-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f096076119da54a6080e8920cbdaac3dbee667eb91dcc5e5b78840b87415bd44", size = 488096, upload-time = "2025-10-14T15:05:45.398Z" }, + { url = "https://files.pythonhosted.org/packages/94/44/d90a9ec8ac309bc26db808a13e7bfc0e4e78b6fc051078a554e132e80160/watchfiles-1.1.1-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:00485f441d183717038ed2e887a7c868154f216877653121068107b227a2f64c", size = 596040, upload-time = "2025-10-14T15:05:46.502Z" }, + { url = "https://files.pythonhosted.org/packages/95/68/4e3479b20ca305cfc561db3ed207a8a1c745ee32bf24f2026a129d0ddb6e/watchfiles-1.1.1-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a55f3e9e493158d7bfdb60a1165035f1cf7d320914e7b7ea83fe22c6023b58fc", size = 473847, upload-time = "2025-10-14T15:05:47.484Z" }, + { url = "https://files.pythonhosted.org/packages/4f/55/2af26693fd15165c4ff7857e38330e1b61ab8c37d15dc79118cdba115b7a/watchfiles-1.1.1-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c91ed27800188c2ae96d16e3149f199d62f86c7af5f5f4d2c61a3ed8cd3666c", size = 455072, upload-time = "2025-10-14T15:05:48.928Z" }, + { url = "https://files.pythonhosted.org/packages/66/1d/d0d200b10c9311ec25d2273f8aad8c3ef7cc7ea11808022501811208a750/watchfiles-1.1.1-cp314-cp314t-musllinux_1_1_aarch64.whl", hash = "sha256:311ff15a0bae3714ffb603e6ba6dbfba4065ab60865d15a6ec544133bdb21099", size = 629104, upload-time = "2025-10-14T15:05:49.908Z" }, + { url = "https://files.pythonhosted.org/packages/e3/bd/fa9bb053192491b3867ba07d2343d9f2252e00811567d30ae8d0f78136fe/watchfiles-1.1.1-cp314-cp314t-musllinux_1_1_x86_64.whl", hash = "sha256:a916a2932da8f8ab582f242c065f5c81bed3462849ca79ee357dd9551b0e9b01", size = 622112, upload-time = "2025-10-14T15:05:50.941Z" }, + { url = "https://files.pythonhosted.org/packages/ba/4c/a888c91e2e326872fa4705095d64acd8aa2fb9c1f7b9bd0588f33850516c/watchfiles-1.1.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:17ef139237dfced9da49fb7f2232c86ca9421f666d78c264c7ffca6601d154c3", size = 409611, upload-time = "2025-10-14T15:06:05.809Z" }, + { url = "https://files.pythonhosted.org/packages/1e/c7/5420d1943c8e3ce1a21c0a9330bcf7edafb6aa65d26b21dbb3267c9e8112/watchfiles-1.1.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:672b8adf25b1a0d35c96b5888b7b18699d27d4194bac8beeae75be4b7a3fc9b2", size = 396889, upload-time = "2025-10-14T15:06:07.035Z" }, + { url = "https://files.pythonhosted.org/packages/0c/e5/0072cef3804ce8d3aaddbfe7788aadff6b3d3f98a286fdbee9fd74ca59a7/watchfiles-1.1.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77a13aea58bc2b90173bc69f2a90de8e282648939a00a602e1dc4ee23e26b66d", size = 451616, upload-time = "2025-10-14T15:06:08.072Z" }, + { url = "https://files.pythonhosted.org/packages/83/4e/b87b71cbdfad81ad7e83358b3e447fedd281b880a03d64a760fe0a11fc2e/watchfiles-1.1.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b495de0bb386df6a12b18335a0285dda90260f51bdb505503c02bcd1ce27a8b", size = 458413, upload-time = "2025-10-14T15:06:09.209Z" }, + { url = "https://files.pythonhosted.org/packages/d3/8e/e500f8b0b77be4ff753ac94dc06b33d8f0d839377fee1b78e8c8d8f031bf/watchfiles-1.1.1-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:db476ab59b6765134de1d4fe96a1a9c96ddf091683599be0f26147ea1b2e4b88", size = 408250, upload-time = "2025-10-14T15:06:10.264Z" }, + { url = "https://files.pythonhosted.org/packages/bd/95/615e72cd27b85b61eec764a5ca51bd94d40b5adea5ff47567d9ebc4d275a/watchfiles-1.1.1-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:89eef07eee5e9d1fda06e38822ad167a044153457e6fd997f8a858ab7564a336", size = 396117, upload-time = "2025-10-14T15:06:11.28Z" }, + { url = "https://files.pythonhosted.org/packages/c9/81/e7fe958ce8a7fb5c73cc9fb07f5aeaf755e6aa72498c57d760af760c91f8/watchfiles-1.1.1-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce19e06cbda693e9e7686358af9cd6f5d61312ab8b00488bc36f5aabbaf77e24", size = 450493, upload-time = "2025-10-14T15:06:12.321Z" }, + { url = "https://files.pythonhosted.org/packages/6e/d4/ed38dd3b1767193de971e694aa544356e63353c33a85d948166b5ff58b9e/watchfiles-1.1.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e6f39af2eab0118338902798b5aa6664f46ff66bc0280de76fca67a7f262a49", size = 457546, upload-time = "2025-10-14T15:06:13.372Z" }, +] + +[[package]] +name = "wcmatch" +version = "10.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "bracex" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/79/3e/c0bdc27cf06f4e47680bd5803a07cb3dfd17de84cde92dd217dcb9e05253/wcmatch-10.1.tar.gz", hash = "sha256:f11f94208c8c8484a16f4f48638a85d771d9513f4ab3f37595978801cb9465af", size = 117421, upload-time = "2025-06-22T19:14:02.49Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/eb/d8/0d1d2e9d3fabcf5d6840362adcf05f8cf3cd06a73358140c3a97189238ae/wcmatch-10.1-py3-none-any.whl", hash = "sha256:5848ace7dbb0476e5e55ab63c6bbd529745089343427caa5537f230cc01beb8a", size = 39854, upload-time = "2025-06-22T19:14:00.978Z" }, +] + [[package]] name = "websockets" version = "15.0.1"