Skip to content

Commit 3c68306

Browse files
committed
fix: address review comments (iteration #1)
1 parent 0112c1e commit 3c68306

File tree

3 files changed

+160
-50
lines changed

3 files changed

+160
-50
lines changed

sagemaker-core/src/sagemaker/core/processing.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from sagemaker.core.local.local_session import LocalSession
5353
from sagemaker.core.helper.session_helper import Session
5454
from sagemaker.core.shapes import ProcessingInput, ProcessingOutput, ProcessingS3Input
55+
from sagemaker.core.shapes.shapes import _DEFAULT_S3_DATA_TYPE, _DEFAULT_S3_INPUT_MODE
5556
from sagemaker.core.resources import ProcessingJob
5657
from sagemaker.core.workflow.pipeline_context import PipelineSession
5758
from sagemaker.core.common_utils import (
@@ -418,16 +419,13 @@ def _normalize_inputs(self, inputs=None, kms_key=None):
418419
if file_input.dataset_definition:
419420
normalized_inputs.append(file_input)
420421
continue
421-
# Handle case where source was set but s3_input was not created
422-
# (e.g., if ProcessingInput was constructed without using the
423-
# convenience __init__ logic)
424-
if file_input.s3_input is None and getattr(file_input, "source", None):
425-
file_input.s3_input = ProcessingS3Input(
426-
s3_uri=file_input.source,
427-
s3_data_type="S3Prefix",
428-
s3_input_mode="File",
422+
if file_input.s3_input is None:
423+
raise ValueError(
424+
f"ProcessingInput '{file_input.input_name}' has no "
425+
"s3_input or dataset_definition. Provide 'source', "
426+
"'s3_input', or 'dataset_definition'."
429427
)
430-
if file_input.s3_input and is_pipeline_variable(file_input.s3_input.s3_uri):
428+
if is_pipeline_variable(file_input.s3_input.s3_uri):
431429
normalized_inputs.append(file_input)
432430
continue
433431
# If the s3_uri is not an s3_uri, create one.

sagemaker-core/src/sagemaker/core/shapes/shapes.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6354,7 +6354,7 @@ class ProcessingS3Input(Base):
63546354

63556355
Attributes
63566356
----------------------
6357-
s3_uri: The URI of the Amazon S3 prefix Amazon SageMaker downloads data required to run a processing job.
6357+
s3_uri: The URI of the Amazon S3 prefix Amazon SageMaker downloads data required to run a processing job. Also accepts local file or directory paths, which will be automatically uploaded to S3 during job normalization.
63586358
local_path: The local path in your container where you want Amazon SageMaker to write input data to. LocalPath is an absolute path to the input data and must begin with /opt/ml/processing/. LocalPath is a required parameter when AppManaged is False (default).
63596359
s3_data_type: Whether you use an S3Prefix or a ManifestFile for the data type. If you choose S3Prefix, S3Uri identifies a key name prefix. Amazon SageMaker uses all objects with the specified key name prefix for the processing job. If you choose ManifestFile, S3Uri identifies an object that is a manifest file containing a list of object keys that you want Amazon SageMaker to use for the processing job.
63606360
s3_input_mode: Whether to use File or Pipe input mode. In File mode, Amazon SageMaker copies the data from the input source onto the local ML storage volume before starting your processing container. This is the most commonly used input mode. In Pipe mode, Amazon SageMaker streams input data from the source directly to your processing container into named pipes without using the ML storage volume.
@@ -6474,6 +6474,11 @@ class DatasetDefinition(Base):
64746474
snowflake_dataset_definition: Optional[SnowflakeDatasetDefinition] = Unassigned()
64756475

64766476

6477+
# Default constants for ProcessingS3Input creation from source parameter
6478+
_DEFAULT_S3_DATA_TYPE = "S3Prefix"
6479+
_DEFAULT_S3_INPUT_MODE = "File"
6480+
6481+
64776482
class ProcessingInput(Base):
64786483
"""
64796484
ProcessingInput
@@ -6485,12 +6490,54 @@ class ProcessingInput(Base):
64856490
app_managed: When True, input operations such as data download are managed natively by the processing job application. When False (default), input operations are managed by Amazon SageMaker.
64866491
s3_input: Configuration for downloading input data from Amazon S3 into the processing container.
64876492
dataset_definition: Configuration for a Dataset Definition input.
6493+
source: Convenience parameter that accepts a local file/directory path or S3 URI.
6494+
When provided (and s3_input is not), a ProcessingS3Input is automatically created.
6495+
Local paths will be uploaded to S3 during job normalization.
6496+
Cannot be specified together with s3_input.
64886497
"""
64896498

6499+
model_config = ConfigDict(
6500+
protected_namespaces=(),
6501+
validate_assignment=True,
6502+
extra="forbid",
6503+
json_schema_extra={"exclude": {"source"}},
6504+
)
6505+
64906506
input_name: StrPipeVar
64916507
app_managed: Optional[bool] = Unassigned()
64926508
s3_input: Optional[ProcessingS3Input] = Unassigned()
64936509
dataset_definition: Optional[DatasetDefinition] = Unassigned()
6510+
source: Optional[StrPipeVar] = Field(default=None, exclude=True)
6511+
6512+
@classmethod
6513+
def _validate_source_and_s3_input(cls, values):
6514+
"""Validate and handle the source convenience parameter."""
6515+
source = values.get("source")
6516+
s3_input = values.get("s3_input")
6517+
6518+
if source is not None and s3_input is not None and not isinstance(
6519+
s3_input, type(Unassigned())
6520+
):
6521+
raise ValueError(
6522+
"Cannot specify both 'source' and 's3_input'. "
6523+
"Use 'source' for convenience (local paths or S3 URIs) "
6524+
"or 's3_input' for full control, but not both."
6525+
)
6526+
6527+
if source is not None and (
6528+
s3_input is None or isinstance(s3_input, type(Unassigned()))
6529+
):
6530+
values["s3_input"] = ProcessingS3Input(
6531+
s3_uri=source,
6532+
s3_data_type=_DEFAULT_S3_DATA_TYPE,
6533+
s3_input_mode=_DEFAULT_S3_INPUT_MODE,
6534+
)
6535+
6536+
return values
6537+
6538+
def __init__(self, **data):
6539+
data = ProcessingInput._validate_source_and_s3_input(data)
6540+
super().__init__(**data)
64946541

64956542

64966543
class EndpointInput(Base):

sagemaker-core/tests/unit/test_processing_local_input.py

Lines changed: 105 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
"""Tests for ProcessingInput source parameter and local file upload behavior."""
14-
from __future__ import absolute_import
14+
from __future__ import annotations
1515

1616
import os
1717
import tempfile
@@ -22,13 +22,21 @@
2222
from sagemaker.core.shapes import ProcessingInput, ProcessingS3Input
2323
from sagemaker.core.processing import Processor
2424

25+
# Test constants
26+
FAKE_ACCOUNT_ID = "012345678901"
27+
FAKE_ROLE_ARN = f"arn:aws:iam::{FAKE_ACCOUNT_ID}:role/SageMakerRole"
28+
FAKE_IMAGE_URI = (
29+
f"{FAKE_ACCOUNT_ID}.dkr.ecr.us-west-2.amazonaws.com/my-image:latest"
30+
)
31+
FAKE_BUCKET = "my-bucket"
32+
2533

2634
@pytest.fixture
2735
def sagemaker_session():
2836
session = MagicMock()
29-
session.default_bucket.return_value = "my-bucket"
37+
session.default_bucket.return_value = FAKE_BUCKET
3038
session.default_bucket_prefix = None
31-
session.expand_role.return_value = "arn:aws:iam::012345678901:role/SageMakerRole"
39+
session.expand_role.return_value = FAKE_ROLE_ARN
3240
session.sagemaker_client = MagicMock()
3341
session.sagemaker_config = None
3442
type(session).local_mode = PropertyMock(return_value=False)
@@ -38,8 +46,8 @@ def sagemaker_session():
3846
@pytest.fixture
3947
def processor(sagemaker_session):
4048
return Processor(
41-
role="arn:aws:iam::012345678901:role/SageMakerRole",
42-
image_uri="012345678901.dkr.ecr.us-west-2.amazonaws.com/my-image:latest",
49+
role=FAKE_ROLE_ARN,
50+
image_uri=FAKE_IMAGE_URI,
4351
instance_count=1,
4452
instance_type="ml.m5.xlarge",
4553
sagemaker_session=sagemaker_session,
@@ -49,7 +57,7 @@ def processor(sagemaker_session):
4957
class TestProcessingInputSourceParameter:
5058
"""Tests for the 'source' convenience parameter on ProcessingInput."""
5159

52-
def test_processing_input_source_parameter_creates_s3_input(self):
60+
def test_source_parameter_creates_s3_input(self):
5361
"""Test that providing 'source' auto-creates a ProcessingS3Input."""
5462
proc_input = ProcessingInput(
5563
input_name="my-input",
@@ -61,18 +69,20 @@ def test_processing_input_source_parameter_creates_s3_input(self):
6169
assert proc_input.s3_input.s3_input_mode == "File"
6270
assert proc_input.source == "/local/path/to/data"
6371

64-
def test_processing_input_source_with_s3_uri_passthrough(self):
65-
"""Test that providing 'source' with an S3 URI creates s3_input with that URI."""
72+
def test_source_with_s3_uri_passthrough(self):
73+
"""Test that 'source' with an S3 URI creates s3_input."""
6674
proc_input = ProcessingInput(
6775
input_name="my-input",
6876
source="s3://my-bucket/my-prefix/data",
6977
)
7078
assert proc_input.s3_input is not None
7179
assert proc_input.s3_input.s3_uri == "s3://my-bucket/my-prefix/data"
7280

73-
def test_processing_input_source_and_s3_input_raises_error(self):
81+
def test_source_and_s3_input_raises_error(self):
7482
"""Test that providing both 'source' and 's3_input' raises ValueError."""
75-
with pytest.raises(ValueError, match="Cannot specify both 'source' and 's3_input'"):
83+
with pytest.raises(
84+
ValueError, match="Cannot specify both 'source' and 's3_input'"
85+
):
7686
ProcessingInput(
7787
input_name="my-input",
7888
source="/local/path/to/data",
@@ -83,7 +93,7 @@ def test_processing_input_source_and_s3_input_raises_error(self):
8393
),
8494
)
8595

86-
def test_processing_input_without_source_works_as_before(self):
96+
def test_without_source_works_as_before(self):
8797
"""Test that ProcessingInput without 'source' works as before."""
8898
proc_input = ProcessingInput(
8999
input_name="my-input",
@@ -97,19 +107,34 @@ def test_processing_input_without_source_works_as_before(self):
97107
assert proc_input.s3_input.s3_uri == "s3://my-bucket/data"
98108
assert proc_input.source is None
99109

110+
def test_source_none_and_s3_input_none_no_dataset(self):
111+
"""Test ProcessingInput with neither source nor s3_input.
112+
113+
When neither source nor s3_input is provided (and no
114+
dataset_definition), _normalize_inputs should raise ValueError.
115+
"""
116+
proc_input = ProcessingInput(input_name="my-input")
117+
assert proc_input.s3_input is None or proc_input.s3_input.__class__.__name__ == "Unassigned"
118+
assert proc_input.source is None
119+
100120

101121
class TestNormalizeInputsLocalUpload:
102122
"""Tests for _normalize_inputs handling of local file paths."""
103123

104124
@patch("sagemaker.core.processing.s3.S3Uploader.upload")
105-
def test_normalize_inputs_with_local_file_path_uploads_to_s3(
125+
def test_local_file_path_uploads_to_s3(
106126
self, mock_upload, processor
107127
):
108-
"""Test that a local file path in s3_uri triggers upload to S3."""
109-
mock_upload.return_value = "s3://my-bucket/job-name/input/my-input/data.csv"
128+
"""Test that a local file path in s3_uri triggers upload."""
129+
expected_uri = (
130+
f"s3://{FAKE_BUCKET}/job-name/input/my-input/data.csv"
131+
)
132+
mock_upload.return_value = expected_uri
110133
processor._current_job_name = "my-job"
111134

112-
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as f:
135+
with tempfile.NamedTemporaryFile(
136+
suffix=".csv", delete=False
137+
) as f:
113138
local_path = f.name
114139
f.write(b"col1,col2\n1,2\n")
115140

@@ -128,20 +153,25 @@ def test_normalize_inputs_with_local_file_path_uploads_to_s3(
128153
normalized = processor._normalize_inputs(inputs)
129154

130155
assert len(normalized) == 1
131-
assert normalized[0].s3_input.s3_uri == "s3://my-bucket/job-name/input/my-input/data.csv"
156+
assert normalized[0].s3_input.s3_uri == expected_uri
132157
mock_upload.assert_called_once()
133158
finally:
134159
os.unlink(local_path)
135160

136161
@patch("sagemaker.core.processing.s3.S3Uploader.upload")
137-
def test_normalize_inputs_with_source_local_path_uploads_to_s3(
162+
def test_source_local_path_uploads_to_s3(
138163
self, mock_upload, processor
139164
):
140-
"""Test that using 'source' with a local path triggers upload to S3."""
141-
mock_upload.return_value = "s3://my-bucket/job-name/input/my-input/data.csv"
165+
"""Test that using 'source' with a local path triggers upload."""
166+
expected_uri = (
167+
f"s3://{FAKE_BUCKET}/job-name/input/my-input/data.csv"
168+
)
169+
mock_upload.return_value = expected_uri
142170
processor._current_job_name = "my-job"
143171

144-
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as f:
172+
with tempfile.NamedTemporaryFile(
173+
suffix=".csv", delete=False
174+
) as f:
145175
local_path = f.name
146176
f.write(b"col1,col2\n1,2\n")
147177

@@ -156,13 +186,13 @@ def test_normalize_inputs_with_source_local_path_uploads_to_s3(
156186
normalized = processor._normalize_inputs(inputs)
157187

158188
assert len(normalized) == 1
159-
assert normalized[0].s3_input.s3_uri == "s3://my-bucket/job-name/input/my-input/data.csv"
189+
assert normalized[0].s3_input.s3_uri == expected_uri
160190
mock_upload.assert_called_once()
161191
finally:
162192
os.unlink(local_path)
163193

164-
def test_normalize_inputs_with_s3_uri_does_not_upload(self, processor):
165-
"""Test that an S3 URI in s3_uri does not trigger upload."""
194+
def test_s3_uri_does_not_upload(self, processor):
195+
"""Test that an S3 URI does not trigger upload."""
166196
processor._current_job_name = "my-job"
167197

168198
inputs = [
@@ -176,24 +206,33 @@ def test_normalize_inputs_with_s3_uri_does_not_upload(self, processor):
176206
)
177207
]
178208

179-
with patch("sagemaker.core.processing.s3.S3Uploader.upload") as mock_upload:
209+
with patch(
210+
"sagemaker.core.processing.s3.S3Uploader.upload"
211+
) as mock_upload:
180212
normalized = processor._normalize_inputs(inputs)
181213

182214
assert len(normalized) == 1
183-
assert normalized[0].s3_input.s3_uri == "s3://my-bucket/existing-data"
215+
assert (
216+
normalized[0].s3_input.s3_uri
217+
== "s3://my-bucket/existing-data"
218+
)
184219
mock_upload.assert_not_called()
185220

186221
@patch("sagemaker.core.processing.s3.S3Uploader.upload")
187-
def test_normalize_inputs_with_local_dir_path_uploads_to_s3(
222+
def test_local_dir_path_uploads_to_s3(
188223
self, mock_upload, processor
189224
):
190-
"""Test that a local directory path in s3_uri triggers upload to S3."""
191-
mock_upload.return_value = "s3://my-bucket/job-name/input/my-input"
225+
"""Test that a local directory path triggers upload."""
226+
expected_uri = (
227+
f"s3://{FAKE_BUCKET}/job-name/input/my-input"
228+
)
229+
mock_upload.return_value = expected_uri
192230
processor._current_job_name = "my-job"
193231

194232
with tempfile.TemporaryDirectory() as tmpdir:
195-
# Create a file in the directory
196-
with open(os.path.join(tmpdir, "data.csv"), "w") as f:
233+
with open(
234+
os.path.join(tmpdir, "data.csv"), "w"
235+
) as f:
197236
f.write("col1,col2\n1,2\n")
198237

199238
inputs = [
@@ -206,21 +245,44 @@ def test_normalize_inputs_with_local_dir_path_uploads_to_s3(
206245
normalized = processor._normalize_inputs(inputs)
207246

208247
assert len(normalized) == 1
209-
assert normalized[0].s3_input.s3_uri == "s3://my-bucket/job-name/input/my-input"
248+
assert normalized[0].s3_input.s3_uri == expected_uri
210249
mock_upload.assert_called_once()
211250

251+
def test_no_s3_input_no_source_no_dataset_raises_error(
252+
self, processor
253+
):
254+
"""Test that missing s3_input, source, and dataset raises error."""
255+
processor._current_job_name = "my-job"
256+
257+
inputs = [
258+
ProcessingInput(input_name="my-input")
259+
]
260+
261+
with pytest.raises(ValueError, match="has no s3_input"):
262+
processor._normalize_inputs(inputs)
263+
212264
@patch("sagemaker.core.processing.s3.S3Uploader.upload")
213-
@patch("sagemaker.core.workflow.utilities._pipeline_config")
214-
def test_normalize_inputs_with_pipeline_config_generates_correct_s3_path(
215-
self, mock_pipeline_config, mock_upload, processor
265+
def test_pipeline_config_generates_correct_s3_path(
266+
self, mock_upload, processor
216267
):
217268
"""Test that pipeline config generates the correct S3 path."""
218-
mock_pipeline_config.pipeline_name = "my-pipeline"
219-
mock_pipeline_config.step_name = "my-step"
220-
mock_upload.return_value = "s3://my-bucket/my-pipeline/my-step/input/my-input"
269+
# _normalize_inputs imports _pipeline_config from
270+
# sagemaker.core.workflow.utilities at call time via:
271+
# from sagemaker.core.workflow.utilities import _pipeline_config
272+
# So we patch it at the module where it's looked up.
273+
mock_config = MagicMock()
274+
mock_config.pipeline_name = "my-pipeline"
275+
mock_config.step_name = "my-step"
276+
277+
expected_uri = (
278+
f"s3://{FAKE_BUCKET}/my-pipeline/my-step/input/my-input"
279+
)
280+
mock_upload.return_value = expected_uri
221281
processor._current_job_name = "my-job"
222282

223-
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as f:
283+
with tempfile.NamedTemporaryFile(
284+
suffix=".csv", delete=False
285+
) as f:
224286
local_path = f.name
225287
f.write(b"col1,col2\n1,2\n")
226288

@@ -232,11 +294,14 @@ def test_normalize_inputs_with_pipeline_config_generates_correct_s3_path(
232294
)
233295
]
234296

235-
normalized = processor._normalize_inputs(inputs)
297+
with patch(
298+
"sagemaker.core.workflow.utilities._pipeline_config",
299+
mock_config,
300+
):
301+
normalized = processor._normalize_inputs(inputs)
236302

237303
assert len(normalized) == 1
238304
mock_upload.assert_called_once()
239-
# Verify the desired_s3_uri contains pipeline path components
240305
call_kwargs = mock_upload.call_args[1]
241306
assert "my-pipeline" in call_kwargs["desired_s3_uri"]
242307
assert "my-step" in call_kwargs["desired_s3_uri"]

0 commit comments

Comments
 (0)