Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 120 additions & 14 deletions sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,18 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Tests for PipelineVariable support in ModelTrainer (GH#5524).
"""Tests for PipelineVariable support in ModelTrainer.
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add back the GitHub issue number in this particular comment? You can also add it back on line 19 where it previously said "See: aws#5524"


Verifies that ModelTrainer fields accept PipelineVariable objects
Verify that ModelTrainer fields accept PipelineVariable objects
(e.g., ParameterString) in addition to their concrete types, following
the existing V3 pattern established by SourceCode and OutputDataConfig.

See: https://github.com/aws/sagemaker-python-sdk/issues/5524
Also verify that safe_serialize correctly handles PipelineVariable objects
in hyperparameters (returning them as-is instead of attempting json.dumps),
and that _create_training_job_args preserves PipelineVariable objects through
the serialization pipeline.

See: https://github.com/aws/sagemaker-python-sdk/issues/5504
"""
from __future__ import absolute_import

Expand All @@ -26,13 +31,18 @@

from sagemaker.core.helper.session_helper import Session
from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar
from sagemaker.core.workflow.parameters import ParameterString
from sagemaker.core.workflow.parameters import (
ParameterString,
ParameterInteger,
ParameterFloat,
)
from sagemaker.train.model_trainer import ModelTrainer, Mode
from sagemaker.train.configs import (
Compute,
StoppingCondition,
OutputDataConfig,
)
from sagemaker.train.utils import safe_serialize
from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE


Expand All @@ -48,7 +58,7 @@
)


@pytest.fixture(scope="module", autouse=True)
@pytest.fixture(scope="module")
def modules_session():
with patch("sagemaker.train.Session", spec=Session) as session_mock:
session_instance = session_mock.return_value
Expand All @@ -64,7 +74,7 @@ class TestModelTrainerPipelineVariableAcceptance:
"""Test that ModelTrainer fields accept PipelineVariable objects."""

