Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions sagemaker-train/src/sagemaker/train/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,14 +410,16 @@ def __del__(self):
self._temp_code_dir.cleanup()

def _validate_training_image_and_algorithm_name(
Comment thread
aviruthen marked this conversation as resolved.
self, training_image: Optional[str], algorithm_name: Optional[str]
self, training_image, algorithm_name
):
"""Validate that only one of 'training_image' or 'algorithm_name' is provided."""
if not training_image and not algorithm_name:
has_image = training_image is not None
has_algo = algorithm_name is not None
if not has_image and not has_algo:
raise ValueError(
"Atleast one of 'training_image' or 'algorithm_name' must be provided.",
)
if training_image and algorithm_name:
if has_image and has_algo:
raise ValueError(
"Only one of 'training_image' or 'algorithm_name' must be provided.",
)
Expand Down
10 changes: 7 additions & 3 deletions sagemaker-train/src/sagemaker/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _get_unique_name(base, max_length=63):
return unique_name


def _get_repo_name_from_image(image: str) -> str:
def _get_repo_name_from_image(image) -> str:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue — don't remove the type annotation, update it:

def _get_repo_name_from_image(image: str | PipelineVariable) -> str | None:

Note the return type should also be updated to str | None since you now return None for PipelineVariable inputs.

"""Get the repository name from the image URI.

Example:
Expand All @@ -152,11 +152,15 @@ def _get_repo_name_from_image(image: str) -> str:
```

Args:
image (str): The image URI
image: The image URI (str or PipelineVariable)

Returns:
str: The repository name
str: The repository name, or None if image is a PipelineVariable
"""
from sagemaker.core.helper.pipeline_variable import PipelineVariable

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move the import to the top of the module (or at minimum to the top of the function). Inline imports inside functions are acceptable for avoiding circular dependencies, but please add a comment explaining why it's done here:

# Import here to avoid circular dependency
from sagemaker.core.helper.pipeline_variable import PipelineVariable

Also, is there actually a circular dependency risk? If not, this import should be at the module level with other imports.

if isinstance(image, PipelineVariable):
return None
Comment thread
aviruthen marked this conversation as resolved.
return image.split("/")[-1].split(":")[0].split("@")[0]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
StoppingCondition,
OutputDataConfig,
)
from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE
from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE, TrainDefaults
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TrainDefaults is imported but never used in the diff. Is this import used elsewhere in the file (not shown in the diff)? If not, this is an unused import that will fail linting. If it IS used in existing code not shown in the diff, please disregard this comment.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test changes seem insufficient for the scope of the fix. The PR modifies validation logic in _validate_training_image_and_algorithm_name and adds PipelineVariable handling in _get_repo_name_from_image, but the test diff only shows an import change. Where are the new test cases that:

  1. Pass a PipelineVariable as training_image and verify validation passes?
  2. Pass a PipelineVariable to _get_repo_name_from_image and verify it returns None?
  3. Verify that both PipelineVariable for training_image AND algorithm_name still raises ValueError?
  4. Verify that None for both still raises ValueError?

Please add explicit unit tests for the changed behavior. Target >90% coverage per SDK standards.



DEFAULT_IMAGE = "000000000000.dkr.ecr.us-west-2.amazonaws.com/dummy-image:latest"
Expand Down
Loading