Skip to content

Commit 79636d1

Browse files
Tighten regex in artifact tool download (#46714)
Co-authored-by: Ayushh Garg <ayushhgarg@microsoft.com>
1 parent 685f7bd commit 79636d1

1 file changed

Lines changed: 23 additions & 3 deletions

File tree

sdk/ml/azure-ai-ml/azure/ai/ml/_utils/_artifact_utils.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,26 @@ def _get_checksum_path(cls, path):
158158
artifact_path = Path(path)
159159
return artifact_path.parent / f"{artifact_path.name}_{cls.POSTFIX_CHECKSUM}"
160160

161+
@staticmethod
162+
def _safe_extractall(zip_file: zipfile.ZipFile, destination: Union[str, os.PathLike]) -> None:
163+
"""Safely extract all members of a zip archive, guarding against ZipSlip/path traversal.
164+
165+
Each member's resolved destination path is validated to remain within ``destination``.
166+
Members with absolute paths or ``..`` segments that would escape the destination are rejected.
167+
168+
:param zip_file: The zip archive to extract.
169+
:type zip_file: zipfile.ZipFile
170+
:param destination: The directory to extract the archive members into.
171+
:type destination: Union[str, os.PathLike]
172+
:raises RuntimeError: If a member would be extracted outside of ``destination``.
173+
"""
174+
destination_path = Path(destination).resolve()
175+
for member in zip_file.namelist():
176+
target_path = (destination_path / member).resolve()
177+
if destination_path != target_path and destination_path not in target_path.parents:
178+
raise RuntimeError(f"Illegal path traversal detected in zip archive member: {member}")
179+
zip_file.extractall(destination_path) # nosec B202
180+
161181
def _redirect_artifacts_tool_path(self, organization: Optional[str]):
162182
"""Downloads the artifacts tool and redirects `az artifact` command to it.
163183
@@ -171,12 +191,12 @@ def _redirect_artifacts_tool_path(self, organization: Optional[str]):
171191
if not organization:
172192
organization, _ = self.get_organization_project_by_git()
173193

174-
organization_pattern = r"https:\/\/(.*)\.visualstudio\.com"
194+
organization_pattern = r"https:\/\/([^/]+)\.visualstudio\.com"
175195
result = re.findall(pattern=organization_pattern, string=organization)
176196
if result:
177197
organization_name = result[0]
178198
else:
179-
organization_pattern = r"https:\/\/dev\.azure\.com\/(.*)"
199+
organization_pattern = r"https:\/\/dev\.azure\.com\/([^/]+)"
180200
result = re.findall(pattern=organization_pattern, string=organization)
181201
if not result:
182202
raise RuntimeError("Cannot find artifact organization.")
@@ -204,7 +224,7 @@ def _redirect_artifacts_tool_path(self, organization: Optional[str]):
204224
artifacts_tool_uri = response.json()["uri"]
205225
response = requests_pipeline.get(artifacts_tool_uri) # pylint: disable=too-many-function-args
206226
with zipfile.ZipFile(BytesIO(response.content)) as zip_file:
207-
zip_file.extractall(artifacts_tool_path)
227+
self._safe_extractall(zip_file, artifacts_tool_path)
208228
os.environ["AZURE_DEVOPS_EXT_ARTIFACTTOOL_OVERRIDE_PATH"] = str(artifacts_tool_path.resolve())
209229
self._artifacts_tool_path = artifacts_tool_path
210230
else:

0 commit comments

Comments
 (0)