Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions sagemaker-train/src/sagemaker/train/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@
from sagemaker.core.jumpstart.utils import get_eula_url
from sagemaker.train.defaults import TrainDefaults, JumpStartTrainDefaults
from sagemaker.core.workflow.pipeline_context import PipelineSession, runnable_by_pipeline
from sagemaker.core.helper.pipeline_variable import StrPipeVar
from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar

from sagemaker.train.local.local_container import _LocalContainer

Expand Down Expand Up @@ -410,14 +410,21 @@ def __del__(self):
self._temp_code_dir.cleanup()

def _validate_training_image_and_algorithm_name(
self, training_image: Optional[str], algorithm_name: Optional[str]
self,
training_image: Optional[StrPipeVar],
algorithm_name: Optional[StrPipeVar],
):
Comment thread
aviruthen marked this conversation as resolved.
"""Validate that only one of 'training_image' or 'algorithm_name' is provided."""
Comment thread
aviruthen marked this conversation as resolved.
if not training_image and not algorithm_name:
# PipelineVariable objects do not support standard boolean coercion
# (__bool__ raises TypeError), so we use isinstance checks to detect
# them as truthy values during validation.
has_image = isinstance(training_image, PipelineVariable) or bool(training_image)
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic isinstance(training_image, PipelineVariable) or bool(training_image) will raise TypeError if training_image is a PipelineVariable that doesn't support __bool__ — but since isinstance short-circuits via or, this is actually safe. However, consider simplifying to:

has_image = training_image is not None and training_image != ""
has_algo = algorithm_name is not None and algorithm_name != ""

This avoids calling bool() entirely and is more explicit about what "not provided" means (None or empty string). The is not None check naturally handles PipelineVariable objects correctly.

has_algo = isinstance(algorithm_name, PipelineVariable) or bool(algorithm_name)
if not has_image and not has_algo:
Comment thread
aviruthen marked this conversation as resolved.
raise ValueError(
"Atleast one of 'training_image' or 'algorithm_name' must be provided.",
)
if training_image and algorithm_name:
if has_image and has_algo:
raise ValueError(
"Only one of 'training_image' or 'algorithm_name' must be provided.",
)
Expand Down Expand Up @@ -546,9 +553,11 @@ def model_post_init(self, __context: Any):
)

if self.training_image:
Comment thread
aviruthen marked this conversation as resolved.
from sagemaker.core.helper.pipeline_variable import PipelineVariable
if isinstance(self.training_image, PipelineVariable):
logger.info("Training image URI: (PipelineVariable - resolved at pipeline execution)")
logger.info(
"Training image URI: "
"(PipelineVariable - resolved at pipeline execution)"
)
else:
logger.info(f"Training image URI: {self.training_image}")

Expand Down
10 changes: 8 additions & 2 deletions sagemaker-train/src/sagemaker/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,15 @@
from datetime import datetime
from typing import Literal, Any

from typing import Union
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate import: Union is imported here from typing, but there's already a from typing import Literal, Any on line 24. Consolidate into a single import statement:

from typing import Literal, Any, Union

Also, since the module already imports PipelineVariable from sagemaker.core.workflow.parameters on line 30, and from __future__ import annotations is not present, consider adding it to enable PEP 604 union syntax (str | PipelineVariable) per SDK conventions.


from sagemaker.core.helper.session_helper import Session
from sagemaker.core.shapes import Unassigned
from sagemaker.train import logger
from sagemaker.core.workflow.parameters import PipelineVariable

_PIPELINE_VARIABLE_IMAGE_PLACEHOLDER = "pipeline-variable-image"


def _default_bucket_and_prefix(session: Session) -> str:
"""Helper function to get the bucket name with the corresponding prefix if applicable
Expand Down Expand Up @@ -142,7 +146,7 @@ def _get_unique_name(base, max_length=63):
return unique_name


def _get_repo_name_from_image(image: str) -> str:
def _get_repo_name_from_image(image: Union[str, PipelineVariable]) -> str:
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type annotation says -> str but when a PipelineVariable is passed, it returns a string placeholder, so the annotation is technically correct. However, consider documenting in the docstring that the placeholder _PIPELINE_VARIABLE_IMAGE_PLACEHOLDER is returned for PipelineVariable inputs, so downstream callers understand the behavior.

"""Get the repository name from the image URI.
Comment thread
aviruthen marked this conversation as resolved.

