Skip to content

Commit 1fdf836

Browse files
committed
fix: [Bug] Pipeline parameters (ParameterInteger, ParameterString) fail in ModelTrain (5504)
1 parent 6a1ba54 commit 1fdf836

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

sagemaker-train/src/sagemaker/train/defaults.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ def get_role(
7878
@staticmethod
7979
def get_base_job_name(
8080
base_job_name: Optional[str] = None,
81-
algorithm_name: Optional[str] = None,
82-
training_image: Optional[str] = None,
81+
algorithm_name=None,
82+
training_image=None,
8383
) -> str:
8484
"""Get the default base job name."""
8585
if base_job_name is None:

sagemaker-train/src/sagemaker/train/model_trainer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,11 +413,19 @@ def _validate_training_image_and_algorithm_name(
413413
self, training_image: Optional[str], algorithm_name: Optional[str]
414414
):
415415
"""Validate that only one of 'training_image' or 'algorithm_name' is provided."""
416-
if not training_image and not algorithm_name:
416+
from sagemaker.core.helper.pipeline_variable import PipelineVariable
417+
418+
has_training_image = training_image is not None and (
419+
isinstance(training_image, PipelineVariable) or bool(training_image)
420+
)
421+
has_algorithm_name = algorithm_name is not None and (
422+
isinstance(algorithm_name, PipelineVariable) or bool(algorithm_name)
423+
)
424+
if not has_training_image and not has_algorithm_name:
417425
raise ValueError(
418426
"Atleast one of 'training_image' or 'algorithm_name' must be provided.",
419427
)
420-
if training_image and algorithm_name:
428+
if has_training_image and has_algorithm_name:
421429
raise ValueError(
422430
"Only one of 'training_image' or 'algorithm_name' must be provided.",
423431
)

0 commit comments

Comments
 (0)