diff --git a/pyproject.toml b/pyproject.toml index 317b378cb..3f2efa098 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ classifiers = [ grpc = ["grpcio>=1.48.2,<2"] opentelemetry = ["opentelemetry-api>=1.11.1,<2", "opentelemetry-sdk>=1.11.1,<2"] pydantic = ["pydantic>=2.0.0,<3"] -openai-agents = ["openai-agents>=0.17.1", "mcp>=1.9.4, <2"] +openai-agents = ["openai-agents>=0.17.5", "mcp>=1.9.4, <2"] google-adk = ["google-adk>=1.27.0,<2"] langgraph = ["langgraph>=1.1.0"] langsmith = ["langsmith>=0.7.34,<0.9"] @@ -257,3 +257,4 @@ exclude = ["temporalio/bridge/target/**/*"] # Prevent uv commands from building the package by default package = false exclude-newer = "2 weeks" +exclude-newer-package = { openai-agents = false } diff --git a/temporalio/contrib/openai_agents/sandbox/_sandbox_client_provider.py b/temporalio/contrib/openai_agents/sandbox/_sandbox_client_provider.py index 9e4d67644..4aa6fd38e 100644 --- a/temporalio/contrib/openai_agents/sandbox/_sandbox_client_provider.py +++ b/temporalio/contrib/openai_agents/sandbox/_sandbox_client_provider.py @@ -3,10 +3,12 @@ from __future__ import annotations import io -from collections.abc import Callable, Sequence +from collections.abc import Callable, Iterator, Sequence +from contextlib import contextmanager from pathlib import Path from typing import Any +from agents.sandbox.errors import SandboxError from agents.sandbox.session.sandbox_client import BaseSandboxClient from agents.sandbox.session.sandbox_session import SandboxSession @@ -34,6 +36,22 @@ from temporalio.contrib.openai_agents.sandbox._temporal_activity_models import ( ExecResult as ExecResultModel, ) +from temporalio.exceptions import ApplicationError + + +@contextmanager +def _translate_sandbox_errors() -> Iterator[None]: + # Temporal retries every activity exception by default, so only a SandboxError + # the library has classified as terminal (retryable is False) is turned into a + # non-retryable ApplicationError. + try: + yield + except SandboxError as e: + if e.retryable is False: + raise ApplicationError( + str(e), type=str(e.error_code), non_retryable=True + ) from e + raise class SandboxClientProvider: @@ -99,132 +117,154 @@ def _get_activities(self) -> Sequence[Callable[..., Any]]: @activity.defn(name=f"{prefix}-sandbox_client_create") async def create_session(args: CreateSessionArgs) -> SessionResult: - session = await self._client.create( - snapshot=args.snapshot_spec, - manifest=args.manifest, - options=args.client_options, - ) - self._sessions[str(session.state.session_id)] = session - return SessionResult( - state=session.state, supports_pty=session.supports_pty() - ) + with _translate_sandbox_errors(): + session = await self._client.create( + snapshot=args.snapshot_spec, + manifest=args.manifest, + options=args.client_options, + ) + self._sessions[str(session.state.session_id)] = session + return SessionResult( + state=session.state, supports_pty=session.supports_pty() + ) @activity.defn(name=f"{prefix}-sandbox_client_resume") async def resume_session(args: ResumeSessionArgs) -> SessionResult: - session = await self._client.resume(args.state) - self._sessions[str(session.state.session_id)] = session - return SessionResult( - state=session.state, supports_pty=session.supports_pty() - ) + with _translate_sandbox_errors(): + session = await self._client.resume(args.state) + self._sessions[str(session.state.session_id)] = session + return SessionResult( + state=session.state, supports_pty=session.supports_pty() + ) @activity.defn(name=f"{prefix}-sandbox_client_delete") async def delete_session(args: StopArgs) -> None: - session = await self._session(args) - await self._client.delete(session) - return None + with _translate_sandbox_errors(): + session = await self._session(args) + await self._client.delete(session) + return None # -- Session-level operations (I/O and lifecycle) -- @activity.defn(name=f"{prefix}-sandbox_session_exec") async def exec_(args: ExecArgs) -> ExecResultModel: - session = await self._session(args) - result = await session.exec( - *args.command, - timeout=args.timeout, - shell=args.shell, - user=args.user, - ) - return ExecResultModel( - stdout=result.stdout, - stderr=result.stderr, - exit_code=result.exit_code, - ) + with _translate_sandbox_errors(): + session = await self._session(args) + result = await session.exec( + *args.command, + timeout=args.timeout, + shell=args.shell, + user=args.user, + ) + return ExecResultModel( + stdout=result.stdout, + stderr=result.stderr, + exit_code=result.exit_code, + ) @activity.defn(name=f"{prefix}-sandbox_session_read") async def read(args: ReadArgs) -> ReadResult: - session = await self._session(args) - handle = await session.read(Path(args.path)) - return ReadResult(data=handle.read()) + with _translate_sandbox_errors(): + session = await self._session(args) + handle = await session.read(Path(args.path)) + return ReadResult(data=handle.read()) @activity.defn(name=f"{prefix}-sandbox_session_write") async def write(args: WriteArgs) -> None: - session = await self._session(args) - await session.write(Path(args.path), io.BytesIO(args.data)) - return None + with _translate_sandbox_errors(): + session = await self._session(args) + await session.write(Path(args.path), io.BytesIO(args.data)) + return None @activity.defn(name=f"{prefix}-sandbox_session_running") async def running(args: RunningArgs) -> RunningResult: - session = await self._session(args) - return RunningResult(is_running=await session.running()) + with _translate_sandbox_errors(): + session = await self._session(args) + return RunningResult(is_running=await session.running()) @activity.defn(name=f"{prefix}-sandbox_session_persist_workspace") async def persist_workspace( args: PersistWorkspaceArgs, ) -> PersistWorkspaceResult: - session = await self._session(args) - stream = await session.persist_workspace() - return PersistWorkspaceResult(data=stream.read()) + with _translate_sandbox_errors(): + session = await self._session(args) + stream = await session.persist_workspace() + return PersistWorkspaceResult(data=stream.read()) @activity.defn(name=f"{prefix}-sandbox_session_hydrate_workspace") async def hydrate_workspace(args: HydrateWorkspaceArgs) -> None: - session = await self._session(args) - await session.hydrate_workspace(io.BytesIO(args.data)) - return None + with _translate_sandbox_errors(): + session = await self._session(args) + await session.hydrate_workspace(io.BytesIO(args.data)) + return None @activity.defn(name=f"{prefix}-sandbox_session_pty_exec_start") async def pty_exec_start(args: PtyExecStartArgs) -> PtyExecUpdateResult: - session = await self._session(args) - update = await session.pty_exec_start( - *args.command, - timeout=args.timeout, - shell=args.shell, - user=args.user, - tty=args.tty, - yield_time_s=args.yield_time_s, - max_output_tokens=args.max_output_tokens, - ) - return PtyExecUpdateResult( - process_id=update.process_id, - output=update.output, - exit_code=update.exit_code, - original_token_count=update.original_token_count, - ) + with _translate_sandbox_errors(): + session = await self._session(args) + update = await session.pty_exec_start( + *args.command, + timeout=args.timeout, + shell=args.shell, + user=args.user, + tty=args.tty, + yield_time_s=args.yield_time_s, + max_output_tokens=args.max_output_tokens, + ) + return PtyExecUpdateResult( + process_id=update.process_id, + output=update.output, + exit_code=update.exit_code, + original_token_count=update.original_token_count, + ) @activity.defn(name=f"{prefix}-sandbox_session_pty_write_stdin") async def pty_write_stdin(args: PtyWriteStdinArgs) -> PtyExecUpdateResult: - session = await self._session(args) - update = await session.pty_write_stdin( - session_id=args.session_id, - chars=args.chars, - yield_time_s=args.yield_time_s, - max_output_tokens=args.max_output_tokens, - ) - return PtyExecUpdateResult( - process_id=update.process_id, - output=update.output, - exit_code=update.exit_code, - original_token_count=update.original_token_count, - ) + with _translate_sandbox_errors(): + session = await self._session(args) + update = await session.pty_write_stdin( + session_id=args.session_id, + chars=args.chars, + yield_time_s=args.yield_time_s, + max_output_tokens=args.max_output_tokens, + ) + return PtyExecUpdateResult( + process_id=update.process_id, + output=update.output, + exit_code=update.exit_code, + original_token_count=update.original_token_count, + ) @activity.defn(name=f"{prefix}-sandbox_session_start") async def start(args: StartArgs) -> None: - session = await self._session(args) - await session.start() - return None + with _translate_sandbox_errors(): + session = await self._session(args) + await session.start() + return None @activity.defn(name=f"{prefix}-sandbox_session_stop") async def session_stop(args: StopArgs) -> None: - session = await self._session(args) - await session.stop() - return None + with _translate_sandbox_errors(): + session = await self._session(args) + await session.stop() + return None @activity.defn(name=f"{prefix}-sandbox_session_shutdown") async def session_shutdown(args: StopArgs) -> None: key = str(args.state.session_id) session = self._sessions.get(key) - if session is not None: - await session.shutdown() + if session is None: + return None + try: + with _translate_sandbox_errors(): + await session.shutdown() + except ApplicationError: + # Terminal failure: the session is dead, so evict it before + # re-raising. A retryable error instead propagates with the + # entry kept so the activity's retry can still shut it down. del self._sessions[key] + raise + del self._sessions[key] return None return [ diff --git a/tests/contrib/openai_agents/test_openai_sandbox.py b/tests/contrib/openai_agents/test_openai_sandbox.py index 74ff80e85..3338f8d64 100644 --- a/tests/contrib/openai_agents/test_openai_sandbox.py +++ b/tests/contrib/openai_agents/test_openai_sandbox.py @@ -9,6 +9,11 @@ import pytest from agents import Agent, FunctionTool, RunConfig, Runner, Tool from agents.sandbox import Capability, Manifest, SandboxAgent, SandboxRunConfig +from agents.sandbox.errors import ( + ExecTransportError, + SandboxError, + WorkspaceArchiveReadError, +) from agents.sandbox.session.base_sandbox_session import BaseSandboxSession from agents.sandbox.session.sandbox_client import ( BaseSandboxClient, @@ -55,6 +60,7 @@ TestModelProvider, ) from temporalio.contrib.openai_agents.workflow import temporal_sandbox_client +from temporalio.exceptions import ApplicationError from temporalio.workflow import ActivityConfig from tests.helpers import new_worker @@ -569,6 +575,149 @@ async def test_multiple_providers_register_distinct_activities(): ) +# ── SandboxError retryable mapping tests ── + + +class _ExecRaisingSession(_MockSandboxSession): + """Mock session whose exec() raises a chosen SandboxError.""" + + def __init__(self, error: SandboxError) -> None: + super().__init__() + self._error = error + + async def _exec_internal( + self, + *command: str | Path, # type: ignore[reportUnusedParameter] + timeout: float | None = None, # type: ignore[reportUnusedParameter] + ) -> ExecResult: + raise self._error + + +async def _exec_with_error(error: SandboxError) -> None: + provider = SandboxClientProvider( + "mock", _MockSandboxClient(_ExecRaisingSession(error)) + ) + acts = _activity_map(provider) + state = ( + await acts["mock-sandbox_client_create"]( + CreateSessionArgs( + snapshot_spec=None, manifest=Manifest(), client_options=None + ) + ) + ).state + await acts["mock-sandbox_session_exec"]( + ExecArgs(state=state, command=["boom"], shell=True) + ) + + +async def test_exec_terminal_error_becomes_non_retryable_application_error(): + """retryable is False should map to a non-retryable ApplicationError.""" + with pytest.raises(ApplicationError) as exc_info: + await _exec_with_error(ExecTransportError(command=["boom"], retryable=False)) + assert exc_info.value.non_retryable is True + assert exc_info.value.type == "exec_transport_error" + + +async def test_exec_transient_error_propagates_unchanged(): + """retryable is True should let the original SandboxError propagate.""" + with pytest.raises(ExecTransportError): + await _exec_with_error(ExecTransportError(command=["boom"], retryable=True)) + + +async def test_exec_unclassified_error_propagates_unchanged(): + """retryable is None should let the original SandboxError propagate (not converted).""" + with pytest.raises(ExecTransportError): + await _exec_with_error(ExecTransportError(command=["boom"], retryable=None)) + + +class _ShutdownRaisingSession(_MockSandboxSession): + """Mock session whose shutdown() raises a chosen SandboxError.""" + + def __init__(self, error: SandboxError) -> None: + super().__init__() + self._error = error + + async def shutdown(self) -> None: + raise self._error + + +async def _create_shutdown_raising( + error: SandboxError, +) -> tuple[dict[str, Any], SandboxClientProvider, StopArgs, str]: + provider = SandboxClientProvider( + "mock", _MockSandboxClient(_ShutdownRaisingSession(error)) + ) + acts = _activity_map(provider) + state = ( + await acts["mock-sandbox_client_create"]( + CreateSessionArgs( + snapshot_spec=None, manifest=Manifest(), client_options=None + ) + ) + ).state + key = str(state.session_id) + assert key in provider._sessions + return acts, provider, StopArgs(state=state), key + + +async def test_shutdown_terminal_error_evicts_session_and_raises(): + """A terminal shutdown error maps to a non-retryable ApplicationError and + evicts the dead session from the cache.""" + acts, provider, args, key = await _create_shutdown_raising( + ExecTransportError(command=["shutdown"], retryable=False) + ) + + with pytest.raises(ApplicationError) as exc_info: + await acts["mock-sandbox_session_shutdown"](args) + assert exc_info.value.non_retryable is True + assert key not in provider._sessions + + +async def test_shutdown_retryable_error_keeps_session_cached(): + """A retryable shutdown error propagates unchanged and leaves the session + cached so the activity's retry can still shut it down.""" + acts, provider, args, key = await _create_shutdown_raising( + ExecTransportError(command=["shutdown"], retryable=True) + ) + + with pytest.raises(ExecTransportError): + await acts["mock-sandbox_session_shutdown"](args) + assert key in provider._sessions + + +class _RunningRaisingSession(_MockSandboxSession): + """Mock session whose running() raises a chosen SandboxError.""" + + def __init__(self, error: SandboxError) -> None: + super().__init__() + self._error = error + + async def running(self) -> bool: + raise self._error + + +async def test_running_terminal_error_becomes_non_retryable_application_error(): + """A terminal SandboxError from a non-exec activity also maps to a + non-retryable ApplicationError, with type set to its error_code.""" + error = WorkspaceArchiveReadError(path=Path("/workspace"), retryable=False) + provider = SandboxClientProvider( + "mock", _MockSandboxClient(_RunningRaisingSession(error)) + ) + acts = _activity_map(provider) + state = ( + await acts["mock-sandbox_client_create"]( + CreateSessionArgs( + snapshot_spec=None, manifest=Manifest(), client_options=None + ) + ) + ).state + + with pytest.raises(ApplicationError) as exc_info: + await acts["mock-sandbox_session_running"](RunningArgs(state=state)) + assert exc_info.value.non_retryable is True + assert exc_info.value.type == "workspace_archive_read_error" + + # ── End-to-end test: Runner + SandboxAgent through Temporal activities ── diff --git a/uv.lock b/uv.lock index 1f46fc241..15a79011a 100644 --- a/uv.lock +++ b/uv.lock @@ -9,9 +9,12 @@ resolution-markers = [ ] [options] -exclude-newer = "2026-05-20T23:41:58.595699Z" +exclude-newer = "2026-06-01T18:36:48.998335583Z" exclude-newer-span = "P2W" +[options.exclude-newer-package] +openai-agents = false + [[package]] name = "aioboto3" version = "15.5.0" @@ -3424,7 +3427,7 @@ wheels = [ [[package]] name = "openai-agents" -version = "0.17.3" +version = "0.17.5" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "griffelib" }, @@ -3436,9 +3439,9 @@ dependencies = [ { name = "typing-extensions" }, { name = "websockets" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/fd/16/b79c1849125eb6d19cae98c21ff35caa2e55b5ec8d7a02b354b711917ef7/openai_agents-0.17.3.tar.gz", hash = "sha256:63b6dda6bd4fb51169e2a2cbd5d187a4e5ce823bbd15f965c8ed1d3b89072eec", size = 5406135, upload-time = "2026-05-19T01:28:15.971Z" } +sdist = { url = "https://files.pythonhosted.org/packages/72/fe/ef185f2a21f2fba1b0b107f72a7646bb51369d4c4025e2ab4d1ec65764f3/openai_agents-0.17.5.tar.gz", hash = "sha256:5dd46943b993e1a68a78acd254fc6a00cf0455fc3dcc802078ea26964b14278c", size = 5420036, upload-time = "2026-06-11T04:12:35.775Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/80/ec/775a14cfd5f12f4ffe458c7ac9527831093c72e8c1aef682898fc6394106/openai_agents-0.17.3-py3-none-any.whl", hash = "sha256:a048bb0752d40913d18bccf6562f56260b603bb57c972597b6da58f60123f4bd", size = 841541, upload-time = "2026-05-19T01:28:13.334Z" }, + { url = "https://files.pythonhosted.org/packages/b9/f0/9184cd6d3d089a568fc544f1c7f0965d63818fa310c912b30abd333ea138/openai_agents-0.17.5-py3-none-any.whl", hash = "sha256:9afa8a67f0b9fbcdfd2d1545b38d3c52d47e4182921cb79952ad61580d950973", size = 846844, upload-time = "2026-06-11T04:12:32.485Z" }, ] [package.optional-dependencies] @@ -5509,7 +5512,7 @@ requires-dist = [ { name = "langsmith", marker = "extra == 'langsmith'", specifier = ">=0.7.34,<0.9" }, { name = "mcp", marker = "extra == 'openai-agents'", specifier = ">=1.9.4,<2" }, { name = "nexus-rpc", specifier = "==1.4.0" }, - { name = "openai-agents", marker = "extra == 'openai-agents'", specifier = ">=0.17.1" }, + { name = "openai-agents", marker = "extra == 'openai-agents'", specifier = ">=0.17.5" }, { name = "opentelemetry-api", marker = "extra == 'lambda-worker-otel'", specifier = ">=1.11.1,<2" }, { name = "opentelemetry-api", marker = "extra == 'opentelemetry'", specifier = ">=1.11.1,<2" }, { name = "opentelemetry-exporter-otlp-proto-grpc", marker = "extra == 'lambda-worker-otel'", specifier = ">=1.11.1,<2" },