diff --git a/src/agents/sandbox/session/archive_ops.py b/src/agents/sandbox/session/archive_ops.py new file mode 100644 index 0000000000..131f667018 --- /dev/null +++ b/src/agents/sandbox/session/archive_ops.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import io +import shutil +import tempfile +from pathlib import Path +from typing import TYPE_CHECKING, Literal, cast + +from ..errors import InvalidCompressionSchemeError +from .archive_extraction import WorkspaceArchiveExtractor, safe_zip_member_rel_path + +if TYPE_CHECKING: + from .base_sandbox_session import BaseSandboxSession + + +async def extract_archive( + session: BaseSandboxSession, + path: Path | str, + data: io.IOBase, + *, + compression_scheme: Literal["tar", "zip"] | None = None, +) -> None: + if isinstance(path, str): + path = Path(path) + + if compression_scheme is None: + suffix = path.suffix.removeprefix(".") + compression_scheme = cast(Literal["tar", "zip"], suffix) if suffix else None + + if compression_scheme is None or compression_scheme not in ["zip", "tar"]: + raise InvalidCompressionSchemeError(path=path, scheme=compression_scheme) + + normalized_path = await session._validate_path_access(path, for_write=True) + destination_root = normalized_path.parent + + # Materialize the archive into a local spool once because both `write()` and the + # extraction step consume the stream, and zip extraction may require seeking. + spool = tempfile.SpooledTemporaryFile(max_size=16 * 1024 * 1024, mode="w+b") + try: + shutil.copyfileobj(data, spool) + spool.seek(0) + await session.write(normalized_path, spool) + spool.seek(0) + + if compression_scheme == "tar": + await session._extract_tar_archive( + archive_path=normalized_path, + destination_root=destination_root, + data=spool, + ) + else: + await session._extract_zip_archive( + archive_path=normalized_path, + destination_root=destination_root, + data=spool, + ) + finally: + spool.close() + + +async def extract_tar_archive( + session: BaseSandboxSession, + *, + archive_path: Path, + destination_root: Path, + data: io.IOBase, +) -> None: + extractor = _build_workspace_archive_extractor(session) + await extractor.extract_tar_archive( + archive_path=archive_path, + destination_root=destination_root, + data=data, + ) + + +async def extract_zip_archive( + session: BaseSandboxSession, + *, + archive_path: Path, + destination_root: Path, + data: io.IOBase, +) -> None: + extractor = _build_workspace_archive_extractor(session) + await extractor.extract_zip_archive( + archive_path=archive_path, + destination_root=destination_root, + data=data, + ) + + +def _build_workspace_archive_extractor(session: BaseSandboxSession) -> WorkspaceArchiveExtractor: + return WorkspaceArchiveExtractor( + mkdir=lambda path: session.mkdir(path, parents=True), + write=session.write, + ls=lambda path: session.ls(path), + ) + + +__all__ = [ + "extract_archive", + "extract_tar_archive", + "extract_zip_archive", + "safe_zip_member_rel_path", +] diff --git a/src/agents/sandbox/session/base_sandbox_session.py b/src/agents/sandbox/session/base_sandbox_session.py index 48e2ab563c..cef10c0075 100644 --- a/src/agents/sandbox/session/base_sandbox_session.py +++ b/src/agents/sandbox/session/base_sandbox_session.py @@ -1,13 +1,9 @@ import abc -import hashlib import io -import json import shlex -import shutil -import tempfile from collections.abc import Awaitable, Callable, Mapping, Sequence from pathlib import Path, PurePath -from typing import Literal, TypeVar, cast +from typing import Literal, TypeVar from typing_extensions import Self @@ -23,17 +19,15 @@ ExecNonZeroError, ExecTransportError, ExposedPortUnavailableError, - InvalidCompressionSchemeError, InvalidManifestPathError, MountConfigError, PtySessionNotFoundError, WorkspaceArchiveWriteError, WorkspaceReadNotFoundError, ) -from ..files import EntryKind, FileEntry +from ..files import FileEntry from ..manifest import Manifest from ..materialization import MaterializationResult, MaterializedFile -from ..snapshot import NoopSnapshot from ..types import ExecResult, ExposedPortEndpoint, User from ..util.parse_utils import parse_ls_la from ..workspace_paths import ( @@ -43,23 +37,17 @@ posix_path_for_error, sandbox_path_str, ) -from .archive_extraction import ( - WorkspaceArchiveExtractor, - safe_zip_member_rel_path, -) +from . import archive_ops, manifest_ops, snapshot_lifecycle from .dependencies import Dependencies -from .manifest_application import ManifestApplier from .pty_types import PtyExecUpdate from .runtime_helpers import ( RESOLVE_WORKSPACE_PATH_HELPER, - WORKSPACE_FINGERPRINT_HELPER, RuntimeHelperScript, ) from .sandbox_session_state import SandboxSessionState _PtyEntryT = TypeVar("_PtyEntryT") _RUNTIME_HELPER_CACHE_KEY_UNSET = object() -_SNAPSHOT_FINGERPRINT_VERSION = "workspace_tar_sha256_v1" _WORKSPACE_ROOT_PROBE_TIMEOUT_S = 10.0 _WRITE_ACCESS_CHECK_SCRIPT = ( 'target="$1"\n' @@ -294,35 +282,7 @@ async def _before_stop(self) -> None: async def _persist_snapshot(self) -> None: """Persist/snapshot the workspace.""" - if isinstance(self.state.snapshot, NoopSnapshot): - return - - fingerprint_record: dict[str, str] | None = None - try: - fingerprint_record = await self._compute_and_cache_snapshot_fingerprint() - except Exception: - fingerprint_record = None - - workspace_archive = await self.persist_workspace() - try: - await self.state.snapshot.persist(workspace_archive, dependencies=self.dependencies) - except Exception: - if fingerprint_record is not None: - await self._delete_cached_snapshot_fingerprint_best_effort() - raise - finally: - try: - workspace_archive.close() - except Exception: - pass - - if fingerprint_record is None: - self.state.snapshot_fingerprint = None - self.state.snapshot_fingerprint_version = None - return - - self.state.snapshot_fingerprint = fingerprint_record["fingerprint"] - self.state.snapshot_fingerprint_version = fingerprint_record["version"] + await snapshot_lifecycle.persist_snapshot(self) def _wrap_stop_error(self, error: Exception) -> Exception: """Return a provider-specific stop error, or the original error.""" @@ -1015,42 +975,12 @@ async def extract( :param compression_scheme: either "tar" or "zip". If not provided, it will try to infer from the path. """ - if isinstance(path, str): - path = Path(path) - - if compression_scheme is None: - suffix = path.suffix.removeprefix(".") - compression_scheme = cast(Literal["tar", "zip"], suffix) if suffix else None - - if compression_scheme is None or compression_scheme not in ["zip", "tar"]: - raise InvalidCompressionSchemeError(path=path, scheme=compression_scheme) - - normalized_path = await self._validate_path_access(path, for_write=True) - destination_root = normalized_path.parent - - # Materialize the archive into a local spool once because both `write()` and the - # extraction step consume the stream, and zip extraction may require seeking. - spool = tempfile.SpooledTemporaryFile(max_size=16 * 1024 * 1024, mode="w+b") - try: - shutil.copyfileobj(data, spool) - spool.seek(0) - await self.write(normalized_path, spool) - spool.seek(0) - - if compression_scheme == "tar": - await self._extract_tar_archive( - archive_path=normalized_path, - destination_root=destination_root, - data=spool, - ) - else: - await self._extract_zip_archive( - archive_path=normalized_path, - destination_root=destination_root, - data=spool, - ) - finally: - spool.close() + await archive_ops.extract_archive( + self, + path, + data, + compression_scheme=compression_scheme, + ) async def apply_patch( self, @@ -1076,12 +1006,8 @@ async def _extract_tar_archive( destination_root: Path, data: io.IOBase, ) -> None: - extractor = WorkspaceArchiveExtractor( - mkdir=lambda path: self.mkdir(path, parents=True), - write=self.write, - ls=lambda path: self.ls(path), - ) - await extractor.extract_tar_archive( + await archive_ops.extract_tar_archive( + self, archive_path=archive_path, destination_root=destination_root, data=data, @@ -1094,12 +1020,8 @@ async def _extract_zip_archive( destination_root: Path, data: io.IOBase, ) -> None: - extractor = WorkspaceArchiveExtractor( - mkdir=lambda path: self.mkdir(path, parents=True), - write=self.write, - ls=lambda path: self.ls(path), - ) - await extractor.extract_zip_archive( + await archive_ops.extract_zip_archive( + self, archive_path=archive_path, destination_root=destination_root, data=data, @@ -1107,7 +1029,7 @@ async def _extract_zip_archive( @staticmethod def _safe_zip_member_rel_path(member) -> Path | None: - return safe_zip_member_rel_path(member) + return archive_ops.safe_zip_member_rel_path(member) async def _apply_manifest( self, @@ -1115,17 +1037,10 @@ async def _apply_manifest( only_ephemeral: bool = False, provision_accounts: bool = True, ) -> MaterializationResult: - applier = ManifestApplier( - mkdir=lambda path: self.mkdir(path, parents=True), - exec_checked_nonzero=self._exec_checked_nonzero, - apply_entry=lambda artifact, dest, base_dir: artifact.apply(self, dest, base_dir), - max_entry_concurrency=self._max_manifest_entry_concurrency, - ) - return await applier.apply_manifest( - self.state.manifest, + return await manifest_ops.apply_manifest( + self, only_ephemeral=only_ephemeral, provision_accounts=provision_accounts, - base_dir=self._manifest_base_dir(), ) async def apply_manifest(self, *, only_ephemeral: bool = False) -> MaterializationResult: @@ -1135,12 +1050,7 @@ async def apply_manifest(self, *, only_ephemeral: bool = False) -> Materializati ) async def provision_manifest_accounts(self) -> None: - applier = ManifestApplier( - mkdir=lambda path: self.mkdir(path, parents=True), - exec_checked_nonzero=self._exec_checked_nonzero, - apply_entry=lambda artifact, dest, base_dir: artifact.apply(self, dest, base_dir), - ) - await applier.provision_accounts(self.state.manifest) + await manifest_ops.provision_manifest_accounts(self) def should_provision_manifest_accounts_on_resume(self) -> bool: """Return whether resume should reprovision manifest-managed users and groups.""" @@ -1155,142 +1065,62 @@ async def _reapply_ephemeral_manifest_on_resume(self) -> None: async def _restore_snapshot_into_workspace_on_resume(self) -> None: """Clear the live workspace contents and repopulate them from the persisted snapshot.""" - await self._clear_workspace_root_on_resume() - workspace_archive = await self.state.snapshot.restore(dependencies=self.dependencies) - try: - await self.hydrate_workspace(workspace_archive) - finally: - try: - workspace_archive.close() - except Exception: - pass + await snapshot_lifecycle.restore_snapshot_into_workspace_on_resume(self) async def _live_workspace_matches_snapshot_on_resume(self) -> bool: """Return whether the running sandbox workspace definitely matches the stored snapshot.""" - stored_fingerprint = self.state.snapshot_fingerprint - stored_version = self.state.snapshot_fingerprint_version - if not stored_fingerprint or not stored_version: - return False - - try: - cached_record = await self._compute_and_cache_snapshot_fingerprint() - except Exception: - return False - - return ( - cached_record.get("fingerprint") == stored_fingerprint - and cached_record.get("version") == stored_version - ) + return await snapshot_lifecycle.live_workspace_matches_snapshot_on_resume(self) async def _can_skip_snapshot_restore_on_resume(self, *, is_running: bool) -> bool: """Return whether resume can safely reuse the running workspace without restore.""" - if not is_running: - return False - return await self._live_workspace_matches_snapshot_on_resume() + return await snapshot_lifecycle.can_skip_snapshot_restore_on_resume( + self, + is_running=is_running, + ) def _snapshot_fingerprint_cache_path(self) -> Path: """Return the runtime-owned path for this session's cached snapshot fingerprint.""" - cache_path = coerce_posix_path( - f"/tmp/openai-agents/session-state/{self.state.session_id.hex}/fingerprint.json" - ) - if self._workspace_path_policy().root_is_existing_host_path(): - return Path(cache_path.as_posix()) - return posix_path_as_path(cache_path) + return snapshot_lifecycle.snapshot_fingerprint_cache_path(self) def _workspace_fingerprint_skip_relpaths(self) -> set[Path]: """Return workspace paths that should be omitted from snapshot fingerprinting.""" - skip_paths = self._persist_workspace_skip_relpaths() - skip_paths.update(self._workspace_resume_mount_skip_relpaths()) - return skip_paths + return snapshot_lifecycle.workspace_fingerprint_skip_relpaths(self) async def _compute_and_cache_snapshot_fingerprint(self) -> dict[str, str]: """Compute the current workspace fingerprint in-container and atomically cache it.""" - helper_path = await self._ensure_runtime_helper_installed(WORKSPACE_FINGERPRINT_HELPER) - command = [ - str(helper_path), - self._workspace_root_path().as_posix(), - self._snapshot_fingerprint_version(), - self._snapshot_fingerprint_cache_path().as_posix(), - self._resume_manifest_digest(), - ] - command.extend( - rel_path.as_posix() - for rel_path in sorted( - self._workspace_fingerprint_skip_relpaths(), - key=lambda path: path.as_posix(), - ) - ) - result = await self.exec(*command, shell=False) - if not result.ok(): - raise ExecNonZeroError(result, command=("compute_workspace_fingerprint", *command[1:])) - return self._parse_snapshot_fingerprint_record(result.stdout) + return await snapshot_lifecycle.compute_and_cache_snapshot_fingerprint(self) async def _read_cached_snapshot_fingerprint(self) -> dict[str, str]: """Read the cached snapshot fingerprint record from the running sandbox.""" - result = await self.exec( - "cat", - "--", - self._snapshot_fingerprint_cache_path().as_posix(), - shell=False, - ) - if not result.ok(): - raise ExecNonZeroError( - result, - command=("cat", self._snapshot_fingerprint_cache_path().as_posix()), - ) - return self._parse_snapshot_fingerprint_record(result.stdout) + return await snapshot_lifecycle.read_cached_snapshot_fingerprint(self) def _parse_snapshot_fingerprint_record( self, payload: bytes | bytearray | str ) -> dict[str, str]: """Validate and normalize a cached snapshot fingerprint JSON payload.""" - raw = payload.decode("utf-8") if isinstance(payload, bytes | bytearray) else payload - data = json.loads(raw) - if not isinstance(data, dict): - raise ValueError("snapshot fingerprint payload must be a JSON object") - fingerprint = data.get("fingerprint") - version = data.get("version") - if not isinstance(fingerprint, str) or not fingerprint: - raise ValueError("snapshot fingerprint payload is missing `fingerprint`") - if not isinstance(version, str) or not version: - raise ValueError("snapshot fingerprint payload is missing `version`") - return {"fingerprint": fingerprint, "version": version} + return snapshot_lifecycle.parse_snapshot_fingerprint_record(payload) async def _delete_cached_snapshot_fingerprint_best_effort(self) -> None: """Remove the cached snapshot fingerprint file without raising on cleanup failure.""" - try: - await self.exec( - "rm", - "-f", - "--", - self._snapshot_fingerprint_cache_path().as_posix(), - shell=False, - ) - except Exception: - return + await snapshot_lifecycle.delete_cached_snapshot_fingerprint_best_effort(self) def _snapshot_fingerprint_version(self) -> str: """Return the version tag for the current snapshot fingerprint algorithm.""" - return _SNAPSHOT_FINGERPRINT_VERSION + return snapshot_lifecycle.snapshot_fingerprint_version() def _resume_manifest_digest(self) -> str: """Return a stable digest of the manifest state that affects resume correctness.""" - manifest_payload = json.dumps( - self.state.manifest.model_dump(mode="json"), - sort_keys=True, - separators=(",", ":"), - ).encode("utf-8") - return hashlib.sha256(manifest_payload).hexdigest() + return snapshot_lifecycle.resume_manifest_digest(self) async def _apply_entry_batch( self, @@ -1298,17 +1128,7 @@ async def _apply_entry_batch( *, base_dir: Path, ) -> list[MaterializedFile]: - applier = ManifestApplier( - mkdir=lambda path: self.mkdir(path, parents=True), - exec_checked_nonzero=self._exec_checked_nonzero, - apply_entry=lambda artifact, dest, current_base_dir: artifact.apply( - self, - dest, - current_base_dir, - ), - max_entry_concurrency=self._max_manifest_entry_concurrency, - ) - return await applier._apply_entry_batch(entries, base_dir=base_dir) + return await manifest_ops.apply_entry_batch(self, entries, base_dir=base_dir) def _manifest_base_dir(self) -> Path: return Path.cwd() @@ -1329,24 +1149,10 @@ async def _clear_workspace_root_on_resume(self) -> None: fail with "failed to find initial working directory". """ - skip_rel_paths = self._workspace_resume_mount_skip_relpaths() - if any(rel_path in (Path(""), Path(".")) for rel_path in skip_rel_paths): - return - - await self._clear_workspace_dir_on_resume_pruned( - current_dir=self._workspace_root_path(), - skip_rel_paths=skip_rel_paths, - ) + await snapshot_lifecycle.clear_workspace_root_on_resume(self) def _workspace_resume_mount_skip_relpaths(self) -> set[Path]: - root = self._workspace_root_path() - skip_rel_paths: set[Path] = set() - for _mount, mount_path in self.state.manifest.ephemeral_mount_targets(): - try: - skip_rel_paths.add(mount_path.relative_to(root)) - except ValueError: - continue - return skip_rel_paths + return snapshot_lifecycle.workspace_resume_mount_skip_relpaths(self) async def _clear_workspace_dir_on_resume_pruned( self, @@ -1354,32 +1160,8 @@ async def _clear_workspace_dir_on_resume_pruned( current_dir: Path, skip_rel_paths: set[Path], ) -> None: - root = self._workspace_root_path() - try: - entries = await self.ls(current_dir) - except ExecNonZeroError: - # If the root or subtree doesn't exist (or isn't listable), treat it as empty and let - # hydrate/apply create it as needed. - return - - for entry in entries: - child = Path(entry.path) - try: - child_rel = child.relative_to(root) - except ValueError: - await self.rm(child, recursive=True) - continue - - if child_rel in skip_rel_paths: - continue - if any(child_rel in skip_rel_path.parents for skip_rel_path in skip_rel_paths): - if entry.kind == EntryKind.DIRECTORY: - await self._clear_workspace_dir_on_resume_pruned( - current_dir=child, - skip_rel_paths=skip_rel_paths, - ) - else: - await self.rm(child, recursive=True) - continue - # `parse_ls_la` filters "." and ".." already; remove everything else recursively. - await self.rm(child, recursive=True) + await snapshot_lifecycle.clear_workspace_dir_on_resume_pruned( + self, + current_dir=current_dir, + skip_rel_paths=skip_rel_paths, + ) diff --git a/src/agents/sandbox/session/manifest_ops.py b/src/agents/sandbox/session/manifest_ops.py new file mode 100644 index 0000000000..04eab029d4 --- /dev/null +++ b/src/agents/sandbox/session/manifest_ops.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +from ..entries import BaseEntry +from ..materialization import MaterializationResult, MaterializedFile +from .manifest_application import ManifestApplier + +if TYPE_CHECKING: + from collections.abc import Sequence + + from .base_sandbox_session import BaseSandboxSession + + +async def apply_manifest( + session: BaseSandboxSession, + *, + only_ephemeral: bool = False, + provision_accounts: bool = True, +) -> MaterializationResult: + applier = _build_manifest_applier(session, include_entry_concurrency=True) + return await applier.apply_manifest( + session.state.manifest, + only_ephemeral=only_ephemeral, + provision_accounts=provision_accounts, + base_dir=session._manifest_base_dir(), + ) + + +async def provision_manifest_accounts(session: BaseSandboxSession) -> None: + applier = _build_manifest_applier(session, include_entry_concurrency=False) + await applier.provision_accounts(session.state.manifest) + + +async def apply_entry_batch( + session: BaseSandboxSession, + entries: Sequence[tuple[Path, BaseEntry]], + *, + base_dir: Path, +) -> list[MaterializedFile]: + applier = _build_manifest_applier(session, include_entry_concurrency=True) + return await applier._apply_entry_batch(entries, base_dir=base_dir) + + +def _build_manifest_applier( + session: BaseSandboxSession, + *, + include_entry_concurrency: bool, +) -> ManifestApplier: + max_entry_concurrency = ( + session._max_manifest_entry_concurrency if include_entry_concurrency else None + ) + return ManifestApplier( + mkdir=lambda path: session.mkdir(path, parents=True), + exec_checked_nonzero=session._exec_checked_nonzero, + apply_entry=lambda artifact, dest, base_dir: artifact.apply(session, dest, base_dir), + max_entry_concurrency=max_entry_concurrency, + ) + + +__all__ = [ + "apply_entry_batch", + "apply_manifest", + "provision_manifest_accounts", +] diff --git a/src/agents/sandbox/session/snapshot_lifecycle.py b/src/agents/sandbox/session/snapshot_lifecycle.py new file mode 100644 index 0000000000..1145f8a247 --- /dev/null +++ b/src/agents/sandbox/session/snapshot_lifecycle.py @@ -0,0 +1,263 @@ +from __future__ import annotations + +import hashlib +import io +import json +from pathlib import Path +from typing import TYPE_CHECKING + +from ..errors import ExecNonZeroError +from ..files import EntryKind +from ..snapshot import NoopSnapshot +from ..workspace_paths import coerce_posix_path, posix_path_as_path +from .runtime_helpers import WORKSPACE_FINGERPRINT_HELPER + +if TYPE_CHECKING: + from .base_sandbox_session import BaseSandboxSession + +SNAPSHOT_FINGERPRINT_VERSION = "workspace_tar_sha256_v1" + + +async def persist_snapshot(session: BaseSandboxSession) -> None: + if isinstance(session.state.snapshot, NoopSnapshot): + return + + fingerprint_record: dict[str, str] | None = None + try: + fingerprint_record = await session._compute_and_cache_snapshot_fingerprint() + except Exception: + fingerprint_record = None + + workspace_archive = await session.persist_workspace() + try: + await session.state.snapshot.persist(workspace_archive, dependencies=session.dependencies) + except Exception: + if fingerprint_record is not None: + await session._delete_cached_snapshot_fingerprint_best_effort() + raise + finally: + _close_best_effort(workspace_archive) + + if fingerprint_record is None: + session.state.snapshot_fingerprint = None + session.state.snapshot_fingerprint_version = None + return + + session.state.snapshot_fingerprint = fingerprint_record["fingerprint"] + session.state.snapshot_fingerprint_version = fingerprint_record["version"] + + +async def restore_snapshot_into_workspace_on_resume(session: BaseSandboxSession) -> None: + await session._clear_workspace_root_on_resume() + workspace_archive = await session.state.snapshot.restore(dependencies=session.dependencies) + try: + await session.hydrate_workspace(workspace_archive) + finally: + _close_best_effort(workspace_archive) + + +async def live_workspace_matches_snapshot_on_resume(session: BaseSandboxSession) -> bool: + stored_fingerprint = session.state.snapshot_fingerprint + stored_version = session.state.snapshot_fingerprint_version + if not stored_fingerprint or not stored_version: + return False + + try: + cached_record = await session._compute_and_cache_snapshot_fingerprint() + except Exception: + return False + + return ( + cached_record.get("fingerprint") == stored_fingerprint + and cached_record.get("version") == stored_version + ) + + +async def can_skip_snapshot_restore_on_resume( + session: BaseSandboxSession, + *, + is_running: bool, +) -> bool: + if not is_running: + return False + return await live_workspace_matches_snapshot_on_resume(session) + + +def snapshot_fingerprint_cache_path(session: BaseSandboxSession) -> Path: + cache_path = coerce_posix_path( + f"/tmp/openai-agents/session-state/{session.state.session_id.hex}/fingerprint.json" + ) + if session._workspace_path_policy().root_is_existing_host_path(): + return Path(cache_path.as_posix()) + return posix_path_as_path(cache_path) + + +def workspace_fingerprint_skip_relpaths(session: BaseSandboxSession) -> set[Path]: + skip_paths = session._persist_workspace_skip_relpaths() + skip_paths.update(session._workspace_resume_mount_skip_relpaths()) + return skip_paths + + +async def compute_and_cache_snapshot_fingerprint( + session: BaseSandboxSession, +) -> dict[str, str]: + helper_path = await session._ensure_runtime_helper_installed(WORKSPACE_FINGERPRINT_HELPER) + command = [ + str(helper_path), + session._workspace_root_path().as_posix(), + session._snapshot_fingerprint_version(), + session._snapshot_fingerprint_cache_path().as_posix(), + session._resume_manifest_digest(), + ] + command.extend( + rel_path.as_posix() + for rel_path in sorted( + session._workspace_fingerprint_skip_relpaths(), + key=lambda path: path.as_posix(), + ) + ) + result = await session.exec(*command, shell=False) + if not result.ok(): + raise ExecNonZeroError(result, command=("compute_workspace_fingerprint", *command[1:])) + return parse_snapshot_fingerprint_record(result.stdout) + + +async def read_cached_snapshot_fingerprint(session: BaseSandboxSession) -> dict[str, str]: + result = await session.exec( + "cat", + "--", + session._snapshot_fingerprint_cache_path().as_posix(), + shell=False, + ) + if not result.ok(): + raise ExecNonZeroError( + result, + command=("cat", session._snapshot_fingerprint_cache_path().as_posix()), + ) + return parse_snapshot_fingerprint_record(result.stdout) + + +def parse_snapshot_fingerprint_record(payload: bytes | bytearray | str) -> dict[str, str]: + raw = payload.decode("utf-8") if isinstance(payload, bytes | bytearray) else payload + data = json.loads(raw) + if not isinstance(data, dict): + raise ValueError("snapshot fingerprint payload must be a JSON object") + fingerprint = data.get("fingerprint") + version = data.get("version") + if not isinstance(fingerprint, str) or not fingerprint: + raise ValueError("snapshot fingerprint payload is missing `fingerprint`") + if not isinstance(version, str) or not version: + raise ValueError("snapshot fingerprint payload is missing `version`") + return {"fingerprint": fingerprint, "version": version} + + +async def delete_cached_snapshot_fingerprint_best_effort(session: BaseSandboxSession) -> None: + try: + await session.exec( + "rm", + "-f", + "--", + session._snapshot_fingerprint_cache_path().as_posix(), + shell=False, + ) + except Exception: + return + + +def snapshot_fingerprint_version() -> str: + return SNAPSHOT_FINGERPRINT_VERSION + + +def resume_manifest_digest(session: BaseSandboxSession) -> str: + manifest_payload = json.dumps( + session.state.manifest.model_dump(mode="json"), + sort_keys=True, + separators=(",", ":"), + ).encode("utf-8") + return hashlib.sha256(manifest_payload).hexdigest() + + +async def clear_workspace_root_on_resume(session: BaseSandboxSession) -> None: + skip_rel_paths = session._workspace_resume_mount_skip_relpaths() + if any(rel_path in (Path(""), Path(".")) for rel_path in skip_rel_paths): + return + + await session._clear_workspace_dir_on_resume_pruned( + current_dir=session._workspace_root_path(), + skip_rel_paths=skip_rel_paths, + ) + + +def workspace_resume_mount_skip_relpaths(session: BaseSandboxSession) -> set[Path]: + root = session._workspace_root_path() + skip_rel_paths: set[Path] = set() + for _mount, mount_path in session.state.manifest.ephemeral_mount_targets(): + try: + skip_rel_paths.add(mount_path.relative_to(root)) + except ValueError: + continue + return skip_rel_paths + + +async def clear_workspace_dir_on_resume_pruned( + session: BaseSandboxSession, + *, + current_dir: Path, + skip_rel_paths: set[Path], +) -> None: + root = session._workspace_root_path() + try: + entries = await session.ls(current_dir) + except ExecNonZeroError: + # If the root or subtree doesn't exist (or isn't listable), treat it as empty and let + # hydrate/apply create it as needed. + return + + for entry in entries: + child = Path(entry.path) + try: + child_rel = child.relative_to(root) + except ValueError: + await session.rm(child, recursive=True) + continue + + if child_rel in skip_rel_paths: + continue + if any(child_rel in skip_rel_path.parents for skip_rel_path in skip_rel_paths): + if entry.kind == EntryKind.DIRECTORY: + await session._clear_workspace_dir_on_resume_pruned( + current_dir=child, + skip_rel_paths=skip_rel_paths, + ) + else: + await session.rm(child, recursive=True) + continue + # `parse_ls_la` filters "." and ".." already; remove everything else recursively. + await session.rm(child, recursive=True) + + +def _close_best_effort(stream: io.IOBase) -> None: + try: + stream.close() + except Exception: + pass + + +__all__ = [ + "SNAPSHOT_FINGERPRINT_VERSION", + "can_skip_snapshot_restore_on_resume", + "clear_workspace_dir_on_resume_pruned", + "clear_workspace_root_on_resume", + "compute_and_cache_snapshot_fingerprint", + "delete_cached_snapshot_fingerprint_best_effort", + "live_workspace_matches_snapshot_on_resume", + "parse_snapshot_fingerprint_record", + "persist_snapshot", + "read_cached_snapshot_fingerprint", + "restore_snapshot_into_workspace_on_resume", + "resume_manifest_digest", + "snapshot_fingerprint_cache_path", + "snapshot_fingerprint_version", + "workspace_fingerprint_skip_relpaths", + "workspace_resume_mount_skip_relpaths", +]