Skip to content

Commit f1cfcfb

Browse files
committed
fix: Support local source for sagemaker.core.shapes.ProcessingInput (5672)
1 parent 6a1ba54 commit f1cfcfb

File tree

2 files changed

+336
-0
lines changed

2 files changed

+336
-0
lines changed

sagemaker-core/src/sagemaker/core/processing.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
which is used for Amazon SageMaker Processing Jobs. These jobs let users perform
1616
data pre-processing, post-processing, feature engineering, data validation, and model evaluation,
1717
and interpretation on Amazon SageMaker.
18+
19+
``ProcessingInput`` supports local file paths via the ``s3_input.s3_uri`` field.
20+
When a local path is provided, it is automatically uploaded to S3 during input
21+
normalization. For a convenient way to create ``ProcessingInput`` objects from
22+
local sources, use the :func:`processing_input_from_local` helper function.
1823
"""
1924
from __future__ import absolute_import
2025

@@ -85,6 +90,105 @@
8590
logger = logging.getLogger(__name__)
8691

8792

93+
def processing_input_from_local(
94+
source: str,
95+
destination: str,
96+
input_name: Optional[str] = None,
97+
s3_data_type: str = "S3Prefix",
98+
s3_input_mode: str = "File",
99+
s3_data_distribution_type: Optional[str] = None,
100+
s3_compression_type: Optional[str] = None,
101+
) -> ProcessingInput:
102+
"""Creates a ProcessingInput from a local file/directory path or S3 URI.
103+
104+
This is a convenience factory function that provides V2-like ergonomics
105+
for creating ``ProcessingInput`` objects. When a local path is provided
106+
as ``source``, it will be automatically uploaded to S3 during input
107+
normalization in ``Processor.run()``.
108+
109+
Args:
110+
source: Local file/directory path or S3 URI. If a local path is
111+
provided, it must exist and will be uploaded to S3 automatically
112+
when the processing job is run.
113+
destination: The container path where the input data will be
114+
available, e.g. ``/opt/ml/processing/input/data``.
115+
input_name: A name for the processing input. If not specified,
116+
one will be auto-generated during normalization.
117+
s3_data_type: The S3 data type. Valid values: ``'S3Prefix'``,
118+
``'ManifestFile'`` (default: ``'S3Prefix'``).
119+
s3_input_mode: The input mode. Valid values: ``'File'``,
120+
``'Pipe'`` (default: ``'File'``).
121+
s3_data_distribution_type: The data distribution type for
122+
distributed processing. Valid values:
123+
``'FullyReplicated'``, ``'ShardedByS3Key'``
124+
(default: None).
125+
s3_compression_type: The compression type. Valid values:
126+
``'None'``, ``'Gzip'`` (default: None).
127+
128+
Returns:
129+
ProcessingInput: A ``ProcessingInput`` object configured with the
130+
given source and destination.
131+
132+
Raises:
133+
ValueError: If ``source`` is a local path that does not exist.
134+
ValueError: If ``source`` is empty or None.
135+
136+
Examples:
137+
Create an input from a local directory::
138+
139+
from sagemaker.core.processing import processing_input_from_local
140+
141+
input_data = processing_input_from_local(
142+
source="/local/data/training",
143+
destination="/opt/ml/processing/input/data",
144+
input_name="training-data",
145+
)
146+
processor.run(inputs=[input_data])
147+
148+
Create an input from an S3 URI::
149+
150+
input_data = processing_input_from_local(
151+
source="s3://my-bucket/data/training",
152+
destination="/opt/ml/processing/input/data",
153+
)
154+
"""
155+
if not source:
156+
raise ValueError(
157+
f"source must be a valid local path or S3 URI, got: {source!r}"
158+
)
159+
160+
# Check if source is a local path (not an S3 URI)
161+
parse_result = urlparse(source)
162+
if parse_result.scheme not in ("s3", "http", "https"):
163+
# Treat as local path - validate existence
164+
local_path = source
165+
if parse_result.scheme == "file":
166+
local_path = url2pathname(parse_result.path)
167+
if not os.path.exists(local_path):
168+
raise ValueError(
169+
f"Input source path does not exist: {source!r}. "
170+
f"Please provide a valid local path or S3 URI."
171+
)
172+
173+
s3_input_kwargs = {
174+
"s3_uri": source,
175+
"local_path": destination,
176+
"s3_data_type": s3_data_type,
177+
"s3_input_mode": s3_input_mode,
178+
}
179+
if s3_data_distribution_type is not None:
180+
s3_input_kwargs["s3_data_distribution_type"] = s3_data_distribution_type
181+
if s3_compression_type is not None:
182+
s3_input_kwargs["s3_compression_type"] = s3_compression_type
183+
184+
s3_input = ProcessingS3Input(**s3_input_kwargs)
185+
186+
return ProcessingInput(
187+
input_name=input_name,
188+
s3_input=s3_input,
189+
)
190+
191+
88192
class Processor(object):
89193
"""Handles Amazon SageMaker Processing tasks."""
90194

@@ -424,6 +528,16 @@ def _normalize_inputs(self, inputs=None, kms_key=None):
424528
# If the s3_uri is not an s3_uri, create one.
425529
parse_result = urlparse(file_input.s3_input.s3_uri)
426530
if parse_result.scheme != "s3":
531+
# Validate that local path exists before attempting upload
532+
local_source = file_input.s3_input.s3_uri
533+
if parse_result.scheme == "file":
534+
local_source = url2pathname(parse_result.path)
535+
if not os.path.exists(local_source):
536+
raise ValueError(
537+
f"Input source path does not exist: {file_input.s3_input.s3_uri!r}. "
538+
f"Please provide a valid local path or S3 URI for input "
539+
f"'{file_input.input_name}'."
540+
)
427541
if _pipeline_config:
428542
desired_s3_uri = s3.s3_path_join(
429543
"s3://",
@@ -443,6 +557,12 @@ def _normalize_inputs(self, inputs=None, kms_key=None):
443557
"input",
444558
file_input.input_name,
445559
)
560+
logger.info(
561+
"Uploading local input '%s' from %s to %s",
562+
file_input.input_name,
563+
file_input.s3_input.s3_uri,
564+
desired_s3_uri,
565+
)
446566
s3_uri = s3.S3Uploader.upload(
447567
local_path=file_input.s3_input.s3_uri,
448568
desired_s3_uri=desired_s3_uri,

sagemaker-core/tests/unit/test_processing.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
_processing_output_to_request_dict,
2424
_get_process_request,
2525
logs_for_processing_job,
26+
processing_input_from_local,
2627
)
2728
from sagemaker.core.shapes import (
2829
ProcessingInput,
@@ -46,6 +47,221 @@ def mock_session():
4647
return session
4748

4849

50+
class TestProcessingInputFromLocal:
51+
"""Tests for the processing_input_from_local() factory function."""
52+
53+
def test_processing_input_from_local_with_file_path_creates_valid_input(self):
54+
"""A local file path should produce a valid ProcessingInput."""
55+
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".csv") as f:
56+
f.write("col1,col2\n1,2\n")
57+
temp_file = f.name
58+
59+
try:
60+
result = processing_input_from_local(
61+
source=temp_file,
62+
destination="/opt/ml/processing/input/data",
63+
input_name="my-data",
64+
)
65+
assert isinstance(result, ProcessingInput)
66+
assert result.input_name == "my-data"
67+
assert result.s3_input.s3_uri == temp_file
68+
assert result.s3_input.local_path == "/opt/ml/processing/input/data"
69+
assert result.s3_input.s3_data_type == "S3Prefix"
70+
assert result.s3_input.s3_input_mode == "File"
71+
finally:
72+
os.unlink(temp_file)
73+
74+
def test_processing_input_from_local_with_directory_path_creates_valid_input(self):
75+
"""A local directory path should produce a valid ProcessingInput."""
76+
with tempfile.TemporaryDirectory() as tmpdir:
77+
result = processing_input_from_local(
78+
source=tmpdir,
79+
destination="/opt/ml/processing/input/data",
80+
input_name="dir-data",
81+
)
82+
assert isinstance(result, ProcessingInput)
83+
assert result.s3_input.s3_uri == tmpdir
84+
85+
def test_processing_input_from_local_with_s3_uri_passes_through(self):
86+
"""An S3 URI should pass through without local path validation."""
87+
result = processing_input_from_local(
88+
source="s3://my-bucket/data/training",
89+
destination="/opt/ml/processing/input/data",
90+
input_name="s3-data",
91+
)
92+
assert isinstance(result, ProcessingInput)
93+
assert result.s3_input.s3_uri == "s3://my-bucket/data/training"
94+
95+
def test_processing_input_from_local_with_nonexistent_path_raises_value_error(self):
96+
"""A nonexistent local path should raise ValueError."""
97+
with pytest.raises(ValueError, match="Input source path does not exist"):
98+
processing_input_from_local(
99+
source="/nonexistent/path/to/data",
100+
destination="/opt/ml/processing/input/data",
101+
)
102+
103+
def test_processing_input_from_local_with_custom_input_name(self):
104+
"""Custom input_name should be set on the ProcessingInput."""
105+
with tempfile.TemporaryDirectory() as tmpdir:
106+
result = processing_input_from_local(
107+
source=tmpdir,
108+
destination="/opt/ml/processing/input/data",
109+
input_name="custom-name",
110+
)
111+
assert result.input_name == "custom-name"
112+
113+
def test_processing_input_from_local_default_parameters(self):
114+
"""Default parameters should be applied correctly."""
115+
with tempfile.TemporaryDirectory() as tmpdir:
116+
result = processing_input_from_local(
117+
source=tmpdir,
118+
destination="/opt/ml/processing/input/data",
119+
)
120+
assert result.input_name is None
121+
assert result.s3_input.s3_data_type == "S3Prefix"
122+
assert result.s3_input.s3_input_mode == "File"
123+
124+
def test_processing_input_from_local_with_empty_source_raises_value_error(self):
125+
"""Empty source should raise ValueError."""
126+
with pytest.raises(ValueError, match="source must be a valid local path or S3 URI"):
127+
processing_input_from_local(
128+
source="",
129+
destination="/opt/ml/processing/input/data",
130+
)
131+
132+
def test_processing_input_from_local_with_none_source_raises_value_error(self):
133+
"""None source should raise ValueError."""
134+
with pytest.raises(ValueError, match="source must be a valid local path or S3 URI"):
135+
processing_input_from_local(
136+
source=None,
137+
destination="/opt/ml/processing/input/data",
138+
)
139+
140+
def test_processing_input_from_local_with_optional_s3_params(self):
141+
"""Optional S3 parameters should be passed through."""
142+
with tempfile.TemporaryDirectory() as tmpdir:
143+
result = processing_input_from_local(
144+
source=tmpdir,
145+
destination="/opt/ml/processing/input/data",
146+
s3_data_distribution_type="FullyReplicated",
147+
s3_compression_type="Gzip",
148+
)
149+
assert result.s3_input.s3_data_distribution_type == "FullyReplicated"
150+
assert result.s3_input.s3_compression_type == "Gzip"
151+
152+
def test_processing_input_from_local_used_in_normalize_inputs(self, mock_session):
153+
"""ProcessingInput from processing_input_from_local should work with _normalize_inputs."""
154+
processor = Processor(
155+
role="arn:aws:iam::123456789012:role/SageMakerRole",
156+
image_uri="test-image:latest",
157+
instance_count=1,
158+
instance_type="ml.m5.xlarge",
159+
sagemaker_session=mock_session,
160+
)
161+
processor._current_job_name = "test-job"
162+
163+
with tempfile.TemporaryDirectory() as tmpdir:
164+
inp = processing_input_from_local(
165+
source=tmpdir,
166+
destination="/opt/ml/processing/input/data",
167+
input_name="local-data",
168+
)
169+
with patch(
170+
"sagemaker.core.s3.S3Uploader.upload", return_value="s3://bucket/uploaded"
171+
):
172+
result = processor._normalize_inputs([inp])
173+
assert len(result) == 1
174+
assert result[0].s3_input.s3_uri == "s3://bucket/uploaded"
175+
176+
177+
class TestNormalizeInputsLocalPathValidation:
178+
"""Tests for local path validation in _normalize_inputs()."""
179+
180+
def test_normalize_inputs_with_nonexistent_local_path_raises_value_error(self, mock_session):
181+
"""A nonexistent local path in s3_input.s3_uri should raise ValueError."""
182+
processor = Processor(
183+
role="arn:aws:iam::123456789012:role/SageMakerRole",
184+
image_uri="test-image:latest",
185+
instance_count=1,
186+
instance_type="ml.m5.xlarge",
187+
sagemaker_session=mock_session,
188+
)
189+
processor._current_job_name = "test-job"
190+
191+
s3_input = ProcessingS3Input(
192+
s3_uri="/nonexistent/path/to/data",
193+
local_path="/opt/ml/processing/input",
194+
s3_data_type="S3Prefix",
195+
s3_input_mode="File",
196+
)
197+
inputs = [ProcessingInput(input_name="bad-input", s3_input=s3_input)]
198+
199+
with pytest.raises(ValueError, match="Input source path does not exist"):
200+
processor._normalize_inputs(inputs)
201+
202+
def test_normalize_inputs_with_local_source_uploads_to_s3(self, mock_session):
203+
"""A valid local path should be uploaded to S3."""
204+
processor = Processor(
205+
role="arn:aws:iam::123456789012:role/SageMakerRole",
206+
image_uri="test-image:latest",
207+
instance_count=1,
208+
instance_type="ml.m5.xlarge",
209+
sagemaker_session=mock_session,
210+
)
211+
processor._current_job_name = "test-job"
212+
213+
with tempfile.TemporaryDirectory() as tmpdir:
214+
s3_input = ProcessingS3Input(
215+
s3_uri=tmpdir,
216+
local_path="/opt/ml/processing/input",
217+
s3_data_type="S3Prefix",
218+
s3_input_mode="File",
219+
)
220+
inputs = [ProcessingInput(input_name="local-input", s3_input=s3_input)]
221+
222+
with patch(
223+
"sagemaker.core.s3.S3Uploader.upload",
224+
return_value="s3://test-bucket/sagemaker/test-job/input/local-input",
225+
) as mock_upload:
226+
result = processor._normalize_inputs(inputs)
227+
assert len(result) == 1
228+
assert result[0].s3_input.s3_uri.startswith("s3://")
229+
mock_upload.assert_called_once()
230+
231+
def test_normalize_inputs_local_path_logs_upload_info(self, mock_session):
232+
"""Uploading a local path should log an info message."""
233+
processor = Processor(
234+
role="arn:aws:iam::123456789012:role/SageMakerRole",
235+
image_uri="test-image:latest",
236+
instance_count=1,
237+
instance_type="ml.m5.xlarge",
238+
sagemaker_session=mock_session,
239+
)
240+
processor._current_job_name = "test-job"
241+
242+
with tempfile.TemporaryDirectory() as tmpdir:
243+
s3_input = ProcessingS3Input(
244+
s3_uri=tmpdir,
245+
local_path="/opt/ml/processing/input",
246+
s3_data_type="S3Prefix",
247+
s3_input_mode="File",
248+
)
249+
inputs = [ProcessingInput(input_name="local-input", s3_input=s3_input)]
250+
251+
with patch(
252+
"sagemaker.core.s3.S3Uploader.upload",
253+
return_value="s3://test-bucket/uploaded",
254+
):
255+
with patch("sagemaker.core.processing.logger") as mock_logger:
256+
processor._normalize_inputs(inputs)
257+
mock_logger.info.assert_any_call(
258+
"Uploading local input '%s' from %s to %s",
259+
"local-input",
260+
tmpdir,
261+
f"s3://test-bucket/sagemaker/test-job/input/local-input",
262+
)
263+
264+
49265
class TestProcessorNormalizeArgs:
50266
def test_normalize_args_with_pipeline_variable_code(self, mock_session):
51267
from sagemaker.core.workflow.pipeline_context import PipelineSession

0 commit comments

Comments
 (0)