Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/ml/azure-ai-ml/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Features Added

### Bugs Fixed
- Fixed `JobOperations.download` writing outside the requested `download_path` when a job output name returned by the service contained `..` segments or an absolute path. Such outputs are now skipped instead of escaping the download directory.

### Other Changes

Expand Down
10 changes: 10 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/operations/_job_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,16 @@ def log_missing_uri(what: str) -> None:
destination = download_path / artifact_directory_name
else:
destination = download_path / output_directory_name / item_name
# item_name is an output name taken from the run's service response; an entry
# containing ".." or an absolute path would resolve outside download_path.
try:
destination.resolve().relative_to(download_path.resolve())
except ValueError:
module_logger.warning(
"Skipping output '%s': resolved path is outside the download directory.",
item_name,
)
continue

module_logger.info("Downloading artifact %s to %s", uri, destination)
download_artifact_from_aml_uri(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import platform
from pathlib import Path
from unittest.mock import Mock, patch

import jwt
Expand All @@ -11,12 +12,18 @@
from azure.ai.ml._azure_environments import _get_aml_resource_id_from_metadata, _resource_to_scopes
from azure.ai.ml._restclient.v2023_04_01_preview import models
from azure.ai.ml._scope_dependent_operations import OperationConfig, OperationScope
from azure.ai.ml.constants._common import AZUREML_PRIVATE_FEATURES_ENV_VAR, AzureMLResourceType, GitProperties
from azure.ai.ml.constants._common import (
AZUREML_PRIVATE_FEATURES_ENV_VAR,
DEFAULT_ARTIFACT_STORE_OUTPUT_NAME,
AzureMLResourceType,
GitProperties,
)
from azure.ai.ml.entities._builders import Command
from azure.ai.ml.entities._job.job import Job
from azure.ai.ml.operations import DatastoreOperations, EnvironmentOperations, JobOperations, WorkspaceOperations
from azure.ai.ml.operations._code_operations import CodeOperations
from azure.ai.ml.operations._job_ops_helper import get_git_properties
from azure.ai.ml.operations._run_history_constants import JobStatus
from azure.ai.ml.operations._run_operations import RunOperations
from azure.core.credentials import AccessToken
from azure.identity import DefaultAzureCredential
Expand Down Expand Up @@ -266,6 +273,38 @@ def test_restore(self, mock_method, mock_job_operation: JobOperations) -> None:
mock_job_operation.service_client_01_2024_preview.jobs.get.assert_called_once()
mock_job_operation._operation_2023_02_preview.create_or_update.assert_called_once()

def test_download_skips_output_name_path_traversal(
self, mocker: MockFixture, mock_job_operation: JobOperations, tmp_path
) -> None:
# Output names come from the run's service response; a malicious name must not
# let the download destination escape download_path.
job_details = Mock()
job_details.status = JobStatus.COMPLETED
job_details.properties = {}
job_details.tags = {}
mock_job_operation.get = Mock(return_value=job_details)

download_path = tmp_path / "download"
mock_job_operation._get_named_output_uri = Mock(
return_value={
DEFAULT_ARTIFACT_STORE_OUTPUT_NAME: "azureml://datastores/store/paths/logs",
"../../../../evil": "azureml://datastores/store/paths/evil",
"/tmp/evil-abs": "azureml://datastores/store/paths/evil-abs",
}
)
mocked_download = mocker.patch("azure.ai.ml.operations._job_operations.download_artifact_from_aml_uri")

mock_job_operation.download("job-1", download_path=str(download_path), all=True)

resolved_root = download_path.resolve()
for call in mocked_download.call_args_list:
destination = Path(call.kwargs["destination"]).resolve()
destination.relative_to(resolved_root) # raises ValueError if it escaped
downloaded_uris = {call.kwargs["uri"] for call in mocked_download.call_args_list}
assert "azureml://datastores/store/paths/evil" not in downloaded_uris
assert "azureml://datastores/store/paths/evil-abs" not in downloaded_uris
assert "azureml://datastores/store/paths/logs" in downloaded_uris

@pytest.mark.parametrize(
"corrupt_job_data",
[
Expand Down
Loading