From 1fdf8364bf8604f2b927b13c497eb3302f242033 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Thu, 26 Mar 2026 18:40:51 -0400 Subject: [PATCH 1/2] fix: [Bug] Pipeline parameters (ParameterInteger, ParameterString) fail in ModelTrain (5504) --- sagemaker-train/src/sagemaker/train/defaults.py | 4 ++-- sagemaker-train/src/sagemaker/train/model_trainer.py | 12 ++++++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/defaults.py b/sagemaker-train/src/sagemaker/train/defaults.py index 3c36b817ec..9e9d643b7c 100644 --- a/sagemaker-train/src/sagemaker/train/defaults.py +++ b/sagemaker-train/src/sagemaker/train/defaults.py @@ -78,8 +78,8 @@ def get_role( @staticmethod def get_base_job_name( base_job_name: Optional[str] = None, - algorithm_name: Optional[str] = None, - training_image: Optional[str] = None, + algorithm_name=None, + training_image=None, ) -> str: """Get the default base job name.""" if base_job_name is None: diff --git a/sagemaker-train/src/sagemaker/train/model_trainer.py b/sagemaker-train/src/sagemaker/train/model_trainer.py index d07edeb025..0b56404438 100644 --- a/sagemaker-train/src/sagemaker/train/model_trainer.py +++ b/sagemaker-train/src/sagemaker/train/model_trainer.py @@ -413,11 +413,19 @@ def _validate_training_image_and_algorithm_name( self, training_image: Optional[str], algorithm_name: Optional[str] ): """Validate that only one of 'training_image' or 'algorithm_name' is provided.""" - if not training_image and not algorithm_name: + from sagemaker.core.helper.pipeline_variable import PipelineVariable + + has_training_image = training_image is not None and ( + isinstance(training_image, PipelineVariable) or bool(training_image) + ) + has_algorithm_name = algorithm_name is not None and ( + isinstance(algorithm_name, PipelineVariable) or bool(algorithm_name) + ) + if not has_training_image and not has_algorithm_name: raise ValueError( "Atleast one of 'training_image' or 'algorithm_name' must be provided.", ) - if training_image and algorithm_name: + if has_training_image and has_algorithm_name: raise ValueError( "Only one of 'training_image' or 'algorithm_name' must be provided.", ) From 7891ab1c7822cb1835d4d528781fbf6a2d5ab127 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Thu, 26 Mar 2026 18:58:34 -0400 Subject: [PATCH 2/2] fix: address review comments (iteration #1) --- sagemaker-train/src/sagemaker/train/defaults.py | 10 ++++++++-- sagemaker-train/src/sagemaker/train/model_trainer.py | 2 +- sagemaker-train/src/sagemaker/train/utils.py | 8 +++++--- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/defaults.py b/sagemaker-train/src/sagemaker/train/defaults.py index 9e9d643b7c..359828004e 100644 --- a/sagemaker-train/src/sagemaker/train/defaults.py +++ b/sagemaker-train/src/sagemaker/train/defaults.py @@ -83,10 +83,16 @@ def get_base_job_name( ) -> str: """Get the default base job name.""" if base_job_name is None: - if algorithm_name: + if algorithm_name and isinstance(algorithm_name, str): base_job_name = f"{algorithm_name}-job" elif training_image: - base_job_name = f"{_get_repo_name_from_image(training_image)}-job" + repo_name = _get_repo_name_from_image(training_image) + if repo_name: + base_job_name = f"{repo_name}-job" + else: + base_job_name = "training-job" + if base_job_name is None: + base_job_name = "training-job" logger.info(f"Base name not provided. Using default name:\n{base_job_name}") return base_job_name diff --git a/sagemaker-train/src/sagemaker/train/model_trainer.py b/sagemaker-train/src/sagemaker/train/model_trainer.py index 0b56404438..a3fd052474 100644 --- a/sagemaker-train/src/sagemaker/train/model_trainer.py +++ b/sagemaker-train/src/sagemaker/train/model_trainer.py @@ -410,7 +410,7 @@ def __del__(self): self._temp_code_dir.cleanup() def _validate_training_image_and_algorithm_name( - self, training_image: Optional[str], algorithm_name: Optional[str] + self, training_image, algorithm_name ): """Validate that only one of 'training_image' or 'algorithm_name' is provided.""" from sagemaker.core.helper.pipeline_variable import PipelineVariable diff --git a/sagemaker-train/src/sagemaker/train/utils.py b/sagemaker-train/src/sagemaker/train/utils.py index 0abd7596b5..994933f151 100644 --- a/sagemaker-train/src/sagemaker/train/utils.py +++ b/sagemaker-train/src/sagemaker/train/utils.py @@ -142,7 +142,7 @@ def _get_unique_name(base, max_length=63): return unique_name -def _get_repo_name_from_image(image: str) -> str: +def _get_repo_name_from_image(image) -> str: """Get the repository name from the image URI. Example: @@ -152,11 +152,13 @@ def _get_repo_name_from_image(image: str) -> str: ``` Args: - image (str): The image URI + image: The image URI (str or PipelineVariable) Returns: - str: The repository name + str: The repository name, or None if image is not a plain string """ + if not isinstance(image, str): + return None return image.split("/")[-1].split(":")[0].split("@")[0]