Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions sagemaker-core/src/sagemaker/core/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
which is used for Amazon SageMaker Processing Jobs. These jobs let users perform
data pre-processing, post-processing, feature engineering, data validation, and model evaluation,
and interpretation on Amazon SageMaker.

``ProcessingInput`` supports local file paths via the ``s3_input.s3_uri`` field.
When a local path is provided, it is automatically uploaded to S3 during input
normalization. For a convenient way to create ``ProcessingInput`` objects from
local sources, use the :func:`processing_input_from_local` helper function.
"""
from __future__ import absolute_import

Expand Down Expand Up @@ -85,6 +90,105 @@
logger = logging.getLogger(__name__)


def processing_input_from_local(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The source parameter is typed as str but the function explicitly handles None (line 155). The type annotation should be str | None to match the actual behavior, or better yet, keep it as str and remove the None handling since None is not a meaningful input.

def processing_input_from_local(
    source: str | None,
    ...

Alternatively, if you want to keep str as the type, remove the None check and let it fail naturally with a TypeError — that's more Pythonic for a required parameter.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per SDK V3 conventions, user-facing data/config classes should use Pydantic BaseModel. While a factory function is a reasonable approach here (since ProcessingInput is auto-generated), consider whether a thin Pydantic wrapper class like LocalProcessingInput(BaseModel) with a .to_processing_input() method would be more consistent with V3 patterns. This is a design suggestion, not a blocker — the factory function approach is pragmatic given that ProcessingInput is auto-generated.

source: str,
destination: str,
input_name: Optional[str] = None,
s3_data_type: str = "S3Prefix",
s3_input_mode: str = "File",
s3_data_distribution_type: Optional[str] = None,
s3_compression_type: Optional[str] = None,
) -> ProcessingInput:
"""Creates a ProcessingInput from a local file/directory path or S3 URI.

This is a convenience factory function that provides V2-like ergonomics
for creating ``ProcessingInput`` objects. When a local path is provided
as ``source``, it will be automatically uploaded to S3 during input
normalization in ``Processor.run()``.

Args:
source: Local file/directory path or S3 URI. If a local path is
provided, it must exist and will be uploaded to S3 automatically
when the processing job is run.
destination: The container path where the input data will be
available, e.g. ``/opt/ml/processing/input/data``.
input_name: A name for the processing input. If not specified,
one will be auto-generated during normalization.
s3_data_type: The S3 data type. Valid values: ``'S3Prefix'``,
``'ManifestFile'`` (default: ``'S3Prefix'``).
s3_input_mode: The input mode. Valid values: ``'File'``,
``'Pipe'`` (default: ``'File'``).
s3_data_distribution_type: The data distribution type for
distributed processing. Valid values:
``'FullyReplicated'``, ``'ShardedByS3Key'``
(default: None).
s3_compression_type: The compression type. Valid values:
``'None'``, ``'Gzip'`` (default: None).

Returns:
ProcessingInput: A ``ProcessingInput`` object configured with the
given source and destination.

Raises:
ValueError: If ``source`` is a local path that does not exist.
ValueError: If ``source`` is empty or None.

Examples:
Create an input from a local directory::

from sagemaker.core.processing import processing_input_from_local

input_data = processing_input_from_local(
source="/local/data/training",
destination="/opt/ml/processing/input/data",
input_name="training-data",
)
processor.run(inputs=[input_data])

Create an input from an S3 URI::

