2626
2727from sagemaker .core .helper .session_helper import Session
2828from sagemaker .core .helper .pipeline_variable import PipelineVariable , StrPipeVar
29- from sagemaker .core .workflow .parameters import ParameterString , ParameterInteger , ParameterFloat
29+ from sagemaker .core .workflow .parameters import ParameterString , ParameterInteger
3030from sagemaker .train .model_trainer import ModelTrainer , Mode
3131from sagemaker .train .configs import (
3232 Compute ,
3333 StoppingCondition ,
3434 OutputDataConfig ,
3535)
36- from sagemaker .core .workflow .pipeline_context import PipelineSession
3736from sagemaker .train .defaults import DEFAULT_INSTANCE_TYPE
3837
3938
@@ -180,10 +179,10 @@ def test_training_image_rejects_invalid_type(self):
180179
181180
182181class TestModelTrainerHyperparametersPipelineVariable :
183- """Test that PipelineVariable objects in hyperparameters survive safe_serialize ."""
182+ """Test that ModelTrainer correctly preserves PipelineVariable objects in hyperparameters."""
184183
185- def test_hyperparameters_with_pipeline_variable_integer (self ):
186- """ParameterInteger in hyperparameters should be passed through as-is ."""
184+ def test_hyperparameters_preserves_pipeline_variable_integer (self ):
185+ """ParameterInteger in hyperparameters should be preserved in ModelTrainer ."""
187186 max_depth = ParameterInteger (name = "MaxDepth" , default_value = 5 )
188187 trainer = ModelTrainer (
189188 training_image = DEFAULT_IMAGE ,
@@ -193,13 +192,10 @@ def test_hyperparameters_with_pipeline_variable_integer(self):
193192 output_data_config = DEFAULT_OUTPUT ,
194193 hyperparameters = {"max_depth" : max_depth },
195194 )
196- # safe_serialize should return the PipelineVariable object directly
197- from sagemaker .train .utils import safe_serialize
198- result = safe_serialize (max_depth )
199- assert result is max_depth
195+ assert trainer .hyperparameters ["max_depth" ] is max_depth
200196
201- def test_hyperparameters_with_pipeline_variable_string (self ):
202- """ParameterString in hyperparameters should be passed through as-is ."""
197+ def test_hyperparameters_preserves_pipeline_variable_string (self ):
198+ """ParameterString in hyperparameters should be preserved in ModelTrainer ."""
203199 optimizer = ParameterString (name = "Optimizer" , default_value = "sgd" )
204200 trainer = ModelTrainer (
205201 training_image = DEFAULT_IMAGE ,
@@ -209,12 +205,10 @@ def test_hyperparameters_with_pipeline_variable_string(self):
209205 output_data_config = DEFAULT_OUTPUT ,
210206 hyperparameters = {"optimizer" : optimizer },
211207 )
212- from sagemaker .train .utils import safe_serialize
213- result = safe_serialize (optimizer )
214- assert result is optimizer
208+ assert trainer .hyperparameters ["optimizer" ] is optimizer
215209
216- def test_hyperparameters_with_mixed_pipeline_and_regular_values (self ):
217- """Mixed PipelineVariable and regular values should both serialize correctly ."""
210+ def test_hyperparameters_preserves_mixed_pipeline_and_regular_values (self ):
211+ """Mixed PipelineVariable and regular values should all be preserved ."""
218212 max_depth = ParameterInteger (name = "MaxDepth" , default_value = 5 )
219213 trainer = ModelTrainer (
220214 training_image = DEFAULT_IMAGE ,
@@ -228,10 +222,8 @@ def test_hyperparameters_with_mixed_pipeline_and_regular_values(self):
228222 "objective" : "binary:logistic" ,
229223 },
230224 )
231- from sagemaker .train .utils import safe_serialize
232- # PipelineVariable should be returned as-is
233- assert safe_serialize (max_depth ) is max_depth
234- # Float should be JSON-serialized
235- assert safe_serialize (0.1 ) == "0.1"
236- # String should be returned as-is
237- assert safe_serialize ("binary:logistic" ) == "binary:logistic"
225+ # PipelineVariable should be preserved as-is
226+ assert trainer .hyperparameters ["max_depth" ] is max_depth
227+ # Regular values should also be preserved
228+ assert trainer .hyperparameters ["eta" ] == 0.1
229+ assert trainer .hyperparameters ["objective" ] == "binary:logistic"
0 commit comments