Skip to content

Commit 238c019

Browse files
Fix handling of training step dependencies to allow successful pipeline creation (aws#5618)
* Add docker-compose path * Check for MacOS * Fix model registration with a model card * Account for both ModelCard and ModelPackageModelCard objects * Add unit tests for model card during model registration * Fix handling of dependencies in get_training_code_hash workflow utility * Update docstring * Add unit tests
1 parent b86faba commit 238c019

File tree

2 files changed

+45
-19
lines changed

2 files changed

+45
-19
lines changed

sagemaker-core/src/sagemaker/core/workflow/utilities.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from _hashlib import HASH as Hash
2929
except ImportError:
3030
import typing
31+
3132
Hash = typing.Any
3233

3334
from sagemaker.core.common_utils import base_from_name
@@ -227,7 +228,9 @@ def get_processing_code_hash(code: str, source_dir: str, dependencies: List[str]
227228
return None
228229

229230

230-
def get_training_code_hash(entry_point: str, source_dir: str, dependencies: List[str]) -> str:
231+
def get_training_code_hash(
232+
entry_point: str, source_dir: str, dependencies: Optional[str] = None
233+
) -> str:
231234
"""Get the hash of a training step's code artifact(s).
232235
233236
Args:
@@ -236,9 +239,9 @@ def get_training_code_hash(entry_point: str, source_dir: str, dependencies: List
236239
training
237240
source_dir (str): Path to a directory with any other training source
238241
code dependencies aside from the entry point file
239-
dependencies (str): A list of paths to directories (absolute
240-
or relative) with any additional libraries that will be exported
241-
to the container
242+
dependencies Optional[str]: The relative path within ``source_dir`` to a
243+
``requirements.txt`` file with any additional libraries that
244+
will be exported to the container
242245
Returns:
243246
str: A hash string representing the unique code artifact(s) for the step
244247
"""
@@ -248,11 +251,17 @@ def get_training_code_hash(entry_point: str, source_dir: str, dependencies: List
248251
if source_dir:
249252
source_dir_url = urlparse(source_dir)
250253
if source_dir_url.scheme == "" or source_dir_url.scheme == "file":
251-
return hash_files_or_dirs([source_dir] + dependencies)
254+
if dependencies:
255+
return hash_files_or_dirs([source_dir] + [dependencies])
256+
else:
257+
return hash_files_or_dirs([source_dir])
252258
elif entry_point:
253259
entry_point_url = urlparse(entry_point)
254260
if entry_point_url.scheme == "" or entry_point_url.scheme == "file":
255-
return hash_files_or_dirs([entry_point] + dependencies)
261+
if dependencies:
262+
return hash_files_or_dirs([entry_point] + [dependencies])
263+
else:
264+
return hash_files_or_dirs([entry_point])
256265
return None
257266

258267

sagemaker-core/tests/unit/workflow/test_utilities.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -269,27 +269,44 @@ def test_get_training_code_hash_with_source_dir(self):
269269
with tempfile.TemporaryDirectory() as temp_dir:
270270
entry_file = Path(temp_dir, "train.py")
271271
entry_file.write_text("print('training')")
272+
requirements_file = Path(temp_dir, "requirements.txt")
273+
requirements_file.write_text("numpy==1.21.0")
272274

273-
result = get_training_code_hash(
274-
entry_point=str(entry_file), source_dir=temp_dir, dependencies=[]
275+
result_no_deps = get_training_code_hash(
276+
entry_point=str(entry_file), source_dir=temp_dir, dependencies=None
277+
)
278+
result_with_deps = get_training_code_hash(
279+
entry_point=str(entry_file), source_dir=temp_dir, dependencies=str(requirements_file)
275280
)
276281

277-
assert result is not None
278-
assert len(result) == 64
282+
assert result_no_deps is not None
283+
assert result_with_deps is not None
284+
assert len(result_no_deps) == 64
285+
assert len(result_with_deps) == 64
286+
assert result_no_deps != result_with_deps
279287

280288
def test_get_training_code_hash_entry_point_only(self):
281289
"""Test get_training_code_hash with entry_point only"""
282-
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
283-
f.write("print('training')")
284-
temp_file = f.name
290+
with tempfile.TemporaryDirectory() as temp_dir:
291+
entry_file = Path(temp_dir, "train.py")
292+
entry_file.write_text("print('training')")
293+
requirements_file = Path(temp_dir, "requirements.txt")
294+
requirements_file.write_text("numpy==1.21.0")
285295

286-
try:
287-
result = get_training_code_hash(entry_point=temp_file, source_dir=None, dependencies=[])
296+
# Without dependencies
297+
result_no_deps = get_training_code_hash(
298+
entry_point=str(entry_file), source_dir=None, dependencies=None
299+
)
300+
# With dependencies
301+
result_with_deps = get_training_code_hash(
302+
entry_point=str(entry_file), source_dir=None, dependencies=str(requirements_file)
303+
)
288304

289-
assert result is not None
290-
assert len(result) == 64
291-
finally:
292-
os.unlink(temp_file)
305+
assert result_no_deps is not None
306+
assert result_with_deps is not None
307+
assert len(result_no_deps) == 64
308+
assert len(result_with_deps) == 64
309+
assert result_no_deps != result_with_deps
293310

294311
def test_get_training_code_hash_s3_uri(self):
295312
"""Test get_training_code_hash with S3 URI returns None"""

0 commit comments

Comments
 (0)