-
Notifications
You must be signed in to change notification settings - Fork 0
fix: Support local source for sagemaker.core.shapes.ProcessingInput (5672) #31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,11 @@ | |
| which is used for Amazon SageMaker Processing Jobs. These jobs let users perform | ||
| data pre-processing, post-processing, feature engineering, data validation, and model evaluation, | ||
| and interpretation on Amazon SageMaker. | ||
|
|
||
| ``ProcessingInput`` supports local file paths via the ``s3_input.s3_uri`` field. | ||
| When a local path is provided, it is automatically uploaded to S3 during input | ||
| normalization. For a convenient way to create ``ProcessingInput`` objects from | ||
| local sources, use the :func:`processing_input_from_local` helper function. | ||
| """ | ||
| from __future__ import absolute_import | ||
|
|
||
|
|
@@ -85,6 +90,105 @@ | |
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def processing_input_from_local( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Per SDK V3 conventions, user-facing data/config classes should use Pydantic |
||
| source: str, | ||
| destination: str, | ||
| input_name: Optional[str] = None, | ||
| s3_data_type: str = "S3Prefix", | ||
| s3_input_mode: str = "File", | ||
| s3_data_distribution_type: Optional[str] = None, | ||
| s3_compression_type: Optional[str] = None, | ||
| ) -> ProcessingInput: | ||
| """Creates a ProcessingInput from a local file/directory path or S3 URI. | ||
|
|
||
| This is a convenience factory function that provides V2-like ergonomics | ||
| for creating ``ProcessingInput`` objects. When a local path is provided | ||
| as ``source``, it will be automatically uploaded to S3 during input | ||
| normalization in ``Processor.run()``. | ||
|
|
||
| Args: | ||
| source: Local file/directory path or S3 URI. If a local path is | ||
| provided, it must exist and will be uploaded to S3 automatically | ||
| when the processing job is run. | ||
| destination: The container path where the input data will be | ||
| available, e.g. ``/opt/ml/processing/input/data``. | ||
| input_name: A name for the processing input. If not specified, | ||
| one will be auto-generated during normalization. | ||
| s3_data_type: The S3 data type. Valid values: ``'S3Prefix'``, | ||
| ``'ManifestFile'`` (default: ``'S3Prefix'``). | ||
| s3_input_mode: The input mode. Valid values: ``'File'``, | ||
| ``'Pipe'`` (default: ``'File'``). | ||
| s3_data_distribution_type: The data distribution type for | ||
| distributed processing. Valid values: | ||
| ``'FullyReplicated'``, ``'ShardedByS3Key'`` | ||
| (default: None). | ||
| s3_compression_type: The compression type. Valid values: | ||
| ``'None'``, ``'Gzip'`` (default: None). | ||
|
|
||
| Returns: | ||
| ProcessingInput: A ``ProcessingInput`` object configured with the | ||
| given source and destination. | ||
|
|
||
| Raises: | ||
| ValueError: If ``source`` is a local path that does not exist. | ||
| ValueError: If ``source`` is empty or None. | ||
|
|
||
| Examples: | ||
| Create an input from a local directory:: | ||
|
|
||
| from sagemaker.core.processing import processing_input_from_local | ||
|
|
||
| input_data = processing_input_from_local( | ||
| source="/local/data/training", | ||
| destination="/opt/ml/processing/input/data", | ||
| input_name="training-data", | ||
| ) | ||
| processor.run(inputs=[input_data]) | ||
|
|
||
| Create an input from an S3 URI:: | ||
|
|
||
| input_data = processing_input_from_local( | ||
| source="s3://my-bucket/data/training", | ||
| destination="/opt/ml/processing/input/data", | ||
| ) | ||
| """ | ||
| if not source: | ||
| raise ValueError( | ||
| f"source must be a valid local path or S3 URI, got: {source!r}" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The f-string formatting here uses raise ValueError(
f"source must be a valid local path or S3 URI, got: {source!r}"
)Actually, looking more carefully, this is already formatted this way. Disregard the formatting concern — but note that this validation doesn't distinguish between |
||
| ) | ||
|
|
||
| # Check if source is a local path (not an S3 URI) | ||
| parse_result = urlparse(source) | ||
| if parse_result.scheme not in ("s3", "http", "https"): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The URL scheme check |
||
| # Treat as local path - validate existence | ||
| local_path = source | ||
| if parse_result.scheme == "file": | ||
| local_path = url2pathname(parse_result.path) | ||
| if not os.path.exists(local_path): | ||
| raise ValueError( | ||
| f"Input source path does not exist: {source!r}. " | ||
| f"Please provide a valid local path or S3 URI." | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The mapping of s3_input_kwargs = {
"s3_uri": source,
# local_path in ProcessingS3Input maps to the container destination path
"local_path": destination,
"s3_data_type": s3_data_type,
"s3_input_mode": s3_input_mode,
} |
||
| ) | ||
|
|
||
| s3_input_kwargs = { | ||
| "s3_uri": source, | ||
| "local_path": destination, | ||
| "s3_data_type": s3_data_type, | ||
| "s3_input_mode": s3_input_mode, | ||
| } | ||
| if s3_data_distribution_type is not None: | ||
| s3_input_kwargs["s3_data_distribution_type"] = s3_data_distribution_type | ||
| if s3_compression_type is not None: | ||
| s3_input_kwargs["s3_compression_type"] = s3_compression_type | ||
|
|
||
| s3_input = ProcessingS3Input(**s3_input_kwargs) | ||
|
|
||
| return ProcessingInput( | ||
| input_name=input_name, | ||
| s3_input=s3_input, | ||
| ) | ||
|
|
||
|
|
||
| class Processor(object): | ||
| """Handles Amazon SageMaker Processing tasks.""" | ||
|
|
||
|
|
@@ -424,6 +528,16 @@ def _normalize_inputs(self, inputs=None, kms_key=None): | |
| # If the s3_uri is not an s3_uri, create one. | ||
| parse_result = urlparse(file_input.s3_input.s3_uri) | ||
| if parse_result.scheme != "s3": | ||
| # Validate that local path exists before attempting upload | ||
| local_source = file_input.s3_input.s3_uri | ||
| if parse_result.scheme == "file": | ||
| local_source = url2pathname(parse_result.path) | ||
| if not os.path.exists(local_source): | ||
| raise ValueError( | ||
| f"Input source path does not exist: {file_input.s3_input.s3_uri!r}. " | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The # After both if/else branches for _pipeline_config
logger.info(
"Uploading local input '%s' from %s to %s",
file_input.input_name,
file_input.s3_input.s3_uri,
desired_s3_uri,
)Wait — looking at the diff more carefully, the |
||
| f"Please provide a valid local path or S3 URI for input " | ||
| f"'{file_input.input_name}'." | ||
| ) | ||
| if _pipeline_config: | ||
| desired_s3_uri = s3.s3_path_join( | ||
| "s3://", | ||
|
|
@@ -443,6 +557,12 @@ def _normalize_inputs(self, inputs=None, kms_key=None): | |
| "input", | ||
| file_input.input_name, | ||
| ) | ||
| logger.info( | ||
| "Uploading local input '%s' from %s to %s", | ||
| file_input.input_name, | ||
| file_input.s3_input.s3_uri, | ||
| desired_s3_uri, | ||
| ) | ||
| s3_uri = s3.S3Uploader.upload( | ||
| local_path=file_input.s3_input.s3_uri, | ||
| desired_s3_uri=desired_s3_uri, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ | |
| _processing_output_to_request_dict, | ||
| _get_process_request, | ||
| logs_for_processing_job, | ||
| processing_input_from_local, | ||
| ) | ||
| from sagemaker.core.shapes import ( | ||
| ProcessingInput, | ||
|
|
@@ -46,6 +47,221 @@ def mock_session(): | |
| return session | ||
|
|
||
|
|
||
| class TestProcessingInputFromLocal: | ||
| """Tests for the processing_input_from_local() factory function.""" | ||
|
|
||
| def test_processing_input_from_local_with_file_path_creates_valid_input(self): | ||
| """A local file path should produce a valid ProcessingInput.""" | ||
| with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".csv") as f: | ||
| f.write("col1,col2\n1,2\n") | ||
| temp_file = f.name | ||
|
|
||
| try: | ||
| result = processing_input_from_local( | ||
| source=temp_file, | ||
| destination="/opt/ml/processing/input/data", | ||
| input_name="my-data", | ||
| ) | ||
| assert isinstance(result, ProcessingInput) | ||
| assert result.input_name == "my-data" | ||
| assert result.s3_input.s3_uri == temp_file | ||
| assert result.s3_input.local_path == "/opt/ml/processing/input/data" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good use of def test_processing_input_from_local_with_file_path(self, tmp_path):
temp_file = tmp_path / "data.csv"
temp_file.write_text("col1,col2\n1,2\n")
result = processing_input_from_local(
source=str(temp_file),
destination="/opt/ml/processing/input/data",
input_name="my-data",
)
... |
||
| assert result.s3_input.s3_data_type == "S3Prefix" | ||
| assert result.s3_input.s3_input_mode == "File" | ||
| finally: | ||
| os.unlink(temp_file) | ||
|
|
||
| def test_processing_input_from_local_with_directory_path_creates_valid_input(self): | ||
| """A local directory path should produce a valid ProcessingInput.""" | ||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| result = processing_input_from_local( | ||
| source=tmpdir, | ||
| destination="/opt/ml/processing/input/data", | ||
| input_name="dir-data", | ||
| ) | ||
| assert isinstance(result, ProcessingInput) | ||
| assert result.s3_input.s3_uri == tmpdir | ||
|
|
||
| def test_processing_input_from_local_with_s3_uri_passes_through(self): | ||
| """An S3 URI should pass through without local path validation.""" | ||
| result = processing_input_from_local( | ||
| source="s3://my-bucket/data/training", | ||
| destination="/opt/ml/processing/input/data", | ||
| input_name="s3-data", | ||
| ) | ||
| assert isinstance(result, ProcessingInput) | ||
| assert result.s3_input.s3_uri == "s3://my-bucket/data/training" | ||
|
|
||
| def test_processing_input_from_local_with_nonexistent_path_raises_value_error(self): | ||
| """A nonexistent local path should raise ValueError.""" | ||
| with pytest.raises(ValueError, match="Input source path does not exist"): | ||
| processing_input_from_local( | ||
| source="/nonexistent/path/to/data", | ||
| destination="/opt/ml/processing/input/data", | ||
| ) | ||
|
|
||
| def test_processing_input_from_local_with_custom_input_name(self): | ||
| """Custom input_name should be set on the ProcessingInput.""" | ||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| result = processing_input_from_local( | ||
| source=tmpdir, | ||
| destination="/opt/ml/processing/input/data", | ||
| input_name="custom-name", | ||
| ) | ||
| assert result.input_name == "custom-name" | ||
|
|
||
| def test_processing_input_from_local_default_parameters(self): | ||
| """Default parameters should be applied correctly.""" | ||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| result = processing_input_from_local( | ||
| source=tmpdir, | ||
| destination="/opt/ml/processing/input/data", | ||
| ) | ||
| assert result.input_name is None | ||
| assert result.s3_input.s3_data_type == "S3Prefix" | ||
| assert result.s3_input.s3_input_mode == "File" | ||
|
|
||
| def test_processing_input_from_local_with_empty_source_raises_value_error(self): | ||
| """Empty source should raise ValueError.""" | ||
| with pytest.raises(ValueError, match="source must be a valid local path or S3 URI"): | ||
| processing_input_from_local( | ||
| source="", | ||
| destination="/opt/ml/processing/input/data", | ||
| ) | ||
|
|
||
| def test_processing_input_from_local_with_none_source_raises_value_error(self): | ||
| """None source should raise ValueError.""" | ||
| with pytest.raises(ValueError, match="source must be a valid local path or S3 URI"): | ||
| processing_input_from_local( | ||
| source=None, | ||
| destination="/opt/ml/processing/input/data", | ||
| ) | ||
|
|
||
| def test_processing_input_from_local_with_optional_s3_params(self): | ||
| """Optional S3 parameters should be passed through.""" | ||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| result = processing_input_from_local( | ||
| source=tmpdir, | ||
| destination="/opt/ml/processing/input/data", | ||
| s3_data_distribution_type="FullyReplicated", | ||
| s3_compression_type="Gzip", | ||
| ) | ||
| assert result.s3_input.s3_data_distribution_type == "FullyReplicated" | ||
| assert result.s3_input.s3_compression_type == "Gzip" | ||
|
|
||
| def test_processing_input_from_local_used_in_normalize_inputs(self, mock_session): | ||
| """ProcessingInput from processing_input_from_local should work with _normalize_inputs.""" | ||
| processor = Processor( | ||
| role="arn:aws:iam::123456789012:role/SageMakerRole", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test patches |
||
| image_uri="test-image:latest", | ||
| instance_count=1, | ||
| instance_type="ml.m5.xlarge", | ||
| sagemaker_session=mock_session, | ||
| ) | ||
| processor._current_job_name = "test-job" | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| inp = processing_input_from_local( | ||
| source=tmpdir, | ||
| destination="/opt/ml/processing/input/data", | ||
| input_name="local-data", | ||
| ) | ||
| with patch( | ||
| "sagemaker.core.s3.S3Uploader.upload", return_value="s3://bucket/uploaded" | ||
| ): | ||
| result = processor._normalize_inputs([inp]) | ||
| assert len(result) == 1 | ||
| assert result[0].s3_input.s3_uri == "s3://bucket/uploaded" | ||
|
|
||
|
|
||
| class TestNormalizeInputsLocalPathValidation: | ||
| """Tests for local path validation in _normalize_inputs().""" | ||
|
|
||
| def test_normalize_inputs_with_nonexistent_local_path_raises_value_error(self, mock_session): | ||
| """A nonexistent local path in s3_input.s3_uri should raise ValueError.""" | ||
| processor = Processor( | ||
| role="arn:aws:iam::123456789012:role/SageMakerRole", | ||
| image_uri="test-image:latest", | ||
| instance_count=1, | ||
| instance_type="ml.m5.xlarge", | ||
| sagemaker_session=mock_session, | ||
| ) | ||
| processor._current_job_name = "test-job" | ||
|
|
||
| s3_input = ProcessingS3Input( | ||
| s3_uri="/nonexistent/path/to/data", | ||
| local_path="/opt/ml/processing/input", | ||
| s3_data_type="S3Prefix", | ||
| s3_input_mode="File", | ||
| ) | ||
| inputs = [ProcessingInput(input_name="bad-input", s3_input=s3_input)] | ||
|
|
||
| with pytest.raises(ValueError, match="Input source path does not exist"): | ||
| processor._normalize_inputs(inputs) | ||
|
|
||
| def test_normalize_inputs_with_local_source_uploads_to_s3(self, mock_session): | ||
| """A valid local path should be uploaded to S3.""" | ||
| processor = Processor( | ||
| role="arn:aws:iam::123456789012:role/SageMakerRole", | ||
| image_uri="test-image:latest", | ||
| instance_count=1, | ||
| instance_type="ml.m5.xlarge", | ||
| sagemaker_session=mock_session, | ||
| ) | ||
| processor._current_job_name = "test-job" | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| s3_input = ProcessingS3Input( | ||
| s3_uri=tmpdir, | ||
| local_path="/opt/ml/processing/input", | ||
| s3_data_type="S3Prefix", | ||
| s3_input_mode="File", | ||
| ) | ||
| inputs = [ProcessingInput(input_name="local-input", s3_input=s3_input)] | ||
|
|
||
| with patch( | ||
| "sagemaker.core.s3.S3Uploader.upload", | ||
| return_value="s3://test-bucket/sagemaker/test-job/input/local-input", | ||
| ) as mock_upload: | ||
| result = processor._normalize_inputs(inputs) | ||
| assert len(result) == 1 | ||
| assert result[0].s3_input.s3_uri.startswith("s3://") | ||
| mock_upload.assert_called_once() | ||
|
|
||
| def test_normalize_inputs_local_path_logs_upload_info(self, mock_session): | ||
| """Uploading a local path should log an info message.""" | ||
| processor = Processor( | ||
| role="arn:aws:iam::123456789012:role/SageMakerRole", | ||
| image_uri="test-image:latest", | ||
| instance_count=1, | ||
| instance_type="ml.m5.xlarge", | ||
| sagemaker_session=mock_session, | ||
| ) | ||
| processor._current_job_name = "test-job" | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| s3_input = ProcessingS3Input( | ||
| s3_uri=tmpdir, | ||
| local_path="/opt/ml/processing/input", | ||
| s3_data_type="S3Prefix", | ||
| s3_input_mode="File", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The f-string in the assertion |
||
| ) | ||
| inputs = [ProcessingInput(input_name="local-input", s3_input=s3_input)] | ||
|
|
||
| with patch( | ||
| "sagemaker.core.s3.S3Uploader.upload", | ||
| return_value="s3://test-bucket/uploaded", | ||
| ): | ||
| with patch("sagemaker.core.processing.logger") as mock_logger: | ||
| processor._normalize_inputs(inputs) | ||
| mock_logger.info.assert_any_call( | ||
| "Uploading local input '%s' from %s to %s", | ||
| "local-input", | ||
| tmpdir, | ||
| f"s3://test-bucket/sagemaker/test-job/input/local-input", | ||
| ) | ||
|
|
||
|
|
||
| class TestProcessorNormalizeArgs: | ||
| def test_normalize_args_with_pipeline_variable_code(self, mock_session): | ||
| from sagemaker.core.workflow.pipeline_context import PipelineSession | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
sourceparameter is typed asstrbut the function explicitly handlesNone(line 155). The type annotation should bestr | Noneto match the actual behavior, or better yet, keep it asstrand remove theNonehandling sinceNoneis not a meaningful input.Alternatively, if you want to keep
stras the type, remove theNonecheck and let it fail naturally with aTypeError— that's more Pythonic for a required parameter.