Skip to content

Commit 2e31aa5

Browse files
committed
fix: address review comments (iteration #1)
1 parent 97579f3 commit 2e31aa5

File tree

1 file changed

+55
-41
lines changed

1 file changed

+55
-41
lines changed

sagemaker-core/tests/unit/workflow/test_utilities.py

Lines changed: 55 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pathlib import Path
1818
from unittest.mock import Mock, patch, MagicMock
1919
from sagemaker.core.workflow.utilities import (
20+
get_code_hash,
2021
list_to_request,
2122
hash_file,
2223
hash_files_or_dirs,
@@ -225,86 +226,99 @@ def test_get_processing_code_hash_with_source_dir(self):
225226
)
226227

227228
assert result is not None
229+
230+
def test_get_processing_code_hash_with_source_dir_and_none_dependencies(self):
231+
"""Test get_processing_code_hash with source_dir and dependencies=None"""
232+
with tempfile.TemporaryDirectory() as temp_dir:
233+
code_file = Path(temp_dir, "script.py")
234+
code_file.write_text("print('hello')")
235+
236+
result = get_processing_code_hash(
237+
code=str(code_file), source_dir=temp_dir, dependencies=None
238+
)
239+
240+
assert result is not None
228241
assert len(result) == 64
229242

230-
def test_get_processing_code_hash_code_only(self):
231-
"""Test get_processing_code_hash with code only"""
243+
def test_get_processing_code_hash_with_code_only_and_none_dependencies(self):
244+
"""Test get_processing_code_hash with code only and dependencies=None"""
232245
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
233246
f.write("print('hello')")
234247
temp_file = f.name
235248

236249
try:
237-
result = get_processing_code_hash(code=temp_file, source_dir=None, dependencies=[])
250+
result = get_processing_code_hash(code=temp_file, source_dir=None, dependencies=None)
238251

239252
assert result is not None
240253
assert len(result) == 64
241254
finally:
242255
os.unlink(temp_file)
243256

244-
def test_get_processing_code_hash_s3_uri(self):
245-
"""Test get_processing_code_hash with S3 URI returns None"""
246-
result = get_processing_code_hash(
247-
code="s3://bucket/script.py", source_dir=None, dependencies=[]
248-
)
249-
250-
assert result is None
257+
@pytest.mark.skip(reason="Requires sagemaker-mlops module which is not installed in sagemaker-core tests")
258+
def test_get_code_hash_with_training_step_none_requirements(self):
259+
"""Test get_code_hash with a TrainingStep whose source_code.requirements is None"""
260+
from sagemaker.mlops.workflow.steps import TrainingStep
251261

252-
def test_get_processing_code_hash_with_dependencies(self):
253-
"""Test get_processing_code_hash with dependencies"""
254262
with tempfile.TemporaryDirectory() as temp_dir:
255-
code_file = Path(temp_dir, "script.py")
256-
code_file.write_text("print('hello')")
263+
entry_file = Path(temp_dir, "train.py")
264+
entry_file.write_text("print('training')")
257265

258-
dep_file = Path(temp_dir, "utils.py")
259-
dep_file.write_text("def helper(): pass")
266+
mock_source_code = Mock()
267+
mock_source_code.source_dir = temp_dir
268+
mock_source_code.requirements = None
269+
mock_source_code.entry_script = str(entry_file)
260270

261-
result = get_processing_code_hash(
262-
code=str(code_file), source_dir=temp_dir, dependencies=[str(dep_file)]
263-
)
271+
mock_model_trainer = Mock()
272+
mock_model_trainer.source_code = mock_source_code
273+
274+
mock_step_args = Mock()
275+
mock_step_args.func_args = [mock_model_trainer]
276+
277+
step = Mock(spec=TrainingStep)
278+
step.step_args = mock_step_args
279+
280+
result = get_code_hash(step)
264281

265282
assert result is not None
283+
assert len(result) == 64
284+
assert len(result) == 64
266285

267-
def test_get_processing_code_hash_with_none_dependencies(self):
268-
"""Test get_processing_code_hash with None dependencies does not raise TypeError"""
286+
def test_get_processing_code_hash_code_only(self):
287+
"""Test get_processing_code_hash with code only"""
269288
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
270289
f.write("print('hello')")
271290
temp_file = f.name
272291

273292
try:
274-
result = get_processing_code_hash(code=temp_file, source_dir=None, dependencies=None)
293+
result = get_processing_code_hash(code=temp_file, source_dir=None, dependencies=[])
275294

276295
assert result is not None
277296
assert len(result) == 64
278297
finally:
279298
os.unlink(temp_file)
280299

281-
def test_get_processing_code_hash_with_source_dir_and_none_dependencies(self):
282-
"""Test get_processing_code_hash with source_dir and None dependencies"""
300+
def test_get_processing_code_hash_s3_uri(self):
301+
"""Test get_processing_code_hash with S3 URI returns None"""
302+
result = get_processing_code_hash(
303+
code="s3://bucket/script.py", source_dir=None, dependencies=[]
304+
)
305+
306+
assert result is None
307+
308+
def test_get_processing_code_hash_with_dependencies(self):
309+
"""Test get_processing_code_hash with dependencies"""
283310
with tempfile.TemporaryDirectory() as temp_dir:
284311
code_file = Path(temp_dir, "script.py")
285312
code_file.write_text("print('hello')")
286313

314+
dep_file = Path(temp_dir, "utils.py")
315+
dep_file.write_text("def helper(): pass")
316+
287317
result = get_processing_code_hash(
288-
code=str(code_file), source_dir=temp_dir, dependencies=None
318+
code=str(code_file), source_dir=temp_dir, dependencies=[str(dep_file)]
289319
)
290320

291321
assert result is not None
292-
assert len(result) == 64
293-
294-
def test_get_processing_code_hash_with_code_only_and_none_dependencies(self):
295-
"""Test get_processing_code_hash with code only and None dependencies"""
296-
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
297-
f.write("print('hello')")
298-
temp_file = f.name
299-
300-
try:
301-
# Ensure no TypeError when dependencies is None
302-
result = get_processing_code_hash(code=temp_file, source_dir=None, dependencies=None)
303-
304-
assert result is not None
305-
assert len(result) == 64
306-
finally:
307-
os.unlink(temp_file)
308322

309323
def test_get_training_code_hash_with_source_dir(self):
310324
"""Test get_training_code_hash with source_dir"""

0 commit comments

Comments
 (0)