Skip to content

Commit bac6c33

Browse files
committed
FEAT Add safe_extract_zip helper for defensive remote ZIP extraction
Signed-off-by: francose <13445813+francose@users.noreply.github.com>
1 parent 4c12796 commit bac6c33

5 files changed

Lines changed: 331 additions & 10 deletions

File tree

pyrit/common/safe_extract.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
Defensive ZIP extraction for untrusted remote archives.
6+
7+
Remote dataset loaders in PyRIT download ZIP archives from third-party sources
8+
and feed them to ``zipfile.ZipFile.extractall()``. ``extractall`` does not
9+
validate member paths, file sizes, or entry types, which leaves the loader
10+
vulnerable to Zip Slip (CWE-22), zip bombs, and symlink-based path escape if
11+
any upstream source is tampered with.
12+
13+
``safe_extract_zip`` validates every archive member before writing anything to
14+
disk. If any member fails validation, the destination directory is left empty.
15+
"""
16+
17+
from __future__ import annotations
18+
19+
import io
20+
import logging
21+
import os
22+
import stat
23+
import zipfile
24+
from pathlib import Path
25+
from typing import IO
26+
27+
logger = logging.getLogger(__name__)
28+
29+
# 5 GiB cumulative uncompressed size across all members
30+
DEFAULT_MAX_TOTAL_SIZE = 5 * 1024**3
31+
# 1 GiB cap on any single member
32+
DEFAULT_MAX_FILE_SIZE = 1 * 1024**3
33+
# 50_000 entries: above legitimate dataset sizes, defeats inode DoS
34+
DEFAULT_MAX_FILE_COUNT = 50_000
35+
# Reject members whose uncompressed/compressed ratio exceeds this (zip bomb)
36+
DEFAULT_MAX_COMPRESSION_RATIO = 100
37+
38+
ZipSource = str | os.PathLike | bytes | IO[bytes]
39+
40+
41+
class UnsafeArchiveError(Exception):
42+
"""Raised when an archive member fails a safe-extraction precondition."""
43+
44+
45+
def safe_extract_zip(
46+
source: ZipSource,
47+
dest_dir: str | os.PathLike,
48+
*,
49+
max_total_size: int = DEFAULT_MAX_TOTAL_SIZE,
50+
max_file_size: int = DEFAULT_MAX_FILE_SIZE,
51+
max_file_count: int = DEFAULT_MAX_FILE_COUNT,
52+
max_compression_ratio: int = DEFAULT_MAX_COMPRESSION_RATIO,
53+
) -> Path:
54+
"""
55+
Extract a ZIP archive after validating every member.
56+
57+
Validation runs in a single pass over the archive's central directory
58+
before any bytes are written. If any check fails, ``UnsafeArchiveError`` is
59+
raised and the destination directory is left without partial output from
60+
this call.
61+
62+
Args:
63+
source: Path, bytes, or file-like object accepted by ``zipfile.ZipFile``.
64+
dest_dir: Directory to extract into. Created if it does not exist.
65+
max_total_size: Cap on the sum of uncompressed member sizes.
66+
max_file_size: Cap on any single member's uncompressed size.
67+
max_file_count: Cap on the number of members in the archive.
68+
max_compression_ratio: Reject members whose uncompressed/compressed
69+
ratio exceeds this value (zip bomb defense).
70+
71+
Returns:
72+
Resolved destination directory.
73+
74+
Raises:
75+
UnsafeArchiveError: If any member fails validation.
76+
"""
77+
if isinstance(source, (bytes, bytearray)):
78+
source = io.BytesIO(source)
79+
80+
dest_real = Path(dest_dir).resolve()
81+
dest_real.mkdir(parents=True, exist_ok=True)
82+
83+
with zipfile.ZipFile(source) as zf:
84+
members = zf.infolist()
85+
_validate_members(
86+
members,
87+
dest_real=dest_real,
88+
max_total_size=max_total_size,
89+
max_file_size=max_file_size,
90+
max_file_count=max_file_count,
91+
max_compression_ratio=max_compression_ratio,
92+
)
93+
for m in members:
94+
zf.extract(m, dest_real)
95+
96+
return dest_real
97+
98+
99+
def _validate_members(
100+
members: list[zipfile.ZipInfo],
101+
*,
102+
dest_real: Path,
103+
max_total_size: int,
104+
max_file_size: int,
105+
max_file_count: int,
106+
max_compression_ratio: int,
107+
) -> None:
108+
if len(members) > max_file_count:
109+
raise UnsafeArchiveError(f"archive contains {len(members)} entries (max {max_file_count})")
110+
111+
total = 0
112+
for m in members:
113+
_reject_disallowed_entry_type(m)
114+
_reject_absolute_path(m)
115+
_reject_path_traversal(m, dest_real)
116+
_reject_oversized_member(m, max_file_size=max_file_size)
117+
_reject_compression_bomb(m, max_ratio=max_compression_ratio)
118+
119+
total += m.file_size
120+
if total > max_total_size:
121+
raise UnsafeArchiveError(f"total uncompressed size exceeds {max_total_size} bytes")
122+
123+
124+
def _reject_disallowed_entry_type(m: zipfile.ZipInfo) -> None:
125+
# Unix mode lives in the upper 16 bits of external_attr when create_system==3.
126+
if m.create_system != 3:
127+
return
128+
mode = m.external_attr >> 16
129+
if stat.S_ISLNK(mode) or stat.S_ISBLK(mode) or stat.S_ISCHR(mode) or stat.S_ISFIFO(mode) or stat.S_ISSOCK(mode):
130+
raise UnsafeArchiveError(f"disallowed entry type: {m.filename}")
131+
132+
133+
def _reject_absolute_path(m: zipfile.ZipInfo) -> None:
134+
name = m.filename
135+
if name.startswith(("/", "\\")):
136+
raise UnsafeArchiveError(f"absolute path in archive: {name}")
137+
if len(name) >= 2 and name[1] == ":":
138+
raise UnsafeArchiveError(f"drive-letter path in archive: {name}")
139+
140+
141+
def _reject_path_traversal(m: zipfile.ZipInfo, dest_real: Path) -> None:
142+
target = (dest_real / m.filename).resolve()
143+
try:
144+
target.relative_to(dest_real)
145+
except ValueError as exc:
146+
raise UnsafeArchiveError(f"path traversal in archive: {m.filename!r} escapes {dest_real}") from exc
147+
148+
149+
def _reject_oversized_member(m: zipfile.ZipInfo, *, max_file_size: int) -> None:
150+
if m.file_size > max_file_size:
151+
raise UnsafeArchiveError(f"member {m.filename!r} uncompressed size {m.file_size} exceeds cap {max_file_size}")
152+
153+
154+
def _reject_compression_bomb(m: zipfile.ZipInfo, *, max_ratio: int) -> None:
155+
if m.compress_size <= 0 or m.file_size <= 0:
156+
return
157+
ratio = m.file_size / m.compress_size
158+
if ratio > max_ratio:
159+
raise UnsafeArchiveError(f"member {m.filename!r} compression ratio {ratio:.1f} exceeds cap {max_ratio}")

pyrit/datasets/seed_datasets/remote/figstep_dataset.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import logging
77
import re
88
import uuid
9-
import zipfile
109
from enum import Enum
1110
from pathlib import Path
1211
from typing import TYPE_CHECKING, Literal
@@ -15,6 +14,7 @@
1514

1615
from pyrit.common.net_utility import make_request_and_raise_if_error_async
1716
from pyrit.common.path import DB_DATA_PATH
17+
from pyrit.common.safe_extract import safe_extract_zip
1818
from pyrit.datasets.seed_datasets.remote._image_cache import (
1919
fetch_and_cache_image_async,
2020
)
@@ -562,9 +562,7 @@ async def _download_and_extract_pro_zip_async(self, *, cache: bool) -> Path:
562562
zip_bytes = response.content
563563

564564
def _extract() -> None:
565-
extract_dir.mkdir(parents=True, exist_ok=True)
566-
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
567-
zf.extractall(extract_dir)
565+
safe_extract_zip(io.BytesIO(zip_bytes), extract_dir)
568566

569567
await asyncio.to_thread(_extract)
570568
return extract_dir

pyrit/datasets/seed_datasets/remote/jailbreakv_28k_dataset.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
import logging
55
import pathlib
66
import uuid
7-
import zipfile
87
from enum import Enum
98
from typing import Literal
109

10+
from pyrit.common.safe_extract import safe_extract_zip
1111
from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import (
1212
_RemoteDatasetLoader,
1313
)
@@ -149,8 +149,7 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset:
149149
# Only unzip if the target directory does not already exist
150150
if not zip_extracted_path.exists():
151151
logger.info(f"Extracting {zip_file_path} to {self.zip_dir}")
152-
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
153-
zip_ref.extractall(self.zip_dir)
152+
safe_extract_zip(zip_file_path, self.zip_dir)
154153

155154
try:
156155
logger.info(f"Loading JailBreakV-28K dataset from {self.source}")

pyrit/datasets/seed_datasets/remote/vlguard_dataset.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import logging
77
import os
88
import uuid
9-
import zipfile
109
from enum import Enum
1110
from pathlib import Path
1211
from typing import TYPE_CHECKING
@@ -15,6 +14,7 @@
1514
from typing_extensions import override
1615

1716
from pyrit.common.path import DB_DATA_PATH
17+
from pyrit.common.safe_extract import safe_extract_zip
1818
from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import (
1919
_RemoteDatasetLoader,
2020
)
@@ -329,8 +329,7 @@ def _download_sync() -> tuple[str, str]:
329329
zip_path = cache_dir / "test.zip"
330330
if zip_path.exists():
331331
logger.info("Extracting VLGuard test images...")
332-
with zipfile.ZipFile(str(zip_path), "r") as zf:
333-
zf.extractall(str(cache_dir))
332+
safe_extract_zip(zip_path, cache_dir)
334333

335334
with open(json_path, encoding="utf-8") as f:
336335
metadata = json.load(f)

0 commit comments

Comments
 (0)