Skip to content

Commit 9ce5304

Browse files
committed
fix: address review comments (iteration #1)
1 parent 3584928 commit 9ce5304

File tree

1 file changed

+40
-82
lines changed

1 file changed

+40
-82
lines changed

sagemaker-mlops/tests/unit/workflow/test_steps.py

Lines changed: 40 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def test_step_find_dependencies_in_depends_on_list_with_string():
252252

253253
def test_step_validate_json_get_property_file_reference_invalid_step_type():
254254
from sagemaker.mlops.workflow.steps import Step, StepTypeEnum
255-
from sagemaker.core.workflow.functions import JsonGet
255+
256256

257257
step = Mock(spec=Step)
258258
step.name = "current-step"
@@ -386,36 +386,37 @@ def test_step_validate_json_get_function_with_property_file():
386386
Step._validate_json_get_function(step, json_get, step_map)
387387

388388

389-
def test_delayed_return_import_from_correct_module():
390-
"""Verify that DelayedReturn can be imported from sagemaker.mlops.workflow.function_step."""
391-
from sagemaker.mlops.workflow.function_step import DelayedReturn
392-
assert DelayedReturn is not None
393-
# Verify it's a class
394-
assert isinstance(DelayedReturn, type)
395389

396390

397-
def test_find_dependencies_in_step_arguments_with_delayed_return_uses_correct_import():
398-
"""Verify _find_dependencies_in_step_arguments uses the mlops import path for DelayedReturn."""
391+
392+
393+
394+
395+
396+
397+
398+
399399
from sagemaker.mlops.workflow.steps import Step, StepTypeEnum
400-
from sagemaker.mlops.workflow.function_step import DelayedReturn, _FunctionStep
400+
401401
from sagemaker.core.workflow.functions import JsonGet
402-
from unittest.mock import patch, Mock
403402

404-
# Create a mock function step
405-
mock_func_step = Mock(spec=_FunctionStep)
406-
mock_func_step.name = "func-step"
407403

408-
# Create a DelayedReturn that acts as a PipelineVariable
409-
delayed = Mock(spec=DelayedReturn)
410-
delayed._referenced_steps = [mock_func_step]
411-
delayed._step = mock_func_step
412404

413-
mock_json_get = Mock(spec=JsonGet)
414-
mock_json_get.property_file = None
415-
delayed._to_json_get = Mock(return_value=mock_json_get)
416405

417-
# Verify isinstance check works with the correct import
418-
assert isinstance(delayed, DelayedReturn)
406+
407+
408+
409+
410+
411+
412+
413+
414+
415+
416+
417+
418+
419+
419420

420421

421422
def test_step_find_dependencies_in_step_arguments_with_json_get():
@@ -440,103 +441,60 @@ def test_step_find_dependencies_in_step_arguments_with_json_get():
440441

441442
obj = {"key": json_get}
442443

443-
with patch('sagemaker.mlops.workflow.steps.TYPE_CHECKING', False):
444-
with patch.dict('sys.modules', {'sagemaker.core.workflow.function_step': Mock()}):
445-
dependencies = Step._find_dependencies_in_step_arguments(step2, obj, {"step1": step1})
446-
assert "step1" in dependencies
444+
dependencies = Step._find_dependencies_in_step_arguments(step2, obj, {"step1": step1})
445+
assert "step1" in dependencies
447446

448447

449448
def test_step_find_dependencies_in_step_arguments_with_delayed_return():
450-
from unittest.mock import patch
451449
from sagemaker.mlops.workflow.steps import Step, StepTypeEnum
452450
from sagemaker.mlops.workflow.function_step import DelayedReturn
453451
from sagemaker.core.workflow.functions import JsonGet
454-
from sagemaker.core.helper.pipeline_variable import PipelineVariable
455-
452+
456453
step1 = Mock(spec=Step)
457454
step1.name = "step1"
458455
step1.step_type = StepTypeEnum.PROCESSING
459456
step1.property_files = []
460457
step1.arguments = {}
461-
458+
462459
json_get = Mock(spec=JsonGet)
463460
json_get.property_file = None
464-
461+
465462
delayed_return = Mock(spec=DelayedReturn)
466463
delayed_return._referenced_steps = [step1]
467464
delayed_return._to_json_get = Mock(return_value=json_get)
468-
465+
469466
step2 = Mock(spec=Step)
470467
step2.name = "step2"
471468
step2._validate_json_get_function = Mock()
472469
step2._get_step_name_from_str = Step._get_step_name_from_str
473-
470+
474471
obj = {"key": delayed_return}
475-
472+
476473
dependencies = Step._find_dependencies_in_step_arguments(step2, obj, {"step1": step1})
477474
assert "step1" in dependencies
478-
from unittest.mock import patch
479-
from sagemaker.mlops.workflow.steps import Step, StepTypeEnum
480-
from sagemaker.core.workflow.functions import JsonGet
481-
from sagemaker.core.helper.pipeline_variable import PipelineVariable
482-
483-
step1 = Mock(spec=Step)
484-
step1.name = "step1"
485-
step1.step_type = StepTypeEnum.PROCESSING
486-
step1.property_files = []
487-
step1.arguments = {}
488-
489-
json_get = Mock(spec=JsonGet)
490-
json_get.property_file = None
491-
492-
delayed_return_class = type('DelayedReturn', (PipelineVariable,), {})
493-
delayed_return = Mock(spec=delayed_return_class)
494-
delayed_return._referenced_steps = [step1]
495-
delayed_return._to_json_get = Mock(return_value=json_get)
496-
delayed_return.__class__ = delayed_return_class
497-
498-
step2 = Mock(spec=Step)
499-
step2.name = "step2"
500-
step2._validate_json_get_function = Mock()
501-
step2._get_step_name_from_str = Step._get_step_name_from_str
502-
503-
obj = {"key": delayed_return}
504-
505-
mock_module = Mock()
506-
mock_module.DelayedReturn = delayed_return_class
507-
508-
with patch.dict('sys.modules', {'sagemaker.core.workflow.function_step': mock_module}):
509-
dependencies = Step._find_dependencies_in_step_arguments(step2, obj, {"step1": step1})
510-
assert "step1" in dependencies
511475

512476

513477

514478
def test_step_find_dependencies_in_step_arguments_with_string_reference():
515-
from unittest.mock import patch
516479
from sagemaker.mlops.workflow.steps import Step
517480
from sagemaker.core.helper.pipeline_variable import PipelineVariable
518-
481+
519482
step1 = Mock(spec=Step)
520483
step1.name = "step1"
521-
484+
522485
pipeline_var = Mock(spec=PipelineVariable)
523486
pipeline_var._referenced_steps = ["step1"]
524-
487+
525488
step2 = Mock(spec=Step)
526489
step2.name = "step2"
527490
step2._get_step_name_from_str = Step._get_step_name_from_str
528-
491+
529492
obj = {"key": pipeline_var}
530-
493+
531494
step_map = {"step1": step1}
532-
533-
delayed_return_class = type('DelayedReturn', (PipelineVariable,), {})
534-
mock_module = Mock()
535-
mock_module.DelayedReturn = delayed_return_class
536-
537-
with patch.dict('sys.modules', {'sagemaker.core.workflow.function_step': mock_module}):
538-
dependencies = Step._find_dependencies_in_step_arguments(step2, obj, step_map)
539-
assert "step1" in dependencies
495+
496+
dependencies = Step._find_dependencies_in_step_arguments(step2, obj, step_map)
497+
assert "step1" in dependencies
540498

541499

542500
def test_tuning_step_requires_step_args():

0 commit comments

Comments
 (0)