Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 128 additions & 1 deletion sagemaker-core/src/sagemaker/core/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@
which is used for Amazon SageMaker Processing Jobs. These jobs let users perform
data pre-processing, post-processing, feature engineering, data validation, and model evaluation,
and interpretation on Amazon SageMaker.

``ProcessingInput`` supports local file paths via the ``s3_input.s3_uri`` field.
When a local path is provided, it is automatically uploaded to S3 during input
normalization. For a convenient way to create ``ProcessingInput`` objects from
local sources, use the :func:`processing_input_from_local` helper function.
"""
from __future__ import absolute_import
from __future__ import annotations

import json
import logging
Expand Down Expand Up @@ -85,6 +90,110 @@
logger = logging.getLogger(__name__)


def processing_input_from_local(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The source parameter is typed as str but the function explicitly handles None (line 155). The type annotation should be str | None to match the actual behavior, or better yet, keep it as str and remove the None handling since None is not a meaningful input.

def processing_input_from_local(
    source: str | None,
    ...

Alternatively, if you want to keep str as the type, remove the None check and let it fail naturally with a TypeError — that's more Pythonic for a required parameter.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per SDK V3 conventions, user-facing data/config classes should use Pydantic BaseModel. While a factory function is a reasonable approach here (since ProcessingInput is auto-generated), consider whether a thin Pydantic wrapper class like LocalProcessingInput(BaseModel) with a .to_processing_input() method would be more consistent with V3 patterns. This is a design suggestion, not a blocker — the factory function approach is pragmatic given that ProcessingInput is auto-generated.

source: str,
destination: str,
input_name: str | None = None,
s3_data_type: str = "S3Prefix",
s3_input_mode: str = "File",
s3_data_distribution_type: str | None = None,
s3_compression_type: str | None = None,
) -> ProcessingInput:
"""Creates a ProcessingInput from a local file/directory path or S3 URI.

This is a convenience factory function that provides V2-like ergonomics
for creating ``ProcessingInput`` objects. When a local path is provided
as ``source``, it will be automatically uploaded to S3 during input
normalization in ``Processor.run()``.

Args:
source: Local file/directory path or S3 URI. If a local path is
provided, it must exist and will be uploaded to S3 automatically
when the processing job is run.
destination: The container path where the input data will be
available, e.g. ``/opt/ml/processing/input/data``.
input_name: A name for the processing input. If not specified,
one will be auto-generated during normalization.
s3_data_type: The S3 data type. Valid values: ``'S3Prefix'``,
``'ManifestFile'`` (default: ``'S3Prefix'``).
s3_input_mode: The input mode. Valid values: ``'File'``,
``'Pipe'`` (default: ``'File'``).
s3_data_distribution_type: The data distribution type for
distributed processing. Valid values:
``'FullyReplicated'``, ``'ShardedByS3Key'``
(default: None).
s3_compression_type: The compression type. Valid values:
``'None'``, ``'Gzip'`` (default: None).

Returns:
ProcessingInput: A ``ProcessingInput`` object configured with the
given source and destination.

Raises:
ValueError: If ``source`` is a local path that does not exist.
TypeError: If ``source`` is not a string.

Examples:
Create an input from a local directory::

from sagemaker.core.processing import processing_input_from_local

input_data = processing_input_from_local(
source="/local/data/training",
destination="/opt/ml/processing/input/data",
input_name="training-data",
)
processor.run(inputs=[input_data])

Create an input from an S3 URI::

