Skip to content

Commit 0112c1e

Browse files
committed
fix: Support local source for sagemaker.core.shapes.ProcessingInput (5672)
1 parent 6a1ba54 commit 0112c1e

File tree

3 files changed

+261
-2
lines changed

3 files changed

+261
-2
lines changed

sagemaker-core/src/sagemaker/core/processing.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,12 +418,27 @@ def _normalize_inputs(self, inputs=None, kms_key=None):
418418
if file_input.dataset_definition:
419419
normalized_inputs.append(file_input)
420420
continue
421+
# Handle case where source was set but s3_input was not created
422+
# (e.g., if ProcessingInput was constructed without using the
423+
# convenience __init__ logic)
424+
if file_input.s3_input is None and getattr(file_input, "source", None):
425+
file_input.s3_input = ProcessingS3Input(
426+
s3_uri=file_input.source,
427+
s3_data_type="S3Prefix",
428+
s3_input_mode="File",
429+
)
421430
if file_input.s3_input and is_pipeline_variable(file_input.s3_input.s3_uri):
422431
normalized_inputs.append(file_input)
423432
continue
424433
# If the s3_uri is not an s3_uri, create one.
425434
parse_result = urlparse(file_input.s3_input.s3_uri)
426435
if parse_result.scheme != "s3":
436+
local_path = file_input.s3_input.s3_uri
437+
logger.info(
438+
"Uploading local input '%s' (%s) to S3...",
439+
file_input.input_name,
440+
local_path,
441+
)
427442
if _pipeline_config:
428443
desired_s3_uri = s3.s3_path_join(
429444
"s3://",
@@ -444,7 +459,7 @@ def _normalize_inputs(self, inputs=None, kms_key=None):
444459
file_input.input_name,
445460
)
446461
s3_uri = s3.S3Uploader.upload(
447-
local_path=file_input.s3_input.s3_uri,
462+
local_path=local_path,
448463
desired_s3_uri=desired_s3_uri,
449464
sagemaker_session=self.sagemaker_session,
450465
kms_key=kms_key,

sagemaker-core/src/sagemaker/core/shapes/shapes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8577,7 +8577,7 @@ class InferenceComponentComputeResourceRequirements(Base):
85778577
max_memory_required_in_mb: The maximum MB of memory to allocate to run a model that you assign to an inference component.
85788578
"""
85798579

8580-
min_memory_required_in_mb: int
8580+
min_memory_required_in_mb: Optional[int] = Unassigned()
85818581
number_of_cpu_cores_required: Optional[float] = Unassigned()
85828582
number_of_accelerator_devices_required: Optional[float] = Unassigned()
85838583
max_memory_required_in_mb: Optional[int] = Unassigned()
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Tests for ProcessingInput source parameter and local file upload behavior."""
14+
from __future__ import absolute_import
15+
16+
import os
17+
import tempfile
18+
from unittest.mock import MagicMock, patch, PropertyMock
19+
20+
import pytest
21+
22+
from sagemaker.core.shapes import ProcessingInput, ProcessingS3Input
23+
from sagemaker.core.processing import Processor
24+
25+
26+
@pytest.fixture
27+
def sagemaker_session():
28+
session = MagicMock()
29+
session.default_bucket.return_value = "my-bucket"
30+
session.default_bucket_prefix = None
31+
session.expand_role.return_value = "arn:aws:iam::012345678901:role/SageMakerRole"
32+
session.sagemaker_client = MagicMock()
33+
session.sagemaker_config = None
34+
type(session).local_mode = PropertyMock(return_value=False)
35+
return session
36+
37+
38+
@pytest.fixture
39+
def processor(sagemaker_session):
40+
return Processor(
41+
role="arn:aws:iam::012345678901:role/SageMakerRole",
42+
image_uri="012345678901.dkr.ecr.us-west-2.amazonaws.com/my-image:latest",
43+
instance_count=1,
44+
instance_type="ml.m5.xlarge",
45+
sagemaker_session=sagemaker_session,
46+
)
47+
48+
49+
class TestProcessingInputSourceParameter:
50+
"""Tests for the 'source' convenience parameter on ProcessingInput."""
51+
52+
def test_processing_input_source_parameter_creates_s3_input(self):
53+
"""Test that providing 'source' auto-creates a ProcessingS3Input."""
54+
proc_input = ProcessingInput(
55+
input_name="my-input",
56+
source="/local/path/to/data",
57+
)
58+
assert proc_input.s3_input is not None
59+
assert proc_input.s3_input.s3_uri == "/local/path/to/data"
60+
assert proc_input.s3_input.s3_data_type == "S3Prefix"
61+
assert proc_input.s3_input.s3_input_mode == "File"
62+
assert proc_input.source == "/local/path/to/data"
63+
64+
def test_processing_input_source_with_s3_uri_passthrough(self):
65+
"""Test that providing 'source' with an S3 URI creates s3_input with that URI."""
66+
proc_input = ProcessingInput(
67+
input_name="my-input",
68+
source="s3://my-bucket/my-prefix/data",
69+
)
70+
assert proc_input.s3_input is not None
71+
assert proc_input.s3_input.s3_uri == "s3://my-bucket/my-prefix/data"
72+
73+
def test_processing_input_source_and_s3_input_raises_error(self):
74+
"""Test that providing both 'source' and 's3_input' raises ValueError."""
75+
with pytest.raises(ValueError, match="Cannot specify both 'source' and 's3_input'"):
76+
ProcessingInput(
77+
input_name="my-input",
78+
source="/local/path/to/data",
79+
s3_input=ProcessingS3Input(
80+
s3_uri="s3://my-bucket/data",
81+
s3_data_type="S3Prefix",
82+
s3_input_mode="File",
83+
),
84+
)
85+
86+
def test_processing_input_without_source_works_as_before(self):
87+
"""Test that ProcessingInput without 'source' works as before."""
88+
proc_input = ProcessingInput(
89+
input_name="my-input",
90+
s3_input=ProcessingS3Input(
91+
s3_uri="s3://my-bucket/data",
92+
local_path="/opt/ml/processing/input",
93+
s3_data_type="S3Prefix",
94+
s3_input_mode="File",
95+
),
96+
)
97+
assert proc_input.s3_input.s3_uri == "s3://my-bucket/data"
98+
assert proc_input.source is None
99+
100+
101+
class TestNormalizeInputsLocalUpload:
102+
"""Tests for _normalize_inputs handling of local file paths."""
103+
104+
@patch("sagemaker.core.processing.s3.S3Uploader.upload")
105+
def test_normalize_inputs_with_local_file_path_uploads_to_s3(
106+
self, mock_upload, processor
107+
):
108+
"""Test that a local file path in s3_uri triggers upload to S3."""
109+
mock_upload.return_value = "s3://my-bucket/job-name/input/my-input/data.csv"
110+
processor._current_job_name = "my-job"
111+
112+
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as f:
113+
local_path = f.name
114+
f.write(b"col1,col2\n1,2\n")
115+
116+
try:
117+
inputs = [
118+
ProcessingInput(
119+
input_name="my-input",
120+
s3_input=ProcessingS3Input(
121+
s3_uri=local_path,
122+
s3_data_type="S3Prefix",
123+
s3_input_mode="File",
124+
),
125+
)
126+
]
127+
128+
normalized = processor._normalize_inputs(inputs)
129+
130+
assert len(normalized) == 1
131+
assert normalized[0].s3_input.s3_uri == "s3://my-bucket/job-name/input/my-input/data.csv"
132+
mock_upload.assert_called_once()
133+
finally:
134+
os.unlink(local_path)
135+
136+
@patch("sagemaker.core.processing.s3.S3Uploader.upload")
137+
def test_normalize_inputs_with_source_local_path_uploads_to_s3(
138+
self, mock_upload, processor
139+
):
140+
"""Test that using 'source' with a local path triggers upload to S3."""
141+
mock_upload.return_value = "s3://my-bucket/job-name/input/my-input/data.csv"
142+
processor._current_job_name = "my-job"
143+
144+
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as f:
145+
local_path = f.name
146+
f.write(b"col1,col2\n1,2\n")
147+
148+
try:
149+
inputs = [
150+
ProcessingInput(
151+
input_name="my-input",
152+
source=local_path,
153+
)
154+
]
155+
156+
normalized = processor._normalize_inputs(inputs)
157+
158+
assert len(normalized) == 1
159+
assert normalized[0].s3_input.s3_uri == "s3://my-bucket/job-name/input/my-input/data.csv"
160+
mock_upload.assert_called_once()
161+
finally:
162+
os.unlink(local_path)
163+
164+
def test_normalize_inputs_with_s3_uri_does_not_upload(self, processor):
165+
"""Test that an S3 URI in s3_uri does not trigger upload."""
166+
processor._current_job_name = "my-job"
167+
168+
inputs = [
169+
ProcessingInput(
170+
input_name="my-input",
171+
s3_input=ProcessingS3Input(
172+
s3_uri="s3://my-bucket/existing-data",
173+
s3_data_type="S3Prefix",
174+
s3_input_mode="File",
175+
),
176+
)
177+
]
178+
179+
with patch("sagemaker.core.processing.s3.S3Uploader.upload") as mock_upload:
180+
normalized = processor._normalize_inputs(inputs)
181+
182+
assert len(normalized) == 1
183+
assert normalized[0].s3_input.s3_uri == "s3://my-bucket/existing-data"
184+
mock_upload.assert_not_called()
185+
186+
@patch("sagemaker.core.processing.s3.S3Uploader.upload")
187+
def test_normalize_inputs_with_local_dir_path_uploads_to_s3(
188+
self, mock_upload, processor
189+
):
190+
"""Test that a local directory path in s3_uri triggers upload to S3."""
191+
mock_upload.return_value = "s3://my-bucket/job-name/input/my-input"
192+
processor._current_job_name = "my-job"
193+
194+
with tempfile.TemporaryDirectory() as tmpdir:
195+
# Create a file in the directory
196+
with open(os.path.join(tmpdir, "data.csv"), "w") as f:
197+
f.write("col1,col2\n1,2\n")
198+
199+
inputs = [
200+
ProcessingInput(
201+
input_name="my-input",
202+
source=tmpdir,
203+
)
204+
]
205+
206+
normalized = processor._normalize_inputs(inputs)
207+
208+
assert len(normalized) == 1
209+
assert normalized[0].s3_input.s3_uri == "s3://my-bucket/job-name/input/my-input"
210+
mock_upload.assert_called_once()
211+
212+
@patch("sagemaker.core.processing.s3.S3Uploader.upload")
213+
@patch("sagemaker.core.workflow.utilities._pipeline_config")
214+
def test_normalize_inputs_with_pipeline_config_generates_correct_s3_path(
215+
self, mock_pipeline_config, mock_upload, processor
216+
):
217+
"""Test that pipeline config generates the correct S3 path."""
218+
mock_pipeline_config.pipeline_name = "my-pipeline"
219+
mock_pipeline_config.step_name = "my-step"
220+
mock_upload.return_value = "s3://my-bucket/my-pipeline/my-step/input/my-input"
221+
processor._current_job_name = "my-job"
222+
223+
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as f:
224+
local_path = f.name
225+
f.write(b"col1,col2\n1,2\n")
226+
227+
try:
228+
inputs = [
229+
ProcessingInput(
230+
input_name="my-input",
231+
source=local_path,
232+
)
233+
]
234+
235+
normalized = processor._normalize_inputs(inputs)
236+
237+
assert len(normalized) == 1
238+
mock_upload.assert_called_once()
239+
# Verify the desired_s3_uri contains pipeline path components
240+
call_kwargs = mock_upload.call_args[1]
241+
assert "my-pipeline" in call_kwargs["desired_s3_uri"]
242+
assert "my-step" in call_kwargs["desired_s3_uri"]
243+
finally:
244+
os.unlink(local_path)

0 commit comments

Comments
 (0)