Skip to content

Commit 3a78790

Browse files
committed
fix: [Bug] Pipeline parameters (ParameterInteger, ParameterString) fail in ModelTrain (5504)
1 parent ee420cc commit 3a78790

File tree

1 file changed

+116
-1
lines changed

1 file changed

+116
-1
lines changed

sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,15 @@
2626

2727
from sagemaker.core.helper.session_helper import Session
2828
from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar
29-
from sagemaker.core.workflow.parameters import ParameterString
29+
from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger, ParameterFloat
3030
from sagemaker.train.model_trainer import ModelTrainer, Mode
3131
from sagemaker.train.configs import (
3232
Compute,
3333
StoppingCondition,
3434
OutputDataConfig,
3535
)
3636
from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE
37+
from sagemaker.train.utils import safe_serialize
3738

3839

3940
DEFAULT_IMAGE = "000000000000.dkr.ecr.us-west-2.amazonaws.com/dummy-image:latest"
@@ -117,6 +118,120 @@ def test_environment_values_accept_parameter_string(self):
117118
assert trainer.environment["STATIC_VAR"] == "hello"
118119

119120

121+
class TestSafeSerializePipelineVariable:
122+
"""Test that safe_serialize correctly handles PipelineVariable objects (GH#5504)."""
123+
124+
def test_safe_serialize_returns_pipeline_variable_as_is(self):
125+
"""safe_serialize should return PipelineVariable objects without JSON serialization."""
126+
param = ParameterInteger(name="MaxDepth", default_value=5)
127+
result = safe_serialize(param)
128+
assert result is param
129+
130+
def test_safe_serialize_returns_parameter_string_as_is(self):
131+
"""safe_serialize should return ParameterString objects without JSON serialization."""
132+
param = ParameterString(name="Algorithm", default_value="xgboost")
133+
result = safe_serialize(param)
134+
assert result is param
135+
136+
def test_safe_serialize_returns_parameter_float_as_is(self):
137+
"""safe_serialize should return ParameterFloat objects without JSON serialization."""
138+
param = ParameterFloat(name="LearningRate", default_value=0.01)
139+
result = safe_serialize(param)
140+
assert result is param
141+
142+
def test_safe_serialize_still_handles_plain_string(self):
143+
"""safe_serialize should return plain strings as-is."""
144+
result = safe_serialize("hello")
145+
assert result == "hello"
146+
147+
def test_safe_serialize_still_handles_int(self):
148+
"""safe_serialize should JSON-encode integers."""
149+
result = safe_serialize(42)
150+
assert result == "42"
151+
152+
def test_safe_serialize_still_handles_dict(self):
153+
"""safe_serialize should JSON-encode dicts."""
154+
result = safe_serialize({"key": "value"})
155+
assert result == '{"key": "value"}'
156+
157+
158+
class TestModelTrainerHyperparametersPipelineVariable:
159+
"""Test that ModelTrainer hyperparameters accept PipelineVariable objects (GH#5504)."""
160+
161+
def test_hyperparameters_accept_parameter_integer_via_safe_serialize(self):
162+
"""ModelTrainer hyperparameters should accept ParameterInteger (GH#5504).
163+
164+
This is the exact bug scenario: ParameterInteger in hyperparameters
165+
caused TypeError in safe_serialize before the fix.
166+
"""
167+
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
168+
trainer = ModelTrainer(
169+
training_image=DEFAULT_IMAGE,
170+
role=DEFAULT_ROLE,
171+
compute=DEFAULT_COMPUTE,
172+
stopping_condition=DEFAULT_STOPPING,
173+
output_data_config=DEFAULT_OUTPUT,
174+
hyperparameters={"max_depth": max_depth},
175+
)
176+
assert trainer.hyperparameters["max_depth"] is max_depth
177+
178+
def test_hyperparameters_accept_parameter_string_via_safe_serialize(self):
179+
"""ModelTrainer hyperparameters should accept ParameterString (GH#5504)."""
180+
objective = ParameterString(name="Objective", default_value="reg:squarederror")
181+
trainer = ModelTrainer(
182+
training_image=DEFAULT_IMAGE,
183+
role=DEFAULT_ROLE,
184+
compute=DEFAULT_COMPUTE,
185+
stopping_condition=DEFAULT_STOPPING,
186+
output_data_config=DEFAULT_OUTPUT,
187+
hyperparameters={"objective": objective},
188+
)
189+
assert trainer.hyperparameters["objective"] is objective
190+
191+
def test_hyperparameters_accept_mixed_pipeline_and_plain_values(self):
192+
"""ModelTrainer hyperparameters should accept a mix of PipelineVariable and plain values."""
193+
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
194+
trainer = ModelTrainer(
195+
training_image=DEFAULT_IMAGE,
196+
role=DEFAULT_ROLE,
197+
compute=DEFAULT_COMPUTE,
198+
stopping_condition=DEFAULT_STOPPING,
199+
output_data_config=DEFAULT_OUTPUT,
200+
hyperparameters={
201+
"max_depth": max_depth,
202+
"eta": 0.1,
203+
"objective": "reg:squarederror",
204+
},
205+
)
206+
assert trainer.hyperparameters["max_depth"] is max_depth
207+
assert trainer.hyperparameters["eta"] == 0.1
208+
assert trainer.hyperparameters["objective"] == "reg:squarederror"
209+
210+
@patch("sagemaker.train.model_trainer._get_unique_name", return_value="test-job-20240101")
211+
def test_create_training_job_args_preserves_pipeline_variable_in_hyperparameters(
212+
self, mock_unique_name
213+
):
214+
"""_create_training_job_args should preserve PipelineVariable in hyper_parameters dict.
215+
216+
When safe_serialize is called on a PipelineVariable, it should return the
217+
PipelineVariable object as-is, not attempt JSON serialization.
218+
"""
219+
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
220+
trainer = ModelTrainer(
221+
training_image=DEFAULT_IMAGE,
222+
role=DEFAULT_ROLE,
223+
compute=DEFAULT_COMPUTE,
224+
stopping_condition=DEFAULT_STOPPING,
225+
output_data_config=DEFAULT_OUTPUT,
226+
hyperparameters={"max_depth": max_depth, "eta": 0.1},
227+
)
228+
args = trainer._create_training_job_args()
229+
# PipelineVariable should be preserved as-is by safe_serialize
230+
assert args["hyper_parameters"]["max_depth"] is max_depth
231+
# Plain values should be JSON-serialized to strings
232+
assert args["hyper_parameters"]["eta"] == "0.1"
233+
234+
120235
class TestModelTrainerRealValuesStillWork:
121236
"""Regression tests: verify that passing real values still works after the change."""
122237

0 commit comments

Comments
 (0)