Skip to content

Commit 04d768d

Browse files
committed
fix: [Bug] Pipeline parameters (ParameterInteger, ParameterString) fail in ModelTrain (5504)
1 parent ee420cc commit 04d768d

File tree

3 files changed

+13
-7
lines changed

3 files changed

+13
-7
lines changed

sagemaker-train/src/sagemaker/train/model_trainer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,14 +410,16 @@ def __del__(self):
410410
self._temp_code_dir.cleanup()
411411

412412
def _validate_training_image_and_algorithm_name(
413-
self, training_image: Optional[str], algorithm_name: Optional[str]
413+
self, training_image, algorithm_name
414414
):
415415
"""Validate that only one of 'training_image' or 'algorithm_name' is provided."""
416-
if not training_image and not algorithm_name:
416+
has_image = training_image is not None
417+
has_algo = algorithm_name is not None
418+
if not has_image and not has_algo:
417419
raise ValueError(
418420
"Atleast one of 'training_image' or 'algorithm_name' must be provided.",
419421
)
420-
if training_image and algorithm_name:
422+
if has_image and has_algo:
421423
raise ValueError(
422424
"Only one of 'training_image' or 'algorithm_name' must be provided.",
423425
)

sagemaker-train/src/sagemaker/train/utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def _get_unique_name(base, max_length=63):
142142
return unique_name
143143

144144

145-
def _get_repo_name_from_image(image: str) -> str:
145+
def _get_repo_name_from_image(image) -> str:
146146
"""Get the repository name from the image URI.
147147
148148
Example:
@@ -152,11 +152,15 @@ def _get_repo_name_from_image(image: str) -> str:
152152
```
153153
154154
Args:
155-
image (str): The image URI
155+
image: The image URI (str or PipelineVariable)
156156
157157
Returns:
158-
str: The repository name
158+
str: The repository name, or None if image is a PipelineVariable
159159
"""
160+
from sagemaker.core.helper.pipeline_variable import PipelineVariable
161+
162+
if isinstance(image, PipelineVariable):
163+
return None
160164
return image.split("/")[-1].split(":")[0].split("@")[0]
161165

162166

sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
StoppingCondition,
3434
OutputDataConfig,
3535
)
36-
from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE
36+
from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE, TrainDefaults
3737

3838

3939
DEFAULT_IMAGE = "000000000000.dkr.ecr.us-west-2.amazonaws.com/dummy-image:latest"

0 commit comments

Comments
 (0)