Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@

from sagemaker.core.helper.session_helper import Session
from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar
from sagemaker.core.workflow.parameters import ParameterString
from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger, ParameterFloat
from sagemaker.train.model_trainer import ModelTrainer, Mode
from sagemaker.train.configs import (
Compute,
StoppingCondition,
OutputDataConfig,
)
from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE
from sagemaker.train.utils import safe_serialize


DEFAULT_IMAGE = "000000000000.dkr.ecr.us-west-2.amazonaws.com/dummy-image:latest"
Expand Down Expand Up @@ -117,6 +118,121 @@ def test_environment_values_accept_parameter_string(self):
assert trainer.environment["STATIC_VAR"] == "hello"


class TestSafeSerializePipelineVariable:
"""Test that safe_serialize correctly handles PipelineVariable objects (GH#5504)."""

def test_safe_serialize_returns_pipeline_variable_as_is(self):
"""safe_serialize should return PipelineVariable objects without JSON serialization."""
param = ParameterInteger(name="MaxDepth", default_value=5)
result = safe_serialize(param)
assert result is param

def test_safe_serialize_returns_parameter_string_as_is(self):
"""safe_serialize should return ParameterString objects without JSON serialization."""
param = ParameterString(name="Algorithm", default_value="xgboost")
result = safe_serialize(param)
assert result is param

def test_safe_serialize_returns_parameter_float_as_is(self):
"""safe_serialize should return ParameterFloat objects without JSON serialization."""
param = ParameterFloat(name="LearningRate", default_value=0.01)
result = safe_serialize(param)
assert result is param

def test_safe_serialize_still_handles_plain_string(self):
"""safe_serialize should return plain strings as-is."""
result = safe_serialize("hello")
assert result == "hello"

def test_safe_serialize_still_handles_int(self):
"""safe_serialize should JSON-encode integers."""
result = safe_serialize(42)
assert result == "42"

def test_safe_serialize_still_handles_dict(self):
"""safe_serialize should JSON-encode dicts."""
result = safe_serialize({"key": "value"})
assert result == '{"key": "value"}'


class TestModelTrainerHyperparametersPipelineVariable:
"""Test that ModelTrainer hyperparameters accept PipelineVariable objects (GH#5504)."""

def test_hyperparameters_accept_parameter_integer(self):
"""ModelTrainer hyperparameters should accept ParameterInteger (GH#5504).

This is the exact bug scenario: ParameterInteger in hyperparameters
caused TypeError in safe_serialize before the fix.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test name includes via_safe_serialize but it's only testing that the ModelTrainer constructor stores the hyperparameter — it doesn't actually exercise safe_serialize. The safe_serialize call happens in _create_training_job_args(), which is tested separately below. Consider renaming to something like test_hyperparameters_accept_parameter_integer to avoid implying safe_serialize is being tested here.

"""
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
hyperparameters={"max_depth": max_depth},
)
assert trainer.hyperparameters["max_depth"] is max_depth

def test_hyperparameters_accept_parameter_string(self):
"""ModelTrainer hyperparameters should accept ParameterString (GH#5504)."""
objective = ParameterString(name="Objective", default_value="reg:squarederror")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same naming concern as above — via_safe_serialize is misleading since this test only checks constructor assignment, not serialization. Consider test_hyperparameters_accept_parameter_string.

trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
hyperparameters={"objective": objective},
)
assert trainer.hyperparameters["objective"] is objective

def test_hyperparameters_accept_mixed_pipeline_and_plain_values(self):
"""ModelTrainer hyperparameters should accept a mix of PipelineVariable and plain values."""
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
hyperparameters={
"max_depth": max_depth,
"eta": 0.1,
"objective": "reg:squarederror",
},
)
assert trainer.hyperparameters["max_depth"] is max_depth
assert trainer.hyperparameters["eta"] == 0.1
assert trainer.hyperparameters["objective"] == "reg:squarederror"

@patch("sagemaker.train.model_trainer._get_unique_name", return_value="test-job-20240101")
def test_create_training_job_args_preserves_pipeline_variable_in_hyperparameters(
self, mock_unique_name
):
"""_create_training_job_args should preserve PipelineVariable in hyper_parameters dict.

When safe_serialize is called on a PipelineVariable, it should return the
PipelineVariable object as-is, not attempt JSON serialization.
"""
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good test — this is the most valuable one in the PR as it actually exercises the code path (_create_training_job_argssafe_serialize) that was broken in the original bug report. Consider also asserting that args["hyper_parameters"]["max_depth"] is an instance of PipelineVariable (in addition to the identity check) for extra clarity:

assert isinstance(args["hyper_parameters"]["max_depth"], PipelineVariable)

role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
hyperparameters={"max_depth": max_depth, "eta": 0.1},
)
args = trainer._create_training_job_args()
# PipelineVariable should be preserved as-is by safe_serialize
assert args["hyper_parameters"]["max_depth"] is max_depth
assert isinstance(args["hyper_parameters"]["max_depth"], PipelineVariable)
# Plain values should be JSON-serialized to strings
assert args["hyper_parameters"]["eta"] == "0.1"


class TestModelTrainerRealValuesStillWork:
"""Regression tests: verify that passing real values still works after the change."""

Expand Down
Loading