-
Notifications
You must be signed in to change notification settings - Fork 0
fix: Support local source for sagemaker.core.shapes.ProcessingInput (5672) #31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,8 +15,13 @@ | |
| which is used for Amazon SageMaker Processing Jobs. These jobs let users perform | ||
| data pre-processing, post-processing, feature engineering, data validation, and model evaluation, | ||
| and interpretation on Amazon SageMaker. | ||
|
|
||
| ``ProcessingInput`` supports local file paths via the ``s3_input.s3_uri`` field. | ||
| When a local path is provided, it is automatically uploaded to S3 during input | ||
| normalization. For a convenient way to create ``ProcessingInput`` objects from | ||
| local sources, use the :func:`processing_input_from_local` helper function. | ||
| """ | ||
| from __future__ import absolute_import | ||
| from __future__ import annotations | ||
|
|
||
| import json | ||
| import logging | ||
|
|
@@ -85,6 +90,110 @@ | |
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def processing_input_from_local( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Per SDK V3 conventions, user-facing data/config classes should use Pydantic |
||
| source: str, | ||
| destination: str, | ||
| input_name: str | None = None, | ||
| s3_data_type: str = "S3Prefix", | ||
| s3_input_mode: str = "File", | ||
| s3_data_distribution_type: str | None = None, | ||
| s3_compression_type: str | None = None, | ||
| ) -> ProcessingInput: | ||
| """Creates a ProcessingInput from a local file/directory path or S3 URI. | ||
|
|
||
| This is a convenience factory function that provides V2-like ergonomics | ||
| for creating ``ProcessingInput`` objects. When a local path is provided | ||
| as ``source``, it will be automatically uploaded to S3 during input | ||
| normalization in ``Processor.run()``. | ||
|
|
||
| Args: | ||
| source: Local file/directory path or S3 URI. If a local path is | ||
| provided, it must exist and will be uploaded to S3 automatically | ||
| when the processing job is run. | ||
| destination: The container path where the input data will be | ||
| available, e.g. ``/opt/ml/processing/input/data``. | ||
| input_name: A name for the processing input. If not specified, | ||
| one will be auto-generated during normalization. | ||
| s3_data_type: The S3 data type. Valid values: ``'S3Prefix'``, | ||
| ``'ManifestFile'`` (default: ``'S3Prefix'``). | ||
| s3_input_mode: The input mode. Valid values: ``'File'``, | ||
| ``'Pipe'`` (default: ``'File'``). | ||
| s3_data_distribution_type: The data distribution type for | ||
| distributed processing. Valid values: | ||
| ``'FullyReplicated'``, ``'ShardedByS3Key'`` | ||
| (default: None). | ||
| s3_compression_type: The compression type. Valid values: | ||
| ``'None'``, ``'Gzip'`` (default: None). | ||
|
|
||
| Returns: | ||
| ProcessingInput: A ``ProcessingInput`` object configured with the | ||
| given source and destination. | ||
|
|
||
| Raises: | ||
| ValueError: If ``source`` is a local path that does not exist. | ||
| TypeError: If ``source`` is not a string. | ||
|
|
||
| Examples: | ||
| Create an input from a local directory:: | ||
|
|
||
| from sagemaker.core.processing import processing_input_from_local | ||
|
|
||
| input_data = processing_input_from_local( | ||
| source="/local/data/training", | ||
| destination="/opt/ml/processing/input/data", | ||
| input_name="training-data", | ||
| ) | ||
| processor.run(inputs=[input_data]) | ||
|
|
||
| Create an input from an S3 URI:: | ||
|
|
||
| input_data = processing_input_from_local( | ||
| source="s3://my-bucket/data/training", | ||
| destination="/opt/ml/processing/input/data", | ||
| ) | ||
| """ | ||
| if not isinstance(source, str) or not source: | ||
| raise ValueError( | ||
| f"source must be a non-empty string containing a valid local path " | ||
| f"or S3 URI, got: {source!r}" | ||
| ) | ||
|
|
||
| # Check if source is a local path (not a remote URI). | ||
| # Note: On Windows, absolute paths like C:\data\file.csv will have | ||
| # parse_result.scheme == 'c', which correctly falls into the local path branch. | ||
| parse_result = urlparse(source) | ||
| if parse_result.scheme not in ("s3", "http", "https"): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The URL scheme check |
||
| # Treat as local path - validate existence | ||
| local_path = source | ||
| if parse_result.scheme == "file": | ||
| local_path = url2pathname(parse_result.path) | ||
| if not os.path.exists(local_path): | ||
| raise ValueError( | ||
| f"Input source path does not exist: {source!r}. " | ||
| f"Please provide a valid local path or S3 URI." | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The mapping of s3_input_kwargs = {
"s3_uri": source,
# local_path in ProcessingS3Input maps to the container destination path
"local_path": destination,
"s3_data_type": s3_data_type,
"s3_input_mode": s3_input_mode,
} |
||
| ) | ||
|
|
||
| s3_input_kwargs = { | ||
| "s3_uri": source, | ||
| # local_path in ProcessingS3Input maps to the container destination path | ||
| # where the input data will be made available inside the processing container. | ||
| "local_path": destination, | ||
| "s3_data_type": s3_data_type, | ||
| "s3_input_mode": s3_input_mode, | ||
| } | ||
| if s3_data_distribution_type is not None: | ||
| s3_input_kwargs["s3_data_distribution_type"] = s3_data_distribution_type | ||
| if s3_compression_type is not None: | ||
| s3_input_kwargs["s3_compression_type"] = s3_compression_type | ||
|
|
||
| s3_input = ProcessingS3Input(**s3_input_kwargs) | ||
|
|
||
| return ProcessingInput( | ||
| input_name=input_name, | ||
| s3_input=s3_input, | ||
| ) | ||
|
|
||
|
|
||
| class Processor(object): | ||
| """Handles Amazon SageMaker Processing tasks.""" | ||
|
|
||
|
|
@@ -424,6 +533,17 @@ def _normalize_inputs(self, inputs=None, kms_key=None): | |
| # If the s3_uri is not an s3_uri, create one. | ||
| parse_result = urlparse(file_input.s3_input.s3_uri) | ||
| if parse_result.scheme != "s3": | ||
| # Validate that local path exists before attempting upload | ||
| local_source = file_input.s3_input.s3_uri | ||
| if parse_result.scheme == "file": | ||
| local_source = url2pathname(parse_result.path) | ||
| if not os.path.exists(local_source): | ||
| raise ValueError( | ||
| f"Input source path does not exist: " | ||
| f"{file_input.s3_input.s3_uri!r}. " | ||
| f"Please provide a valid local path or S3 URI " | ||
| f"for input '{file_input.input_name}'." | ||
| ) | ||
| if _pipeline_config: | ||
| desired_s3_uri = s3.s3_path_join( | ||
| "s3://", | ||
|
|
@@ -443,6 +563,13 @@ def _normalize_inputs(self, inputs=None, kms_key=None): | |
| "input", | ||
| file_input.input_name, | ||
| ) | ||
| # Log after both pipeline/non-pipeline branches have set desired_s3_uri | ||
| logger.info( | ||
| "Uploading local input '%s' from %s to %s", | ||
| file_input.input_name, | ||
| file_input.s3_input.s3_uri, | ||
| desired_s3_uri, | ||
| ) | ||
| s3_uri = s3.S3Uploader.upload( | ||
| local_path=file_input.s3_input.s3_uri, | ||
| desired_s3_uri=desired_s3_uri, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ | |
| _processing_output_to_request_dict, | ||
| _get_process_request, | ||
| logs_for_processing_job, | ||
| processing_input_from_local, | ||
| ) | ||
| from sagemaker.core.shapes import ( | ||
| ProcessingInput, | ||
|
|
@@ -46,6 +47,222 @@ def mock_session(): | |
| return session | ||
|
|
||
|
|
||
| class TestProcessingInputFromLocal: | ||
| """Tests for the processing_input_from_local() factory function.""" | ||
|
|
||
| def test_processing_input_from_local_with_file_path_creates_valid_input(self, tmp_path): | ||
| """A local file path should produce a valid ProcessingInput.""" | ||
| temp_file = tmp_path / "data.csv" | ||
| temp_file.write_text("col1,col2\n1,2\n") | ||
|
|
||
| result = processing_input_from_local( | ||
| source=str(temp_file), | ||
| destination="/opt/ml/processing/input/data", | ||
| input_name="my-data", | ||
| ) | ||
| assert isinstance(result, ProcessingInput) | ||
| assert result.input_name == "my-data" | ||
| assert result.s3_input.s3_uri == str(temp_file) | ||
| # local_path in ProcessingS3Input maps to the container destination path | ||
| assert result.s3_input.local_path == "/opt/ml/processing/input/data" | ||
| assert result.s3_input.s3_data_type == "S3Prefix" | ||
| assert result.s3_input.s3_input_mode == "File" | ||
|
|
||
| def test_processing_input_from_local_with_directory_path_creates_valid_input(self, tmp_path): | ||
| """A local directory path should produce a valid ProcessingInput.""" | ||
| result = processing_input_from_local( | ||
| source=str(tmp_path), | ||
| destination="/opt/ml/processing/input/data", | ||
| input_name="dir-data", | ||
| ) | ||
| assert isinstance(result, ProcessingInput) | ||
| assert result.s3_input.s3_uri == str(tmp_path) | ||
|
|
||
| def test_processing_input_from_local_with_s3_uri_passes_through(self): | ||
| """An S3 URI should pass through without local path validation.""" | ||
| result = processing_input_from_local( | ||
| source="s3://my-bucket/data/training", | ||
| destination="/opt/ml/processing/input/data", | ||
| input_name="s3-data", | ||
| ) | ||
| assert isinstance(result, ProcessingInput) | ||
| assert result.s3_input.s3_uri == "s3://my-bucket/data/training" | ||
|
|
||
| def test_processing_input_from_local_with_nonexistent_path_raises_value_error(self): | ||
| """A nonexistent local path should raise ValueError.""" | ||
| with pytest.raises(ValueError, match="Input source path does not exist"): | ||
| processing_input_from_local( | ||
| source="/nonexistent/path/to/data", | ||
| destination="/opt/ml/processing/input/data", | ||
| ) | ||
|
|
||
| def test_processing_input_from_local_with_custom_input_name(self, tmp_path): | ||
| """Custom input_name should be set on the ProcessingInput.""" | ||
| result = processing_input_from_local( | ||
| source=str(tmp_path), | ||
| destination="/opt/ml/processing/input/data", | ||
| input_name="custom-name", | ||
| ) | ||
| assert result.input_name == "custom-name" | ||
|
|
||
| def test_processing_input_from_local_default_parameters(self, tmp_path): | ||
| """Default parameters should be applied correctly.""" | ||
| result = processing_input_from_local( | ||
| source=str(tmp_path), | ||
| destination="/opt/ml/processing/input/data", | ||
| ) | ||
| assert result.input_name is None | ||
| assert result.s3_input.s3_data_type == "S3Prefix" | ||
| assert result.s3_input.s3_input_mode == "File" | ||
|
|
||
| def test_processing_input_from_local_with_empty_source_raises_value_error(self): | ||
| """Empty source should raise ValueError.""" | ||
| with pytest.raises(ValueError, match="source must be a non-empty string"): | ||
| processing_input_from_local( | ||
| source="", | ||
| destination="/opt/ml/processing/input/data", | ||
| ) | ||
|
|
||
| def test_processing_input_from_local_with_none_source_raises_value_error(self): | ||
| """None source should raise ValueError.""" | ||
| with pytest.raises(ValueError, match="source must be a non-empty string"): | ||
| processing_input_from_local( | ||
| source=None, | ||
| destination="/opt/ml/processing/input/data", | ||
| ) | ||
|
|
||
| def test_processing_input_from_local_with_optional_s3_params(self, tmp_path): | ||
| """Optional S3 parameters should be passed through.""" | ||
| result = processing_input_from_local( | ||
| source=str(tmp_path), | ||
| destination="/opt/ml/processing/input/data", | ||
| s3_data_distribution_type="FullyReplicated", | ||
| s3_compression_type="Gzip", | ||
| ) | ||
| assert result.s3_input.s3_data_distribution_type == "FullyReplicated" | ||
| assert result.s3_input.s3_compression_type == "Gzip" | ||
|
|
||
| def test_processing_input_from_local_used_in_normalize_inputs(self, mock_session, tmp_path): | ||
| """ProcessingInput from processing_input_from_local should work with _normalize_inputs.""" | ||
| processor = Processor( | ||
| role="arn:aws:iam::123456789012:role/SageMakerRole", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test patches |
||
| image_uri="test-image:latest", | ||
| instance_count=1, | ||
| instance_type="ml.m5.xlarge", | ||
| sagemaker_session=mock_session, | ||
| ) | ||
| processor._current_job_name = "test-job" | ||
|
|
||
| inp = processing_input_from_local( | ||
| source=str(tmp_path), | ||
| destination="/opt/ml/processing/input/data", | ||
| input_name="local-data", | ||
| ) | ||
| with patch( | ||
| "sagemaker.core.processing.s3.S3Uploader.upload", | ||
| return_value="s3://bucket/uploaded", | ||
| ): | ||
| result = processor._normalize_inputs([inp]) | ||
| assert len(result) == 1 | ||
| assert result[0].s3_input.s3_uri == "s3://bucket/uploaded" | ||
|
|
||
|
|
||
| class TestNormalizeInputsLocalPathValidation: | ||
| """Tests for local path validation in _normalize_inputs().""" | ||
|
|
||
| def test_normalize_inputs_with_nonexistent_local_path_raises_value_error(self, mock_session): | ||
| """A nonexistent local path in s3_input.s3_uri should raise ValueError.""" | ||
| processor = Processor( | ||
| role="arn:aws:iam::123456789012:role/SageMakerRole", | ||
| image_uri="test-image:latest", | ||
| instance_count=1, | ||
| instance_type="ml.m5.xlarge", | ||
| sagemaker_session=mock_session, | ||
| ) | ||
| processor._current_job_name = "test-job" | ||
|
|
||
| s3_input = ProcessingS3Input( | ||
| s3_uri="/nonexistent/path/to/data", | ||
| local_path="/opt/ml/processing/input", | ||
| s3_data_type="S3Prefix", | ||
| s3_input_mode="File", | ||
| ) | ||
| inputs = [ProcessingInput(input_name="bad-input", s3_input=s3_input)] | ||
|
|
||
| with pytest.raises(ValueError, match="Input source path does not exist"): | ||
| processor._normalize_inputs(inputs) | ||
|
|
||
| def test_normalize_inputs_with_local_source_uploads_to_s3(self, mock_session, tmp_path): | ||
| """A valid local path should be uploaded to S3.""" | ||
| processor = Processor( | ||
| role="arn:aws:iam::123456789012:role/SageMakerRole", | ||
| image_uri="test-image:latest", | ||
| instance_count=1, | ||
| instance_type="ml.m5.xlarge", | ||
| sagemaker_session=mock_session, | ||
| ) | ||
| processor._current_job_name = "test-job" | ||
|
|
||
| s3_input = ProcessingS3Input( | ||
| s3_uri=str(tmp_path), | ||
| local_path="/opt/ml/processing/input", | ||
| s3_data_type="S3Prefix", | ||
| s3_input_mode="File", | ||
| ) | ||
| inputs = [ProcessingInput(input_name="local-input", s3_input=s3_input)] | ||
|
|
||
| with patch( | ||
| "sagemaker.core.processing.s3.S3Uploader.upload", | ||
| return_value="s3://test-bucket/sagemaker/test-job/input/local-input", | ||
| ) as mock_upload: | ||
| result = processor._normalize_inputs(inputs) | ||
| assert len(result) == 1 | ||
| assert result[0].s3_input.s3_uri.startswith("s3://") | ||
| mock_upload.assert_called_once() | ||
|
|
||
| def test_normalize_inputs_local_path_logs_upload_info(self, mock_session, tmp_path): | ||
| """Uploading a local path should log an info message.""" | ||
| processor = Processor( | ||
| role="arn:aws:iam::123456789012:role/SageMakerRole", | ||
| image_uri="test-image:latest", | ||
| instance_count=1, | ||
| instance_type="ml.m5.xlarge", | ||
| sagemaker_session=mock_session, | ||
| ) | ||
| processor._current_job_name = "test-job" | ||
|
|
||
| s3_input = ProcessingS3Input( | ||
| s3_uri=str(tmp_path), | ||
| local_path="/opt/ml/processing/input", | ||
| s3_data_type="S3Prefix", | ||
| s3_input_mode="File", | ||
| ) | ||
| inputs = [ProcessingInput(input_name="local-input", s3_input=s3_input)] | ||
|
|
||
| with patch( | ||
| "sagemaker.core.processing.s3.S3Uploader.upload", | ||
| return_value="s3://test-bucket/uploaded", | ||
| ): | ||
| with patch("sagemaker.core.processing.logger") as mock_logger: | ||
| processor._normalize_inputs(inputs) | ||
| # Verify the upload log message was emitted with the correct format. | ||
| # The exact S3 path depends on default_bucket/prefix/job_name/input_name. | ||
| mock_logger.info.assert_called() | ||
| log_calls = mock_logger.info.call_args_list | ||
| upload_log_found = any( | ||
| len(call.args) >= 4 | ||
| and call.args[0] == "Uploading local input '%s' from %s to %s" | ||
| and call.args[1] == "local-input" | ||
| and call.args[2] == str(tmp_path) | ||
| and call.args[3].startswith("s3://") | ||
| for call in log_calls | ||
| ) | ||
| assert upload_log_found, ( | ||
| f"Expected upload log message not found in logger.info calls: " | ||
| f"{log_calls}" | ||
| ) | ||
|
|
||
|
|
||
| class TestProcessorNormalizeArgs: | ||
| def test_normalize_args_with_pipeline_variable_code(self, mock_session): | ||
| from sagemaker.core.workflow.pipeline_context import PipelineSession | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
sourceparameter is typed asstrbut the function explicitly handlesNone(line 155). The type annotation should bestr | Noneto match the actual behavior, or better yet, keep it asstrand remove theNonehandling sinceNoneis not a meaningful input.Alternatively, if you want to keep
stras the type, remove theNonecheck and let it fail naturally with aTypeError— that's more Pythonic for a required parameter.