Skip to content

Commit 88142fc

Browse files
committed
fix: address review comments (iteration #2)
1 parent 2e31aa5 commit 88142fc

File tree

2 files changed

+57
-45
lines changed

2 files changed

+57
-45
lines changed

sagemaker-core/src/sagemaker/core/workflow/utilities.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,6 @@ def get_code_hash(step: Entity) -> str:
149149
Returns:
150150
str: A hash string representing the unique code artifact(s) for the step
151151
"""
152-
153-
dependencies = dependencies or []
154152
from sagemaker.mlops.workflow.steps import ProcessingStep, TrainingStep
155153

156154
if isinstance(step, ProcessingStep) and step.step_args:
@@ -175,7 +173,7 @@ def get_code_hash(step: Entity) -> str:
175173
source_code = model_trainer.source_code
176174
if source_code:
177175
source_dir = source_code.source_dir
178-
requirements = source_code.requirements
176+
requirements = source_code.requirements or []
179177
entry_point = source_code.entry_script
180178
return get_training_code_hash(entry_point, source_dir, requirements)
181179
return None
@@ -211,6 +209,7 @@ def get_processing_code_hash(code: str, source_dir: str, dependencies: List[str]
211209
Returns:
212210
str: A hash string representing the unique code artifact(s) for the step
213211
"""
212+
dependencies = dependencies or []
214213

215214
# FrameworkProcessor
216215
if source_dir:

sagemaker-core/tests/unit/workflow/test_utilities.py

Lines changed: 55 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from pathlib import Path
1818
from unittest.mock import Mock, patch, MagicMock
1919
from sagemaker.core.workflow.utilities import (
20-
get_code_hash,
2120
list_to_request,
2221
hash_file,
2322
hash_files_or_dirs,
@@ -226,77 +225,50 @@ def test_get_processing_code_hash_with_source_dir(self):
226225
)
227226

228227
assert result is not None
229-
230-
def test_get_processing_code_hash_with_source_dir_and_none_dependencies(self):
231-
"""Test get_processing_code_hash with source_dir and dependencies=None"""
232-
with tempfile.TemporaryDirectory() as temp_dir:
233-
code_file = Path(temp_dir, "script.py")
234-
code_file.write_text("print('hello')")
235-
236-
result = get_processing_code_hash(
237-
code=str(code_file), source_dir=temp_dir, dependencies=None
238-
)
239-
240-
assert result is not None
241228
assert len(result) == 64
242229

243-
def test_get_processing_code_hash_with_code_only_and_none_dependencies(self):
244-
"""Test get_processing_code_hash with code only and dependencies=None"""
230+
def test_get_processing_code_hash_code_only(self):
231+
"""Test get_processing_code_hash with code only"""
245232
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
246233
f.write("print('hello')")
247234
temp_file = f.name
248235

249236
try:
250-
result = get_processing_code_hash(code=temp_file, source_dir=None, dependencies=None)
237+
result = get_processing_code_hash(code=temp_file, source_dir=None, dependencies=[])
251238

252239
assert result is not None
253240
assert len(result) == 64
254241
finally:
255242
os.unlink(temp_file)
256243

257-
@pytest.mark.skip(reason="Requires sagemaker-mlops module which is not installed in sagemaker-core tests")
258-
def test_get_code_hash_with_training_step_none_requirements(self):
259-
"""Test get_code_hash with a TrainingStep whose source_code.requirements is None"""
260-
from sagemaker.mlops.workflow.steps import TrainingStep
261-
244+
def test_get_processing_code_hash_with_none_dependencies_and_source_dir(self):
245+
"""Test get_processing_code_hash with None dependencies and source_dir"""
262246
with tempfile.TemporaryDirectory() as temp_dir:
263-
entry_file = Path(temp_dir, "train.py")
264-
entry_file.write_text("print('training')")
265-
266-
mock_source_code = Mock()
267-
mock_source_code.source_dir = temp_dir
268-
mock_source_code.requirements = None
269-
mock_source_code.entry_script = str(entry_file)
270-
271-
mock_model_trainer = Mock()
272-
mock_model_trainer.source_code = mock_source_code
273-
274-
mock_step_args = Mock()
275-
mock_step_args.func_args = [mock_model_trainer]
276-
277-
step = Mock(spec=TrainingStep)
278-
step.step_args = mock_step_args
247+
code_file = Path(temp_dir, "script.py")
248+
code_file.write_text("print('hello')")
279249

280-
result = get_code_hash(step)
250+
result = get_processing_code_hash(
251+
code=str(code_file), source_dir=temp_dir, dependencies=None
252+
)
281253

282254
assert result is not None
283255
assert len(result) == 64
284-
assert len(result) == 64
285256

286-
def test_get_processing_code_hash_code_only(self):
287-
"""Test get_processing_code_hash with code only"""
257+
def test_get_processing_code_hash_with_none_dependencies_and_code_only(self):
258+
"""Test get_processing_code_hash with None dependencies and code only"""
288259
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
289260
f.write("print('hello')")
290261
temp_file = f.name
291262

292263
try:
293-
result = get_processing_code_hash(code=temp_file, source_dir=None, dependencies=[])
264+
result = get_processing_code_hash(code=temp_file, source_dir=None, dependencies=None)
294265

295266
assert result is not None
296267
assert len(result) == 64
297268
finally:
298269
os.unlink(temp_file)
299270

271+
300272
def test_get_processing_code_hash_s3_uri(self):
301273
"""Test get_processing_code_hash with S3 URI returns None"""
302274
result = get_processing_code_hash(
@@ -364,6 +336,47 @@ def test_get_training_code_hash_entry_point_only(self):
364336
assert len(result_with_deps) == 64
365337
assert result_no_deps != result_with_deps
366338

339+
def test_get_code_hash_training_step_with_none_requirements(self):
340+
"""Test get_code_hash with TrainingStep whose source_code has requirements=None"""
341+
from sagemaker.core.workflow.utilities import get_code_hash
342+
343+
with tempfile.TemporaryDirectory() as temp_dir:
344+
entry_file = Path(temp_dir, "train.py")
345+
entry_file.write_text("print('training')")
346+
347+
mock_source_code = Mock()
348+
mock_source_code.source_dir = temp_dir
349+
mock_source_code.requirements = None
350+
mock_source_code.entry_script = str(entry_file)
351+
352+
mock_model_trainer = Mock()
353+
mock_model_trainer.source_code = mock_source_code
354+
355+
mock_step_args = Mock()
356+
mock_step_args.func_args = [mock_model_trainer]
357+
358+
mock_step = Mock()
359+
mock_step.step_args = mock_step_args
360+
361+
with patch("sagemaker.core.workflow.utilities.isinstance") as mock_isinstance:
362+
def isinstance_side_effect(obj, cls):
363+
from sagemaker.mlops.workflow.steps import TrainingStep, ProcessingStep
364+
if cls is ProcessingStep:
365+
return False
366+
if cls is TrainingStep:
367+
return obj is mock_step
368+
return builtins_isinstance(obj, cls)
369+
370+
import builtins
371+
builtins_isinstance = builtins.isinstance
372+
mock_isinstance.side_effect = isinstance_side_effect
373+
374+
result = get_code_hash(mock_step)
375+
376+
assert result is not None
377+
assert len(result) == 64
378+
379+
367380
def test_get_training_code_hash_s3_uri(self):
368381
"""Test get_training_code_hash with S3 URI returns None"""
369382
result = get_training_code_hash(

0 commit comments

Comments
 (0)