Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions src/agents/sandbox/session/archive_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import zipfile
from collections.abc import Awaitable, Callable, Iterator
from contextlib import contextmanager
from pathlib import Path, PurePosixPath
from pathlib import Path, PurePosixPath, PureWindowsPath
from typing import Literal, cast

from ..errors import ExecNonZeroError, WorkspaceArchiveWriteError
from ..files import EntryKind, FileEntry
from ..util.tar_utils import UnsafeTarMemberError, safe_tar_member_rel_path
from ..util.tar_utils import UnsafeTarMemberError, safe_tar_member_rel_path, validate_tarfile


class UnsafeZipMemberError(ValueError):
Expand Down Expand Up @@ -46,6 +46,7 @@ async def extract_tar_archive(
child_entry_cache: dict[Path, dict[str, EntryKind]] = {}
try:
with tarfile.open(fileobj=data, mode="r:*") as archive:
validate_tarfile(archive, allow_symlinks=False)
for member in archive.getmembers():
rel_path = safe_tar_member_rel_path(member)
if rel_path is None:
Expand Down Expand Up @@ -112,6 +113,7 @@ async def extract_zip_archive(
try:
with zipfile_compatible_stream(data) as zip_data:
with zipfile.ZipFile(zip_data) as archive:
validate_zipfile(archive)
for member in archive.infolist():
rel_path = safe_zip_member_rel_path(member)
if rel_path is None:
Expand Down Expand Up @@ -281,6 +283,12 @@ def safe_zip_member_rel_path(member: zipfile.ZipInfo) -> Path | None:
if member.filename in ("", ".", "./"):
return None

windows_path = PureWindowsPath(member.filename)
if windows_path.drive:
raise UnsafeZipMemberError(member=member.filename, reason="windows drive path")
if "\\" in member.filename:
raise UnsafeZipMemberError(member=member.filename, reason="windows path separator")

rel = PurePosixPath(member.filename)
if rel.is_absolute():
raise UnsafeZipMemberError(member=member.filename, reason="absolute path")
Expand All @@ -294,6 +302,36 @@ def safe_zip_member_rel_path(member: zipfile.ZipInfo) -> Path | None:
return Path(*rel.parts)


def validate_zipfile(archive: zipfile.ZipFile) -> None:
members_by_rel_path: dict[Path, zipfile.ZipInfo] = {}
members: list[tuple[zipfile.ZipInfo, Path]] = []

for member in archive.infolist():
rel_path = safe_zip_member_rel_path(member)
if rel_path is None:
continue

previous = members_by_rel_path.get(rel_path)
if previous is not None and not (previous.is_dir() and member.is_dir()):
raise UnsafeZipMemberError(
member=member.filename,
reason=f"duplicate archive path: {rel_path.as_posix()}",
)
members_by_rel_path[rel_path] = member
members.append((member, rel_path))

for member, rel_path in members:
for parent in rel_path.parents:
if parent == Path():
break
parent_member = members_by_rel_path.get(parent)
if parent_member is not None and not parent_member.is_dir():
raise UnsafeZipMemberError(
member=member.filename,
reason=f"archive path descends through non-directory: {parent.as_posix()}",
)


class _ZipFileStreamAdapter(io.IOBase):
# Python 3.10's zipfile._SharedFile reads `file.seekable` directly, so this
# adapter keeps ZIP-compatible random-access streams working across versions.
Expand Down
20 changes: 18 additions & 2 deletions src/agents/sandbox/util/tar_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import tarfile
import tempfile
from collections.abc import Iterable
from pathlib import Path, PurePosixPath
from pathlib import Path, PurePosixPath, PureWindowsPath


class UnsafeTarMemberError(ValueError):
Expand All @@ -27,6 +27,14 @@ def _validate_archive_root_member(member: tarfile.TarInfo) -> None:
raise UnsafeTarMemberError(member=member.name, reason="archive root member must be directory")


def _raise_if_windows_member_path(member_name: str) -> None:
windows_path = PureWindowsPath(member_name)
if windows_path.drive:
raise UnsafeTarMemberError(member=member_name, reason="windows drive path")
if "\\" in member_name:
raise UnsafeTarMemberError(member=member_name, reason="windows path separator")


def safe_tar_member_rel_path(
member: tarfile.TarInfo,
*,
Expand All @@ -37,6 +45,7 @@ def safe_tar_member_rel_path(
if member.name in ("", ".", "./"):
_validate_archive_root_member(member)
return None
_raise_if_windows_member_path(member.name)
rel = PurePosixPath(member.name)
if rel.is_absolute():
raise UnsafeTarMemberError(member=member.name, reason="absolute path")
Expand Down Expand Up @@ -189,6 +198,7 @@ def validate_tarfile(
reject_symlink_rel_paths: Iterable[str | Path] = (),
skip_rel_paths: Iterable[str | Path] = (),
root_name: str | None = None,
allow_symlinks: bool = True,
) -> None:
"""Validate a workspace tar before handing it to a local or remote extractor.

Expand All @@ -212,7 +222,7 @@ def validate_tarfile(
root_name=root_name,
):
continue
rel_path = safe_tar_member_rel_path(member, allow_symlinks=True)
rel_path = safe_tar_member_rel_path(member, allow_symlinks=allow_symlinks)
if rel_path is None:
continue

Expand Down Expand Up @@ -242,6 +252,12 @@ def validate_tarfile(
member=member.name,
reason=f"archive path descends through symlink: {parent.as_posix()}",
)
parent_member = members_by_rel_path.get(parent)
if parent_member is not None and not parent_member.isdir():
raise UnsafeTarMemberError(
member=member.name,
reason=f"archive path descends through non-directory: {parent.as_posix()}",
)


def validate_tar_bytes(
Expand Down
102 changes: 102 additions & 0 deletions tests/sandbox/test_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,28 @@ def _zip_bytes(*, members: dict[str, bytes]) -> io.BytesIO:
return buf


async def _assert_extract_rejects_member(
tmp_path: Path,
archive_name: str,
data: io.IOBase,
*,
expected_member: str,
expected_reason: str,
) -> Path:
session = _build_session(tmp_path)
await session.start()
try:
workspace = Path(session.state.manifest.root)
with pytest.raises(WorkspaceArchiveWriteError) as exc_info:
await session.extract(archive_name, data)

assert exc_info.value.context["member"] == expected_member
assert exc_info.value.context["reason"] == expected_reason
return workspace
finally:
await session.shutdown()


@pytest.mark.asyncio
async def test_extract_tar_writes_archive_and_unpacks_contents(tmp_path: Path) -> None:
session = _build_session(tmp_path)
Expand Down Expand Up @@ -300,6 +322,86 @@ async def test_extract_zip_rejects_symlinked_parent_paths(tmp_path: Path) -> Non
await session.shutdown()


@pytest.mark.asyncio
async def test_extract_tar_rejects_windows_drive_member_paths(tmp_path: Path) -> None:
await _assert_extract_rejects_member(
tmp_path,
"bundle.tar",
_tar_bytes(members={"C:/tmp/evil.txt": b"evil"}),
expected_member="C:/tmp/evil.txt",
expected_reason="windows drive path",
)


@pytest.mark.asyncio
async def test_extract_zip_rejects_windows_drive_member_paths(tmp_path: Path) -> None:
await _assert_extract_rejects_member(
tmp_path,
"bundle.zip",
_zip_bytes(members={r"C:\tmp\evil.txt": b"evil"}),
expected_member=r"C:\tmp\evil.txt",
expected_reason="windows drive path",
)


@pytest.mark.asyncio
async def test_extract_tar_rejects_windows_separator_member_paths(tmp_path: Path) -> None:
await _assert_extract_rejects_member(
tmp_path,
"bundle.tar",
_tar_bytes(members={r"..\evil.txt": b"evil"}),
expected_member=r"..\evil.txt",
expected_reason="windows path separator",
)


@pytest.mark.asyncio
async def test_extract_zip_rejects_windows_separator_member_paths(tmp_path: Path) -> None:
await _assert_extract_rejects_member(
tmp_path,
"bundle.zip",
_zip_bytes(members={r"\evil.txt": b"evil"}),
expected_member=r"\evil.txt",
expected_reason="windows path separator",
)


@pytest.mark.asyncio
async def test_extract_tar_rejects_member_under_non_directory_member(tmp_path: Path) -> None:
workspace = await _assert_extract_rejects_member(
tmp_path,
"bundle.tar",
_tar_bytes(
members={
"nested/hello.txt": b"hello from tar",
"nested": b"not a directory",
}
),
expected_member="nested/hello.txt",
expected_reason="archive path descends through non-directory: nested",
)

assert not (workspace / "nested").exists()


@pytest.mark.asyncio
async def test_extract_zip_rejects_member_under_non_directory_member(tmp_path: Path) -> None:
workspace = await _assert_extract_rejects_member(
tmp_path,
"bundle.zip",
_zip_bytes(
members={
"nested/hello.txt": b"hello from zip",
"nested": b"not a directory",
}
),
expected_member="nested/hello.txt",
expected_reason="archive path descends through non-directory: nested",
)

assert not (workspace / "nested").exists()


@pytest.mark.asyncio
async def test_unix_local_persist_workspace_excludes_resolved_mount_path(tmp_path: Path) -> None:
workspace_root = tmp_path / "workspace"
Expand Down
29 changes: 29 additions & 0 deletions tests/sandbox/test_tar_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,35 @@ def test_validate_tar_bytes_rejects_root_symlink() -> None:
validate_tar_bytes(raw)


@pytest.mark.parametrize("member_name", ["C:/tmp/evil.txt", r"C:\tmp\evil.txt"])
def test_validate_tar_bytes_rejects_windows_drive_member_paths(member_name: str) -> None:
raw = _tar_bytes(_file(member_name, b"evil"))

with pytest.raises(UnsafeTarMemberError, match="windows drive path"):
validate_tar_bytes(raw)


@pytest.mark.parametrize("member_name", [r"..\evil.txt", r"\evil.txt", r"nested\evil.txt"])
def test_validate_tar_bytes_rejects_windows_separator_member_paths(member_name: str) -> None:
raw = _tar_bytes(_file(member_name, b"evil"))

with pytest.raises(UnsafeTarMemberError, match="windows path separator"):
validate_tar_bytes(raw)


def test_validate_tar_bytes_rejects_member_under_non_directory_member() -> None:
raw = _tar_bytes(
_file("nested/hello.txt", b"hello"),
_file("nested", b"not a directory"),
)

with pytest.raises(
UnsafeTarMemberError,
match="archive path descends through non-directory: nested",
):
validate_tar_bytes(raw)


def test_strip_tar_member_prefix_returns_workspace_relative_archive() -> None:
raw = _tar_bytes(
_dir("workspace"),
Expand Down