Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ classifiers = [
grpc = ["grpcio>=1.48.2,<2"]
opentelemetry = ["opentelemetry-api>=1.11.1,<2", "opentelemetry-sdk>=1.11.1,<2"]
pydantic = ["pydantic>=2.0.0,<3"]
openai-agents = ["openai-agents>=0.17.1", "mcp>=1.9.4, <2"]
openai-agents = ["openai-agents>=0.17.5", "mcp>=1.9.4, <2"]
google-adk = ["google-adk>=1.27.0,<2"]
langgraph = ["langgraph>=1.1.0"]
langsmith = ["langsmith>=0.7.34,<0.9"]
Expand Down Expand Up @@ -257,3 +257,4 @@ exclude = ["temporalio/bridge/target/**/*"]
# Prevent uv commands from building the package by default
package = false
exclude-newer = "2 weeks"
exclude-newer-package = { openai-agents = false }
202 changes: 121 additions & 81 deletions temporalio/contrib/openai_agents/sandbox/_sandbox_client_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from __future__ import annotations

import io
from collections.abc import Callable, Sequence
from collections.abc import Callable, Iterator, Sequence
from contextlib import contextmanager
from pathlib import Path
from typing import Any

from agents.sandbox.errors import SandboxError
from agents.sandbox.session.sandbox_client import BaseSandboxClient
from agents.sandbox.session.sandbox_session import SandboxSession

Expand Down Expand Up @@ -34,6 +36,22 @@
from temporalio.contrib.openai_agents.sandbox._temporal_activity_models import (
ExecResult as ExecResultModel,
)
from temporalio.exceptions import ApplicationError


@contextmanager
def _translate_sandbox_errors() -> Iterator[None]:
# Temporal retries every activity exception by default, so only a SandboxError
# the library has classified as terminal (retryable is False) is turned into a
# non-retryable ApplicationError.
try:
yield
except SandboxError as e:
if e.retryable is False:
raise ApplicationError(
str(e), type=str(e.error_code), non_retryable=True
) from e
raise


class SandboxClientProvider:
Expand Down Expand Up @@ -99,132 +117,154 @@ def _get_activities(self) -> Sequence[Callable[..., Any]]:

@activity.defn(name=f"{prefix}-sandbox_client_create")
async def create_session(args: CreateSessionArgs) -> SessionResult:
session = await self._client.create(
snapshot=args.snapshot_spec,
manifest=args.manifest,
options=args.client_options,
)
self._sessions[str(session.state.session_id)] = session
return SessionResult(
state=session.state, supports_pty=session.supports_pty()
)
with _translate_sandbox_errors():
session = await self._client.create(
snapshot=args.snapshot_spec,
manifest=args.manifest,
options=args.client_options,
)
self._sessions[str(session.state.session_id)] = session
return SessionResult(
state=session.state, supports_pty=session.supports_pty()
)

@activity.defn(name=f"{prefix}-sandbox_client_resume")
async def resume_session(args: ResumeSessionArgs) -> SessionResult:
session = await self._client.resume(args.state)
self._sessions[str(session.state.session_id)] = session
return SessionResult(
state=session.state, supports_pty=session.supports_pty()
)
with _translate_sandbox_errors():
session = await self._client.resume(args.state)
self._sessions[str(session.state.session_id)] = session
return SessionResult(
state=session.state, supports_pty=session.supports_pty()
)

@activity.defn(name=f"{prefix}-sandbox_client_delete")
async def delete_session(args: StopArgs) -> None:
session = await self._session(args)
await self._client.delete(session)
return None
with _translate_sandbox_errors():
session = await self._session(args)
await self._client.delete(session)
return None

# -- Session-level operations (I/O and lifecycle) --

@activity.defn(name=f"{prefix}-sandbox_session_exec")
async def exec_(args: ExecArgs) -> ExecResultModel:
session = await self._session(args)
result = await session.exec(
*args.command,
timeout=args.timeout,
shell=args.shell,
user=args.user,
)
return ExecResultModel(
stdout=result.stdout,
stderr=result.stderr,
exit_code=result.exit_code,
)
with _translate_sandbox_errors():
session = await self._session(args)
result = await session.exec(
*args.command,
timeout=args.timeout,
shell=args.shell,
user=args.user,
)
return ExecResultModel(
stdout=result.stdout,
stderr=result.stderr,
exit_code=result.exit_code,
)

@activity.defn(name=f"{prefix}-sandbox_session_read")
async def read(args: ReadArgs) -> ReadResult:
session = await self._session(args)
handle = await session.read(Path(args.path))
return ReadResult(data=handle.read())
with _translate_sandbox_errors():
session = await self._session(args)
handle = await session.read(Path(args.path))
return ReadResult(data=handle.read())

@activity.defn(name=f"{prefix}-sandbox_session_write")
async def write(args: WriteArgs) -> None:
session = await self._session(args)
await session.write(Path(args.path), io.BytesIO(args.data))
return None
with _translate_sandbox_errors():
session = await self._session(args)
await session.write(Path(args.path), io.BytesIO(args.data))
return None

@activity.defn(name=f"{prefix}-sandbox_session_running")
async def running(args: RunningArgs) -> RunningResult:
session = await self._session(args)
return RunningResult(is_running=await session.running())
with _translate_sandbox_errors():
session = await self._session(args)
return RunningResult(is_running=await session.running())

@activity.defn(name=f"{prefix}-sandbox_session_persist_workspace")
async def persist_workspace(
args: PersistWorkspaceArgs,
) -> PersistWorkspaceResult:
session = await self._session(args)
stream = await session.persist_workspace()
return PersistWorkspaceResult(data=stream.read())
with _translate_sandbox_errors():
session = await self._session(args)
stream = await session.persist_workspace()
return PersistWorkspaceResult(data=stream.read())

@activity.defn(name=f"{prefix}-sandbox_session_hydrate_workspace")
async def hydrate_workspace(args: HydrateWorkspaceArgs) -> None:
session = await self._session(args)
await session.hydrate_workspace(io.BytesIO(args.data))
return None
with _translate_sandbox_errors():
session = await self._session(args)
await session.hydrate_workspace(io.BytesIO(args.data))
return None

@activity.defn(name=f"{prefix}-sandbox_session_pty_exec_start")
async def pty_exec_start(args: PtyExecStartArgs) -> PtyExecUpdateResult:
session = await self._session(args)
update = await session.pty_exec_start(
*args.command,
timeout=args.timeout,
shell=args.shell,
user=args.user,
tty=args.tty,
yield_time_s=args.yield_time_s,
max_output_tokens=args.max_output_tokens,
)
return PtyExecUpdateResult(
process_id=update.process_id,
output=update.output,
exit_code=update.exit_code,
original_token_count=update.original_token_count,
)
with _translate_sandbox_errors():
session = await self._session(args)
update = await session.pty_exec_start(
*args.command,
timeout=args.timeout,
shell=args.shell,
user=args.user,
tty=args.tty,
yield_time_s=args.yield_time_s,
max_output_tokens=args.max_output_tokens,
)
return PtyExecUpdateResult(
process_id=update.process_id,
output=update.output,
exit_code=update.exit_code,
original_token_count=update.original_token_count,
)

@activity.defn(name=f"{prefix}-sandbox_session_pty_write_stdin")
async def pty_write_stdin(args: PtyWriteStdinArgs) -> PtyExecUpdateResult:
session = await self._session(args)
update = await session.pty_write_stdin(
session_id=args.session_id,
chars=args.chars,
yield_time_s=args.yield_time_s,
max_output_tokens=args.max_output_tokens,
)
return PtyExecUpdateResult(
process_id=update.process_id,
output=update.output,
exit_code=update.exit_code,
original_token_count=update.original_token_count,
)
with _translate_sandbox_errors():
session = await self._session(args)
update = await session.pty_write_stdin(
session_id=args.session_id,
chars=args.chars,
yield_time_s=args.yield_time_s,
max_output_tokens=args.max_output_tokens,
)
return PtyExecUpdateResult(
process_id=update.process_id,
output=update.output,
exit_code=update.exit_code,
original_token_count=update.original_token_count,
)

@activity.defn(name=f"{prefix}-sandbox_session_start")
async def start(args: StartArgs) -> None:
session = await self._session(args)
await session.start()
return None
with _translate_sandbox_errors():
session = await self._session(args)
await session.start()
return None

@activity.defn(name=f"{prefix}-sandbox_session_stop")
async def session_stop(args: StopArgs) -> None:
session = await self._session(args)
await session.stop()
return None
with _translate_sandbox_errors():
session = await self._session(args)
await session.stop()
return None

@activity.defn(name=f"{prefix}-sandbox_session_shutdown")
async def session_shutdown(args: StopArgs) -> None:
key = str(args.state.session_id)
session = self._sessions.get(key)
if session is not None:
await session.shutdown()
if session is None:
return None
try:
with _translate_sandbox_errors():
await session.shutdown()
except ApplicationError:
# Terminal failure: the session is dead, so evict it before
# re-raising. A retryable error instead propagates with the
# entry kept so the activity's retry can still shut it down.
del self._sessions[key]
raise
del self._sessions[key]
return None

return [
Expand Down
Loading
Loading