Skip to content

Commit 1f1e792

Browse files
committed
FEAT Add safe_extract_zip helper for defensive remote ZIP extraction
Signed-off-by: francose <13445813+francose@users.noreply.github.com>
1 parent 4c12796 commit 1f1e792

5 files changed

Lines changed: 425 additions & 10 deletions

File tree

pyrit/common/safe_extract.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
Defensive ZIP extraction for untrusted remote archives.
6+
7+
Remote dataset loaders in PyRIT download ZIP archives from third-party sources
8+
and feed them to ``zipfile.ZipFile.extractall()``. ``extractall`` does not
9+
validate member paths, file sizes, or entry types, which leaves the loader
10+
vulnerable to Zip Slip (CWE-22), zip bombs, and symlink-based path escape if
11+
any upstream source is tampered with.
12+
13+
``safe_extract_zip`` validates every archive member before writing anything to
14+
disk. If any member fails validation, the destination directory is left empty.
15+
"""
16+
17+
from __future__ import annotations
18+
19+
import io
20+
import logging
21+
import os
22+
import stat
23+
import zipfile
24+
from pathlib import Path
25+
from typing import IO
26+
27+
logger = logging.getLogger(__name__)
28+
29+
# 5 GiB cumulative uncompressed size across all members
30+
DEFAULT_MAX_TOTAL_SIZE = 5 * 1024**3
31+
# 1 GiB cap on any single member
32+
DEFAULT_MAX_FILE_SIZE = 1 * 1024**3
33+
# 50_000 entries: above legitimate dataset sizes, defeats inode DoS
34+
DEFAULT_MAX_FILE_COUNT = 50_000
35+
# Reject members whose uncompressed/compressed ratio exceeds this (zip bomb)
36+
DEFAULT_MAX_COMPRESSION_RATIO = 100
37+
38+
ZipSource = str | os.PathLike | bytes | IO[bytes]
39+
40+
41+
class UnsafeArchiveError(Exception):
42+
"""Raised when an archive member fails a safe-extraction precondition."""
43+
44+
45+
def safe_extract_zip(
46+
source: ZipSource,
47+
dest_dir: str | os.PathLike,
48+
*,
49+
max_total_size: int = DEFAULT_MAX_TOTAL_SIZE,
50+
max_file_size: int = DEFAULT_MAX_FILE_SIZE,
51+
max_file_count: int = DEFAULT_MAX_FILE_COUNT,
52+
max_compression_ratio: int = DEFAULT_MAX_COMPRESSION_RATIO,
53+
) -> Path:
54+
"""
55+
Extract a ZIP archive after validating every member.
56+
57+
Validation runs in a single pass over the archive's central directory
58+
before any bytes are written. If any check fails, ``UnsafeArchiveError`` is
59+
raised and the destination directory is left without partial output from
60+
this call.
61+
62+
Args:
63+
source: Path, bytes, or file-like object accepted by ``zipfile.ZipFile``.
64+
dest_dir: Directory to extract into. Created if it does not exist.
65+
max_total_size: Cap on the sum of uncompressed member sizes.
66+
max_file_size: Cap on any single member's uncompressed size.
67+
max_file_count: Cap on the number of members in the archive.
68+
max_compression_ratio: Reject members whose uncompressed/compressed
69+
ratio exceeds this value (zip bomb defense).
70+
71+
Returns:
72+
Resolved destination directory.
73+
74+
Raises:
75+
UnsafeArchiveError: If any member fails validation.
76+
"""
77+
if isinstance(source, (bytes, bytearray)):
78+
source = io.BytesIO(source)
79+
80+
dest_real = Path(dest_dir).resolve()
81+
dest_real.mkdir(parents=True, exist_ok=True)
82+
83+
with zipfile.ZipFile(source) as zf:
84+
members = zf.infolist()
85+
try:
86+
_validate_members(
87+
members,
88+
dest_real=dest_real,
89+
max_total_size=max_total_size,
90+
max_file_size=max_file_size,
91+
max_file_count=max_file_count,
92+
max_compression_ratio=max_compression_ratio,
93+
)
94+
except UnsafeArchiveError as exc:
95+
logger.warning("safe_extract_zip rejected archive: %s", exc)
96+
raise
97+
for m in members:
98+
zf.extract(m, dest_real)
99+
100+
return dest_real
101+
102+
103+
def _validate_members(
104+
members: list[zipfile.ZipInfo],
105+
*,
106+
dest_real: Path,
107+
max_total_size: int,
108+
max_file_size: int,
109+
max_file_count: int,
110+
max_compression_ratio: int,
111+
) -> None:
112+
if len(members) > max_file_count:
113+
raise UnsafeArchiveError(f"archive contains {len(members)} entries (max {max_file_count})")
114+
115+
total = 0
116+
for m in members:
117+
_reject_disallowed_entry_type(m)
118+
_reject_absolute_path(m)
119+
_reject_path_traversal(m, dest_real)
120+
_reject_oversized_member(m, max_file_size=max_file_size)
121+
_reject_compression_bomb(m, max_ratio=max_compression_ratio)
122+
123+
total += m.file_size
124+
if total > max_total_size:
125+
raise UnsafeArchiveError(f"total uncompressed size exceeds {max_total_size} bytes")
126+
127+
128+
def _reject_disallowed_entry_type(m: zipfile.ZipInfo) -> None:
129+
# The upper 16 bits of external_attr hold the Unix mode when the archive was
130+
# created on a Unix system. We check unconditionally because create_system
131+
# is attacker-controlled metadata and the check is essentially free; a zip
132+
# crafted with create_system=0 (DOS) but Unix-style mode bits set should
133+
# still be rejected.
134+
mode = m.external_attr >> 16
135+
if stat.S_ISLNK(mode) or stat.S_ISBLK(mode) or stat.S_ISCHR(mode) or stat.S_ISFIFO(mode) or stat.S_ISSOCK(mode):
136+
raise UnsafeArchiveError(f"disallowed entry type: {m.filename}")
137+
138+
139+
def _reject_absolute_path(m: zipfile.ZipInfo) -> None:
140+
name = m.filename
141+
if name.startswith(("/", "\\")):
142+
raise UnsafeArchiveError(f"absolute path in archive: {name}")
143+
if len(name) >= 2 and name[1] == ":":
144+
raise UnsafeArchiveError(f"drive-letter path in archive: {name}")
145+
146+
147+
def _reject_path_traversal(m: zipfile.ZipInfo, dest_real: Path) -> None:
148+
try:
149+
target = (dest_real / m.filename).resolve()
150+
except ValueError as exc:
151+
# Path raises ValueError on null bytes and other invalid path characters.
152+
raise UnsafeArchiveError(f"invalid characters in archive entry: {m.filename!r}") from exc
153+
try:
154+
target.relative_to(dest_real)
155+
except ValueError as exc:
156+
raise UnsafeArchiveError(f"path traversal in archive: {m.filename!r} escapes {dest_real}") from exc
157+
158+
159+
def _reject_oversized_member(m: zipfile.ZipInfo, *, max_file_size: int) -> None:
160+
if m.file_size > max_file_size:
161+
raise UnsafeArchiveError(f"member {m.filename!r} uncompressed size {m.file_size} exceeds cap {max_file_size}")
162+
163+
164+
def _reject_compression_bomb(m: zipfile.ZipInfo, *, max_ratio: int) -> None:
165+
if m.file_size <= 0:
166+
return
167+
if m.compress_size <= 0:
168+
# Declared non-zero uncompressed size with zero compressed size is
169+
# malformed metadata, refuse rather than skip the ratio check.
170+
raise UnsafeArchiveError(
171+
f"member {m.filename!r} declares uncompressed size {m.file_size} but compressed size {m.compress_size}"
172+
)
173+
ratio = m.file_size / m.compress_size
174+
if ratio > max_ratio:
175+
raise UnsafeArchiveError(f"member {m.filename!r} compression ratio {ratio:.1f} exceeds cap {max_ratio}")

pyrit/datasets/seed_datasets/remote/figstep_dataset.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import logging
77
import re
88
import uuid
9-
import zipfile
109
from enum import Enum
1110
from pathlib import Path
1211
from typing import TYPE_CHECKING, Literal
@@ -15,6 +14,7 @@
1514

1615
from pyrit.common.net_utility import make_request_and_raise_if_error_async
1716
from pyrit.common.path import DB_DATA_PATH
17+
from pyrit.common.safe_extract import safe_extract_zip
1818
from pyrit.datasets.seed_datasets.remote._image_cache import (
1919
fetch_and_cache_image_async,
2020
)
@@ -562,9 +562,7 @@ async def _download_and_extract_pro_zip_async(self, *, cache: bool) -> Path:
562562
zip_bytes = response.content
563563

564564
def _extract() -> None:
565-
extract_dir.mkdir(parents=True, exist_ok=True)
566-
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
567-
zf.extractall(extract_dir)
565+
safe_extract_zip(io.BytesIO(zip_bytes), extract_dir)
568566

569567
await asyncio.to_thread(_extract)
570568
return extract_dir

pyrit/datasets/seed_datasets/remote/jailbreakv_28k_dataset.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
import logging
55
import pathlib
66
import uuid
7-
import zipfile
87
from enum import Enum
98
from typing import Literal
109

10+
from pyrit.common.safe_extract import safe_extract_zip
1111
from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import (
1212
_RemoteDatasetLoader,
1313
)
@@ -149,8 +149,7 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset:
149149
# Only unzip if the target directory does not already exist
150150
if not zip_extracted_path.exists():
151151
logger.info(f"Extracting {zip_file_path} to {self.zip_dir}")
152-
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
153-
zip_ref.extractall(self.zip_dir)
152+
safe_extract_zip(zip_file_path, self.zip_dir)
154153

155154
try:
156155
logger.info(f"Loading JailBreakV-28K dataset from {self.source}")

pyrit/datasets/seed_datasets/remote/vlguard_dataset.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import logging
77
import os
88
import uuid
9-
import zipfile
109
from enum import Enum
1110
from pathlib import Path
1211
from typing import TYPE_CHECKING
@@ -15,6 +14,7 @@
1514
from typing_extensions import override
1615

1716
from pyrit.common.path import DB_DATA_PATH
17+
from pyrit.common.safe_extract import safe_extract_zip
1818
from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import (
1919
_RemoteDatasetLoader,
2020
)
@@ -329,8 +329,7 @@ def _download_sync() -> tuple[str, str]:
329329
zip_path = cache_dir / "test.zip"
330330
if zip_path.exists():
331331
logger.info("Extracting VLGuard test images...")
332-
with zipfile.ZipFile(str(zip_path), "r") as zf:
333-
zf.extractall(str(cache_dir))
332+
safe_extract_zip(zip_path, cache_dir)
334333

335334
with open(json_path, encoding="utf-8") as f:
336335
metadata = json.load(f)

0 commit comments

Comments
 (0)