@@ -50,37 +50,33 @@ def mock_session():
5050class TestProcessingInputFromLocal :
5151 """Tests for the processing_input_from_local() factory function."""
5252
53- def test_processing_input_from_local_with_file_path_creates_valid_input (self ):
53+ def test_processing_input_from_local_with_file_path_creates_valid_input (self , tmp_path ):
5454 """A local file path should produce a valid ProcessingInput."""
55- with tempfile .NamedTemporaryFile (mode = "w" , delete = False , suffix = ".csv" ) as f :
56- f .write ("col1,col2\n 1,2\n " )
57- temp_file = f .name
58-
59- try :
60- result = processing_input_from_local (
61- source = temp_file ,
62- destination = "/opt/ml/processing/input/data" ,
63- input_name = "my-data" ,
64- )
65- assert isinstance (result , ProcessingInput )
66- assert result .input_name == "my-data"
67- assert result .s3_input .s3_uri == temp_file
68- assert result .s3_input .local_path == "/opt/ml/processing/input/data"
69- assert result .s3_input .s3_data_type == "S3Prefix"
70- assert result .s3_input .s3_input_mode == "File"
71- finally :
72- os .unlink (temp_file )
55+ temp_file = tmp_path / "data.csv"
56+ temp_file .write_text ("col1,col2\n 1,2\n " )
7357
74- def test_processing_input_from_local_with_directory_path_creates_valid_input (self ):
58+ result = processing_input_from_local (
59+ source = str (temp_file ),
60+ destination = "/opt/ml/processing/input/data" ,
61+ input_name = "my-data" ,
62+ )
63+ assert isinstance (result , ProcessingInput )
64+ assert result .input_name == "my-data"
65+ assert result .s3_input .s3_uri == str (temp_file )
66+ # local_path in ProcessingS3Input maps to the container destination path
67+ assert result .s3_input .local_path == "/opt/ml/processing/input/data"
68+ assert result .s3_input .s3_data_type == "S3Prefix"
69+ assert result .s3_input .s3_input_mode == "File"
70+
71+ def test_processing_input_from_local_with_directory_path_creates_valid_input (self , tmp_path ):
7572 """A local directory path should produce a valid ProcessingInput."""
76- with tempfile .TemporaryDirectory () as tmpdir :
77- result = processing_input_from_local (
78- source = tmpdir ,
79- destination = "/opt/ml/processing/input/data" ,
80- input_name = "dir-data" ,
81- )
82- assert isinstance (result , ProcessingInput )
83- assert result .s3_input .s3_uri == tmpdir
73+ result = processing_input_from_local (
74+ source = str (tmp_path ),
75+ destination = "/opt/ml/processing/input/data" ,
76+ input_name = "dir-data" ,
77+ )
78+ assert isinstance (result , ProcessingInput )
79+ assert result .s3_input .s3_uri == str (tmp_path )
8480
8581 def test_processing_input_from_local_with_s3_uri_passes_through (self ):
8682 """An S3 URI should pass through without local path validation."""
@@ -100,56 +96,53 @@ def test_processing_input_from_local_with_nonexistent_path_raises_value_error(se
10096 destination = "/opt/ml/processing/input/data" ,
10197 )
10298
103- def test_processing_input_from_local_with_custom_input_name (self ):
99+ def test_processing_input_from_local_with_custom_input_name (self , tmp_path ):
104100 """Custom input_name should be set on the ProcessingInput."""
105- with tempfile .TemporaryDirectory () as tmpdir :
106- result = processing_input_from_local (
107- source = tmpdir ,
108- destination = "/opt/ml/processing/input/data" ,
109- input_name = "custom-name" ,
110- )
111- assert result .input_name == "custom-name"
101+ result = processing_input_from_local (
102+ source = str (tmp_path ),
103+ destination = "/opt/ml/processing/input/data" ,
104+ input_name = "custom-name" ,
105+ )
106+ assert result .input_name == "custom-name"
112107
113- def test_processing_input_from_local_default_parameters (self ):
108+ def test_processing_input_from_local_default_parameters (self , tmp_path ):
114109 """Default parameters should be applied correctly."""
115- with tempfile .TemporaryDirectory () as tmpdir :
116- result = processing_input_from_local (
117- source = tmpdir ,
118- destination = "/opt/ml/processing/input/data" ,
119- )
120- assert result .input_name is None
121- assert result .s3_input .s3_data_type == "S3Prefix"
122- assert result .s3_input .s3_input_mode == "File"
110+ result = processing_input_from_local (
111+ source = str (tmp_path ),
112+ destination = "/opt/ml/processing/input/data" ,
113+ )
114+ assert result .input_name is None
115+ assert result .s3_input .s3_data_type == "S3Prefix"
116+ assert result .s3_input .s3_input_mode == "File"
123117
124118 def test_processing_input_from_local_with_empty_source_raises_value_error (self ):
125119 """Empty source should raise ValueError."""
126- with pytest .raises (ValueError , match = "source must be a valid local path or S3 URI " ):
120+ with pytest .raises (ValueError , match = "source must be a non-empty string " ):
127121 processing_input_from_local (
128122 source = "" ,
129123 destination = "/opt/ml/processing/input/data" ,
130124 )
131125
132126 def test_processing_input_from_local_with_none_source_raises_value_error (self ):
133127 """None source should raise ValueError."""
134- with pytest .raises (ValueError , match = "source must be a valid local path or S3 URI " ):
128+ with pytest .raises (ValueError , match = "source must be a non-empty string " ):
135129 processing_input_from_local (
136130 source = None ,
137131 destination = "/opt/ml/processing/input/data" ,
138132 )
139133
140- def test_processing_input_from_local_with_optional_s3_params (self ):
134+ def test_processing_input_from_local_with_optional_s3_params (self , tmp_path ):
141135 """Optional S3 parameters should be passed through."""
142- with tempfile .TemporaryDirectory () as tmpdir :
143- result = processing_input_from_local (
144- source = tmpdir ,
145- destination = "/opt/ml/processing/input/data" ,
146- s3_data_distribution_type = "FullyReplicated" ,
147- s3_compression_type = "Gzip" ,
148- )
149- assert result .s3_input .s3_data_distribution_type == "FullyReplicated"
150- assert result .s3_input .s3_compression_type == "Gzip"
136+ result = processing_input_from_local (
137+ source = str (tmp_path ),
138+ destination = "/opt/ml/processing/input/data" ,
139+ s3_data_distribution_type = "FullyReplicated" ,
140+ s3_compression_type = "Gzip" ,
141+ )
142+ assert result .s3_input .s3_data_distribution_type == "FullyReplicated"
143+ assert result .s3_input .s3_compression_type == "Gzip"
151144
152- def test_processing_input_from_local_used_in_normalize_inputs (self , mock_session ):
145+ def test_processing_input_from_local_used_in_normalize_inputs (self , mock_session , tmp_path ):
153146 """ProcessingInput from processing_input_from_local should work with _normalize_inputs."""
154147 processor = Processor (
155148 role = "arn:aws:iam::123456789012:role/SageMakerRole" ,
@@ -160,18 +153,18 @@ def test_processing_input_from_local_used_in_normalize_inputs(self, mock_session
160153 )
161154 processor ._current_job_name = "test-job"
162155
163- with tempfile . TemporaryDirectory () as tmpdir :
164- inp = processing_input_from_local (
165- source = tmpdir ,
166- destination = "/opt/ml/processing/input/ data" ,
167- input_name = "local-data" ,
168- )
169- with patch (
170- "sagemaker.core.s3.S3Uploader.upload" , return_value = "s3://bucket/uploaded"
171- ):
172- result = processor ._normalize_inputs ([inp ])
173- assert len (result ) == 1
174- assert result [0 ].s3_input .s3_uri == "s3://bucket/uploaded"
156+ inp = processing_input_from_local (
157+ source = str ( tmp_path ),
158+ destination = "/opt/ml/processing/input/data" ,
159+ input_name = "local- data" ,
160+ )
161+ with patch (
162+ "sagemaker.core.processing.s3.S3Uploader.upload" ,
163+ return_value = "s3://bucket/uploaded" ,
164+ ):
165+ result = processor ._normalize_inputs ([inp ])
166+ assert len (result ) == 1
167+ assert result [0 ].s3_input .s3_uri == "s3://bucket/uploaded"
175168
176169
177170class TestNormalizeInputsLocalPathValidation :
@@ -199,7 +192,7 @@ def test_normalize_inputs_with_nonexistent_local_path_raises_value_error(self, m
199192 with pytest .raises (ValueError , match = "Input source path does not exist" ):
200193 processor ._normalize_inputs (inputs )
201194
202- def test_normalize_inputs_with_local_source_uploads_to_s3 (self , mock_session ):
195+ def test_normalize_inputs_with_local_source_uploads_to_s3 (self , mock_session , tmp_path ):
203196 """A valid local path should be uploaded to S3."""
204197 processor = Processor (
205198 role = "arn:aws:iam::123456789012:role/SageMakerRole" ,
@@ -210,25 +203,24 @@ def test_normalize_inputs_with_local_source_uploads_to_s3(self, mock_session):
210203 )
211204 processor ._current_job_name = "test-job"
212205
213- with tempfile .TemporaryDirectory () as tmpdir :
214- s3_input = ProcessingS3Input (
215- s3_uri = tmpdir ,
216- local_path = "/opt/ml/processing/input" ,
217- s3_data_type = "S3Prefix" ,
218- s3_input_mode = "File" ,
219- )
220- inputs = [ProcessingInput (input_name = "local-input" , s3_input = s3_input )]
206+ s3_input = ProcessingS3Input (
207+ s3_uri = str (tmp_path ),
208+ local_path = "/opt/ml/processing/input" ,
209+ s3_data_type = "S3Prefix" ,
210+ s3_input_mode = "File" ,
211+ )
212+ inputs = [ProcessingInput (input_name = "local-input" , s3_input = s3_input )]
221213
222- with patch (
223- "sagemaker.core.s3.S3Uploader.upload" ,
224- return_value = "s3://test-bucket/sagemaker/test-job/input/local-input" ,
225- ) as mock_upload :
226- result = processor ._normalize_inputs (inputs )
227- assert len (result ) == 1
228- assert result [0 ].s3_input .s3_uri .startswith ("s3://" )
229- mock_upload .assert_called_once ()
214+ with patch (
215+ "sagemaker.core.processing .s3.S3Uploader.upload" ,
216+ return_value = "s3://test-bucket/sagemaker/test-job/input/local-input" ,
217+ ) as mock_upload :
218+ result = processor ._normalize_inputs (inputs )
219+ assert len (result ) == 1
220+ assert result [0 ].s3_input .s3_uri .startswith ("s3://" )
221+ mock_upload .assert_called_once ()
230222
231- def test_normalize_inputs_local_path_logs_upload_info (self , mock_session ):
223+ def test_normalize_inputs_local_path_logs_upload_info (self , mock_session , tmp_path ):
232224 """Uploading a local path should log an info message."""
233225 processor = Processor (
234226 role = "arn:aws:iam::123456789012:role/SageMakerRole" ,
@@ -239,27 +231,36 @@ def test_normalize_inputs_local_path_logs_upload_info(self, mock_session):
239231 )
240232 processor ._current_job_name = "test-job"
241233
242- with tempfile .TemporaryDirectory () as tmpdir :
243- s3_input = ProcessingS3Input (
244- s3_uri = tmpdir ,
245- local_path = "/opt/ml/processing/input" ,
246- s3_data_type = "S3Prefix" ,
247- s3_input_mode = "File" ,
248- )
249- inputs = [ProcessingInput (input_name = "local-input" , s3_input = s3_input )]
234+ s3_input = ProcessingS3Input (
235+ s3_uri = str (tmp_path ),
236+ local_path = "/opt/ml/processing/input" ,
237+ s3_data_type = "S3Prefix" ,
238+ s3_input_mode = "File" ,
239+ )
240+ inputs = [ProcessingInput (input_name = "local-input" , s3_input = s3_input )]
250241
251- with patch (
252- "sagemaker.core.s3.S3Uploader.upload" ,
253- return_value = "s3://test-bucket/uploaded" ,
254- ):
255- with patch ("sagemaker.core.processing.logger" ) as mock_logger :
256- processor ._normalize_inputs (inputs )
257- mock_logger .info .assert_any_call (
258- "Uploading local input '%s' from %s to %s" ,
259- "local-input" ,
260- tmpdir ,
261- f"s3://test-bucket/sagemaker/test-job/input/local-input" ,
262- )
242+ with patch (
243+ "sagemaker.core.processing.s3.S3Uploader.upload" ,
244+ return_value = "s3://test-bucket/uploaded" ,
245+ ):
246+ with patch ("sagemaker.core.processing.logger" ) as mock_logger :
247+ processor ._normalize_inputs (inputs )
248+ # Verify the upload log message was emitted with the correct format.
249+ # The exact S3 path depends on default_bucket/prefix/job_name/input_name.
250+ mock_logger .info .assert_called ()
251+ log_calls = mock_logger .info .call_args_list
252+ upload_log_found = any (
253+ len (call .args ) >= 4
254+ and call .args [0 ] == "Uploading local input '%s' from %s to %s"
255+ and call .args [1 ] == "local-input"
256+ and call .args [2 ] == str (tmp_path )
257+ and call .args [3 ].startswith ("s3://" )
258+ for call in log_calls
259+ )
260+ assert upload_log_found , (
261+ f"Expected upload log message not found in logger.info calls: "
262+ f"{ log_calls } "
263+ )
263264
264265
265266class TestProcessorNormalizeArgs :
0 commit comments