Skip to content

Commit ba484da

Browse files
committed
fix: address review comments (iteration #1)
1 parent f5abeac commit ba484da

File tree

2 files changed

+123
-53
lines changed

2 files changed

+123
-53
lines changed

sagemaker-core/src/sagemaker/core/workflow/utilities.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,9 @@ def get_code_hash(step: Entity) -> str:
175175
source_dir = source_code.source_dir
176176
requirements = source_code.requirements
177177
entry_point = source_code.entry_script
178-
return get_training_code_hash(entry_point, source_dir, requirements)
178+
return get_training_code_hash(
179+
entry_point, source_dir, requirements
180+
)
179181
return None
180182

181183

sagemaker-core/tests/unit/workflow/test_utilities.py

Lines changed: 120 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -274,10 +274,14 @@ def test_get_training_code_hash_with_source_dir(self):
274274
requirements_file.write_text("numpy==1.21.0")
275275

276276
result_no_deps = get_training_code_hash(
277-
entry_point=str(entry_file), source_dir=temp_dir, dependencies=None
277+
entry_point=str(entry_file),
278+
source_dir=temp_dir,
279+
dependencies=None,
278280
)
279281
result_with_deps = get_training_code_hash(
280-
entry_point=str(entry_file), source_dir=temp_dir, dependencies=str(requirements_file)
282+
entry_point=str(entry_file),
283+
source_dir=temp_dir,
284+
dependencies=str(requirements_file),
281285
)
282286

283287
assert result_no_deps is not None
@@ -286,19 +290,32 @@ def test_get_training_code_hash_with_source_dir(self):
286290
assert len(result_with_deps) == 64
287291
assert result_no_deps != result_with_deps
288292

289-
def test_get_training_code_hash_with_source_dir_and_none_dependencies(self):
290-
"""Test get_training_code_hash with source_dir and None dependencies does not raise TypeError"""
293+
def test_get_training_code_hash_source_dir_none_deps(
294+
self,
295+
):
296+
"""Test get_training_code_hash with source_dir
297+
and None dependencies does not raise TypeError.
298+
"""
291299
with tempfile.TemporaryDirectory() as temp_dir:
292300
entry_file = Path(temp_dir, "train.py")
293301
entry_file.write_text("print('training')")
294302

295-
# This should NOT raise TypeError: can only concatenate list (not "NoneType") to list
296-
result = get_training_code_hash(
297-
entry_point=str(entry_file), source_dir=temp_dir, dependencies=None
303+
# Should NOT raise TypeError
304+
result_none = get_training_code_hash(
305+
entry_point=str(entry_file),
306+
source_dir=temp_dir,
307+
dependencies=None,
308+
)
309+
# Empty list should be equivalent to None
310+
result_empty = get_training_code_hash(
311+
entry_point=str(entry_file),
312+
source_dir=temp_dir,
313+
dependencies=[],
298314
)
299315

300-
assert result is not None
301-
assert len(result) == 64
316+
assert result_none is not None
317+
assert len(result_none) == 64
318+
assert result_none == result_empty
302319

303320
def test_get_training_code_hash_entry_point_only(self):
304321
"""Test get_training_code_hash with entry_point only"""
@@ -310,11 +327,15 @@ def test_get_training_code_hash_entry_point_only(self):
310327

311328
# Without dependencies
312329
result_no_deps = get_training_code_hash(
313-
entry_point=str(entry_file), source_dir=None, dependencies=None
330+
entry_point=str(entry_file),
331+
source_dir=None,
332+
dependencies=None,
314333
)
315334
# With dependencies
316335
result_with_deps = get_training_code_hash(
317-
entry_point=str(entry_file), source_dir=None, dependencies=str(requirements_file)
336+
entry_point=str(entry_file),
337+
source_dir=None,
338+
dependencies=str(requirements_file),
318339
)
319340

320341
assert result_no_deps is not None
@@ -323,19 +344,32 @@ def test_get_training_code_hash_entry_point_only(self):
323344
assert len(result_with_deps) == 64
324345
assert result_no_deps != result_with_deps
325346

326-
def test_get_training_code_hash_entry_point_only_and_none_dependencies(self):
327-
"""Test get_training_code_hash with entry_point only and None dependencies does not raise TypeError"""
347+
def test_get_training_code_hash_entry_point_none_deps(
348+
self,
349+
):
350+
"""Test get_training_code_hash with entry_point
351+
and None dependencies does not raise TypeError.
352+
"""
328353
with tempfile.TemporaryDirectory() as temp_dir:
329354
entry_file = Path(temp_dir, "train.py")
330355
entry_file.write_text("print('training')")
331356

332-
# This should NOT raise TypeError: can only concatenate list (not "NoneType") to list
333-
result = get_training_code_hash(
334-
entry_point=str(entry_file), source_dir=None, dependencies=None
357+
# Should NOT raise TypeError
358+
result_none = get_training_code_hash(
359+
entry_point=str(entry_file),
360+
source_dir=None,
361+
dependencies=None,
362+
)
363+
# Empty list should be equivalent to None
364+
result_empty = get_training_code_hash(
365+
entry_point=str(entry_file),
366+
source_dir=None,
367+
dependencies=[],
335368
)
336369

337-
assert result is not None
338-
assert len(result) == 64
370+
assert result_none is not None
371+
assert len(result_none) == 64
372+
assert result_none == result_empty
339373

340374
def test_get_training_code_hash_s3_uri(self):
341375
"""Test get_training_code_hash with S3 URI returns None"""
@@ -354,72 +388,106 @@ def test_get_training_code_hash_pipeline_variable(self):
354388

355389
assert result is None
356390

357-
@pytest.mark.skip(reason="Requires sagemaker-mlops module which is not installed in sagemaker-core tests")
358-
def test_get_code_hash_with_training_step_and_no_requirements(self):
359-
"""Test get_code_hash with TrainingStep where SourceCode has requirements=None"""
360-
from sagemaker.mlops.workflow.steps import TrainingStep
391+
def test_get_code_hash_training_step_no_requirements(
392+
self,
393+
):
394+
"""Test get_code_hash with TrainingStep where
395+
SourceCode has requirements=None.
396+
"""
397+
# Create a fake TrainingStep class to patch isinstance
398+
FakeTrainingStep = type(
399+
"TrainingStep", (), {}
400+
)
361401

362402
with tempfile.TemporaryDirectory() as temp_dir:
363403
entry_file = Path(temp_dir, "train.py")
364404
entry_file.write_text("print('training')")
365405

366406
mock_source_code = Mock()
367407
mock_source_code.source_dir = temp_dir
368-
mock_source_code.requirements = None # This is the key: requirements is None
408+
mock_source_code.requirements = None
369409
mock_source_code.entry_script = str(entry_file)
370410

371411
mock_model_trainer = Mock()
372412
mock_model_trainer.source_code = mock_source_code
373413

374414
mock_step_args = Mock()
375-
mock_step_args.func_args = [mock_model_trainer]
415+
mock_step_args.func_args = [
416+
mock_model_trainer
417+
]
376418

377-
mock_step = Mock(spec=TrainingStep)
419+
mock_step = MagicMock(spec=FakeTrainingStep)
378420
mock_step.step_args = mock_step_args
379421

380-
# This should NOT raise TypeError
381-
result = get_code_hash(mock_step)
422+
with patch(
423+
"sagemaker.core.workflow.utilities"
424+
".TrainingStep",
425+
new=FakeTrainingStep,
426+
):
427+
result = get_code_hash(mock_step)
382428

383429
assert result is not None
384430
assert len(result) == 64
385431

386-
@pytest.mark.skip(reason="Requires sagemaker-mlops module which is not installed in sagemaker-core tests")
387-
def test_get_code_hash_with_training_step_and_requirements(self):
388-
"""Test get_code_hash with TrainingStep where SourceCode has valid requirements"""
389-
from sagemaker.mlops.workflow.steps import TrainingStep
432+
def test_get_code_hash_training_step_with_requirements(
433+
self,
434+
):
435+
"""Test get_code_hash with TrainingStep where
436+
SourceCode has valid requirements.
437+
"""
438+
FakeTrainingStep = type(
439+
"TrainingStep", (), {}
440+
)
390441

391442
with tempfile.TemporaryDirectory() as temp_dir:
392443
entry_file = Path(temp_dir, "train.py")
393444
entry_file.write_text("print('training')")
394-
requirements_file = Path(temp_dir, "requirements.txt")
395-
requirements_file.write_text("numpy==1.21.0")
445+
req_file = Path(temp_dir, "requirements.txt")
446+
req_file.write_text("numpy==1.21.0")
396447

397-
mock_source_code_no_req = Mock()
398-
mock_source_code_no_req.source_dir = temp_dir
399-
mock_source_code_no_req.requirements = None
400-
mock_source_code_no_req.entry_script = str(entry_file)
448+
mock_sc_no_req = Mock()
449+
mock_sc_no_req.source_dir = temp_dir
450+
mock_sc_no_req.requirements = None
451+
mock_sc_no_req.entry_script = str(entry_file)
401452

402-
mock_source_code_with_req = Mock()
403-
mock_source_code_with_req.source_dir = temp_dir
404-
mock_source_code_with_req.requirements = str(requirements_file)
405-
mock_source_code_with_req.entry_script = str(entry_file)
453+
mock_sc_with_req = Mock()
454+
mock_sc_with_req.source_dir = temp_dir
455+
mock_sc_with_req.requirements = str(req_file)
456+
mock_sc_with_req.entry_script = str(entry_file)
406457

407-
mock_model_trainer_no_req = Mock()
408-
mock_model_trainer_no_req.source_code = mock_source_code_no_req
458+
mock_mt_no_req = Mock()
459+
mock_mt_no_req.source_code = mock_sc_no_req
409460

410-
mock_model_trainer_with_req = Mock()
411-
mock_model_trainer_with_req.source_code = mock_source_code_with_req
461+
mock_mt_with_req = Mock()
462+
mock_mt_with_req.source_code = mock_sc_with_req
412463

413-
mock_step_no_req = Mock(spec=TrainingStep)
464+
mock_step_no_req = MagicMock(
465+
spec=FakeTrainingStep
466+
)
414467
mock_step_no_req.step_args = Mock()
415-
mock_step_no_req.step_args.func_args = [mock_model_trainer_no_req]
468+
mock_step_no_req.step_args.func_args = [
469+
mock_mt_no_req
470+
]
416471

417-
mock_step_with_req = Mock(spec=TrainingStep)
472+
mock_step_with_req = MagicMock(
473+
spec=FakeTrainingStep
474+
)
418475
mock_step_with_req.step_args = Mock()
419-
mock_step_with_req.step_args.func_args = [mock_model_trainer_with_req]
420-
421-
result_no_req = get_code_hash(mock_step_no_req)
422-
result_with_req = get_code_hash(mock_step_with_req)
476+
mock_step_with_req.step_args.func_args = [
477+
mock_mt_with_req
478+
]
479+
480+
with patch(
481+
"sagemaker.core.workflow.utilities"
482+
".TrainingStep",
483+
new=FakeTrainingStep,
484+
):
485+
result_no_req = get_code_hash(
486+
mock_step_no_req
487+
)
488+
result_with_req = get_code_hash(
489+
mock_step_with_req
490+
)
423491

424492
assert result_no_req is not None
425493
assert result_with_req is not None

0 commit comments

Comments
 (0)