|
17 | 17 | from pathlib import Path |
18 | 18 | from unittest.mock import Mock, patch, MagicMock |
19 | 19 | from sagemaker.core.workflow.utilities import ( |
| 20 | + list_to_request, |
| 21 | + hash_file, |
| 22 | + hash_files_or_dirs, |
| 23 | + hash_object, |
| 24 | + get_processing_dependencies, |
| 25 | + get_processing_code_hash, |
| 26 | + get_training_code_hash, |
| 27 | + get_code_hash, |
| 28 | + validate_step_args_input, |
| 29 | + override_pipeline_parameter_var, |
| 30 | + trim_request_dict, |
| 31 | + _collect_parameters, |
20 | 32 | list_to_request, |
21 | 33 | hash_file, |
22 | 34 | hash_files_or_dirs, |
|
31 | 43 | ) |
32 | 44 | from sagemaker.core.workflow.entities import Entity |
33 | 45 | from sagemaker.core.workflow.parameters import Parameter |
34 | | -from sagemaker.core.workflow.pipeline_context import _StepArguments |
| 46 | +from sagemaker.core.workflow.pipeline_context import _StepArguments, _JobStepArguments |
35 | 47 |
|
36 | 48 |
|
37 | 49 | class MockEntity(Entity): |
@@ -430,6 +442,170 @@ def __init__(self, param1, param2, param3=None): |
430 | 442 | obj = TestClass("value1", "value2", param3="value3") |
431 | 443 |
|
432 | 444 | assert obj.param1 == "value1" |
| 445 | + |
| 446 | + |
| 447 | + def test_get_training_code_hash_source_dir_with_none_dependencies(self): |
| 448 | + """Test get_training_code_hash with source_dir and dependencies=None does not raise TypeError. |
| 449 | +
|
| 450 | + Regression test for https://github.com/aws/sagemaker-python-sdk/issues/5181 |
| 451 | + """ |
| 452 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 453 | + entry_file = Path(temp_dir, "train.py") |
| 454 | + entry_file.write_text("print('training')") |
| 455 | + |
| 456 | + result = get_training_code_hash( |
| 457 | + entry_point=str(entry_file), source_dir=temp_dir, dependencies=None |
| 458 | + ) |
| 459 | + |
| 460 | + assert result is not None |
| 461 | + assert len(result) == 64 |
| 462 | + |
| 463 | + def test_get_training_code_hash_entry_point_with_none_dependencies(self): |
| 464 | + """Test get_training_code_hash with entry_point only and dependencies=None does not raise TypeError. |
| 465 | +
|
| 466 | + Regression test for https://github.com/aws/sagemaker-python-sdk/issues/5181 |
| 467 | + """ |
| 468 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 469 | + entry_file = Path(temp_dir, "train.py") |
| 470 | + entry_file.write_text("print('training')") |
| 471 | + |
| 472 | + result = get_training_code_hash( |
| 473 | + entry_point=str(entry_file), source_dir=None, dependencies=None |
| 474 | + ) |
| 475 | + |
| 476 | + assert result is not None |
| 477 | + assert len(result) == 64 |
| 478 | + |
| 479 | + def test_get_code_hash_training_step_with_source_code_no_requirements(self): |
| 480 | + """Test get_code_hash with TrainingStep where SourceCode has requirements=None. |
| 481 | +
|
| 482 | + Regression test for https://github.com/aws/sagemaker-python-sdk/issues/5181 |
| 483 | + When SourceCode.requirements is None (the default), get_code_hash should not |
| 484 | + raise a TypeError. |
| 485 | + """ |
| 486 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 487 | + entry_file = Path(temp_dir, "train.py") |
| 488 | + entry_file.write_text("print('training')") |
| 489 | + |
| 490 | + # Create a mock source_code with requirements=None (the default) |
| 491 | + source_code = Mock() |
| 492 | + source_code.source_dir = temp_dir |
| 493 | + source_code.requirements = None |
| 494 | + source_code.entry_script = str(entry_file) |
| 495 | + |
| 496 | + # Create a mock model_trainer |
| 497 | + model_trainer = Mock() |
| 498 | + model_trainer.source_code = source_code |
| 499 | + |
| 500 | + # Create a mock TrainingStep |
| 501 | + step = Mock() |
| 502 | + step.__class__ = type('TrainingStep', (), {}) |
| 503 | + step.step_args = Mock() |
| 504 | + step.step_args.func_args = [model_trainer] |
| 505 | + |
| 506 | + # Patch isinstance to recognize our mock as a TrainingStep |
| 507 | + with patch('sagemaker.core.workflow.utilities.isinstance') as mock_isinstance: |
| 508 | + def side_effect(obj, cls): |
| 509 | + from sagemaker.mlops.workflow.steps import TrainingStep, ProcessingStep |
| 510 | + if cls == ProcessingStep or cls == (ProcessingStep,): |
| 511 | + return False |
| 512 | + if cls == TrainingStep or cls == (TrainingStep,): |
| 513 | + return obj is step |
| 514 | + return builtins_isinstance(obj, cls) |
| 515 | + |
| 516 | + import builtins |
| 517 | + builtins_isinstance = builtins.isinstance |
| 518 | + mock_isinstance.side_effect = side_effect |
| 519 | + |
| 520 | + # This should not raise TypeError |
| 521 | + result = get_code_hash(step) |
| 522 | + |
| 523 | + # Verify we get a valid hash |
| 524 | + assert result is not None |
| 525 | + assert len(result) == 64 |
| 526 | + |
| 527 | + def test_get_code_hash_training_step_with_source_code_with_requirements(self): |
| 528 | + """Test get_code_hash with TrainingStep where SourceCode has a requirements file.""" |
| 529 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 530 | + entry_file = Path(temp_dir, "train.py") |
| 531 | + entry_file.write_text("print('training')") |
| 532 | + req_file = Path(temp_dir, "requirements.txt") |
| 533 | + req_file.write_text("numpy==1.21.0") |
| 534 | + |
| 535 | + # Create a mock source_code with requirements set |
| 536 | + source_code = Mock() |
| 537 | + source_code.source_dir = temp_dir |
| 538 | + source_code.requirements = str(req_file) |
| 539 | + source_code.entry_script = str(entry_file) |
| 540 | + |
| 541 | + # Create a mock model_trainer |
| 542 | + model_trainer = Mock() |
| 543 | + model_trainer.source_code = source_code |
| 544 | + |
| 545 | + # Create a mock TrainingStep |
| 546 | + step = Mock() |
| 547 | + step.step_args = Mock() |
| 548 | + step.step_args.func_args = [model_trainer] |
| 549 | + |
| 550 | + with patch('sagemaker.core.workflow.utilities.isinstance') as mock_isinstance: |
| 551 | + def side_effect(obj, cls): |
| 552 | + from sagemaker.mlops.workflow.steps import TrainingStep, ProcessingStep |
| 553 | + if cls == ProcessingStep or cls == (ProcessingStep,): |
| 554 | + return False |
| 555 | + if cls == TrainingStep or cls == (TrainingStep,): |
| 556 | + return obj is step |
| 557 | + return builtins_isinstance(obj, cls) |
| 558 | + |
| 559 | + import builtins |
| 560 | + builtins_isinstance = builtins.isinstance |
| 561 | + mock_isinstance.side_effect = side_effect |
| 562 | + |
| 563 | + result = get_code_hash(step) |
| 564 | + |
| 565 | + assert result is not None |
| 566 | + assert len(result) == 64 |
| 567 | + |
| 568 | + def test_get_code_hash_training_step_entry_script_only_no_requirements(self): |
| 569 | + """Test get_code_hash with TrainingStep where SourceCode has entry_script but no source_dir or requirements. |
| 570 | +
|
| 571 | + Regression test for https://github.com/aws/sagemaker-python-sdk/issues/5181 |
| 572 | + """ |
| 573 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 574 | + entry_file = Path(temp_dir, "train.py") |
| 575 | + entry_file.write_text("print('training')") |
| 576 | + |
| 577 | + # Create a mock source_code with only entry_script |
| 578 | + source_code = Mock() |
| 579 | + source_code.source_dir = None |
| 580 | + source_code.requirements = None |
| 581 | + source_code.entry_script = str(entry_file) |
| 582 | + |
| 583 | + # Create a mock model_trainer |
| 584 | + model_trainer = Mock() |
| 585 | + model_trainer.source_code = source_code |
| 586 | + |
| 587 | + # Create a mock TrainingStep |
| 588 | + step = Mock() |
| 589 | + step.step_args = Mock() |
| 590 | + step.step_args.func_args = [model_trainer] |
| 591 | + |
| 592 | + with patch('sagemaker.core.workflow.utilities.isinstance') as mock_isinstance: |
| 593 | + def side_effect(obj, cls): |
| 594 | + from sagemaker.mlops.workflow.steps import TrainingStep, ProcessingStep |
| 595 | + if cls == ProcessingStep or cls == (ProcessingStep,): |
| 596 | + return False |
| 597 | + if cls == TrainingStep or cls == (TrainingStep,): |
| 598 | + return obj is step |
| 599 | + return builtins_isinstance(obj, cls) |
| 600 | + |
| 601 | + import builtins |
| 602 | + builtins_isinstance = builtins.isinstance |
| 603 | + mock_isinstance.side_effect = side_effect |
| 604 | + |
| 605 | + result = get_code_hash(step) |
| 606 | + |
| 607 | + assert result is not None |
| 608 | + assert len(result) == 64 |
433 | 609 | assert obj.param2 == "value2" |
434 | 610 | assert obj.param3 == "value3" |
435 | 611 |
|
|
0 commit comments