Skip to content

Commit 3584928

Browse files
committed
fix: Invalid import (5637)
1 parent ee420cc commit 3584928

File tree

2 files changed

+62
-1
lines changed

2 files changed

+62
-1
lines changed

sagemaker-mlops/src/sagemaker/mlops/workflow/steps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def _find_dependencies_in_step_arguments(
205205
else:
206206
dependencies.add(self._get_step_name_from_str(referenced_step, step_map))
207207

208-
from sagemaker.core.workflow.function_step import DelayedReturn
208+
from sagemaker.mlops.workflow.function_step import DelayedReturn
209209

210210
# TODO: we can remove the if-elif once move the validators to JsonGet constructor
211211
if isinstance(pipeline_variable, JsonGet):

sagemaker-mlops/tests/unit/workflow/test_steps.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,38 @@ def test_step_validate_json_get_function_with_property_file():
386386
Step._validate_json_get_function(step, json_get, step_map)
387387

388388

389+
def test_delayed_return_import_from_correct_module():
390+
"""Verify that DelayedReturn can be imported from sagemaker.mlops.workflow.function_step."""
391+
from sagemaker.mlops.workflow.function_step import DelayedReturn
392+
assert DelayedReturn is not None
393+
# Verify it's a class
394+
assert isinstance(DelayedReturn, type)
395+
396+
397+
def test_find_dependencies_in_step_arguments_with_delayed_return_uses_correct_import():
398+
"""Verify _find_dependencies_in_step_arguments uses the mlops import path for DelayedReturn."""
399+
from sagemaker.mlops.workflow.steps import Step, StepTypeEnum
400+
from sagemaker.mlops.workflow.function_step import DelayedReturn, _FunctionStep
401+
from sagemaker.core.workflow.functions import JsonGet
402+
from unittest.mock import patch, Mock
403+
404+
# Create a mock function step
405+
mock_func_step = Mock(spec=_FunctionStep)
406+
mock_func_step.name = "func-step"
407+
408+
# Create a DelayedReturn that acts as a PipelineVariable
409+
delayed = Mock(spec=DelayedReturn)
410+
delayed._referenced_steps = [mock_func_step]
411+
delayed._step = mock_func_step
412+
413+
mock_json_get = Mock(spec=JsonGet)
414+
mock_json_get.property_file = None
415+
delayed._to_json_get = Mock(return_value=mock_json_get)
416+
417+
# Verify isinstance check works with the correct import
418+
assert isinstance(delayed, DelayedReturn)
419+
420+
389421
def test_step_find_dependencies_in_step_arguments_with_json_get():
390422
from unittest.mock import patch
391423
from sagemaker.mlops.workflow.steps import Step, StepTypeEnum
@@ -415,6 +447,34 @@ def test_step_find_dependencies_in_step_arguments_with_json_get():
415447

416448

417449
def test_step_find_dependencies_in_step_arguments_with_delayed_return():
450+
from unittest.mock import patch
451+
from sagemaker.mlops.workflow.steps import Step, StepTypeEnum
452+
from sagemaker.mlops.workflow.function_step import DelayedReturn
453+
from sagemaker.core.workflow.functions import JsonGet
454+
from sagemaker.core.helper.pipeline_variable import PipelineVariable
455+
456+
step1 = Mock(spec=Step)
457+
step1.name = "step1"
458+
step1.step_type = StepTypeEnum.PROCESSING
459+
step1.property_files = []
460+
step1.arguments = {}
461+
462+
json_get = Mock(spec=JsonGet)
463+
json_get.property_file = None
464+
465+
delayed_return = Mock(spec=DelayedReturn)
466+
delayed_return._referenced_steps = [step1]
467+
delayed_return._to_json_get = Mock(return_value=json_get)
468+
469+
step2 = Mock(spec=Step)
470+
step2.name = "step2"
471+
step2._validate_json_get_function = Mock()
472+
step2._get_step_name_from_str = Step._get_step_name_from_str
473+
474+
obj = {"key": delayed_return}
475+
476+
dependencies = Step._find_dependencies_in_step_arguments(step2, obj, {"step1": step1})
477+
assert "step1" in dependencies
418478
from unittest.mock import patch
419479
from sagemaker.mlops.workflow.steps import Step, StepTypeEnum
420480
from sagemaker.core.workflow.functions import JsonGet
@@ -450,6 +510,7 @@ def test_step_find_dependencies_in_step_arguments_with_delayed_return():
450510
assert "step1" in dependencies
451511

452512

513+
453514
def test_step_find_dependencies_in_step_arguments_with_string_reference():
454515
from unittest.mock import patch
455516
from sagemaker.mlops.workflow.steps import Step

0 commit comments

Comments
 (0)