|
1 | 1 | import json |
2 | 2 | import os |
3 | 3 | import platform |
| 4 | +from pathlib import Path |
4 | 5 | from unittest.mock import Mock, patch |
5 | 6 |
|
6 | 7 | import jwt |
|
11 | 12 | from azure.ai.ml._azure_environments import _get_aml_resource_id_from_metadata, _resource_to_scopes |
12 | 13 | from azure.ai.ml._restclient.v2023_04_01_preview import models |
13 | 14 | from azure.ai.ml._scope_dependent_operations import OperationConfig, OperationScope |
14 | | -from azure.ai.ml.constants._common import AZUREML_PRIVATE_FEATURES_ENV_VAR, AzureMLResourceType, GitProperties |
| 15 | +from azure.ai.ml.constants._common import ( |
| 16 | + AZUREML_PRIVATE_FEATURES_ENV_VAR, |
| 17 | + DEFAULT_ARTIFACT_STORE_OUTPUT_NAME, |
| 18 | + AzureMLResourceType, |
| 19 | + GitProperties, |
| 20 | +) |
15 | 21 | from azure.ai.ml.entities._builders import Command |
16 | 22 | from azure.ai.ml.entities._job.job import Job |
17 | 23 | from azure.ai.ml.operations import DatastoreOperations, EnvironmentOperations, JobOperations, WorkspaceOperations |
18 | 24 | from azure.ai.ml.operations._code_operations import CodeOperations |
19 | 25 | from azure.ai.ml.operations._job_ops_helper import get_git_properties |
| 26 | +from azure.ai.ml.operations._run_history_constants import JobStatus |
20 | 27 | from azure.ai.ml.operations._run_operations import RunOperations |
21 | 28 | from azure.core.credentials import AccessToken |
22 | 29 | from azure.identity import DefaultAzureCredential |
@@ -266,6 +273,38 @@ def test_restore(self, mock_method, mock_job_operation: JobOperations) -> None: |
266 | 273 | mock_job_operation.service_client_01_2024_preview.jobs.get.assert_called_once() |
267 | 274 | mock_job_operation._operation_2023_02_preview.create_or_update.assert_called_once() |
268 | 275 |
|
| 276 | + def test_download_skips_output_name_path_traversal( |
| 277 | + self, mocker: MockFixture, mock_job_operation: JobOperations, tmp_path |
| 278 | + ) -> None: |
| 279 | + # Output names come from the run's service response; a malicious name must not |
| 280 | + # let the download destination escape download_path. |
| 281 | + job_details = Mock() |
| 282 | + job_details.status = JobStatus.COMPLETED |
| 283 | + job_details.properties = {} |
| 284 | + job_details.tags = {} |
| 285 | + mock_job_operation.get = Mock(return_value=job_details) |
| 286 | + |
| 287 | + download_path = tmp_path / "download" |
| 288 | + mock_job_operation._get_named_output_uri = Mock( |
| 289 | + return_value={ |
| 290 | + DEFAULT_ARTIFACT_STORE_OUTPUT_NAME: "azureml://datastores/store/paths/logs", |
| 291 | + "../../../../evil": "azureml://datastores/store/paths/evil", |
| 292 | + "/tmp/evil-abs": "azureml://datastores/store/paths/evil-abs", |
| 293 | + } |
| 294 | + ) |
| 295 | + mocked_download = mocker.patch("azure.ai.ml.operations._job_operations.download_artifact_from_aml_uri") |
| 296 | + |
| 297 | + mock_job_operation.download("job-1", download_path=str(download_path), all=True) |
| 298 | + |
| 299 | + resolved_root = download_path.resolve() |
| 300 | + for call in mocked_download.call_args_list: |
| 301 | + destination = Path(call.kwargs["destination"]).resolve() |
| 302 | + destination.relative_to(resolved_root) # raises ValueError if it escaped |
| 303 | + downloaded_uris = {call.kwargs["uri"] for call in mocked_download.call_args_list} |
| 304 | + assert "azureml://datastores/store/paths/evil" not in downloaded_uris |
| 305 | + assert "azureml://datastores/store/paths/evil-abs" not in downloaded_uris |
| 306 | + assert "azureml://datastores/store/paths/logs" in downloaded_uris |
| 307 | + |
269 | 308 | @pytest.mark.parametrize( |
270 | 309 | "corrupt_job_data", |
271 | 310 | [ |
|
0 commit comments