Skip to content

Commit 46ac099

Browse files
ayushhgarg-workayushhgarg03jayesh-tanna
authored
Path traversal protection in unzip to temp file (#45691)
* Path traversal protection in unzip to temp file * ghcp changes * fix black formatting and pylint errors * tox formatting --------- Co-authored-by: Ayushh Garg <ayushhgarg@microsoft.com> Co-authored-by: Jayesh Tanna <tanna.jay90@gmail.com>
1 parent ff8bfa9 commit 46ac099

2 files changed

Lines changed: 186 additions & 6 deletions

File tree

sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_job_invoker.py

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,15 @@
4141
def unzip_to_temporary_file(job_definition: JobBaseData, zip_content: Any) -> Path:
4242
temp_dir = Path(tempfile.gettempdir(), AZUREML_RUNS_DIR, job_definition.name)
4343
temp_dir.mkdir(parents=True, exist_ok=True)
44+
resolved_temp_dir = temp_dir.resolve()
4445
with zipfile.ZipFile(io.BytesIO(zip_content)) as zip_ref:
46+
for member in zip_ref.namelist():
47+
member_path = (resolved_temp_dir / member).resolve()
48+
# Ensure the member extracts within temp_dir (allow temp_dir itself for directory entries)
49+
if member_path != resolved_temp_dir and not str(member_path).startswith(str(resolved_temp_dir) + os.sep):
50+
raise ValueError(
51+
f"Zip archive contains a path traversal entry and cannot be extracted safely: {member}"
52+
)
4553
zip_ref.extractall(temp_dir)
4654
return temp_dir
4755

@@ -142,7 +150,7 @@ def get_execution_service_response(
142150
try:
143151
local = job_definition.properties.services.get("Local", None)
144152

145-
(url, encodedBody) = local.endpoint.split(EXECUTION_SERVICE_URL_KEY)
153+
url, encodedBody = local.endpoint.split(EXECUTION_SERVICE_URL_KEY)
146154
body = urllib.parse.unquote_plus(encodedBody)
147155
body_dict: Dict = json.loads(body)
148156
response = requests_pipeline.post(url, json=body_dict, headers={"Authorization": "Bearer " + token})
@@ -167,6 +175,51 @@ def is_local_run(job_definition: JobBaseData) -> bool:
167175
return local is not None and EXECUTION_SERVICE_URL_KEY in local.endpoint
168176

169177

178+
def _safe_tar_extractall(tar: tarfile.TarFile, dest_dir: str) -> None:
179+
"""Extract tar archive members safely, preventing path traversal (TarSlip).
180+
181+
On Python 3.12+, uses the built-in 'data' filter. On older versions,
182+
manually validates each member to ensure no path traversal, symlinks,
183+
hard links, or other special entries that could write outside the
184+
destination directory or create unsafe filesystem nodes.
185+
186+
:param tar: An opened tarfile.TarFile object.
187+
:type tar: tarfile.TarFile
188+
:param dest_dir: The destination directory for extraction.
189+
:type dest_dir: str
190+
:raises ValueError: If a tar member would escape the destination directory
191+
or contains a symlink, hard link, or unsupported special entry type.
192+
"""
193+
resolved_dest = os.path.realpath(dest_dir)
194+
195+
# Python 3.12+ has built-in data_filter for safe extraction
196+
if hasattr(tarfile, "data_filter"):
197+
try:
198+
tar.extractall(resolved_dest, filter="data")
199+
except tarfile.TarError as exc:
200+
raise ValueError(f"Failed to safely extract tar archive: {exc}") from exc
201+
else:
202+
for member in tar.getmembers():
203+
# Reject symbolic and hard links
204+
if member.issym() or member.islnk():
205+
raise ValueError(
206+
f"Tar archive contains a symbolic or hard link and cannot be extracted safely: {member.name}"
207+
)
208+
# Reject any non-regular, non-directory entries (e.g., devices, FIFOs)
209+
if not (member.isfile() or member.isdir()):
210+
raise ValueError(
211+
f"Tar archive contains an unsupported special entry type and cannot be extracted safely: "
212+
f"{member.name}"
213+
)
214+
member_path = os.path.realpath(os.path.join(resolved_dest, member.name))
215+
if member_path != resolved_dest and not member_path.startswith(resolved_dest + os.sep):
216+
raise ValueError(
217+
f"Tar archive contains a path traversal entry and cannot be extracted safely: {member.name}"
218+
)
219+
# All members validated; safe to extract
220+
tar.extractall(resolved_dest)
221+
222+
170223
class CommonRuntimeHelper:
171224
COMMON_RUNTIME_BOOTSTRAPPER_INFO = "common_runtime_bootstrapper_info.json"
172225
COMMON_RUNTIME_JOB_SPEC = "common_runtime_job_spec.json"
@@ -208,10 +261,14 @@ def __init__(self, job_name: str):
208261
CommonRuntimeHelper.VM_BOOTSTRAPPER_FILE_NAME,
209262
)
210263
self.stdout = open( # pylint: disable=consider-using-with
211-
os.path.join(self.common_runtime_temp_folder, "stdout"), "w+", encoding=DefaultOpenEncoding.WRITE
264+
os.path.join(self.common_runtime_temp_folder, "stdout"),
265+
"w+",
266+
encoding=DefaultOpenEncoding.WRITE,
212267
)
213268
self.stderr = open( # pylint: disable=consider-using-with
214-
os.path.join(self.common_runtime_temp_folder, "stderr"), "w+", encoding=DefaultOpenEncoding.WRITE
269+
os.path.join(self.common_runtime_temp_folder, "stderr"),
270+
"w+",
271+
encoding=DefaultOpenEncoding.WRITE,
215272
)
216273

217274
# Bug Item number: 2885723
@@ -266,8 +323,7 @@ def copy_bootstrapper_from_container(self, container: "docker.models.containers.
266323
for chunk in data_stream:
267324
f.write(chunk)
268325
with tarfile.open(tar_file, mode="r") as tar:
269-
for file_name in tar.getnames():
270-
tar.extract(file_name, os.path.dirname(path_in_host))
326+
_safe_tar_extractall(tar, os.path.dirname(path_in_host))
271327
os.remove(tar_file)
272328
except docker.errors.APIError as e:
273329
msg = f"Copying {path_in_container} from container has failed. Detailed message: {e}"
@@ -408,7 +464,7 @@ def start_run_if_local(
408464
:rtype: str
409465
"""
410466
token = credential.get_token(ws_base_url + "/.default").token
411-
(zip_content, snapshot_id) = get_execution_service_response(job_definition, token, requests_pipeline)
467+
zip_content, snapshot_id = get_execution_service_response(job_definition, token, requests_pipeline)
412468

413469
try:
414470
temp_dir = unzip_to_temporary_file(job_definition, zip_content)

sdk/ml/azure-ai-ml/tests/job_common/unittests/test_local_job_invoker.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
1+
import io
12
import os
3+
import shutil
4+
import tarfile
25
import tempfile
6+
import zipfile
37
from pathlib import Path
8+
from unittest.mock import MagicMock
49

510
import pytest
611

712
from azure.ai.ml.operations._local_job_invoker import (
813
_get_creationflags_and_startupinfo_for_background_process,
14+
_safe_tar_extractall,
915
patch_invocation_script_serialization,
16+
unzip_to_temporary_file,
1017
)
1118

1219

@@ -61,3 +68,120 @@ def test_creation_flags(self):
6168
flags = _get_creationflags_and_startupinfo_for_background_process("linux")
6269

6370
assert flags == {"stderr": -2, "stdin": -3, "stdout": -3}
71+
72+
73+
def _make_job_definition(name="test-run"):
74+
job_def = MagicMock()
75+
job_def.name = name
76+
return job_def
77+
78+
79+
@pytest.mark.unittest
80+
@pytest.mark.training_experiences_test
81+
class TestUnzipPathTraversalPrevention:
82+
"""Tests for ZIP path traversal prevention in unzip_to_temporary_file."""
83+
84+
def test_normal_zip_extracts_successfully(self):
85+
buf = io.BytesIO()
86+
with zipfile.ZipFile(buf, "w") as zf:
87+
zf.writestr("azureml-setup/invocation.sh", "#!/bin/bash\necho hello\n")
88+
zf.writestr("azureml-setup/config.json", '{"key": "value"}')
89+
zip_bytes = buf.getvalue()
90+
91+
job_def = _make_job_definition("safe-run")
92+
result = unzip_to_temporary_file(job_def, zip_bytes)
93+
94+
try:
95+
assert result.exists()
96+
assert (result / "azureml-setup" / "invocation.sh").exists()
97+
assert (result / "azureml-setup" / "config.json").exists()
98+
finally:
99+
if result.exists():
100+
shutil.rmtree(result, ignore_errors=True)
101+
102+
def test_zip_with_path_traversal_is_rejected(self):
103+
buf = io.BytesIO()
104+
with zipfile.ZipFile(buf, "w") as zf:
105+
zf.writestr("azureml-setup/invocation.sh", "#!/bin/bash\necho hello\n")
106+
zf.writestr("../../etc/evil.sh", "#!/bin/bash\necho pwned\n")
107+
zip_bytes = buf.getvalue()
108+
109+
job_def = _make_job_definition("traversal-run")
110+
with pytest.raises(ValueError, match="path traversal"):
111+
unzip_to_temporary_file(job_def, zip_bytes)
112+
113+
def test_zip_with_absolute_path_is_rejected(self):
114+
buf = io.BytesIO()
115+
with zipfile.ZipFile(buf, "w") as zf:
116+
if os.name == "nt":
117+
zf.writestr("C:/Windows/Temp/evil.sh", "#!/bin/bash\necho pwned\n")
118+
else:
119+
zf.writestr("/tmp/evil.sh", "#!/bin/bash\necho pwned\n")
120+
zip_bytes = buf.getvalue()
121+
122+
job_def = _make_job_definition("absolute-path-run")
123+
with pytest.raises(ValueError, match="path traversal"):
124+
unzip_to_temporary_file(job_def, zip_bytes)
125+
126+
127+
@pytest.mark.unittest
128+
@pytest.mark.training_experiences_test
129+
class TestSafeTarExtract:
130+
"""Tests for tar path traversal prevention in _safe_tar_extractall."""
131+
132+
def test_normal_tar_extracts_successfully(self):
133+
with tempfile.TemporaryDirectory() as dest:
134+
buf = io.BytesIO()
135+
with tarfile.open(fileobj=buf, mode="w") as tar:
136+
data = b"normal content"
137+
info = tarfile.TarInfo(name="vm-bootstrapper")
138+
info.size = len(data)
139+
tar.addfile(info, io.BytesIO(data))
140+
buf.seek(0)
141+
142+
with tarfile.open(fileobj=buf, mode="r") as tar:
143+
_safe_tar_extractall(tar, dest)
144+
145+
assert os.path.exists(os.path.join(dest, "vm-bootstrapper"))
146+
147+
def test_tar_with_path_traversal_is_rejected(self):
148+
with tempfile.TemporaryDirectory() as dest:
149+
buf = io.BytesIO()
150+
with tarfile.open(fileobj=buf, mode="w") as tar:
151+
data = b"evil content"
152+
info = tarfile.TarInfo(name="../../evil_script.sh")
153+
info.size = len(data)
154+
tar.addfile(info, io.BytesIO(data))
155+
buf.seek(0)
156+
157+
with tarfile.open(fileobj=buf, mode="r") as tar:
158+
with pytest.raises(ValueError):
159+
_safe_tar_extractall(tar, dest)
160+
161+
def test_tar_with_symlink_is_rejected(self):
162+
with tempfile.TemporaryDirectory() as dest:
163+
buf = io.BytesIO()
164+
with tarfile.open(fileobj=buf, mode="w") as tar:
165+
info = tarfile.TarInfo(name="evil_link")
166+
info.type = tarfile.SYMTYPE
167+
info.linkname = "/etc/passwd"
168+
tar.addfile(info)
169+
buf.seek(0)
170+
171+
with tarfile.open(fileobj=buf, mode="r") as tar:
172+
with pytest.raises(ValueError):
173+
_safe_tar_extractall(tar, dest)
174+
175+
def test_tar_with_hardlink_is_rejected(self):
176+
with tempfile.TemporaryDirectory() as dest:
177+
buf = io.BytesIO()
178+
with tarfile.open(fileobj=buf, mode="w") as tar:
179+
info = tarfile.TarInfo(name="evil_hardlink")
180+
info.type = tarfile.LNKTYPE
181+
info.linkname = "/etc/shadow"
182+
tar.addfile(info)
183+
buf.seek(0)
184+
185+
with tarfile.open(fileobj=buf, mode="r") as tar:
186+
with pytest.raises(ValueError):
187+
_safe_tar_extractall(tar, dest)

0 commit comments

Comments
 (0)