Skip to content

Commit a2e7b8c

Browse files
committed
FEAT Add safe_extract_zip helper for defensive remote ZIP extraction
Signed-off-by: francose <13445813+francose@users.noreply.github.com>
1 parent 4c12796 commit a2e7b8c

5 files changed

Lines changed: 461 additions & 10 deletions

File tree

pyrit/common/safe_extract.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
Defensive ZIP extraction for untrusted remote archives.
6+
7+
Remote dataset loaders in PyRIT download ZIP archives from third-party sources
8+
and feed them to ``zipfile.ZipFile.extractall()``. ``extractall`` does not
9+
validate member paths, file sizes, or entry types, which leaves the loader
10+
vulnerable to Zip Slip (CWE-22), zip bombs, and symlink-based path escape if
11+
any upstream source is tampered with.
12+
13+
``safe_extract_zip`` validates every archive member before writing anything to
14+
disk. If any member fails validation, no archive members are written from the
15+
failing call (pre-existing contents of ``dest_dir`` are untouched).
16+
"""
17+
18+
from __future__ import annotations
19+
20+
import io
21+
import logging
22+
import os
23+
import stat
24+
import zipfile
25+
from pathlib import Path
26+
from typing import IO
27+
28+
logger = logging.getLogger(__name__)
29+
30+
# 5 GiB cumulative uncompressed size across all members
31+
DEFAULT_MAX_TOTAL_SIZE = 5 * 1024**3
32+
# 1 GiB cap on any single member
33+
DEFAULT_MAX_FILE_SIZE = 1 * 1024**3
34+
# 50_000 entries: above legitimate dataset sizes, defeats inode DoS
35+
DEFAULT_MAX_FILE_COUNT = 50_000
36+
# Reject members whose uncompressed/compressed ratio exceeds this (zip bomb)
37+
DEFAULT_MAX_COMPRESSION_RATIO = 100
38+
39+
# Sanitized permissions applied to extracted entries, stripping any setuid /
40+
# setgid / sticky / world-write bits the archive may have requested.
41+
_EXTRACTED_FILE_MODE = 0o644
42+
_EXTRACTED_DIR_MODE = 0o755
43+
44+
# Predicates for entry types we refuse to extract.
45+
_DISALLOWED_TYPE_PREDICATES = (
46+
stat.S_ISLNK,
47+
stat.S_ISBLK,
48+
stat.S_ISCHR,
49+
stat.S_ISFIFO,
50+
stat.S_ISSOCK,
51+
)
52+
53+
ZipSource = str | os.PathLike | bytes | IO[bytes]
54+
55+
56+
class UnsafeArchiveError(Exception):
57+
"""Raised when an archive member fails a safe-extraction precondition."""
58+
59+
60+
def safe_extract_zip(
61+
*,
62+
source: ZipSource,
63+
dest_dir: str | os.PathLike,
64+
max_total_size: int = DEFAULT_MAX_TOTAL_SIZE,
65+
max_file_size: int = DEFAULT_MAX_FILE_SIZE,
66+
max_file_count: int = DEFAULT_MAX_FILE_COUNT,
67+
max_compression_ratio: int = DEFAULT_MAX_COMPRESSION_RATIO,
68+
) -> Path:
69+
"""
70+
Extract a ZIP archive after validating every member.
71+
72+
Validation runs in a single pass over the archive's central directory
73+
before any bytes are written. If any check fails, ``UnsafeArchiveError`` is
74+
raised and no archive members are written from this call. After extraction
75+
each member's filesystem mode is replaced with a sanitized default so a
76+
tampered archive cannot set setuid/setgid/sticky/exec bits on the host.
77+
78+
Args:
79+
source: Path, bytes, or file-like object accepted by ``zipfile.ZipFile``.
80+
dest_dir: Directory to extract into. Created if it does not exist.
81+
max_total_size: Cap on the sum of uncompressed member sizes.
82+
max_file_size: Cap on any single member's uncompressed size.
83+
max_file_count: Cap on the number of members in the archive.
84+
max_compression_ratio: Reject members whose uncompressed/compressed
85+
ratio exceeds this value (zip bomb defense).
86+
87+
Returns:
88+
Resolved destination directory.
89+
90+
Raises:
91+
UnsafeArchiveError: If any member fails validation.
92+
"""
93+
if isinstance(source, (bytes, bytearray)):
94+
source = io.BytesIO(source)
95+
96+
dest_real = Path(dest_dir).resolve()
97+
dest_real.mkdir(parents=True, exist_ok=True)
98+
99+
with zipfile.ZipFile(source) as zf:
100+
members = zf.infolist()
101+
try:
102+
_validate_members(
103+
members,
104+
dest_real=dest_real,
105+
max_total_size=max_total_size,
106+
max_file_size=max_file_size,
107+
max_file_count=max_file_count,
108+
max_compression_ratio=max_compression_ratio,
109+
)
110+
except UnsafeArchiveError as exc:
111+
logger.warning("safe_extract_zip rejected archive: %s", exc)
112+
raise
113+
for m in members:
114+
extracted = Path(zf.extract(m, dest_real))
115+
_sanitize_extracted_permissions(extracted)
116+
117+
return dest_real
118+
119+
120+
def _sanitize_extracted_permissions(path: Path) -> None:
121+
# zipfile.ZipFile.extract applies the archive's external_attr mode bits on
122+
# POSIX, so a tampered archive can request setuid/setgid/sticky or
123+
# executable bits on extracted entries. Replace with a sane default.
124+
try:
125+
if path.is_dir():
126+
os.chmod(path, _EXTRACTED_DIR_MODE)
127+
else:
128+
os.chmod(path, _EXTRACTED_FILE_MODE)
129+
except OSError as exc:
130+
logger.warning("safe_extract_zip could not chmod %s: %s", path, exc)
131+
132+
133+
def _validate_members(
134+
members: list[zipfile.ZipInfo],
135+
*,
136+
dest_real: Path,
137+
max_total_size: int,
138+
max_file_size: int,
139+
max_file_count: int,
140+
max_compression_ratio: int,
141+
) -> None:
142+
if len(members) > max_file_count:
143+
raise UnsafeArchiveError(f"archive contains {len(members)} entries (max {max_file_count})")
144+
145+
total = 0
146+
for m in members:
147+
_reject_disallowed_entry_type(m)
148+
_reject_absolute_path(m)
149+
_reject_path_traversal(m, dest_real)
150+
_reject_oversized_member(m, max_file_size=max_file_size)
151+
_reject_compression_bomb(m, max_ratio=max_compression_ratio)
152+
153+
total += m.file_size
154+
if total > max_total_size:
155+
raise UnsafeArchiveError(f"total uncompressed size exceeds {max_total_size} bytes")
156+
157+
158+
def _reject_disallowed_entry_type(m: zipfile.ZipInfo) -> None:
159+
# The upper 16 bits of external_attr hold the Unix mode when the archive
160+
# was created on a Unix system. Check unconditionally because create_system
161+
# is attacker-controlled metadata: a zip crafted with create_system=0 (DOS)
162+
# but Unix-style mode bits set should still be rejected.
163+
mode = m.external_attr >> 16
164+
if any(predicate(mode) for predicate in _DISALLOWED_TYPE_PREDICATES):
165+
raise UnsafeArchiveError(f"disallowed entry type: {m.filename}")
166+
167+
168+
def _reject_absolute_path(m: zipfile.ZipInfo) -> None:
169+
name = m.filename
170+
if name.startswith(("/", "\\")):
171+
raise UnsafeArchiveError(f"absolute path in archive: {name}")
172+
if len(name) >= 2 and name[1] == ":":
173+
raise UnsafeArchiveError(f"drive-letter path in archive: {name}")
174+
175+
176+
def _reject_path_traversal(m: zipfile.ZipInfo, dest_real: Path) -> None:
177+
# Explicit null-byte check: Path.resolve() only raises ValueError for
178+
# embedded null bytes on POSIX. On Windows the path round-trips with the
179+
# null byte intact, so we need an OS-independent guard up front.
180+
if "\x00" in m.filename:
181+
raise UnsafeArchiveError(f"invalid characters in archive entry: {m.filename!r}")
182+
try:
183+
target = (dest_real / m.filename).resolve()
184+
except ValueError as exc:
185+
# Fallback for any other ValueError from Path construction or resolve.
186+
raise UnsafeArchiveError(f"invalid characters in archive entry: {m.filename!r}") from exc
187+
try:
188+
target.relative_to(dest_real)
189+
except ValueError as exc:
190+
raise UnsafeArchiveError(f"path traversal in archive: {m.filename!r} escapes {dest_real}") from exc
191+
192+
193+
def _reject_oversized_member(m: zipfile.ZipInfo, *, max_file_size: int) -> None:
194+
if m.file_size > max_file_size:
195+
raise UnsafeArchiveError(f"member {m.filename!r} uncompressed size {m.file_size} exceeds cap {max_file_size}")
196+
197+
198+
def _reject_compression_bomb(m: zipfile.ZipInfo, *, max_ratio: int) -> None:
199+
if m.file_size <= 0:
200+
return
201+
if m.compress_size <= 0:
202+
# Declared non-zero uncompressed size with zero compressed size is
203+
# malformed metadata, refuse rather than skip the ratio check.
204+
raise UnsafeArchiveError(
205+
f"member {m.filename!r} declares uncompressed size {m.file_size} but compressed size {m.compress_size}"
206+
)
207+
ratio = m.file_size / m.compress_size
208+
if ratio > max_ratio:
209+
raise UnsafeArchiveError(f"member {m.filename!r} compression ratio {ratio:.1f} exceeds cap {max_ratio}")

pyrit/datasets/seed_datasets/remote/figstep_dataset.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import logging
77
import re
88
import uuid
9-
import zipfile
109
from enum import Enum
1110
from pathlib import Path
1211
from typing import TYPE_CHECKING, Literal
@@ -15,6 +14,7 @@
1514

1615
from pyrit.common.net_utility import make_request_and_raise_if_error_async
1716
from pyrit.common.path import DB_DATA_PATH
17+
from pyrit.common.safe_extract import safe_extract_zip
1818
from pyrit.datasets.seed_datasets.remote._image_cache import (
1919
fetch_and_cache_image_async,
2020
)
@@ -562,9 +562,7 @@ async def _download_and_extract_pro_zip_async(self, *, cache: bool) -> Path:
562562
zip_bytes = response.content
563563

564564
def _extract() -> None:
565-
extract_dir.mkdir(parents=True, exist_ok=True)
566-
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
567-
zf.extractall(extract_dir)
565+
safe_extract_zip(source=io.BytesIO(zip_bytes), dest_dir=extract_dir)
568566

569567
await asyncio.to_thread(_extract)
570568
return extract_dir

pyrit/datasets/seed_datasets/remote/jailbreakv_28k_dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT license.
33

4+
import asyncio
45
import logging
56
import pathlib
67
import uuid
7-
import zipfile
88
from enum import Enum
99
from typing import Literal
1010

11+
from pyrit.common.safe_extract import safe_extract_zip
1112
from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import (
1213
_RemoteDatasetLoader,
1314
)
@@ -149,8 +150,7 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset:
149150
# Only unzip if the target directory does not already exist
150151
if not zip_extracted_path.exists():
151152
logger.info(f"Extracting {zip_file_path} to {self.zip_dir}")
152-
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
153-
zip_ref.extractall(self.zip_dir)
153+
await asyncio.to_thread(safe_extract_zip, source=zip_file_path, dest_dir=self.zip_dir)
154154

155155
try:
156156
logger.info(f"Loading JailBreakV-28K dataset from {self.source}")

pyrit/datasets/seed_datasets/remote/vlguard_dataset.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import logging
77
import os
88
import uuid
9-
import zipfile
109
from enum import Enum
1110
from pathlib import Path
1211
from typing import TYPE_CHECKING
@@ -15,6 +14,7 @@
1514
from typing_extensions import override
1615

1716
from pyrit.common.path import DB_DATA_PATH
17+
from pyrit.common.safe_extract import safe_extract_zip
1818
from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import (
1919
_RemoteDatasetLoader,
2020
)
@@ -329,8 +329,7 @@ def _download_sync() -> tuple[str, str]:
329329
zip_path = cache_dir / "test.zip"
330330
if zip_path.exists():
331331
logger.info("Extracting VLGuard test images...")
332-
with zipfile.ZipFile(str(zip_path), "r") as zf:
333-
zf.extractall(str(cache_dir))
332+
await asyncio.to_thread(safe_extract_zip, source=zip_path, dest_dir=cache_dir)
334333

335334
with open(json_path, encoding="utf-8") as f:
336335
metadata = json.load(f)

0 commit comments

Comments
 (0)