Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sagemaker-core/src/sagemaker/core/workflow/utilities.py
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! Can you omit the change to requirements (source_code.requirements or None)? Our team is okay not having this change

Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def get_code_hash(step: Entity) -> str:
source_code = model_trainer.source_code
if source_code:
source_dir = source_code.source_dir
requirements = source_code.requirements
requirements = source_code.requirements or None
entry_point = source_code.entry_script
return get_training_code_hash(entry_point, source_dir, requirements)
return None
Expand Down Expand Up @@ -248,6 +248,7 @@ def get_training_code_hash(
from sagemaker.core.workflow import is_pipeline_variable

if not is_pipeline_variable(source_dir) and not is_pipeline_variable(entry_point):
dependencies = dependencies or None
if source_dir:
source_dir_url = urlparse(source_dir)
if source_dir_url.scheme == "" or source_dir_url.scheme == "file":
Expand Down
99 changes: 99 additions & 0 deletions sagemaker-core/tests/unit/workflow/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,3 +487,102 @@ def reinit(self, param1):
obj.reinit("new_value")

assert obj.param1 == "new_value"

def test_get_training_code_hash_with_source_dir_and_none_dependencies(self):
"""Test get_training_code_hash returns valid hash when source_dir is set and dependencies is None"""
with tempfile.TemporaryDirectory() as temp_dir:
entry_file = Path(temp_dir, "train.py")
entry_file.write_text("print('training')")

result = get_training_code_hash(
entry_point=str(entry_file), source_dir=temp_dir, dependencies=None
)

assert result is not None
assert len(result) == 64

def test_get_training_code_hash_with_entry_point_and_none_dependencies(self):
"""Test get_training_code_hash returns valid hash when entry_point is set and dependencies is None"""
with tempfile.TemporaryDirectory() as temp_dir:
entry_file = Path(temp_dir, "train.py")
entry_file.write_text("print('training')")

result = get_training_code_hash(
entry_point=str(entry_file), source_dir=None, dependencies=None
)

assert result is not None
assert len(result) == 64

def test_get_training_code_hash_with_source_dir_and_empty_string_dependencies(self):
"""Test get_training_code_hash returns valid hash when dependencies is empty string (falsy)"""
with tempfile.TemporaryDirectory() as temp_dir:
entry_file = Path(temp_dir, "train.py")
entry_file.write_text("print('training')")

# Empty string is falsy and should be treated as no dependencies
result = get_training_code_hash(
entry_point=str(entry_file), source_dir=temp_dir, dependencies=""
)

assert result is not None
assert len(result) == 64

@pytest.mark.skip(reason="Requires sagemaker-mlops module which is not installed in sagemaker-core tests")
def test_get_code_hash_with_training_step_no_requirements(self):
"""Test get_code_hash works when SourceCode.requirements is None"""
from sagemaker.core.workflow.utilities import get_code_hash

with tempfile.TemporaryDirectory() as temp_dir:
entry_file = Path(temp_dir, "train.py")
entry_file.write_text("print('training')")

mock_source_code = Mock()
mock_source_code.source_dir = temp_dir
mock_source_code.requirements = None
mock_source_code.entry_script = str(entry_file)

mock_model_trainer = Mock()
mock_model_trainer.source_code = mock_source_code

mock_step_args = Mock()
mock_step_args.func_args = [mock_model_trainer]

mock_step = Mock()
mock_step.step_args = mock_step_args

# Patch isinstance to make our mock look like a TrainingStep
from sagemaker.mlops.workflow.steps import TrainingStep
mock_step.__class__ = TrainingStep

result = get_code_hash(mock_step)
assert result is not None

@pytest.mark.skip(reason="Requires sagemaker-mlops module which is not installed in sagemaker-core tests")
def test_get_code_hash_with_training_step_empty_requirements(self):
"""Test get_code_hash works when SourceCode.requirements is empty string"""
from sagemaker.core.workflow.utilities import get_code_hash

with tempfile.TemporaryDirectory() as temp_dir:
entry_file = Path(temp_dir, "train.py")
entry_file.write_text("print('training')")

mock_source_code = Mock()
mock_source_code.source_dir = temp_dir
mock_source_code.requirements = ""
mock_source_code.entry_script = str(entry_file)

mock_model_trainer = Mock()
mock_model_trainer.source_code = mock_source_code

mock_step_args = Mock()
mock_step_args.func_args = [mock_model_trainer]

mock_step = Mock()
mock_step.step_args = mock_step_args

from sagemaker.mlops.workflow.steps import TrainingStep
mock_step.__class__ = TrainingStep

result = get_code_hash(mock_step)
assert result is not None
Loading