|
17 | 17 | from pathlib import Path |
18 | 18 | from unittest.mock import Mock, patch, MagicMock |
19 | 19 | from sagemaker.core.workflow.utilities import ( |
20 | | - get_code_hash, |
21 | 20 | list_to_request, |
22 | 21 | hash_file, |
23 | 22 | hash_files_or_dirs, |
@@ -226,77 +225,50 @@ def test_get_processing_code_hash_with_source_dir(self): |
226 | 225 | ) |
227 | 226 |
|
228 | 227 | assert result is not None |
229 | | - |
230 | | - def test_get_processing_code_hash_with_source_dir_and_none_dependencies(self): |
231 | | - """Test get_processing_code_hash with source_dir and dependencies=None""" |
232 | | - with tempfile.TemporaryDirectory() as temp_dir: |
233 | | - code_file = Path(temp_dir, "script.py") |
234 | | - code_file.write_text("print('hello')") |
235 | | - |
236 | | - result = get_processing_code_hash( |
237 | | - code=str(code_file), source_dir=temp_dir, dependencies=None |
238 | | - ) |
239 | | - |
240 | | - assert result is not None |
241 | 228 | assert len(result) == 64 |
242 | 229 |
|
243 | | - def test_get_processing_code_hash_with_code_only_and_none_dependencies(self): |
244 | | - """Test get_processing_code_hash with code only and dependencies=None""" |
| 230 | + def test_get_processing_code_hash_code_only(self): |
| 231 | + """Test get_processing_code_hash with code only""" |
245 | 232 | with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: |
246 | 233 | f.write("print('hello')") |
247 | 234 | temp_file = f.name |
248 | 235 |
|
249 | 236 | try: |
250 | | - result = get_processing_code_hash(code=temp_file, source_dir=None, dependencies=None) |
| 237 | + result = get_processing_code_hash(code=temp_file, source_dir=None, dependencies=[]) |
251 | 238 |
|
252 | 239 | assert result is not None |
253 | 240 | assert len(result) == 64 |
254 | 241 | finally: |
255 | 242 | os.unlink(temp_file) |
256 | 243 |
|
257 | | - @pytest.mark.skip(reason="Requires sagemaker-mlops module which is not installed in sagemaker-core tests") |
258 | | - def test_get_code_hash_with_training_step_none_requirements(self): |
259 | | - """Test get_code_hash with a TrainingStep whose source_code.requirements is None""" |
260 | | - from sagemaker.mlops.workflow.steps import TrainingStep |
261 | | - |
| 244 | + def test_get_processing_code_hash_with_none_dependencies_and_source_dir(self): |
| 245 | + """Test get_processing_code_hash with None dependencies and source_dir""" |
262 | 246 | with tempfile.TemporaryDirectory() as temp_dir: |
263 | | - entry_file = Path(temp_dir, "train.py") |
264 | | - entry_file.write_text("print('training')") |
265 | | - |
266 | | - mock_source_code = Mock() |
267 | | - mock_source_code.source_dir = temp_dir |
268 | | - mock_source_code.requirements = None |
269 | | - mock_source_code.entry_script = str(entry_file) |
270 | | - |
271 | | - mock_model_trainer = Mock() |
272 | | - mock_model_trainer.source_code = mock_source_code |
273 | | - |
274 | | - mock_step_args = Mock() |
275 | | - mock_step_args.func_args = [mock_model_trainer] |
276 | | - |
277 | | - step = Mock(spec=TrainingStep) |
278 | | - step.step_args = mock_step_args |
| 247 | + code_file = Path(temp_dir, "script.py") |
| 248 | + code_file.write_text("print('hello')") |
279 | 249 |
|
280 | | - result = get_code_hash(step) |
| 250 | + result = get_processing_code_hash( |
| 251 | + code=str(code_file), source_dir=temp_dir, dependencies=None |
| 252 | + ) |
281 | 253 |
|
282 | 254 | assert result is not None |
283 | 255 | assert len(result) == 64 |
284 | | - assert len(result) == 64 |
285 | 256 |
|
286 | | - def test_get_processing_code_hash_code_only(self): |
287 | | - """Test get_processing_code_hash with code only""" |
| 257 | + def test_get_processing_code_hash_with_none_dependencies_and_code_only(self): |
| 258 | + """Test get_processing_code_hash with None dependencies and code only""" |
288 | 259 | with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: |
289 | 260 | f.write("print('hello')") |
290 | 261 | temp_file = f.name |
291 | 262 |
|
292 | 263 | try: |
293 | | - result = get_processing_code_hash(code=temp_file, source_dir=None, dependencies=[]) |
| 264 | + result = get_processing_code_hash(code=temp_file, source_dir=None, dependencies=None) |
294 | 265 |
|
295 | 266 | assert result is not None |
296 | 267 | assert len(result) == 64 |
297 | 268 | finally: |
298 | 269 | os.unlink(temp_file) |
299 | 270 |
|
| 271 | + |
300 | 272 | def test_get_processing_code_hash_s3_uri(self): |
301 | 273 | """Test get_processing_code_hash with S3 URI returns None""" |
302 | 274 | result = get_processing_code_hash( |
@@ -364,6 +336,47 @@ def test_get_training_code_hash_entry_point_only(self): |
364 | 336 | assert len(result_with_deps) == 64 |
365 | 337 | assert result_no_deps != result_with_deps |
366 | 338 |
|
| 339 | + def test_get_code_hash_training_step_with_none_requirements(self): |
| 340 | + """Test get_code_hash with TrainingStep whose source_code has requirements=None""" |
| 341 | + from sagemaker.core.workflow.utilities import get_code_hash |
| 342 | + |
| 343 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 344 | + entry_file = Path(temp_dir, "train.py") |
| 345 | + entry_file.write_text("print('training')") |
| 346 | + |
| 347 | + mock_source_code = Mock() |
| 348 | + mock_source_code.source_dir = temp_dir |
| 349 | + mock_source_code.requirements = None |
| 350 | + mock_source_code.entry_script = str(entry_file) |
| 351 | + |
| 352 | + mock_model_trainer = Mock() |
| 353 | + mock_model_trainer.source_code = mock_source_code |
| 354 | + |
| 355 | + mock_step_args = Mock() |
| 356 | + mock_step_args.func_args = [mock_model_trainer] |
| 357 | + |
| 358 | + mock_step = Mock() |
| 359 | + mock_step.step_args = mock_step_args |
| 360 | + |
| 361 | + with patch("sagemaker.core.workflow.utilities.isinstance") as mock_isinstance: |
| 362 | + def isinstance_side_effect(obj, cls): |
| 363 | + from sagemaker.mlops.workflow.steps import TrainingStep, ProcessingStep |
| 364 | + if cls is ProcessingStep: |
| 365 | + return False |
| 366 | + if cls is TrainingStep: |
| 367 | + return obj is mock_step |
| 368 | + return builtins_isinstance(obj, cls) |
| 369 | + |
| 370 | + import builtins |
| 371 | + builtins_isinstance = builtins.isinstance |
| 372 | + mock_isinstance.side_effect = isinstance_side_effect |
| 373 | + |
| 374 | + result = get_code_hash(mock_step) |
| 375 | + |
| 376 | + assert result is not None |
| 377 | + assert len(result) == 64 |
| 378 | + |
| 379 | + |
367 | 380 | def test_get_training_code_hash_s3_uri(self): |
368 | 381 | """Test get_training_code_hash with S3 URI returns None""" |
369 | 382 | result = get_training_code_hash( |
|
0 commit comments