|
17 | 17 | from pathlib import Path |
18 | 18 | from unittest.mock import Mock, patch, MagicMock |
19 | 19 | from sagemaker.core.workflow.utilities import ( |
| 20 | + get_code_hash, |
20 | 21 | list_to_request, |
21 | 22 | hash_file, |
22 | 23 | hash_files_or_dirs, |
@@ -225,86 +226,99 @@ def test_get_processing_code_hash_with_source_dir(self): |
225 | 226 | ) |
226 | 227 |
|
227 | 228 | assert result is not None |
| 229 | + |
| 230 | + def test_get_processing_code_hash_with_source_dir_and_none_dependencies(self): |
| 231 | + """Test get_processing_code_hash with source_dir and dependencies=None""" |
| 232 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 233 | + code_file = Path(temp_dir, "script.py") |
| 234 | + code_file.write_text("print('hello')") |
| 235 | + |
| 236 | + result = get_processing_code_hash( |
| 237 | + code=str(code_file), source_dir=temp_dir, dependencies=None |
| 238 | + ) |
| 239 | + |
| 240 | + assert result is not None |
228 | 241 | assert len(result) == 64 |
229 | 242 |
|
230 | | - def test_get_processing_code_hash_code_only(self): |
231 | | - """Test get_processing_code_hash with code only""" |
| 243 | + def test_get_processing_code_hash_with_code_only_and_none_dependencies(self): |
| 244 | + """Test get_processing_code_hash with code only and dependencies=None""" |
232 | 245 | with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: |
233 | 246 | f.write("print('hello')") |
234 | 247 | temp_file = f.name |
235 | 248 |
|
236 | 249 | try: |
237 | | - result = get_processing_code_hash(code=temp_file, source_dir=None, dependencies=[]) |
| 250 | + result = get_processing_code_hash(code=temp_file, source_dir=None, dependencies=None) |
238 | 251 |
|
239 | 252 | assert result is not None |
240 | 253 | assert len(result) == 64 |
241 | 254 | finally: |
242 | 255 | os.unlink(temp_file) |
243 | 256 |
|
244 | | - def test_get_processing_code_hash_s3_uri(self): |
245 | | - """Test get_processing_code_hash with S3 URI returns None""" |
246 | | - result = get_processing_code_hash( |
247 | | - code="s3://bucket/script.py", source_dir=None, dependencies=[] |
248 | | - ) |
249 | | - |
250 | | - assert result is None |
| 257 | + @pytest.mark.skip(reason="Requires sagemaker-mlops module which is not installed in sagemaker-core tests") |
| 258 | + def test_get_code_hash_with_training_step_none_requirements(self): |
| 259 | + """Test get_code_hash with a TrainingStep whose source_code.requirements is None""" |
| 260 | + from sagemaker.mlops.workflow.steps import TrainingStep |
251 | 261 |
|
252 | | - def test_get_processing_code_hash_with_dependencies(self): |
253 | | - """Test get_processing_code_hash with dependencies""" |
254 | 262 | with tempfile.TemporaryDirectory() as temp_dir: |
255 | | - code_file = Path(temp_dir, "script.py") |
256 | | - code_file.write_text("print('hello')") |
| 263 | + entry_file = Path(temp_dir, "train.py") |
| 264 | + entry_file.write_text("print('training')") |
257 | 265 |
|
258 | | - dep_file = Path(temp_dir, "utils.py") |
259 | | - dep_file.write_text("def helper(): pass") |
| 266 | + mock_source_code = Mock() |
| 267 | + mock_source_code.source_dir = temp_dir |
| 268 | + mock_source_code.requirements = None |
| 269 | + mock_source_code.entry_script = str(entry_file) |
260 | 270 |
|
261 | | - result = get_processing_code_hash( |
262 | | - code=str(code_file), source_dir=temp_dir, dependencies=[str(dep_file)] |
263 | | - ) |
| 271 | + mock_model_trainer = Mock() |
| 272 | + mock_model_trainer.source_code = mock_source_code |
| 273 | + |
| 274 | + mock_step_args = Mock() |
| 275 | + mock_step_args.func_args = [mock_model_trainer] |
| 276 | + |
| 277 | + step = Mock(spec=TrainingStep) |
| 278 | + step.step_args = mock_step_args |
| 279 | + |
| 280 | + result = get_code_hash(step) |
264 | 281 |
|
265 | 282 | assert result is not None |
| 283 | + assert len(result) == 64 |
| 284 | + assert len(result) == 64 |
266 | 285 |
|
267 | | - def test_get_processing_code_hash_with_none_dependencies(self): |
268 | | - """Test get_processing_code_hash with None dependencies does not raise TypeError""" |
| 286 | + def test_get_processing_code_hash_code_only(self): |
| 287 | + """Test get_processing_code_hash with code only""" |
269 | 288 | with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: |
270 | 289 | f.write("print('hello')") |
271 | 290 | temp_file = f.name |
272 | 291 |
|
273 | 292 | try: |
274 | | - result = get_processing_code_hash(code=temp_file, source_dir=None, dependencies=None) |
| 293 | + result = get_processing_code_hash(code=temp_file, source_dir=None, dependencies=[]) |
275 | 294 |
|
276 | 295 | assert result is not None |
277 | 296 | assert len(result) == 64 |
278 | 297 | finally: |
279 | 298 | os.unlink(temp_file) |
280 | 299 |
|
281 | | - def test_get_processing_code_hash_with_source_dir_and_none_dependencies(self): |
282 | | - """Test get_processing_code_hash with source_dir and None dependencies""" |
| 300 | + def test_get_processing_code_hash_s3_uri(self): |
| 301 | + """Test get_processing_code_hash with S3 URI returns None""" |
| 302 | + result = get_processing_code_hash( |
| 303 | + code="s3://bucket/script.py", source_dir=None, dependencies=[] |
| 304 | + ) |
| 305 | + |
| 306 | + assert result is None |
| 307 | + |
| 308 | + def test_get_processing_code_hash_with_dependencies(self): |
| 309 | + """Test get_processing_code_hash with dependencies""" |
283 | 310 | with tempfile.TemporaryDirectory() as temp_dir: |
284 | 311 | code_file = Path(temp_dir, "script.py") |
285 | 312 | code_file.write_text("print('hello')") |
286 | 313 |
|
| 314 | + dep_file = Path(temp_dir, "utils.py") |
| 315 | + dep_file.write_text("def helper(): pass") |
| 316 | + |
287 | 317 | result = get_processing_code_hash( |
288 | | - code=str(code_file), source_dir=temp_dir, dependencies=None |
| 318 | + code=str(code_file), source_dir=temp_dir, dependencies=[str(dep_file)] |
289 | 319 | ) |
290 | 320 |
|
291 | 321 | assert result is not None |
292 | | - assert len(result) == 64 |
293 | | - |
294 | | - def test_get_processing_code_hash_with_code_only_and_none_dependencies(self): |
295 | | - """Test get_processing_code_hash with code only and None dependencies""" |
296 | | - with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: |
297 | | - f.write("print('hello')") |
298 | | - temp_file = f.name |
299 | | - |
300 | | - try: |
301 | | - # Ensure no TypeError when dependencies is None |
302 | | - result = get_processing_code_hash(code=temp_file, source_dir=None, dependencies=None) |
303 | | - |
304 | | - assert result is not None |
305 | | - assert len(result) == 64 |
306 | | - finally: |
307 | | - os.unlink(temp_file) |
308 | 322 |
|
309 | 323 | def test_get_training_code_hash_with_source_dir(self): |
310 | 324 | """Test get_training_code_hash with source_dir""" |
|
0 commit comments