Skip to content

Commit 7891ab1

Browse files
committed
fix: address review comments (iteration #1)
1 parent 1fdf836 commit 7891ab1

File tree

3 files changed

+14
-6
lines changed

3 files changed

+14
-6
lines changed

sagemaker-train/src/sagemaker/train/defaults.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,16 @@ def get_base_job_name(
8383
) -> str:
8484
"""Get the default base job name."""
8585
if base_job_name is None:
86-
if algorithm_name:
86+
if algorithm_name and isinstance(algorithm_name, str):
8787
base_job_name = f"{algorithm_name}-job"
8888
elif training_image:
89-
base_job_name = f"{_get_repo_name_from_image(training_image)}-job"
89+
repo_name = _get_repo_name_from_image(training_image)
90+
if repo_name:
91+
base_job_name = f"{repo_name}-job"
92+
else:
93+
base_job_name = "training-job"
94+
if base_job_name is None:
95+
base_job_name = "training-job"
9096
logger.info(f"Base name not provided. Using default name:\n{base_job_name}")
9197
return base_job_name
9298

sagemaker-train/src/sagemaker/train/model_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def __del__(self):
410410
self._temp_code_dir.cleanup()
411411

412412
def _validate_training_image_and_algorithm_name(
413-
self, training_image: Optional[str], algorithm_name: Optional[str]
413+
self, training_image, algorithm_name
414414
):
415415
"""Validate that only one of 'training_image' or 'algorithm_name' is provided."""
416416
from sagemaker.core.helper.pipeline_variable import PipelineVariable

sagemaker-train/src/sagemaker/train/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def _get_unique_name(base, max_length=63):
142142
return unique_name
143143

144144

145-
def _get_repo_name_from_image(image: str) -> str:
145+
def _get_repo_name_from_image(image) -> str:
146146
"""Get the repository name from the image URI.
147147
148148
Example:
@@ -152,11 +152,13 @@ def _get_repo_name_from_image(image: str) -> str:
152152
```
153153
154154
Args:
155-
image (str): The image URI
155+
image: The image URI (str or PipelineVariable)
156156
157157
Returns:
158-
str: The repository name
158+
str: The repository name, or None if image is not a plain string
159159
"""
160+
if not isinstance(image, str):
161+
return None
160162
return image.split("/")[-1].split(":")[0].split("@")[0]
161163

162164

0 commit comments

Comments
 (0)