Skip to content

Commit 7a36f43

Browse files
committed
fix: address review comments (iteration #1)
1 parent 1b80f49 commit 7a36f43

File tree

2 files changed

+18
-16
lines changed

2 files changed

+18
-16
lines changed

sagemaker-core/src/sagemaker/core/workflow/utilities.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,6 @@ def get_code_hash(step: Entity) -> str:
173173
source_code = model_trainer.source_code
174174
if source_code:
175175
source_dir = source_code.source_dir
176-
# requirements may be None when SourceCode.requirements is not set;
177-
# get_training_code_hash handles None dependencies gracefully
178176
requirements = source_code.requirements
179177
entry_point = source_code.entry_script
180178
return get_training_code_hash(entry_point, source_dir, requirements)
@@ -199,7 +197,6 @@ def get_processing_dependencies(dependency_args: List[List[str]]) -> List[str]:
199197

200198

201199
def get_processing_code_hash(code: str, source_dir: str, dependencies: List[str]) -> str:
202-
dependencies = dependencies or []
203200
"""Get the hash of a processing step's code artifact(s).
204201
205202
Args:
@@ -213,6 +210,12 @@ def get_processing_code_hash(code: str, source_dir: str, dependencies: List[str]
213210
str: A hash string representing the unique code artifact(s) for the step
214211
"""
215212

213+
# SourceCode.requirements and other upstream dependency fields default to None
214+
# when not explicitly set. Since this function concatenates dependencies via list
215+
# addition (e.g. [source_dir] + dependencies), we default None to an empty list
216+
# to prevent TypeError.
217+
dependencies = dependencies or []
218+
216219
# FrameworkProcessor
217220
if source_dir:
218221
source_dir_url = urlparse(source_dir)

sagemaker-core/tests/unit/workflow/test_utilities.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -214,14 +214,13 @@ def test_get_processing_dependencies_multiple_lists(self):
214214

215215
assert result == ["dep1", "dep2", "dep3", "dep4", "dep5"]
216216

217-
def test_get_processing_code_hash_with_none_dependencies(self):
218-
"""Test get_processing_code_hash does not raise TypeError when dependencies is None"""
217+
def test_get_processing_code_hash_with_none_dependencies_and_code_only(self):
218+
"""Test get_processing_code_hash with None dependencies and code only"""
219219
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
220220
f.write("print('hello')")
221221
temp_file = f.name
222222

223223
try:
224-
# Should not raise TypeError
225224
result = get_processing_code_hash(code=temp_file, source_dir=None, dependencies=None)
226225

227226
assert result is not None
@@ -230,12 +229,11 @@ def test_get_processing_code_hash_with_none_dependencies(self):
230229
os.unlink(temp_file)
231230

232231
def test_get_processing_code_hash_with_none_dependencies_and_source_dir(self):
233-
"""Test get_processing_code_hash with source_dir and None dependencies"""
232+
"""Test get_processing_code_hash with None dependencies and source_dir"""
234233
with tempfile.TemporaryDirectory() as temp_dir:
235234
code_file = Path(temp_dir, "script.py")
236235
code_file.write_text("print('hello')")
237236

238-
# Should not raise TypeError
239237
result = get_processing_code_hash(
240238
code=str(code_file), source_dir=temp_dir, dependencies=None
241239
)
@@ -294,12 +292,11 @@ def test_get_processing_code_hash_with_dependencies(self):
294292
assert result is not None
295293

296294
def test_get_training_code_hash_with_none_dependencies_and_source_dir(self):
297-
"""Test get_training_code_hash with source_dir and None dependencies does not raise"""
295+
"""Test get_training_code_hash with None dependencies and source_dir"""
298296
with tempfile.TemporaryDirectory() as temp_dir:
299297
entry_file = Path(temp_dir, "train.py")
300298
entry_file.write_text("print('training')")
301299

302-
# Should not raise TypeError
303300
result = get_training_code_hash(
304301
entry_point=str(entry_file), source_dir=temp_dir, dependencies=None
305302
)
@@ -308,18 +305,20 @@ def test_get_training_code_hash_with_none_dependencies_and_source_dir(self):
308305
assert len(result) == 64
309306

310307
def test_get_training_code_hash_with_none_dependencies_and_entry_point(self):
311-
"""Test get_training_code_hash with entry_point only and None dependencies does not raise"""
312-
with tempfile.TemporaryDirectory() as temp_dir:
313-
entry_file = Path(temp_dir, "train.py")
314-
entry_file.write_text("print('training')")
308+
"""Test get_training_code_hash with None dependencies and entry_point only"""
309+
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
310+
f.write("print('training')")
311+
temp_file = f.name
315312

316-
# Should not raise TypeError
313+
try:
317314
result = get_training_code_hash(
318-
entry_point=str(entry_file), source_dir=None, dependencies=None
315+
entry_point=temp_file, source_dir=None, dependencies=None
319316
)
320317

321318
assert result is not None
322319
assert len(result) == 64
320+
finally:
321+
os.unlink(temp_file)
323322

324323
def test_get_training_code_hash_with_source_dir(self):
325324
"""Test get_training_code_hash with source_dir"""

0 commit comments

Comments
 (0)