Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions sagemaker-train/src/sagemaker/train/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,14 +410,18 @@ def __del__(self):
self._temp_code_dir.cleanup()

def _validate_training_image_and_algorithm_name(
self, training_image: Optional[str], algorithm_name: Optional[str]
self, training_image, algorithm_name
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing type annotations: The type annotations for training_image and algorithm_name were removed entirely. Per SDK coding standards (PEP 484), all public/private methods must retain type annotations. Since these parameters now accept both str and PipelineVariable, please use the appropriate union type:

def _validate_training_image_and_algorithm_name(
    self, training_image: str | PipelineVariable | None, algorithm_name: str | PipelineVariable | None
):

Or if StrPipeVar is already defined as a type alias in the codebase, use that.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing test file in the diff: The PR description references test_model_trainer_pipeline_variable.py but this file is not included in the changed files. Please ensure the test file is included in the PR. Without tests, we cannot verify the fix works or guard against regressions.

):
Comment thread
aviruthen marked this conversation as resolved.
"""Validate that only one of 'training_image' or 'algorithm_name' is provided."""
Comment thread
aviruthen marked this conversation as resolved.
if not training_image and not algorithm_name:
from sagemaker.core.helper.pipeline_variable import PipelineVariable as _PV
# PipelineVariables are truthy for validation purposes
has_image = isinstance(training_image, _PV) or bool(training_image)
has_algo = isinstance(algorithm_name, _PV) or bool(algorithm_name)
if not has_image and not has_algo:
Comment thread
aviruthen marked this conversation as resolved.
raise ValueError(
"Atleast one of 'training_image' or 'algorithm_name' must be provided.",
)
if training_image and algorithm_name:
if has_image and has_algo:
raise ValueError(
"Only one of 'training_image' or 'algorithm_name' must be provided.",
)
Expand Down
6 changes: 4 additions & 2 deletions sagemaker-train/src/sagemaker/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _get_unique_name(base, max_length=63):
return unique_name


def _get_repo_name_from_image(image: str) -> str:
def _get_repo_name_from_image(image) -> str:
"""Get the repository name from the image URI.
Comment thread
aviruthen marked this conversation as resolved.

Example:
Expand All @@ -152,11 +152,13 @@ def _get_repo_name_from_image(image: str) -> str:
```

Args:
image (str): The image URI
image: The image URI (str or PipelineVariable)
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed type annotation: Same issue here — the image parameter lost its type annotation. Please restore it with the correct union type:

def _get_repo_name_from_image(image: str | PipelineVariable) -> str:


Returns:
Comment thread
aviruthen marked this conversation as resolved.
str: The repository name
"""
if isinstance(image, PipelineVariable):
Comment thread
aviruthen marked this conversation as resolved.
return "pipeline-variable-image"
return image.split("/")[-1].split(":")[0].split("@")[0]


Expand Down
Loading