Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions sagemaker-train/src/sagemaker/train/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@
from sagemaker.core.jumpstart.utils import get_eula_url
from sagemaker.train.defaults import TrainDefaults, JumpStartTrainDefaults
from sagemaker.core.workflow.pipeline_context import PipelineSession, runnable_by_pipeline
from sagemaker.core.helper.pipeline_variable import StrPipeVar
from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar

from sagemaker.train.local.local_container import _LocalContainer

Expand Down Expand Up @@ -410,14 +410,19 @@ def __del__(self):
self._temp_code_dir.cleanup()

def _validate_training_image_and_algorithm_name(
self, training_image: Optional[str], algorithm_name: Optional[str]
self,
training_image: "str | PipelineVariable | None",
algorithm_name: "str | PipelineVariable | None",
):
Comment thread
aviruthen marked this conversation as resolved.
"""Validate that only one of 'training_image' or 'algorithm_name' is provided."""
Comment thread
aviruthen marked this conversation as resolved.
if not training_image and not algorithm_name:
# PipelineVariables are truthy for validation purposes
has_image = isinstance(training_image, PipelineVariable) or bool(training_image)
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic isinstance(training_image, PipelineVariable) or bool(training_image) will raise TypeError if training_image is a PipelineVariable that doesn't support __bool__ — but since isinstance short-circuits via or, this is actually safe. However, consider simplifying to:

has_image = training_image is not None and training_image != ""
has_algo = algorithm_name is not None and algorithm_name != ""

This avoids calling bool() entirely and is more explicit about what "not provided" means (None or empty string). The is not None check naturally handles PipelineVariable objects correctly.

has_algo = isinstance(algorithm_name, PipelineVariable) or bool(algorithm_name)
if not has_image and not has_algo:
Comment thread
aviruthen marked this conversation as resolved.
raise ValueError(
"Atleast one of 'training_image' or 'algorithm_name' must be provided.",
)
if training_image and algorithm_name:
if has_image and has_algo:
raise ValueError(
"Only one of 'training_image' or 'algorithm_name' must be provided.",
)
Expand Down Expand Up @@ -546,7 +551,6 @@ def model_post_init(self, __context: Any):
)

if self.training_image:
Comment thread
aviruthen marked this conversation as resolved.
from sagemaker.core.helper.pipeline_variable import PipelineVariable
if isinstance(self.training_image, PipelineVariable):
logger.info("Training image URI: (PipelineVariable - resolved at pipeline execution)")
else:
Expand Down
8 changes: 6 additions & 2 deletions sagemaker-train/src/sagemaker/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from sagemaker.train import logger
from sagemaker.core.workflow.parameters import PipelineVariable

_PIPELINE_VARIABLE_IMAGE_PLACEHOLDER = "pipeline-variable-image"


def _default_bucket_and_prefix(session: Session) -> str:
"""Helper function to get the bucket name with the corresponding prefix if applicable
Expand Down Expand Up @@ -142,7 +144,7 @@ def _get_unique_name(base, max_length=63):
return unique_name


def _get_repo_name_from_image(image: str) -> str:
def _get_repo_name_from_image(image: "str | PipelineVariable") -> str:
"""Get the repository name from the image URI.
Comment thread
aviruthen marked this conversation as resolved.