input_data = processing_input_from_local(
source="s3://my-bucket/data/training",
destination="/opt/ml/processing/input/data",
)
"""
if not isinstance(source, str) or not source:
raise ValueError(
f"source must be a non-empty string containing a valid local path "
f"or S3 URI, got: {source!r}"
)

# Check if source is a local path (not a remote URI).
# Note: On Windows, absolute paths like C:\data\file.csv will have
# parse_result.scheme == 'c', which correctly falls into the local path branch.
parse_result = urlparse(source)
if parse_result.scheme not in ("s3", "http", "https"):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The URL scheme check parse_result.scheme not in ("s3", "http", "https") doesn't account for other valid remote schemes that might be used (e.g., gs:// for GCS if ever supported). More importantly, on Windows, absolute paths like C:\data\file.csv will have parse_result.scheme == 'c', which would correctly fall into the local path branch. However, this should be documented or tested. Consider adding a comment about this behavior.

# Treat as local path - validate existence
local_path = source
if parse_result.scheme == "file":
local_path = url2pathname(parse_result.path)
if not os.path.exists(local_path):
raise ValueError(
f"Input source path does not exist: {source!r}. "
f"Please provide a valid local path or S3 URI."
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mapping of destinationlocal_path on ProcessingS3Input is semantically confusing and potentially incorrect. In the ProcessingS3Input shape, local_path typically refers to the container path where data is made available, while s3_uri is the S3 location. However, here you're setting local_path=destination which is the container path — this seems correct for the SageMaker API, but the parameter name local_path on ProcessingS3Input is misleading in this context. Please add a comment clarifying this mapping:

s3_input_kwargs = {
    "s3_uri": source,
    # local_path in ProcessingS3Input maps to the container destination path
    "local_path": destination,
    "s3_data_type": s3_data_type,
    "s3_input_mode": s3_input_mode,
}

)

s3_input_kwargs = {
"s3_uri": source,
# local_path in ProcessingS3Input maps to the container destination path
# where the input data will be made available inside the processing container.
"local_path": destination,
"s3_data_type": s3_data_type,
"s3_input_mode": s3_input_mode,
}
if s3_data_distribution_type is not None:
s3_input_kwargs["s3_data_distribution_type"] = s3_data_distribution_type
if s3_compression_type is not None:
s3_input_kwargs["s3_compression_type"] = s3_compression_type

s3_input = ProcessingS3Input(**s3_input_kwargs)

return ProcessingInput(
input_name=input_name,
s3_input=s3_input,
)


class Processor(object):
"""Handles Amazon SageMaker Processing tasks."""

Expand Down Expand Up @@ -424,6 +533,17 @@ def _normalize_inputs(self, inputs=None, kms_key=None):
# If the s3_uri is not an s3_uri, create one.
parse_result = urlparse(file_input.s3_input.s3_uri)
if parse_result.scheme != "s3":
# Validate that local path exists before attempting upload
local_source = file_input.s3_input.s3_uri
if parse_result.scheme == "file":
local_source = url2pathname(parse_result.path)
if not os.path.exists(local_source):
raise ValueError(
f"Input source path does not exist: "
f"{file_input.s3_input.s3_uri!r}. "
f"Please provide a valid local path or S3 URI "
f"for input '{file_input.input_name}'."
)
if _pipeline_config:
desired_s3_uri = s3.s3_path_join(
"s3://",
Expand All @@ -443,6 +563,13 @@ def _normalize_inputs(self, inputs=None, kms_key=None):
"input",
file_input.input_name,
)
# Log after both pipeline/non-pipeline branches have set desired_s3_uri
logger.info(
"Uploading local input '%s' from %s to %s",
file_input.input_name,
file_input.s3_input.s3_uri,
desired_s3_uri,
)
s3_uri = s3.S3Uploader.upload(
local_path=file_input.s3_input.s3_uri,
desired_s3_uri=desired_s3_uri,
Expand Down
217 changes: 217 additions & 0 deletions sagemaker-core/tests/unit/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_processing_output_to_request_dict,
_get_process_request,
logs_for_processing_job,
processing_input_from_local,
)
from sagemaker.core.shapes import (
ProcessingInput,
Expand All @@ -46,6 +47,222 @@ def mock_session():
return session


class TestProcessingInputFromLocal:
"""Tests for the processing_input_from_local() factory function."""

def test_processing_input_from_local_with_file_path_creates_valid_input(self, tmp_path):
"""A local file path should produce a valid ProcessingInput."""
temp_file = tmp_path / "data.csv"
temp_file.write_text("col1,col2\n1,2\n")

result = processing_input_from_local(
source=str(temp_file),
destination="/opt/ml/processing/input/data",
input_name="my-data",
)
assert isinstance(result, ProcessingInput)
assert result.input_name == "my-data"
assert result.s3_input.s3_uri == str(temp_file)
# local_path in ProcessingS3Input maps to the container destination path
assert result.s3_input.local_path == "/opt/ml/processing/input/data"
assert result.s3_input.s3_data_type == "S3Prefix"
assert result.s3_input.s3_input_mode == "File"

def test_processing_input_from_local_with_directory_path_creates_valid_input(self, tmp_path):
"""A local directory path should produce a valid ProcessingInput."""
result = processing_input_from_local(
source=str(tmp_path),
destination="/opt/ml/processing/input/data",
input_name="dir-data",
)
assert isinstance(result, ProcessingInput)
assert result.s3_input.s3_uri == str(tmp_path)

def test_processing_input_from_local_with_s3_uri_passes_through(self):
"""An S3 URI should pass through without local path validation."""
result = processing_input_from_local(
source="s3://my-bucket/data/training",
destination="/opt/ml/processing/input/data",
input_name="s3-data",
)
assert isinstance(result, ProcessingInput)
assert result.s3_input.s3_uri == "s3://my-bucket/data/training"

def test_processing_input_from_local_with_nonexistent_path_raises_value_error(self):
"""A nonexistent local path should raise ValueError."""
with pytest.raises(ValueError, match="Input source path does not exist"):
processing_input_from_local(
source="/nonexistent/path/to/data",
destination="/opt/ml/processing/input/data",
)

def test_processing_input_from_local_with_custom_input_name(self, tmp_path):
"""Custom input_name should be set on the ProcessingInput."""
result = processing_input_from_local(
source=str(tmp_path),
destination="/opt/ml/processing/input/data",
input_name="custom-name",
)
assert result.input_name == "custom-name"

def test_processing_input_from_local_default_parameters(self, tmp_path):
"""Default parameters should be applied correctly."""
result = processing_input_from_local(
source=str(tmp_path),
destination="/opt/ml/processing/input/data",
)
assert result.input_name is None
assert result.s3_input.s3_data_type == "S3Prefix"
assert result.s3_input.s3_input_mode == "File"

def test_processing_input_from_local_with_empty_source_raises_value_error(self):
"""Empty source should raise ValueError."""
with pytest.raises(ValueError, match="source must be a non-empty string"):
processing_input_from_local(
source="",
destination="/opt/ml/processing/input/data",
)

def test_processing_input_from_local_with_none_source_raises_value_error(self):
"""None source should raise ValueError."""
with pytest.raises(ValueError, match="source must be a non-empty string"):
processing_input_from_local(
source=None,
destination="/opt/ml/processing/input/data",
)

def test_processing_input_from_local_with_optional_s3_params(self, tmp_path):
"""Optional S3 parameters should be passed through."""
result = processing_input_from_local(
source=str(tmp_path),
destination="/opt/ml/processing/input/data",
s3_data_distribution_type="FullyReplicated",
s3_compression_type="Gzip",
)
assert result.s3_input.s3_data_distribution_type == "FullyReplicated"
assert result.s3_input.s3_compression_type == "Gzip"

def test_processing_input_from_local_used_in_normalize_inputs(self, mock_session, tmp_path):
"""ProcessingInput from processing_input_from_local should work with _normalize_inputs."""
processor = Processor(
role="arn:aws:iam::123456789012:role/SageMakerRole",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test patches sagemaker.core.s3.S3Uploader.upload but the actual code in _normalize_inputs calls s3.S3Uploader.upload where s3 is imported as from sagemaker.core import s3. The patch target should be sagemaker.core.processing.s3.S3Uploader.upload (patching where it's used, not where it's defined) to ensure the mock is applied correctly. Please verify this works — if the import in processing.py is from sagemaker.core import s3, then the correct patch target is sagemaker.core.processing.s3.S3Uploader.upload.

image_uri="test-image:latest",
instance_count=1,
instance_type="ml.m5.xlarge",
sagemaker_session=mock_session,
)
processor._current_job_name = "test-job"

inp = processing_input_from_local(
source=str(tmp_path),
destination="/opt/ml/processing/input/data",
input_name="local-data",
)
with patch(
"sagemaker.core.processing.s3.S3Uploader.upload",
return_value="s3://bucket/uploaded",
):
result = processor._normalize_inputs([inp])
assert len(result) == 1
assert result[0].s3_input.s3_uri == "s3://bucket/uploaded"


class TestNormalizeInputsLocalPathValidation:
"""Tests for local path validation in _normalize_inputs()."""

def test_normalize_inputs_with_nonexistent_local_path_raises_value_error(self, mock_session):
"""A nonexistent local path in s3_input.s3_uri should raise ValueError."""
processor = Processor(
role="arn:aws:iam::123456789012:role/SageMakerRole",
image_uri="test-image:latest",
instance_count=1,
instance_type="ml.m5.xlarge",
sagemaker_session=mock_session,
)
processor._current_job_name = "test-job"

s3_input = ProcessingS3Input(
s3_uri="/nonexistent/path/to/data",
local_path="/opt/ml/processing/input",
s3_data_type="S3Prefix",
s3_input_mode="File",
)
inputs = [ProcessingInput(input_name="bad-input", s3_input=s3_input)]

with pytest.raises(ValueError, match="Input source path does not exist"):
processor._normalize_inputs(inputs)

def test_normalize_inputs_with_local_source_uploads_to_s3(self, mock_session, tmp_path):
"""A valid local path should be uploaded to S3."""
processor = Processor(
role="arn:aws:iam::123456789012:role/SageMakerRole",
image_uri="test-image:latest",
instance_count=1,
instance_type="ml.m5.xlarge",
sagemaker_session=mock_session,
)
processor._current_job_name = "test-job"

s3_input = ProcessingS3Input(
s3_uri=str(tmp_path),
local_path="/opt/ml/processing/input",
s3_data_type="S3Prefix",
s3_input_mode="File",
)
inputs = [ProcessingInput(input_name="local-input", s3_input=s3_input)]

with patch(
"sagemaker.core.processing.s3.S3Uploader.upload",
return_value="s3://test-bucket/sagemaker/test-job/input/local-input",
) as mock_upload:
result = processor._normalize_inputs(inputs)
assert len(result) == 1
assert result[0].s3_input.s3_uri.startswith("s3://")
mock_upload.assert_called_once()

def test_normalize_inputs_local_path_logs_upload_info(self, mock_session, tmp_path):
"""Uploading a local path should log an info message."""
processor = Processor(
role="arn:aws:iam::123456789012:role/SageMakerRole",
image_uri="test-image:latest",
instance_count=1,
instance_type="ml.m5.xlarge",
sagemaker_session=mock_session,
)
processor._current_job_name = "test-job"

s3_input = ProcessingS3Input(
s3_uri=str(tmp_path),
local_path="/opt/ml/processing/input",
s3_data_type="S3Prefix",
s3_input_mode="File",
)
inputs = [ProcessingInput(input_name="local-input", s3_input=s3_input)]

with patch(
"sagemaker.core.processing.s3.S3Uploader.upload",
return_value="s3://test-bucket/uploaded",
):
with patch("sagemaker.core.processing.logger") as mock_logger:
processor._normalize_inputs(inputs)
# Verify the upload log message was emitted with the correct format.
# The exact S3 path depends on default_bucket/prefix/job_name/input_name.
mock_logger.info.assert_called()
log_calls = mock_logger.info.call_args_list
upload_log_found = any(
len(call.args) >= 4
and call.args[0] == "Uploading local input '%s' from %s to %s"
and call.args[1] == "local-input"
and call.args[2] == str(tmp_path)
and call.args[3].startswith("s3://")
for call in log_calls
)
assert upload_log_found, (
f"Expected upload log message not found in logger.info calls: "
f"{log_calls}"
)


class TestProcessorNormalizeArgs:
def test_normalize_args_with_pipeline_variable_code(self, mock_session):
from sagemaker.core.workflow.pipeline_context import PipelineSession
Expand Down
Loading