input_data = processing_input_from_local(
source="s3://my-bucket/data/training",
destination="/opt/ml/processing/input/data",
)
"""
if not source:
raise ValueError(
f"source must be a valid local path or S3 URI, got: {source!r}"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The f-string formatting here uses !r which is good, but the line exceeds 100 characters. While CI will catch this, consider reformatting:

        raise ValueError(
            f"source must be a valid local path or S3 URI, got: {source!r}"
        )

Actually, looking more carefully, this is already formatted this way. Disregard the formatting concern — but note that this validation doesn't distinguish between None and empty string in the error message. Consider separate messages for better debugging.

)

# Check if source is a local path (not an S3 URI)
parse_result = urlparse(source)
if parse_result.scheme not in ("s3", "http", "https"):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The URL scheme check parse_result.scheme not in ("s3", "http", "https") doesn't account for other valid remote schemes that might be used (e.g., gs:// for GCS if ever supported). More importantly, on Windows, absolute paths like C:\data\file.csv will have parse_result.scheme == 'c', which would correctly fall into the local path branch. However, this should be documented or tested. Consider adding a comment about this behavior.

# Treat as local path - validate existence
local_path = source
if parse_result.scheme == "file":
local_path = url2pathname(parse_result.path)
if not os.path.exists(local_path):
raise ValueError(
f"Input source path does not exist: {source!r}. "
f"Please provide a valid local path or S3 URI."
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mapping of destinationlocal_path on ProcessingS3Input is semantically confusing and potentially incorrect. In the ProcessingS3Input shape, local_path typically refers to the container path where data is made available, while s3_uri is the S3 location. However, here you're setting local_path=destination which is the container path — this seems correct for the SageMaker API, but the parameter name local_path on ProcessingS3Input is misleading in this context. Please add a comment clarifying this mapping:

s3_input_kwargs = {
    "s3_uri": source,
    # local_path in ProcessingS3Input maps to the container destination path
    "local_path": destination,
    "s3_data_type": s3_data_type,
    "s3_input_mode": s3_input_mode,
}

)

s3_input_kwargs = {
"s3_uri": source,
"local_path": destination,
"s3_data_type": s3_data_type,
"s3_input_mode": s3_input_mode,
}
if s3_data_distribution_type is not None:
s3_input_kwargs["s3_data_distribution_type"] = s3_data_distribution_type
if s3_compression_type is not None:
s3_input_kwargs["s3_compression_type"] = s3_compression_type

s3_input = ProcessingS3Input(**s3_input_kwargs)

return ProcessingInput(
input_name=input_name,
s3_input=s3_input,
)


class Processor(object):
"""Handles Amazon SageMaker Processing tasks."""

Expand Down Expand Up @@ -424,6 +528,16 @@ def _normalize_inputs(self, inputs=None, kms_key=None):
# If the s3_uri is not an s3_uri, create one.
parse_result = urlparse(file_input.s3_input.s3_uri)
if parse_result.scheme != "s3":
# Validate that local path exists before attempting upload
local_source = file_input.s3_input.s3_uri
if parse_result.scheme == "file":
local_source = url2pathname(parse_result.path)
if not os.path.exists(local_source):
raise ValueError(
f"Input source path does not exist: {file_input.s3_input.s3_uri!r}. "
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logger.info() call at line 562 is placed after the desired_s3_uri is computed in the non-pipeline branch, but the variable desired_s3_uri is set inside two different conditional branches (if _pipeline_config and else). If _pipeline_config is truthy, desired_s3_uri is set in that branch; otherwise in the else branch. The logger.info() at line 562 is placed after the else block but still inside the if parse_result.scheme != "s3": block. This means it will only execute for the non-pipeline path. For the pipeline path, the log message won't be emitted. Move the logger.info() after both branches to ensure it logs in both cases:

                    # After both if/else branches for _pipeline_config
                    logger.info(
                        "Uploading local input '%s' from %s to %s",
                        file_input.input_name,
                        file_input.s3_input.s3_uri,
                        desired_s3_uri,
                    )

Wait — looking at the diff more carefully, the logger.info is placed right before s3_uri = s3.S3Uploader.upload(...) which is after both branches. This should be correct. Please verify the indentation level is correct and that it's inside the if parse_result.scheme != "s3": block but after both the pipeline and non-pipeline branches.

f"Please provide a valid local path or S3 URI for input "
f"'{file_input.input_name}'."
)
if _pipeline_config:
desired_s3_uri = s3.s3_path_join(
"s3://",
Expand All @@ -443,6 +557,12 @@ def _normalize_inputs(self, inputs=None, kms_key=None):
"input",
file_input.input_name,
)
logger.info(
"Uploading local input '%s' from %s to %s",
file_input.input_name,
file_input.s3_input.s3_uri,
desired_s3_uri,
)
s3_uri = s3.S3Uploader.upload(
local_path=file_input.s3_input.s3_uri,
desired_s3_uri=desired_s3_uri,
Expand Down
216 changes: 216 additions & 0 deletions sagemaker-core/tests/unit/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_processing_output_to_request_dict,
_get_process_request,
logs_for_processing_job,
processing_input_from_local,
)
from sagemaker.core.shapes import (
ProcessingInput,
Expand All @@ -46,6 +47,221 @@ def mock_session():
return session


class TestProcessingInputFromLocal:
"""Tests for the processing_input_from_local() factory function."""

def test_processing_input_from_local_with_file_path_creates_valid_input(self):
"""A local file path should produce a valid ProcessingInput."""
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".csv") as f:
f.write("col1,col2\n1,2\n")
temp_file = f.name

try:
result = processing_input_from_local(
source=temp_file,
destination="/opt/ml/processing/input/data",
input_name="my-data",
)
assert isinstance(result, ProcessingInput)
assert result.input_name == "my-data"
assert result.s3_input.s3_uri == temp_file
assert result.s3_input.local_path == "/opt/ml/processing/input/data"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good use of tempfile.NamedTemporaryFile with proper cleanup in finally. However, consider using tempfile.TemporaryDirectory with a file created inside it as a fixture to reduce boilerplate across tests. Also, delete=False on Windows can leave files behind if the test crashes before finally. A pytest.fixture with tmp_path would be cleaner:

def test_processing_input_from_local_with_file_path(self, tmp_path):
    temp_file = tmp_path / "data.csv"
    temp_file.write_text("col1,col2\n1,2\n")
    result = processing_input_from_local(
        source=str(temp_file),
        destination="/opt/ml/processing/input/data",
        input_name="my-data",
    )
    ...

assert result.s3_input.s3_data_type == "S3Prefix"
assert result.s3_input.s3_input_mode == "File"
finally:
os.unlink(temp_file)

def test_processing_input_from_local_with_directory_path_creates_valid_input(self):
"""A local directory path should produce a valid ProcessingInput."""
with tempfile.TemporaryDirectory() as tmpdir:
result = processing_input_from_local(
source=tmpdir,
destination="/opt/ml/processing/input/data",
input_name="dir-data",
)
assert isinstance(result, ProcessingInput)
assert result.s3_input.s3_uri == tmpdir

def test_processing_input_from_local_with_s3_uri_passes_through(self):
"""An S3 URI should pass through without local path validation."""
result = processing_input_from_local(
source="s3://my-bucket/data/training",
destination="/opt/ml/processing/input/data",
input_name="s3-data",
)
assert isinstance(result, ProcessingInput)
assert result.s3_input.s3_uri == "s3://my-bucket/data/training"

def test_processing_input_from_local_with_nonexistent_path_raises_value_error(self):
"""A nonexistent local path should raise ValueError."""
with pytest.raises(ValueError, match="Input source path does not exist"):
processing_input_from_local(
source="/nonexistent/path/to/data",
destination="/opt/ml/processing/input/data",
)

def test_processing_input_from_local_with_custom_input_name(self):
"""Custom input_name should be set on the ProcessingInput."""
with tempfile.TemporaryDirectory() as tmpdir:
result = processing_input_from_local(
source=tmpdir,
destination="/opt/ml/processing/input/data",
input_name="custom-name",
)
assert result.input_name == "custom-name"

def test_processing_input_from_local_default_parameters(self):
"""Default parameters should be applied correctly."""
with tempfile.TemporaryDirectory() as tmpdir:
result = processing_input_from_local(
source=tmpdir,
destination="/opt/ml/processing/input/data",
)
assert result.input_name is None
assert result.s3_input.s3_data_type == "S3Prefix"
assert result.s3_input.s3_input_mode == "File"

def test_processing_input_from_local_with_empty_source_raises_value_error(self):
"""Empty source should raise ValueError."""
with pytest.raises(ValueError, match="source must be a valid local path or S3 URI"):
processing_input_from_local(
source="",
destination="/opt/ml/processing/input/data",
)

def test_processing_input_from_local_with_none_source_raises_value_error(self):
"""None source should raise ValueError."""
with pytest.raises(ValueError, match="source must be a valid local path or S3 URI"):
processing_input_from_local(
source=None,
destination="/opt/ml/processing/input/data",
)

def test_processing_input_from_local_with_optional_s3_params(self):
"""Optional S3 parameters should be passed through."""
with tempfile.TemporaryDirectory() as tmpdir:
result = processing_input_from_local(
source=tmpdir,
destination="/opt/ml/processing/input/data",
s3_data_distribution_type="FullyReplicated",
s3_compression_type="Gzip",
)
assert result.s3_input.s3_data_distribution_type == "FullyReplicated"
assert result.s3_input.s3_compression_type == "Gzip"

def test_processing_input_from_local_used_in_normalize_inputs(self, mock_session):
"""ProcessingInput from processing_input_from_local should work with _normalize_inputs."""
processor = Processor(
role="arn:aws:iam::123456789012:role/SageMakerRole",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test patches sagemaker.core.s3.S3Uploader.upload but the actual code in _normalize_inputs calls s3.S3Uploader.upload where s3 is imported as from sagemaker.core import s3. The patch target should be sagemaker.core.processing.s3.S3Uploader.upload (patching where it's used, not where it's defined) to ensure the mock is applied correctly. Please verify this works — if the import in processing.py is from sagemaker.core import s3, then the correct patch target is sagemaker.core.processing.s3.S3Uploader.upload.

image_uri="test-image:latest",
instance_count=1,
instance_type="ml.m5.xlarge",
sagemaker_session=mock_session,
)
processor._current_job_name = "test-job"

with tempfile.TemporaryDirectory() as tmpdir:
inp = processing_input_from_local(
source=tmpdir,
destination="/opt/ml/processing/input/data",
input_name="local-data",
)
with patch(
"sagemaker.core.s3.S3Uploader.upload", return_value="s3://bucket/uploaded"
):
result = processor._normalize_inputs([inp])
assert len(result) == 1
assert result[0].s3_input.s3_uri == "s3://bucket/uploaded"


class TestNormalizeInputsLocalPathValidation:
"""Tests for local path validation in _normalize_inputs()."""

def test_normalize_inputs_with_nonexistent_local_path_raises_value_error(self, mock_session):
"""A nonexistent local path in s3_input.s3_uri should raise ValueError."""
processor = Processor(
role="arn:aws:iam::123456789012:role/SageMakerRole",
image_uri="test-image:latest",
instance_count=1,
instance_type="ml.m5.xlarge",
sagemaker_session=mock_session,
)
processor._current_job_name = "test-job"

s3_input = ProcessingS3Input(
s3_uri="/nonexistent/path/to/data",
local_path="/opt/ml/processing/input",
s3_data_type="S3Prefix",
s3_input_mode="File",
)
inputs = [ProcessingInput(input_name="bad-input", s3_input=s3_input)]

with pytest.raises(ValueError, match="Input source path does not exist"):
processor._normalize_inputs(inputs)

def test_normalize_inputs_with_local_source_uploads_to_s3(self, mock_session):
"""A valid local path should be uploaded to S3."""
processor = Processor(
role="arn:aws:iam::123456789012:role/SageMakerRole",
image_uri="test-image:latest",
instance_count=1,
instance_type="ml.m5.xlarge",
sagemaker_session=mock_session,
)
processor._current_job_name = "test-job"

with tempfile.TemporaryDirectory() as tmpdir:
s3_input = ProcessingS3Input(
s3_uri=tmpdir,
local_path="/opt/ml/processing/input",
s3_data_type="S3Prefix",
s3_input_mode="File",
)
inputs = [ProcessingInput(input_name="local-input", s3_input=s3_input)]

with patch(
"sagemaker.core.s3.S3Uploader.upload",
return_value="s3://test-bucket/sagemaker/test-job/input/local-input",
) as mock_upload:
result = processor._normalize_inputs(inputs)
assert len(result) == 1
assert result[0].s3_input.s3_uri.startswith("s3://")
mock_upload.assert_called_once()

def test_normalize_inputs_local_path_logs_upload_info(self, mock_session):
"""Uploading a local path should log an info message."""
processor = Processor(
role="arn:aws:iam::123456789012:role/SageMakerRole",
image_uri="test-image:latest",
instance_count=1,
instance_type="ml.m5.xlarge",
sagemaker_session=mock_session,
)
processor._current_job_name = "test-job"

with tempfile.TemporaryDirectory() as tmpdir:
s3_input = ProcessingS3Input(
s3_uri=tmpdir,
local_path="/opt/ml/processing/input",
s3_data_type="S3Prefix",
s3_input_mode="File",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The f-string in the assertion f"s3://test-bucket/sagemaker/test-job/input/local-input" hardcodes the expected S3 path structure. This is brittle — if the S3 path construction logic changes in _normalize_inputs, this test will break. Consider using mock_logger.info.assert_called() and then inspecting the call args more flexibly, or at minimum add a comment explaining the expected path structure.

)
inputs = [ProcessingInput(input_name="local-input", s3_input=s3_input)]

with patch(
"sagemaker.core.s3.S3Uploader.upload",
return_value="s3://test-bucket/uploaded",
):
with patch("sagemaker.core.processing.logger") as mock_logger:
processor._normalize_inputs(inputs)
mock_logger.info.assert_any_call(
"Uploading local input '%s' from %s to %s",
"local-input",
tmpdir,
f"s3://test-bucket/sagemaker/test-job/input/local-input",
)


class TestProcessorNormalizeArgs:
def test_normalize_args_with_pipeline_variable_code(self, mock_session):
from sagemaker.core.workflow.pipeline_context import PipelineSession
Expand Down
Loading