Example:
Expand All @@ -152,11 +156,13 @@ def _get_repo_name_from_image(image: str) -> str:
```

Args:
image (str): The image URI
image (str or PipelineVariable): The image URI.

Returns:
Comment thread
aviruthen marked this conversation as resolved.
str: The repository name
"""
if isinstance(image, PipelineVariable):
Comment thread
aviruthen marked this conversation as resolved.
return _PIPELINE_VARIABLE_IMAGE_PLACEHOLDER
return image.split("/")[-1].split(":")[0].split("@")[0]


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Tests for PipelineVariable support in ModelTrainer."""
from __future__ import annotations

Comment thread
aviruthen marked this conversation as resolved.
import pytest
from unittest.mock import MagicMock, patch

from sagemaker.core.workflow.parameters import (
ParameterString,
ParameterInteger,
)
from sagemaker.core.helper.pipeline_variable import PipelineVariable
from sagemaker.train.utils import (
safe_serialize,
_get_repo_name_from_image,
_PIPELINE_VARIABLE_IMAGE_PLACEHOLDER,
)

_TEST_IMAGE_URI = (
"683313688378.dkr.ecr.us-east-1.amazonaws.com/"
"sagemaker-xgboost:1.0-1-cpu-py3"
)
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This test image URI contains a hardcoded region (us-east-1) and account ID (683313688378). While this is acceptable for unit tests since it's just a string constant and not used to make actual API calls, consider using a clearly fake account ID (e.g., 123456789012) for consistency with the mock session fixture below.


Comment thread
aviruthen marked this conversation as resolved.

class TestSafeSerializeWithPipelineVariable:
"""Tests for safe_serialize handling of PipelineVariable objects."""

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TestSafeSerializeWithPipelineVariable tests verify safe_serialize behavior with PipelineVariable, but the PR diff doesn't show any changes to safe_serialize. If safe_serialize already handled PipelineVariable correctly, these tests are documenting existing behavior (which is fine), but it would be good to note that in the test class docstring. If safe_serialize needed changes, those changes should be included in this PR.

def test_safe_serialize_string(self):
"""Test that plain strings are returned as-is."""
assert safe_serialize("hello") == "hello"

def test_safe_serialize_int(self):
"""Test that integers are JSON-serialized."""
assert safe_serialize(5) == "5"

def test_safe_serialize_float(self):
"""Test that floats are JSON-serialized."""
assert safe_serialize(3.14) == "3.14"

def test_safe_serialize_dict(self):
"""Test that dicts are JSON-serialized."""
result = safe_serialize({"key": "value"})
assert result == '{"key": "value"}'

def test_safe_serialize_pipeline_variable_parameter_string(self):
"""Test that ParameterString is returned as the PipelineVariable object itself."""
param = ParameterString(name="MyParam", default_value="test")
result = safe_serialize(param)
# Should return the PipelineVariable object, not raise TypeError
assert isinstance(result, PipelineVariable)
assert result is param

def test_safe_serialize_pipeline_variable_parameter_integer(self):
"""Test that ParameterInteger is returned as the PipelineVariable object itself."""
param = ParameterInteger(name="MaxDepth", default_value=5)
result = safe_serialize(param)
# Should return the PipelineVariable object, not raise TypeError
assert isinstance(result, PipelineVariable)
assert result is param


class TestGetRepoNameFromImage:
"""Tests for _get_repo_name_from_image handling of PipelineVariable objects."""

def test_get_repo_name_from_image_string(self):
"""Test that a normal image URI returns the repo name."""
result = _get_repo_name_from_image(_TEST_IMAGE_URI)
assert result == "sagemaker-xgboost"

def test_get_repo_name_from_image_pipeline_variable(self):
"""Test that a PipelineVariable returns the placeholder constant."""
param = ParameterString(
name="TrainingImage", default_value="some-image"
)
result = _get_repo_name_from_image(param)
assert result == _PIPELINE_VARIABLE_IMAGE_PLACEHOLDER

def test_get_repo_name_from_image_simple_string(self):
"""Test with a simple image name."""
result = _get_repo_name_from_image("my-repo:latest")
assert result == "my-repo"

def test_get_repo_name_from_image_with_digest(self):
"""Test with an image URI containing a digest."""
image = (
"123456789012.dkr.ecr.us-west-2.amazonaws.com/"
"my-repo@sha256:abc123"
)
result = _get_repo_name_from_image(image)
assert result == "my-repo"


@pytest.fixture
def mock_session():
"""Create a mock SageMaker session."""
session = MagicMock()
session.boto_region_name = "us-east-1"
session.default_bucket.return_value = "my-bucket"
session.default_bucket_prefix = None
return session


@pytest.fixture
def mock_train_defaults():
"""Patch TrainDefaults for ModelTrainer construction."""
with patch("sagemaker.train.model_trainer.TrainDefaults") as mock_defaults:
from sagemaker.train.configs import Compute

mock_defaults.get_sagemaker_session.return_value = MagicMock()
mock_defaults.get_role.return_value = (
"arn:aws:iam::123456789012:role/SageMakerRole"
)
mock_defaults.get_base_job_name.return_value = "test-job"
mock_defaults.get_compute.return_value = Compute(
instance_type="ml.m5.xlarge", instance_count=1
)
mock_defaults.get_stopping_condition.return_value = MagicMock()
mock_defaults.get_output_data_config.return_value = MagicMock()
yield mock_defaults


class TestModelTrainerValidationWithPipelineVariable:
"""Tests for ModelTrainer validation with PipelineVariable objects."""

def test_training_image_accepts_parameter_string(
self, mock_session, mock_train_defaults
):
"""Test that training_image accepts ParameterString."""
from sagemaker.train.model_trainer import ModelTrainer
from sagemaker.train.configs import Compute

param = ParameterString(
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The from sagemaker.train.model_trainer import ModelTrainer import is repeated inside every test method in this class. Move it to the top of the file with the other imports. Inline imports in tests add unnecessary noise and are not consistent with SDK test conventions.

name="TrainingImage", default_value="some-image-uri"
)

# Should not raise
trainer = ModelTrainer(
training_image=param,
compute=Compute(
instance_type="ml.m5.xlarge", instance_count=1
),
sagemaker_session=mock_session,
role="arn:aws:iam::123456789012:role/SageMakerRole",
)
assert trainer.training_image is param

def test_algorithm_name_accepts_parameter_string(
self, mock_session, mock_train_defaults
):
"""Test that algorithm_name accepts ParameterString."""
from sagemaker.train.model_trainer import ModelTrainer
from sagemaker.train.configs import Compute

param = ParameterString(
name="AlgorithmName", default_value="some-algo"
)

# Should not raise
trainer = ModelTrainer(
algorithm_name=param,
compute=Compute(
instance_type="ml.m5.xlarge", instance_count=1
),
sagemaker_session=mock_session,
role="arn:aws:iam::123456789012:role/SageMakerRole",
)
assert trainer.algorithm_name is param

def test_environment_values_accept_parameter_string(
self, mock_session, mock_train_defaults
):
"""Test that environment dict values accept ParameterString."""
Comment thread
aviruthen marked this conversation as resolved.
from sagemaker.train.model_trainer import ModelTrainer
from sagemaker.train.configs import Compute

env_param = ParameterString(
name="EnvValue", default_value="val"
)

# Should not raise
trainer = ModelTrainer(
training_image=_TEST_IMAGE_URI,
compute=Compute(
instance_type="ml.m5.xlarge", instance_count=1
),
sagemaker_session=mock_session,
role="arn:aws:iam::123456789012:role/SageMakerRole",
environment={"MY_VAR": env_param},
)
assert trainer.environment["MY_VAR"] is env_param

def test_plain_string_values_still_work(
self, mock_session, mock_train_defaults
):
"""Regression test: plain string values continue to work."""
from sagemaker.train.model_trainer import ModelTrainer
from sagemaker.train.configs import Compute

# Should not raise
trainer = ModelTrainer(
training_image=_TEST_IMAGE_URI,
compute=Compute(
instance_type="ml.m5.xlarge", instance_count=1
),
sagemaker_session=mock_session,
role="arn:aws:iam::123456789012:role/SageMakerRole",
)
assert trainer.training_image == _TEST_IMAGE_URI

def test_validation_accepts_pipeline_variable_image_none_algo(self):
"""Test validation accepts PipelineVariable image with None algorithm."""
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using ModelTrainer.__new__(ModelTrainer) to bypass __init__ and directly test the validation method is fragile — it creates an uninitialized object. If _validate_training_image_and_algorithm_name ever accesses self attributes, these tests will break with confusing errors. Consider either:

  1. Making _validate_training_image_and_algorithm_name a @staticmethod (it doesn't use self), or
  2. Using the existing mock_train_defaults fixture to construct a proper instance and test through the public interface.

from sagemaker.train.model_trainer import ModelTrainer

trainer = ModelTrainer.__new__(ModelTrainer)
param = ParameterString(
name="Image", default_value="img"
)
# Should not raise
trainer._validate_training_image_and_algorithm_name(
param, None
)

def test_validation_accepts_none_image_pipeline_variable_algo(self):
"""Test validation accepts None image with PipelineVariable algorithm."""
from sagemaker.train.model_trainer import ModelTrainer

trainer = ModelTrainer.__new__(ModelTrainer)
param = ParameterString(
name="Algo", default_value="algo"
)
# Should not raise
trainer._validate_training_image_and_algorithm_name(
None, param
)

def test_validation_rejects_no_image_or_algorithm(self):
"""Test that validation rejects when neither is provided."""
from sagemaker.train.model_trainer import ModelTrainer

trainer = ModelTrainer.__new__(ModelTrainer)
with pytest.raises(ValueError, match="Atleast one of"):
trainer._validate_training_image_and_algorithm_name(
None, None
)

def test_validation_rejects_both_image_and_algorithm(self):
"""Test that validation rejects when both are provided."""
from sagemaker.train.model_trainer import ModelTrainer

trainer = ModelTrainer.__new__(ModelTrainer)
with pytest.raises(ValueError, match="Only one of"):
trainer._validate_training_image_and_algorithm_name(
"image", "algo"
)

def test_validation_rejects_both_pipeline_variables(self):
"""Test that validation rejects when both are PipelineVariables."""
from sagemaker.train.model_trainer import ModelTrainer

trainer = ModelTrainer.__new__(ModelTrainer)
img_param = ParameterString(
name="Image", default_value="img"
)
algo_param = ParameterString(
name="Algo", default_value="algo"
)
with pytest.raises(ValueError, match="Only one of"):
trainer._validate_training_image_and_algorithm_name(
img_param, algo_param
)
Loading