diff --git a/sagemaker-train/src/sagemaker/train/defaults.py b/sagemaker-train/src/sagemaker/train/defaults.py index 3c36b817ec..359828004e 100644 --- a/sagemaker-train/src/sagemaker/train/defaults.py +++ b/sagemaker-train/src/sagemaker/train/defaults.py @@ -78,15 +78,21 @@ def get_role( @staticmethod def get_base_job_name( base_job_name: Optional[str] = None, - algorithm_name: Optional[str] = None, - training_image: Optional[str] = None, + algorithm_name=None, + training_image=None, ) -> str: """Get the default base job name.""" if base_job_name is None: - if algorithm_name: + if algorithm_name and isinstance(algorithm_name, str): base_job_name = f"{algorithm_name}-job" elif training_image: - base_job_name = f"{_get_repo_name_from_image(training_image)}-job" + repo_name = _get_repo_name_from_image(training_image) + if repo_name: + base_job_name = f"{repo_name}-job" + else: + base_job_name = "training-job" + if base_job_name is None: + base_job_name = "training-job" logger.info(f"Base name not provided. Using default name:\n{base_job_name}") return base_job_name diff --git a/sagemaker-train/src/sagemaker/train/model_trainer.py b/sagemaker-train/src/sagemaker/train/model_trainer.py index d07edeb025..a3fd052474 100644 --- a/sagemaker-train/src/sagemaker/train/model_trainer.py +++ b/sagemaker-train/src/sagemaker/train/model_trainer.py @@ -410,14 +410,22 @@ def __del__(self): self._temp_code_dir.cleanup() def _validate_training_image_and_algorithm_name( - self, training_image: Optional[str], algorithm_name: Optional[str] + self, training_image, algorithm_name ): """Validate that only one of 'training_image' or 'algorithm_name' is provided.""" - if not training_image and not algorithm_name: + from sagemaker.core.helper.pipeline_variable import PipelineVariable + + has_training_image = training_image is not None and ( + isinstance(training_image, PipelineVariable) or bool(training_image) + ) + has_algorithm_name = algorithm_name is not None and ( + isinstance(algorithm_name, PipelineVariable) or bool(algorithm_name) + ) + if not has_training_image and not has_algorithm_name: raise ValueError( "Atleast one of 'training_image' or 'algorithm_name' must be provided.", ) - if training_image and algorithm_name: + if has_training_image and has_algorithm_name: raise ValueError( "Only one of 'training_image' or 'algorithm_name' must be provided.", ) diff --git a/sagemaker-train/src/sagemaker/train/utils.py b/sagemaker-train/src/sagemaker/train/utils.py index 0abd7596b5..994933f151 100644 --- a/sagemaker-train/src/sagemaker/train/utils.py +++ b/sagemaker-train/src/sagemaker/train/utils.py @@ -142,7 +142,7 @@ def _get_unique_name(base, max_length=63): return unique_name -def _get_repo_name_from_image(image: str) -> str: +def _get_repo_name_from_image(image) -> str: """Get the repository name from the image URI. Example: @@ -152,11 +152,13 @@ def _get_repo_name_from_image(image: str) -> str: ``` Args: - image (str): The image URI + image: The image URI (str or PipelineVariable) Returns: - str: The repository name + str: The repository name, or None if image is not a plain string """ + if not isinstance(image, str): + return None return image.split("/")[-1].split(":")[0].split("@")[0]