Skip to content

Commit 4e87b6d

Browse files
committed
fix: address review comments (iteration #1)
1 parent f1cfcfb commit 4e87b6d

File tree

2 files changed

+127
-119
lines changed

2 files changed

+127
-119
lines changed

sagemaker-core/src/sagemaker/core/processing.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
normalization. For a convenient way to create ``ProcessingInput`` objects from
2222
local sources, use the :func:`processing_input_from_local` helper function.
2323
"""
24-
from __future__ import absolute_import
24+
from __future__ import annotations
2525

2626
import json
2727
import logging
@@ -93,11 +93,11 @@
9393
def processing_input_from_local(
9494
source: str,
9595
destination: str,
96-
input_name: Optional[str] = None,
96+
input_name: str | None = None,
9797
s3_data_type: str = "S3Prefix",
9898
s3_input_mode: str = "File",
99-
s3_data_distribution_type: Optional[str] = None,
100-
s3_compression_type: Optional[str] = None,
99+
s3_data_distribution_type: str | None = None,
100+
s3_compression_type: str | None = None,
101101
) -> ProcessingInput:
102102
"""Creates a ProcessingInput from a local file/directory path or S3 URI.
103103
@@ -131,7 +131,7 @@ def processing_input_from_local(
131131
132132
Raises:
133133
ValueError: If ``source`` is a local path that does not exist.
134-
ValueError: If ``source`` is empty or None.
134+
TypeError: If ``source`` is not a string.
135135
136136
Examples:
137137
Create an input from a local directory::
@@ -152,12 +152,15 @@ def processing_input_from_local(
152152
destination="/opt/ml/processing/input/data",
153153
)
154154
"""
155-
if not source:
155+
if not isinstance(source, str) or not source:
156156
raise ValueError(
157-
f"source must be a valid local path or S3 URI, got: {source!r}"
157+
f"source must be a non-empty string containing a valid local path "
158+
f"or S3 URI, got: {source!r}"
158159
)
159160

160-
# Check if source is a local path (not an S3 URI)
161+
# Check if source is a local path (not a remote URI).
162+
# Note: On Windows, absolute paths like C:\data\file.csv will have
163+
# parse_result.scheme == 'c', which correctly falls into the local path branch.
161164
parse_result = urlparse(source)
162165
if parse_result.scheme not in ("s3", "http", "https"):
163166
# Treat as local path - validate existence
@@ -172,6 +175,8 @@ def processing_input_from_local(
172175

173176
s3_input_kwargs = {
174177
"s3_uri": source,
178+
# local_path in ProcessingS3Input maps to the container destination path
179+
# where the input data will be made available inside the processing container.
175180
"local_path": destination,
176181
"s3_data_type": s3_data_type,
177182
"s3_input_mode": s3_input_mode,
@@ -534,9 +539,10 @@ def _normalize_inputs(self, inputs=None, kms_key=None):
534539
local_source = url2pathname(parse_result.path)
535540
if not os.path.exists(local_source):
536541
raise ValueError(
537-
f"Input source path does not exist: {file_input.s3_input.s3_uri!r}. "
538-
f"Please provide a valid local path or S3 URI for input "
539-
f"'{file_input.input_name}'."
542+
f"Input source path does not exist: "
543+
f"{file_input.s3_input.s3_uri!r}. "
544+
f"Please provide a valid local path or S3 URI "
545+
f"for input '{file_input.input_name}'."
540546
)
541547
if _pipeline_config:
542548
desired_s3_uri = s3.s3_path_join(
@@ -557,6 +563,7 @@ def _normalize_inputs(self, inputs=None, kms_key=None):
557563
"input",
558564
file_input.input_name,
559565
)
566+
# Log after both pipeline/non-pipeline branches have set desired_s3_uri
560567
logger.info(
561568
"Uploading local input '%s' from %s to %s",
562569
file_input.input_name,

sagemaker-core/tests/unit/test_processing.py

Lines changed: 109 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -50,37 +50,33 @@ def mock_session():
5050
class TestProcessingInputFromLocal:
5151
"""Tests for the processing_input_from_local() factory function."""
5252

53-
def test_processing_input_from_local_with_file_path_creates_valid_input(self):
53+
def test_processing_input_from_local_with_file_path_creates_valid_input(self, tmp_path):
5454
"""A local file path should produce a valid ProcessingInput."""
55-
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".csv") as f:
56-
f.write("col1,col2\n1,2\n")
57-
temp_file = f.name
58-
59-
try:
60-
result = processing_input_from_local(
61-
source=temp_file,
62-
destination="/opt/ml/processing/input/data",
63-
input_name="my-data",
64-
)
65-
assert isinstance(result, ProcessingInput)
66-
assert result.input_name == "my-data"
67-
assert result.s3_input.s3_uri == temp_file
68-
assert result.s3_input.local_path == "/opt/ml/processing/input/data"
69-
assert result.s3_input.s3_data_type == "S3Prefix"
70-
assert result.s3_input.s3_input_mode == "File"
71-
finally:
72-
os.unlink(temp_file)
55+
temp_file = tmp_path / "data.csv"
56+
temp_file.write_text("col1,col2\n1,2\n")
7357

74-
def test_processing_input_from_local_with_directory_path_creates_valid_input(self):
58+
result = processing_input_from_local(
59+
source=str(temp_file),
60+
destination="/opt/ml/processing/input/data",
61+
input_name="my-data",
62+
)
63+
assert isinstance(result, ProcessingInput)
64+
assert result.input_name == "my-data"
65+
assert result.s3_input.s3_uri == str(temp_file)
66+
# local_path in ProcessingS3Input maps to the container destination path
67+
assert result.s3_input.local_path == "/opt/ml/processing/input/data"
68+
assert result.s3_input.s3_data_type == "S3Prefix"
69+
assert result.s3_input.s3_input_mode == "File"
70+
71+
def test_processing_input_from_local_with_directory_path_creates_valid_input(self, tmp_path):
7572
"""A local directory path should produce a valid ProcessingInput."""
76-
with tempfile.TemporaryDirectory() as tmpdir:
77-
result = processing_input_from_local(
78-
source=tmpdir,
79-
destination="/opt/ml/processing/input/data",
80-
input_name="dir-data",
81-
)
82-
assert isinstance(result, ProcessingInput)
83-
assert result.s3_input.s3_uri == tmpdir
73+
result = processing_input_from_local(
74+
source=str(tmp_path),
75+
destination="/opt/ml/processing/input/data",
76+
input_name="dir-data",
77+
)
78+
assert isinstance(result, ProcessingInput)
79+
assert result.s3_input.s3_uri == str(tmp_path)
8480

8581
def test_processing_input_from_local_with_s3_uri_passes_through(self):
8682
"""An S3 URI should pass through without local path validation."""
@@ -100,56 +96,53 @@ def test_processing_input_from_local_with_nonexistent_path_raises_value_error(se
10096
destination="/opt/ml/processing/input/data",
10197
)
10298

103-
def test_processing_input_from_local_with_custom_input_name(self):
99+
def test_processing_input_from_local_with_custom_input_name(self, tmp_path):
104100
"""Custom input_name should be set on the ProcessingInput."""
105-
with tempfile.TemporaryDirectory() as tmpdir:
106-
result = processing_input_from_local(
107-
source=tmpdir,
108-
destination="/opt/ml/processing/input/data",
109-
input_name="custom-name",
110-
)
111-
assert result.input_name == "custom-name"
101+
result = processing_input_from_local(
102+
source=str(tmp_path),
103+
destination="/opt/ml/processing/input/data",
104+
input_name="custom-name",
105+
)
106+
assert result.input_name == "custom-name"
112107

113-
def test_processing_input_from_local_default_parameters(self):
108+
def test_processing_input_from_local_default_parameters(self, tmp_path):
114109
"""Default parameters should be applied correctly."""
115-
with tempfile.TemporaryDirectory() as tmpdir:
116-
result = processing_input_from_local(
117-
source=tmpdir,
118-
destination="/opt/ml/processing/input/data",
119-
)
120-
assert result.input_name is None
121-
assert result.s3_input.s3_data_type == "S3Prefix"
122-
assert result.s3_input.s3_input_mode == "File"
110+
result = processing_input_from_local(
111+
source=str(tmp_path),
112+
destination="/opt/ml/processing/input/data",
113+
)
114+
assert result.input_name is None
115+
assert result.s3_input.s3_data_type == "S3Prefix"
116+
assert result.s3_input.s3_input_mode == "File"
123117

124118
def test_processing_input_from_local_with_empty_source_raises_value_error(self):
125119
"""Empty source should raise ValueError."""
126-
with pytest.raises(ValueError, match="source must be a valid local path or S3 URI"):
120+
with pytest.raises(ValueError, match="source must be a non-empty string"):
127121
processing_input_from_local(
128122
source="",
129123
destination="/opt/ml/processing/input/data",
130124
)
131125

132126
def test_processing_input_from_local_with_none_source_raises_value_error(self):
133127
"""None source should raise ValueError."""
134-
with pytest.raises(ValueError, match="source must be a valid local path or S3 URI"):
128+
with pytest.raises(ValueError, match="source must be a non-empty string"):
135129
processing_input_from_local(
136130
source=None,
137131
destination="/opt/ml/processing/input/data",
138132
)
139133

140-
def test_processing_input_from_local_with_optional_s3_params(self):
134+
def test_processing_input_from_local_with_optional_s3_params(self, tmp_path):
141135
"""Optional S3 parameters should be passed through."""
142-
with tempfile.TemporaryDirectory() as tmpdir:
143-
result = processing_input_from_local(
144-
source=tmpdir,
145-
destination="/opt/ml/processing/input/data",
146-
s3_data_distribution_type="FullyReplicated",
147-
s3_compression_type="Gzip",
148-
)
149-
assert result.s3_input.s3_data_distribution_type == "FullyReplicated"
150-
assert result.s3_input.s3_compression_type == "Gzip"
136+
result = processing_input_from_local(
137+
source=str(tmp_path),
138+
destination="/opt/ml/processing/input/data",
139+
s3_data_distribution_type="FullyReplicated",
140+
s3_compression_type="Gzip",
141+
)
142+
assert result.s3_input.s3_data_distribution_type == "FullyReplicated"
143+
assert result.s3_input.s3_compression_type == "Gzip"
151144

152-
def test_processing_input_from_local_used_in_normalize_inputs(self, mock_session):
145+
def test_processing_input_from_local_used_in_normalize_inputs(self, mock_session, tmp_path):
153146
"""ProcessingInput from processing_input_from_local should work with _normalize_inputs."""
154147
processor = Processor(
155148
role="arn:aws:iam::123456789012:role/SageMakerRole",
@@ -160,18 +153,18 @@ def test_processing_input_from_local_used_in_normalize_inputs(self, mock_session
160153
)
161154
processor._current_job_name = "test-job"
162155

163-
with tempfile.TemporaryDirectory() as tmpdir:
164-
inp = processing_input_from_local(
165-
source=tmpdir,
166-
destination="/opt/ml/processing/input/data",
167-
input_name="local-data",
168-
)
169-
with patch(
170-
"sagemaker.core.s3.S3Uploader.upload", return_value="s3://bucket/uploaded"
171-
):
172-
result = processor._normalize_inputs([inp])
173-
assert len(result) == 1
174-
assert result[0].s3_input.s3_uri == "s3://bucket/uploaded"
156+
inp = processing_input_from_local(
157+
source=str(tmp_path),
158+
destination="/opt/ml/processing/input/data",
159+
input_name="local-data",
160+
)
161+
with patch(
162+
"sagemaker.core.processing.s3.S3Uploader.upload",
163+
return_value="s3://bucket/uploaded",
164+
):
165+
result = processor._normalize_inputs([inp])
166+
assert len(result) == 1
167+
assert result[0].s3_input.s3_uri == "s3://bucket/uploaded"
175168

176169

177170
class TestNormalizeInputsLocalPathValidation:
@@ -199,7 +192,7 @@ def test_normalize_inputs_with_nonexistent_local_path_raises_value_error(self, m
199192
with pytest.raises(ValueError, match="Input source path does not exist"):
200193
processor._normalize_inputs(inputs)
201194

202-
def test_normalize_inputs_with_local_source_uploads_to_s3(self, mock_session):
195+
def test_normalize_inputs_with_local_source_uploads_to_s3(self, mock_session, tmp_path):
203196
"""A valid local path should be uploaded to S3."""
204197
processor = Processor(
205198
role="arn:aws:iam::123456789012:role/SageMakerRole",
@@ -210,25 +203,24 @@ def test_normalize_inputs_with_local_source_uploads_to_s3(self, mock_session):
210203
)
211204
processor._current_job_name = "test-job"
212205

213-
with tempfile.TemporaryDirectory() as tmpdir:
214-
s3_input = ProcessingS3Input(
215-
s3_uri=tmpdir,
216-
local_path="/opt/ml/processing/input",
217-
s3_data_type="S3Prefix",
218-
s3_input_mode="File",
219-
)
220-
inputs = [ProcessingInput(input_name="local-input", s3_input=s3_input)]
206+
s3_input = ProcessingS3Input(
207+
s3_uri=str(tmp_path),
208+
local_path="/opt/ml/processing/input",
209+
s3_data_type="S3Prefix",
210+
s3_input_mode="File",
211+
)
212+
inputs = [ProcessingInput(input_name="local-input", s3_input=s3_input)]
221213

222-
with patch(
223-
"sagemaker.core.s3.S3Uploader.upload",
224-
return_value="s3://test-bucket/sagemaker/test-job/input/local-input",
225-
) as mock_upload:
226-
result = processor._normalize_inputs(inputs)
227-
assert len(result) == 1
228-
assert result[0].s3_input.s3_uri.startswith("s3://")
229-
mock_upload.assert_called_once()
214+
with patch(
215+
"sagemaker.core.processing.s3.S3Uploader.upload",
216+
return_value="s3://test-bucket/sagemaker/test-job/input/local-input",
217+
) as mock_upload:
218+
result = processor._normalize_inputs(inputs)
219+
assert len(result) == 1
220+
assert result[0].s3_input.s3_uri.startswith("s3://")
221+
mock_upload.assert_called_once()
230222

231-
def test_normalize_inputs_local_path_logs_upload_info(self, mock_session):
223+
def test_normalize_inputs_local_path_logs_upload_info(self, mock_session, tmp_path):
232224
"""Uploading a local path should log an info message."""
233225
processor = Processor(
234226
role="arn:aws:iam::123456789012:role/SageMakerRole",
@@ -239,27 +231,36 @@ def test_normalize_inputs_local_path_logs_upload_info(self, mock_session):
239231
)
240232
processor._current_job_name = "test-job"
241233

242-
with tempfile.TemporaryDirectory() as tmpdir:
243-
s3_input = ProcessingS3Input(
244-
s3_uri=tmpdir,
245-
local_path="/opt/ml/processing/input",
246-
s3_data_type="S3Prefix",
247-
s3_input_mode="File",
248-
)
249-
inputs = [ProcessingInput(input_name="local-input", s3_input=s3_input)]
234+
s3_input = ProcessingS3Input(
235+
s3_uri=str(tmp_path),
236+
local_path="/opt/ml/processing/input",
237+
s3_data_type="S3Prefix",
238+
s3_input_mode="File",
239+
)
240+
inputs = [ProcessingInput(input_name="local-input", s3_input=s3_input)]
250241

251-
with patch(
252-
"sagemaker.core.s3.S3Uploader.upload",
253-
return_value="s3://test-bucket/uploaded",
254-
):
255-
with patch("sagemaker.core.processing.logger") as mock_logger:
256-
processor._normalize_inputs(inputs)
257-
mock_logger.info.assert_any_call(
258-
"Uploading local input '%s' from %s to %s",
259-
"local-input",
260-
tmpdir,
261-
f"s3://test-bucket/sagemaker/test-job/input/local-input",
262-
)
242+
with patch(
243+
"sagemaker.core.processing.s3.S3Uploader.upload",
244+
return_value="s3://test-bucket/uploaded",
245+
):
246+
with patch("sagemaker.core.processing.logger") as mock_logger:
247+
processor._normalize_inputs(inputs)
248+
# Verify the upload log message was emitted with the correct format.
249+
# The exact S3 path depends on default_bucket/prefix/job_name/input_name.
250+
mock_logger.info.assert_called()
251+
log_calls = mock_logger.info.call_args_list
252+
upload_log_found = any(
253+
len(call.args) >= 4
254+
and call.args[0] == "Uploading local input '%s' from %s to %s"
255+
and call.args[1] == "local-input"
256+
and call.args[2] == str(tmp_path)
257+
and call.args[3].startswith("s3://")
258+
for call in log_calls
259+
)
260+
assert upload_log_found, (
261+
f"Expected upload log message not found in logger.info calls: "
262+
f"{log_calls}"
263+
)
263264

264265

265266
class TestProcessorNormalizeArgs:

0 commit comments

Comments
 (0)