Skip to content

Commit 41eb3f5

Browse files
committed
fix: [Bug] Pipeline parameters (ParameterInteger, ParameterString) fail in ModelTrain (5504)
1 parent 6497a94 commit 41eb3f5

1 file changed

Lines changed: 91 additions & 1 deletion

File tree

sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@
2626

2727
from sagemaker.core.helper.session_helper import Session
2828
from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar
29-
from sagemaker.core.workflow.parameters import ParameterString
29+
from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger, ParameterFloat
3030
from sagemaker.train.model_trainer import ModelTrainer, Mode
31+
from sagemaker.train.utils import safe_serialize
3132
from sagemaker.train.configs import (
3233
Compute,
3334
StoppingCondition,
@@ -116,6 +117,54 @@ def test_environment_values_accept_parameter_string(self):
116117
assert trainer.environment["DATASET_VERSION"] is param
117118
assert trainer.environment["STATIC_VAR"] == "hello"
118119

120+
def test_hyperparameters_accept_parameter_integer(self):
121+
"""ModelTrainer.hyperparameters should accept ParameterInteger values (GH#5504)."""
122+
param = ParameterInteger(name="MaxDepth", default_value=5)
123+
trainer = ModelTrainer(
124+
training_image=DEFAULT_IMAGE,
125+
hyperparameters={"max_depth": param},
126+
role=DEFAULT_ROLE,
127+
compute=DEFAULT_COMPUTE,
128+
stopping_condition=DEFAULT_STOPPING,
129+
output_data_config=DEFAULT_OUTPUT,
130+
)
131+
assert trainer.hyperparameters["max_depth"] is param
132+
133+
def test_hyperparameters_accept_parameter_string(self):
134+
"""ModelTrainer.hyperparameters should accept ParameterString values (GH#5504)."""
135+
param = ParameterString(name="LearningRate", default_value="0.01")
136+
trainer = ModelTrainer(
137+
training_image=DEFAULT_IMAGE,
138+
hyperparameters={"learning_rate": param},
139+
role=DEFAULT_ROLE,
140+
compute=DEFAULT_COMPUTE,
141+
stopping_condition=DEFAULT_STOPPING,
142+
output_data_config=DEFAULT_OUTPUT,
143+
)
144+
assert trainer.hyperparameters["learning_rate"] is param
145+
146+
def test_hyperparameters_accept_mixed_pipeline_and_static_values(self):
147+
"""ModelTrainer.hyperparameters should accept a mix of PipelineVariable and static values."""
148+
param_int = ParameterInteger(name="MaxDepth", default_value=5)
149+
param_str = ParameterString(name="Objective", default_value="reg:squarederror")
150+
trainer = ModelTrainer(
151+
training_image=DEFAULT_IMAGE,
152+
hyperparameters={
153+
"max_depth": param_int,
154+
"objective": param_str,
155+
"eta": 0.1,
156+
"num_round": 100,
157+
},
158+
role=DEFAULT_ROLE,
159+
compute=DEFAULT_COMPUTE,
160+
stopping_condition=DEFAULT_STOPPING,
161+
output_data_config=DEFAULT_OUTPUT,
162+
)
163+
assert trainer.hyperparameters["max_depth"] is param_int
164+
assert trainer.hyperparameters["objective"] is param_str
165+
assert trainer.hyperparameters["eta"] == 0.1
166+
assert trainer.hyperparameters["num_round"] == 100
167+
119168

120169
class TestModelTrainerRealValuesStillWork:
121170
"""Regression tests: verify that passing real values still works after the change."""
@@ -176,3 +225,44 @@ def test_training_image_rejects_invalid_type(self):
176225
stopping_condition=DEFAULT_STOPPING,
177226
output_data_config=DEFAULT_OUTPUT,
178227
)
228+
229+
230+
class TestSafeSerializePipelineVariable:
231+
"""Test that safe_serialize correctly handles PipelineVariable objects (GH#5504).
232+
233+
The bug was that safe_serialize would call json.dumps() on PipelineVariable objects,
234+
causing TypeError. The fix returns PipelineVariable objects as-is.
235+
"""
236+
237+
def test_safe_serialize_with_parameter_integer(self):
238+
"""safe_serialize should return ParameterInteger as-is, not attempt JSON serialization."""
239+
param = ParameterInteger(name="MaxDepth", default_value=5)
240+
result = safe_serialize(param)
241+
assert result is param
242+
243+
def test_safe_serialize_with_parameter_string(self):
244+
"""safe_serialize should return ParameterString as-is, not attempt JSON serialization."""
245+
param = ParameterString(name="Objective", default_value="reg:squarederror")
246+
result = safe_serialize(param)
247+
assert result is param
248+
249+
def test_safe_serialize_with_parameter_float(self):
250+
"""safe_serialize should return ParameterFloat as-is, not attempt JSON serialization."""
251+
param = ParameterFloat(name="LearningRate", default_value=0.01)
252+
result = safe_serialize(param)
253+
assert result is param
254+
255+
def test_safe_serialize_with_plain_string(self):
256+
"""safe_serialize should return plain strings unchanged."""
257+
result = safe_serialize("hello")
258+
assert result == "hello"
259+
260+
def test_safe_serialize_with_int(self):
261+
"""safe_serialize should JSON-dump integers to their string representation."""
262+
result = safe_serialize(5)
263+
assert result == "5"
264+
265+
def test_safe_serialize_with_dict(self):
266+
"""safe_serialize should JSON-dump dicts to their string representation."""
267+
result = safe_serialize({"key": "value"})
268+
assert result == '{"key": "value"}'

0 commit comments

Comments
 (0)