Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions sagemaker-train/src/sagemaker/train/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,21 @@ def get_role(
@staticmethod
def get_base_job_name(
base_job_name: Optional[str] = None,
algorithm_name: Optional[str] = None,
training_image: Optional[str] = None,
algorithm_name=None,
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type annotations removed instead of broadened. Removing the Optional[str] annotations silently drops type safety. Instead, broaden the type hints to accept PipelineVariable:

from sagemaker.core.helper.pipeline_variable import PipelineVariable

@staticmethod
def get_base_job_name(
    base_job_name: Optional[str] = None,
    algorithm_name: str | PipelineVariable | None = None,
    training_image: str | PipelineVariable | None = None,
) -> str:

Also, the PR description says this method was updated to "skip PipelineVariable values when deriving the base job name," but the diff shows no such logic was added. The body of get_base_job_name likely calls something like base_name_from_image(training_image) or uses algorithm_name as a string — passing a PipelineVariable there will still fail. Please add the guard logic to skip PipelineVariable values and fall back to a generic default name.

training_image=None,
) -> str:
"""Get the default base job name."""
if base_job_name is None:
if algorithm_name:
if algorithm_name and isinstance(algorithm_name, str):
base_job_name = f"{algorithm_name}-job"
elif training_image:
base_job_name = f"{_get_repo_name_from_image(training_image)}-job"
repo_name = _get_repo_name_from_image(training_image)
if repo_name:
base_job_name = f"{repo_name}-job"
else:
base_job_name = "training-job"
if base_job_name is None:
base_job_name = "training-job"
logger.info(f"Base name not provided. Using default name:\n{base_job_name}")
return base_job_name

Expand Down
14 changes: 11 additions & 3 deletions sagemaker-train/src/sagemaker/train/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,14 +410,22 @@ def __del__(self):
self._temp_code_dir.cleanup()

def _validate_training_image_and_algorithm_name(
self, training_image: Optional[str], algorithm_name: Optional[str]
self, training_image, algorithm_name
):
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type annotations on the method signature should be broadened to reflect that PipelineVariable is now accepted. Currently the signature still says Optional[str] for both parameters, which is misleading:

def _validate_training_image_and_algorithm_name(
    self, training_image: str | PipelineVariable | None, algorithm_name: str | PipelineVariable | None
):

"""Validate that only one of 'training_image' or 'algorithm_name' is provided."""
if not training_image and not algorithm_name:
from sagemaker.core.helper.pipeline_variable import PipelineVariable
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move import to module level. Placing the import inside the method means it runs on every call. Unless there's a circular import issue, prefer a top-of-file import:

from sagemaker.core.helper.pipeline_variable import PipelineVariable

If there is a circular dependency, please add a comment explaining why the local import is necessary.


has_training_image = training_image is not None and (
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplify the truthiness check. The is not None and (isinstance(...) or bool(...)) pattern is a bit convoluted. Since PipelineVariable.__str__() raises TypeError, the only problematic path is Python's implicit bool() / truthiness. You can simplify:

def _is_provided(value):
    if value is None:
        return False
    if isinstance(value, PipelineVariable):
        return True
    return bool(value)

has_training_image = _is_provided(training_image)
has_algorithm_name = _is_provided(algorithm_name)

This is easier to read and self-documenting. Consider adding a brief comment explaining why PipelineVariable needs special handling (its __str__ raises TypeError).

isinstance(training_image, PipelineVariable) or bool(training_image)
)
has_algorithm_name = algorithm_name is not None and (
isinstance(algorithm_name, PipelineVariable) or bool(algorithm_name)
)
if not has_training_image and not has_algorithm_name:
raise ValueError(
"Atleast one of 'training_image' or 'algorithm_name' must be provided.",
)
if training_image and algorithm_name:
if has_training_image and has_algorithm_name:
raise ValueError(
"Only one of 'training_image' or 'algorithm_name' must be provided.",
)
Expand Down
8 changes: 5 additions & 3 deletions sagemaker-train/src/sagemaker/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _get_unique_name(base, max_length=63):
return unique_name


def _get_repo_name_from_image(image: str) -> str:
def _get_repo_name_from_image(image) -> str:
"""Get the repository name from the image URI.

Example:
Expand All @@ -152,11 +152,13 @@ def _get_repo_name_from_image(image: str) -> str:
```

Args:
image (str): The image URI
image: The image URI (str or PipelineVariable)

Returns:
str: The repository name
str: The repository name, or None if image is not a plain string
"""
if not isinstance(image, str):
return None
return image.split("/")[-1].split(":")[0].split("@")[0]


Expand Down
Loading