Skip to content

Commit 9450aee

Browse files
committed
fix: [Bug] Pipeline parameters (ParameterInteger, ParameterString) fail in ModelTrain (5504)
1 parent 6a1ba54 commit 9450aee

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

sagemaker-train/src/sagemaker/train/model_trainer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,14 +410,18 @@ def __del__(self):
410410
self._temp_code_dir.cleanup()
411411

412412
def _validate_training_image_and_algorithm_name(
413-
self, training_image: Optional[str], algorithm_name: Optional[str]
413+
self, training_image, algorithm_name
414414
):
415415
"""Validate that only one of 'training_image' or 'algorithm_name' is provided."""
416-
if not training_image and not algorithm_name:
416+
from sagemaker.core.helper.pipeline_variable import PipelineVariable as _PV
417+
# PipelineVariables are truthy for validation purposes
418+
has_image = isinstance(training_image, _PV) or bool(training_image)
419+
has_algo = isinstance(algorithm_name, _PV) or bool(algorithm_name)
420+
if not has_image and not has_algo:
417421
raise ValueError(
418422
"Atleast one of 'training_image' or 'algorithm_name' must be provided.",
419423
)
420-
if training_image and algorithm_name:
424+
if has_image and has_algo:
421425
raise ValueError(
422426
"Only one of 'training_image' or 'algorithm_name' must be provided.",
423427
)

sagemaker-train/src/sagemaker/train/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def _get_unique_name(base, max_length=63):
142142
return unique_name
143143

144144

145-
def _get_repo_name_from_image(image: str) -> str:
145+
def _get_repo_name_from_image(image) -> str:
146146
"""Get the repository name from the image URI.
147147
148148
Example:
@@ -152,11 +152,13 @@ def _get_repo_name_from_image(image: str) -> str:
152152
```
153153
154154
Args:
155-
image (str): The image URI
155+
image: The image URI (str or PipelineVariable)
156156
157157
Returns:
158158
str: The repository name
159159
"""
160+
if isinstance(image, PipelineVariable):
161+
return "pipeline-variable-image"
160162
return image.split("/")[-1].split(":")[0].split("@")[0]
161163

162164

0 commit comments

Comments
 (0)