Example:
Expand All @@ -152,11 +154,13 @@ def _get_repo_name_from_image(image: str) -> str:
```

Args:
image (str): The image URI
image (str or PipelineVariable): The image URI

Returns:
Comment thread
aviruthen marked this conversation as resolved.
str: The repository name
"""
if isinstance(image, PipelineVariable):
Comment thread
aviruthen marked this conversation as resolved.
return _PIPELINE_VARIABLE_IMAGE_PLACEHOLDER
return image.split("/")[-1].split(":")[0].split("@")[0]


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Tests for PipelineVariable support in ModelTrainer."""
from __future__ import absolute_import

Comment thread
aviruthen marked this conversation as resolved.
import pytest
from unittest.mock import MagicMock, patch

from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger
from sagemaker.core.helper.pipeline_variable import PipelineVariable
from sagemaker.train.utils import safe_serialize, _get_repo_name_from_image, _PIPELINE_VARIABLE_IMAGE_PLACEHOLDER

Comment thread
aviruthen marked this conversation as resolved.

class TestSafeSerializeWithPipelineVariable:
"""Tests for safe_serialize handling of PipelineVariable objects."""

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TestSafeSerializeWithPipelineVariable tests verify safe_serialize behavior with PipelineVariable, but the PR diff doesn't show any changes to safe_serialize. If safe_serialize already handled PipelineVariable correctly, these tests are documenting existing behavior (which is fine), but it would be good to note that in the test class docstring. If safe_serialize needed changes, those changes should be included in this PR.

def test_safe_serialize_string(self):
"""Test that plain strings are returned as-is."""
assert safe_serialize("hello") == "hello"

def test_safe_serialize_int(self):
"""Test that integers are JSON-serialized."""
assert safe_serialize(5) == "5"

def test_safe_serialize_float(self):
"""Test that floats are JSON-serialized."""
assert safe_serialize(3.14) == "3.14"

def test_safe_serialize_dict(self):
"""Test that dicts are JSON-serialized."""
result = safe_serialize({"key": "value"})
assert result == '{"key": "value"}'

def test_safe_serialize_pipeline_variable_parameter_string(self):
"""Test that ParameterString is returned as the PipelineVariable object itself."""
param = ParameterString(name="MyParam", default_value="test")
result = safe_serialize(param)
# Should return the PipelineVariable object, not raise TypeError
assert isinstance(result, PipelineVariable)
assert result is param

def test_safe_serialize_pipeline_variable_parameter_integer(self):
"""Test that ParameterInteger is returned as the PipelineVariable object itself."""
param = ParameterInteger(name="MaxDepth", default_value=5)
result = safe_serialize(param)
# Should return the PipelineVariable object, not raise TypeError
assert isinstance(result, PipelineVariable)
assert result is param


class TestGetRepoNameFromImage:
"""Tests for _get_repo_name_from_image handling of PipelineVariable objects."""

def test_get_repo_name_from_image_string(self):
"""Test that a normal image URI returns the repo name."""
image = "683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3"
result = _get_repo_name_from_image(image)
assert result == "sagemaker-xgboost"

def test_get_repo_name_from_image_pipeline_variable(self):
"""Test that a PipelineVariable returns the placeholder constant."""
param = ParameterString(name="TrainingImage", default_value="some-image")
result = _get_repo_name_from_image(param)
assert result == _PIPELINE_VARIABLE_IMAGE_PLACEHOLDER

def test_get_repo_name_from_image_simple_string(self):
"""Test with a simple image name."""
result = _get_repo_name_from_image("my-repo:latest")
assert result == "my-repo"

def test_get_repo_name_from_image_with_digest(self):
"""Test with an image URI containing a digest."""
image = "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-repo@sha256:abc123"
result = _get_repo_name_from_image(image)
assert result == "my-repo"


class TestModelTrainerValidationWithPipelineVariable:
"""Tests for ModelTrainer validation with PipelineVariable objects."""

@patch("sagemaker.train.model_trainer.TrainDefaults")
def test_training_image_accepts_parameter_string(self, mock_defaults):
"""Test that training_image accepts ParameterString."""
from sagemaker.train.model_trainer import ModelTrainer
from sagemaker.train.configs import Compute

mock_session = MagicMock()
mock_session.boto_region_name = "us-east-1"
mock_session.default_bucket.return_value = "my-bucket"
mock_session.default_bucket_prefix = None
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mock session and mock_defaults setup is duplicated across 4 test methods (test_training_image_accepts_parameter_string, test_algorithm_name_accepts_parameter_string, test_environment_values_accept_parameter_string, test_plain_string_values_still_work). Extract this into a @pytest.fixture to reduce duplication and improve maintainability:

@pytest.fixture
def mock_session():
    session = MagicMock()
    session.boto_region_name = "us-east-1"
    session.default_bucket.return_value = "my-bucket"
    session.default_bucket_prefix = None
    return session

And similarly for the mock_defaults patching.


mock_defaults.get_sagemaker_session.return_value = mock_session
mock_defaults.get_role.return_value = "arn:aws:iam::123456789012:role/SageMakerRole"
mock_defaults.get_base_job_name.return_value = "test-job"
mock_defaults.get_compute.return_value = Compute(
instance_type="ml.m5.xlarge", instance_count=1
)
mock_defaults.get_stopping_condition.return_value = MagicMock()
mock_defaults.get_output_data_config.return_value = MagicMock()

param = ParameterString(name="TrainingImage", default_value="some-image-uri")

# Should not raise
trainer = ModelTrainer(
training_image=param,
compute=Compute(instance_type="ml.m5.xlarge", instance_count=1),
sagemaker_session=mock_session,
role="arn:aws:iam::123456789012:role/SageMakerRole",
)
assert trainer.training_image is param

@patch("sagemaker.train.model_trainer.TrainDefaults")
def test_algorithm_name_accepts_parameter_string(self, mock_defaults):
"""Test that algorithm_name accepts ParameterString."""
from sagemaker.train.model_trainer import ModelTrainer
from sagemaker.train.configs import Compute

mock_session = MagicMock()
mock_session.boto_region_name = "us-east-1"
mock_session.default_bucket.return_value = "my-bucket"
mock_session.default_bucket_prefix = None

mock_defaults.get_sagemaker_session.return_value = mock_session
mock_defaults.get_role.return_value = "arn:aws:iam::123456789012:role/SageMakerRole"
mock_defaults.get_base_job_name.return_value = "test-job"
mock_defaults.get_compute.return_value = Compute(
instance_type="ml.m5.xlarge", instance_count=1
)
mock_defaults.get_stopping_condition.return_value = MagicMock()
mock_defaults.get_output_data_config.return_value = MagicMock()

param = ParameterString(name="AlgorithmName", default_value="some-algo")

# Should not raise
trainer = ModelTrainer(
algorithm_name=param,
compute=Compute(instance_type="ml.m5.xlarge", instance_count=1),
sagemaker_session=mock_session,
role="arn:aws:iam::123456789012:role/SageMakerRole",
)
assert trainer.algorithm_name is param

@patch("sagemaker.train.model_trainer.TrainDefaults")
def test_environment_values_accept_parameter_string(self, mock_defaults):
"""Test that environment dict values accept ParameterString."""
Comment thread
aviruthen marked this conversation as resolved.
from sagemaker.train.model_trainer import ModelTrainer
from sagemaker.train.configs import Compute

mock_session = MagicMock()
mock_session.boto_region_name = "us-east-1"
mock_session.default_bucket.return_value = "my-bucket"
mock_session.default_bucket_prefix = None

mock_defaults.get_sagemaker_session.return_value = mock_session
mock_defaults.get_role.return_value = "arn:aws:iam::123456789012:role/SageMakerRole"
mock_defaults.get_base_job_name.return_value = "test-job"
mock_defaults.get_compute.return_value = Compute(
instance_type="ml.m5.xlarge", instance_count=1
)
mock_defaults.get_stopping_condition.return_value = MagicMock()
mock_defaults.get_output_data_config.return_value = MagicMock()

env_param = ParameterString(name="EnvValue", default_value="val")

# Should not raise
trainer = ModelTrainer(
training_image="683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3",
compute=Compute(instance_type="ml.m5.xlarge", instance_count=1),
sagemaker_session=mock_session,
role="arn:aws:iam::123456789012:role/SageMakerRole",
environment={"MY_VAR": env_param},
)
assert trainer.environment["MY_VAR"] is env_param

@patch("sagemaker.train.model_trainer.TrainDefaults")
def test_plain_string_values_still_work(self, mock_defaults):
"""Regression test: plain string values continue to work."""
from sagemaker.train.model_trainer import ModelTrainer
from sagemaker.train.configs import Compute

mock_session = MagicMock()
mock_session.boto_region_name = "us-east-1"
mock_session.default_bucket.return_value = "my-bucket"
mock_session.default_bucket_prefix = None

mock_defaults.get_sagemaker_session.return_value = mock_session
Comment thread
aviruthen marked this conversation as resolved.
Outdated
mock_defaults.get_role.return_value = "arn:aws:iam::123456789012:role/SageMakerRole"
mock_defaults.get_base_job_name.return_value = "test-job"
mock_defaults.get_compute.return_value = Compute(
instance_type="ml.m5.xlarge", instance_count=1
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a test case for _validate_training_image_and_algorithm_name where one argument is a PipelineVariable and the other is None — this is the primary success case the fix enables. The current tests test_training_image_accepts_parameter_string and test_algorithm_name_accepts_parameter_string test this indirectly through full ModelTrainer construction, but a direct unit test of the validation method (like the rejection tests at lines 199-238) would be more focused and faster.

)
mock_defaults.get_stopping_condition.return_value = MagicMock()
mock_defaults.get_output_data_config.return_value = MagicMock()

# Should not raise
trainer = ModelTrainer(
training_image="683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3",
compute=Compute(instance_type="ml.m5.xlarge", instance_count=1),
sagemaker_session=mock_session,
role="arn:aws:iam::123456789012:role/SageMakerRole",
)
assert trainer.training_image == "683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3"

def test_validation_rejects_no_image_or_algorithm(self):
"""Test that validation rejects when neither training_image nor algorithm_name is provided."""
from sagemaker.train.model_trainer import ModelTrainer

trainer = ModelTrainer.__new__(ModelTrainer)
with pytest.raises(ValueError, match="Atleast one of"):
trainer._validate_training_image_and_algorithm_name(None, None)

def test_validation_rejects_both_image_and_algorithm(self):
"""Test that validation rejects when both training_image and algorithm_name are provided."""
from sagemaker.train.model_trainer import ModelTrainer

trainer = ModelTrainer.__new__(ModelTrainer)
with pytest.raises(ValueError, match="Only one of"):
trainer._validate_training_image_and_algorithm_name("image", "algo")

def test_validation_rejects_both_pipeline_variables(self):
"""Test that validation rejects when both are PipelineVariables."""
from sagemaker.train.model_trainer import ModelTrainer

trainer = ModelTrainer.__new__(ModelTrainer)
img_param = ParameterString(name="Image", default_value="img")
algo_param = ParameterString(name="Algo", default_value="algo")
with pytest.raises(ValueError, match="Only one of"):
trainer._validate_training_image_and_algorithm_name(img_param, algo_param)
Loading