@@ -158,7 +158,7 @@ def test_safe_serialize_still_handles_dict(self):
158158class TestModelTrainerHyperparametersPipelineVariable :
159159 """Test that ModelTrainer hyperparameters accept PipelineVariable objects (GH#5504)."""
160160
161- def test_hyperparameters_accept_parameter_integer_via_safe_serialize (self ):
161+ def test_hyperparameters_accept_parameter_integer (self ):
162162 """ModelTrainer hyperparameters should accept ParameterInteger (GH#5504).
163163
164164 This is the exact bug scenario: ParameterInteger in hyperparameters
@@ -175,7 +175,7 @@ def test_hyperparameters_accept_parameter_integer_via_safe_serialize(self):
175175 )
176176 assert trainer .hyperparameters ["max_depth" ] is max_depth
177177
178- def test_hyperparameters_accept_parameter_string_via_safe_serialize (self ):
178+ def test_hyperparameters_accept_parameter_string (self ):
179179 """ModelTrainer hyperparameters should accept ParameterString (GH#5504)."""
180180 objective = ParameterString (name = "Objective" , default_value = "reg:squarederror" )
181181 trainer = ModelTrainer (
@@ -228,6 +228,7 @@ def test_create_training_job_args_preserves_pipeline_variable_in_hyperparameters
228228 args = trainer ._create_training_job_args ()
229229 # PipelineVariable should be preserved as-is by safe_serialize
230230 assert args ["hyper_parameters" ]["max_depth" ] is max_depth
231+ assert isinstance (args ["hyper_parameters" ]["max_depth" ], PipelineVariable )
231232 # Plain values should be JSON-serialized to strings
232233 assert args ["hyper_parameters" ]["eta" ] == "0.1"
233234
0 commit comments