def test_training_image_accepts_parameter_string(self):
"""ModelTrainer.training_image should accept ParameterString (GH#5524)."""
"""Verify ModelTrainer.training_image accepts ParameterString (GH#5504)."""
param = ParameterString(name="TrainingImage", default_value=DEFAULT_IMAGE)
trainer = ModelTrainer(
training_image=param,
Expand All @@ -77,7 +87,7 @@ def test_training_image_accepts_parameter_string(self):
assert trainer.training_image is param

def test_algorithm_name_accepts_parameter_string(self):
"""ModelTrainer.algorithm_name should accept ParameterString."""
"""Verify ModelTrainer.algorithm_name accepts ParameterString."""
param = ParameterString(name="AlgorithmName", default_value="my-algo-arn")
trainer = ModelTrainer(
algorithm_name=param,
Expand All @@ -90,7 +100,7 @@ def test_algorithm_name_accepts_parameter_string(self):
assert trainer.algorithm_name is param

def test_training_input_mode_accepts_parameter_string(self):
"""ModelTrainer.training_input_mode should accept ParameterString."""
"""Verify ModelTrainer.training_input_mode accepts ParameterString."""
param = ParameterString(name="InputMode", default_value="File")
trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
Expand All @@ -103,7 +113,7 @@ def test_training_input_mode_accepts_parameter_string(self):
assert trainer.training_input_mode is param

def test_environment_values_accept_parameter_string(self):
"""ModelTrainer.environment dict values should accept ParameterString."""
"""Verify ModelTrainer.environment dict values accept ParameterString."""
param = ParameterString(name="DatasetVersion", default_value="v1")
trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
Expand All @@ -121,7 +131,7 @@ class TestModelTrainerRealValuesStillWork:
"""Regression tests: verify that passing real values still works after the change."""

def test_training_image_accepts_real_string(self):
"""ModelTrainer.training_image should still accept a plain string."""
"""Verify ModelTrainer.training_image still accepts a plain string."""
trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
Expand All @@ -132,7 +142,7 @@ def test_training_image_accepts_real_string(self):
assert trainer.training_image == DEFAULT_IMAGE

def test_algorithm_name_accepts_real_string(self):
"""ModelTrainer.algorithm_name should still accept a plain string."""
"""Verify ModelTrainer.algorithm_name still accepts a plain string."""
trainer = ModelTrainer(
algorithm_name="arn:aws:sagemaker:us-west-2:000000000000:algorithm/my-algo",
role=DEFAULT_ROLE,
Expand All @@ -143,7 +153,7 @@ def test_algorithm_name_accepts_real_string(self):
assert trainer.algorithm_name == "arn:aws:sagemaker:us-west-2:000000000000:algorithm/my-algo"

def test_training_input_mode_accepts_real_string(self):
"""ModelTrainer.training_input_mode should still accept a plain string."""
"""Verify ModelTrainer.training_input_mode still accepts a plain string."""
trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
training_input_mode="Pipe",
Expand All @@ -155,7 +165,7 @@ def test_training_input_mode_accepts_real_string(self):
assert trainer.training_input_mode == "Pipe"

def test_environment_accepts_real_string_values(self):
"""ModelTrainer.environment should still accept plain string values."""
"""Verify ModelTrainer.environment still accepts plain string values."""
trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
environment={"KEY1": "value1", "KEY2": "value2"},
Expand All @@ -167,7 +177,7 @@ def test_environment_accepts_real_string_values(self):
assert trainer.environment == {"KEY1": "value1", "KEY2": "value2"}

def test_training_image_rejects_invalid_type(self):
"""ModelTrainer.training_image should still reject invalid types (e.g., int)."""
"""Verify ModelTrainer.training_image still rejects invalid types (e.g., int)."""
with pytest.raises(ValidationError):
ModelTrainer(
training_image=12345,
Expand All @@ -176,3 +186,99 @@ def test_training_image_rejects_invalid_type(self):
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
)


class TestSafeSerializeWithPipelineVariables:
"""Verify that safe_serialize handles PipelineVariable objects correctly.

The safe_serialize function must return PipelineVariable objects as-is
instead of attempting json.dumps(), which would raise TypeError.
See: https://github.com/aws/sagemaker-python-sdk/issues/5504
"""

@pytest.mark.parametrize("param", [
ParameterInteger(name="MaxDepth", default_value=5),
ParameterString(name="Optimizer", default_value="adam"),
ParameterFloat(name="LearningRate", default_value=0.01),
])
def test_safe_serialize_returns_pipeline_variable_as_is(self, param):
"""Verify safe_serialize returns PipelineVariable objects as-is."""
result = safe_serialize(param)
assert result is param
assert isinstance(result, PipelineVariable)

@pytest.mark.parametrize("input_val,expected", [
("hello", "hello"),
(42, "42"),
({"key": "value"}, '{"key": "value"}'),
(0.01, "0.01"),
(True, "true"),
(False, "false"),
])
def test_safe_serialize_handles_normal_types(self, input_val, expected):
"""Verify safe_serialize correctly serializes normal (non-PipelineVariable) types."""
result = safe_serialize(input_val)
assert result == expected


class TestModelTrainerHyperparametersWithPipelineVariables:
"""Verify that ModelTrainer accepts PipelineVariable objects in hyperparameters.

See: https://github.com/aws/sagemaker-python-sdk/issues/5504
"""

def test_hyperparameters_accept_pipeline_variable_values(self):
"""Verify ModelTrainer accepts PipelineVariable objects as hyperparameter values."""
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
learning_rate = ParameterFloat(name="LearningRate", default_value=0.01)
optimizer = ParameterString(name="Optimizer", default_value="adam")

trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
hyperparameters={
"max_depth": max_depth,
"learning_rate": learning_rate,
"optimizer": optimizer,
"static_param": 10,
},
)
assert trainer.hyperparameters["max_depth"] is max_depth
assert trainer.hyperparameters["learning_rate"] is learning_rate
assert trainer.hyperparameters["optimizer"] is optimizer
assert trainer.hyperparameters["static_param"] == 10

def test_create_training_job_args_with_pipeline_variable_hyperparameters(
self, modules_session
):
"""Verify _create_training_job_args preserves PipelineVariable in hyper_parameters."""
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
learning_rate = ParameterFloat(name="LearningRate", default_value=0.01)

trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
sagemaker_session=modules_session,
hyperparameters={
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test calls trainer._create_training_job_args() which likely makes internal calls (e.g., to session, STS, or other AWS services). Unlike the other tests in this file, there's no mocking of the session or internal dependencies. If _create_training_job_args() doesn't require a session for this code path, that's fine, but if it does, this test will fail or make real API calls.

Please verify this test passes in isolation (pytest tests/unit/train/test_model_trainer_pipeline_variable.py::TestModelTrainerHyperparametersWithPipelineVariables::test_create_training_job_args_with_pipeline_variable_hyperparameters -v) and consider mocking the sagemaker_session if needed, similar to how other _create_training_job_args tests in the codebase handle it.

"max_depth": max_depth,
"learning_rate": learning_rate,
"epochs": 10,
"verbose": "true",
},
)

training_args = trainer._create_training_job_args()
hyper_params = training_args["hyper_parameters"]

# PipelineVariable objects should be preserved as-is by safe_serialize
assert hyper_params["max_depth"] is max_depth
assert hyper_params["learning_rate"] is learning_rate
# Regular values should be serialized to strings
assert hyper_params["epochs"] == "10"
assert hyper_params["verbose"] == "true"
Loading