Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@

from sagemaker.core.helper.session_helper import Session
from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar
from sagemaker.core.workflow.parameters import ParameterString
from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger, ParameterFloat
from sagemaker.train.model_trainer import ModelTrainer, Mode
from sagemaker.train.utils import safe_serialize
from sagemaker.train.configs import (
Compute,
StoppingCondition,
Expand Down Expand Up @@ -116,6 +117,54 @@ def test_environment_values_accept_parameter_string(self):
assert trainer.environment["DATASET_VERSION"] is param
assert trainer.environment["STATIC_VAR"] == "hello"

def test_hyperparameters_accept_parameter_integer(self):
"""ModelTrainer.hyperparameters should accept ParameterInteger values (GH#5504)."""
param = ParameterInteger(name="MaxDepth", default_value=5)
trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
hyperparameters={"max_depth": param},
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
)
assert trainer.hyperparameters["max_depth"] is param

def test_hyperparameters_accept_parameter_string(self):
"""ModelTrainer.hyperparameters should accept ParameterString values (GH#5504)."""
param = ParameterString(name="LearningRate", default_value="0.01")
trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
hyperparameters={"learning_rate": param},
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
)
assert trainer.hyperparameters["learning_rate"] is param

def test_hyperparameters_accept_mixed_pipeline_and_static_values(self):
"""ModelTrainer.hyperparameters should accept a mix of PipelineVariable and static values."""
param_int = ParameterInteger(name="MaxDepth", default_value=5)
param_str = ParameterString(name="Objective", default_value="reg:squarederror")
trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
hyperparameters={
"max_depth": param_int,
"objective": param_str,
"eta": 0.1,
"num_round": 100,
},
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
)
assert trainer.hyperparameters["max_depth"] is param_int
assert trainer.hyperparameters["objective"] is param_str
assert trainer.hyperparameters["eta"] == 0.1
assert trainer.hyperparameters["num_round"] == 100


class TestModelTrainerRealValuesStillWork:
"""Regression tests: verify that passing real values still works after the change."""
Expand Down Expand Up @@ -176,3 +225,44 @@ def test_training_image_rejects_invalid_type(self):
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
)


class TestSafeSerializePipelineVariable:
"""Test that safe_serialize correctly handles PipelineVariable objects (GH#5504).

The bug was that safe_serialize would call json.dumps() on PipelineVariable objects,
causing TypeError. The fix returns PipelineVariable objects as-is.
"""

def test_safe_serialize_with_parameter_integer(self):
"""safe_serialize should return ParameterInteger as-is, not attempt JSON serialization."""
param = ParameterInteger(name="MaxDepth", default_value=5)
result = safe_serialize(param)
assert result is param

def test_safe_serialize_with_parameter_string(self):
"""safe_serialize should return ParameterString as-is, not attempt JSON serialization."""
param = ParameterString(name="Objective", default_value="reg:squarederror")
result = safe_serialize(param)
assert result is param

def test_safe_serialize_with_parameter_float(self):
"""safe_serialize should return ParameterFloat as-is, not attempt JSON serialization."""
param = ParameterFloat(name="LearningRate", default_value=0.01)
result = safe_serialize(param)
assert result is param

def test_safe_serialize_with_plain_string(self):
"""safe_serialize should return plain strings unchanged."""
result = safe_serialize("hello")
assert result == "hello"

def test_safe_serialize_with_int(self):
"""safe_serialize should JSON-dump integers to their string representation."""
result = safe_serialize(5)
assert result == "5"

def test_safe_serialize_with_dict(self):
"""safe_serialize should JSON-dump dicts to their string representation."""
result = safe_serialize({"key": "value"})
assert result == '{"key": "value"}'
Loading