1212# language governing permissions and limitations under the License.
1313"""Tests for PipelineVariable support in ModelTrainer.
1414
15- Verifies that ModelTrainer fields accept PipelineVariable objects
15+ Verify that ModelTrainer fields accept PipelineVariable objects
1616(e.g., ParameterString) in addition to their concrete types, following
1717the existing V3 pattern established by SourceCode and OutputDataConfig.
1818
19- Also verifies that safe_serialize correctly handles PipelineVariable objects
19+ Also verify that safe_serialize correctly handles PipelineVariable objects
2020in hyperparameters (returning them as-is instead of attempting json.dumps),
2121and that _create_training_job_args preserves PipelineVariable objects through
2222the serialization pipeline.
23+
24+ See: https://github.com/aws/sagemaker-python-sdk/issues/5504
2325"""
2426from __future__ import absolute_import
2527
5658)
5759
5860
59- @pytest .fixture (scope = "module" , autouse = True )
61+ @pytest .fixture (scope = "module" )
6062def modules_session ():
6163 with patch ("sagemaker.train.Session" , spec = Session ) as session_mock :
6264 session_instance = session_mock .return_value
@@ -72,7 +74,7 @@ class TestModelTrainerPipelineVariableAcceptance:
7274 """Test that ModelTrainer fields accept PipelineVariable objects."""
7375
7476 def test_training_image_accepts_parameter_string (self ):
75- """ModelTrainer.training_image should accept ParameterString (GH#5524 )."""
77+ """Verify ModelTrainer.training_image accepts ParameterString (GH#5504 )."""
7678 param = ParameterString (name = "TrainingImage" , default_value = DEFAULT_IMAGE )
7779 trainer = ModelTrainer (
7880 training_image = param ,
@@ -85,7 +87,7 @@ def test_training_image_accepts_parameter_string(self):
8587 assert trainer .training_image is param
8688
8789 def test_algorithm_name_accepts_parameter_string (self ):
88- """ModelTrainer.algorithm_name should accept ParameterString."""
90+ """Verify ModelTrainer.algorithm_name accepts ParameterString."""
8991 param = ParameterString (name = "AlgorithmName" , default_value = "my-algo-arn" )
9092 trainer = ModelTrainer (
9193 algorithm_name = param ,
@@ -98,7 +100,7 @@ def test_algorithm_name_accepts_parameter_string(self):
98100 assert trainer .algorithm_name is param
99101
100102 def test_training_input_mode_accepts_parameter_string (self ):
101- """ModelTrainer.training_input_mode should accept ParameterString."""
103+ """Verify ModelTrainer.training_input_mode accepts ParameterString."""
102104 param = ParameterString (name = "InputMode" , default_value = "File" )
103105 trainer = ModelTrainer (
104106 training_image = DEFAULT_IMAGE ,
@@ -111,7 +113,7 @@ def test_training_input_mode_accepts_parameter_string(self):
111113 assert trainer .training_input_mode is param
112114
113115 def test_environment_values_accept_parameter_string (self ):
114- """ModelTrainer.environment dict values should accept ParameterString."""
116+ """Verify ModelTrainer.environment dict values accept ParameterString."""
115117 param = ParameterString (name = "DatasetVersion" , default_value = "v1" )
116118 trainer = ModelTrainer (
117119 training_image = DEFAULT_IMAGE ,
@@ -129,7 +131,7 @@ class TestModelTrainerRealValuesStillWork:
129131 """Regression tests: verify that passing real values still works after the change."""
130132
131133 def test_training_image_accepts_real_string (self ):
132- """ModelTrainer.training_image should still accept a plain string."""
134+ """Verify ModelTrainer.training_image still accepts a plain string."""
133135 trainer = ModelTrainer (
134136 training_image = DEFAULT_IMAGE ,
135137 role = DEFAULT_ROLE ,
@@ -140,7 +142,7 @@ def test_training_image_accepts_real_string(self):
140142 assert trainer .training_image == DEFAULT_IMAGE
141143
142144 def test_algorithm_name_accepts_real_string (self ):
143- """ModelTrainer.algorithm_name should still accept a plain string."""
145+ """Verify ModelTrainer.algorithm_name still accepts a plain string."""
144146 trainer = ModelTrainer (
145147 algorithm_name = "arn:aws:sagemaker:us-west-2:000000000000:algorithm/my-algo" ,
146148 role = DEFAULT_ROLE ,
@@ -151,7 +153,7 @@ def test_algorithm_name_accepts_real_string(self):
151153 assert trainer .algorithm_name == "arn:aws:sagemaker:us-west-2:000000000000:algorithm/my-algo"
152154
153155 def test_training_input_mode_accepts_real_string (self ):
154- """ModelTrainer.training_input_mode should still accept a plain string."""
156+ """Verify ModelTrainer.training_input_mode still accepts a plain string."""
155157 trainer = ModelTrainer (
156158 training_image = DEFAULT_IMAGE ,
157159 training_input_mode = "Pipe" ,
@@ -163,7 +165,7 @@ def test_training_input_mode_accepts_real_string(self):
163165 assert trainer .training_input_mode == "Pipe"
164166
165167 def test_environment_accepts_real_string_values (self ):
166- """ModelTrainer.environment should still accept plain string values."""
168+ """Verify ModelTrainer.environment still accepts plain string values."""
167169 trainer = ModelTrainer (
168170 training_image = DEFAULT_IMAGE ,
169171 environment = {"KEY1" : "value1" , "KEY2" : "value2" },
@@ -175,7 +177,7 @@ def test_environment_accepts_real_string_values(self):
175177 assert trainer .environment == {"KEY1" : "value1" , "KEY2" : "value2" }
176178
177179 def test_training_image_rejects_invalid_type (self ):
178- """ModelTrainer.training_image should still reject invalid types (e.g., int)."""
180+ """Verify ModelTrainer.training_image still rejects invalid types (e.g., int)."""
179181 with pytest .raises (ValidationError ):
180182 ModelTrainer (
181183 training_image = 12345 ,
@@ -187,64 +189,46 @@ def test_training_image_rejects_invalid_type(self):
187189
188190
189191class TestSafeSerializeWithPipelineVariables :
190- """Tests that safe_serialize handles PipelineVariable objects correctly.
192+ """Verify that safe_serialize handles PipelineVariable objects correctly.
191193
192194 The safe_serialize function must return PipelineVariable objects as-is
193195 instead of attempting json.dumps(), which would raise TypeError.
196+ See: https://github.com/aws/sagemaker-python-sdk/issues/5504
194197 """
195198
196- def test_safe_serialize_with_parameter_integer_returns_pipeline_variable (self ):
197- """safe_serialize should return ParameterInteger as-is."""
198- param = ParameterInteger (name = "MaxDepth" , default_value = 5 )
199- result = safe_serialize (param )
200- assert result is param
201- assert isinstance (result , PipelineVariable )
202-
203- def test_safe_serialize_with_parameter_string_returns_pipeline_variable (self ):
204- """safe_serialize should return ParameterString as-is."""
205- param = ParameterString (name = "Optimizer" , default_value = "adam" )
199+ @pytest .mark .parametrize ("param" , [
200+ ParameterInteger (name = "MaxDepth" , default_value = 5 ),
201+ ParameterString (name = "Optimizer" , default_value = "adam" ),
202+ ParameterFloat (name = "LearningRate" , default_value = 0.01 ),
203+ ])
204+ def test_safe_serialize_returns_pipeline_variable_as_is (self , param ):
205+ """Verify safe_serialize returns PipelineVariable objects as-is."""
206206 result = safe_serialize (param )
207207 assert result is param
208208 assert isinstance (result , PipelineVariable )
209209
210- def test_safe_serialize_with_parameter_float_returns_pipeline_variable (self ):
211- """safe_serialize should return ParameterFloat as-is."""
212- param = ParameterFloat (name = "LearningRate" , default_value = 0.01 )
213- result = safe_serialize (param )
214- assert result is param
215- assert isinstance (result , PipelineVariable )
216-
217- def test_safe_serialize_still_handles_strings (self ):
218- """safe_serialize should return plain strings as-is (no quotes wrapping)."""
219- result = safe_serialize ("hello" )
220- assert result == "hello"
221-
222- def test_safe_serialize_still_handles_integers (self ):
223- """safe_serialize should JSON-encode integers."""
224- result = safe_serialize (42 )
225- assert result == "42"
226-
227- def test_safe_serialize_still_handles_dicts (self ):
228- """safe_serialize should JSON-encode dicts."""
229- result = safe_serialize ({"key" : "value" })
230- assert result == '{"key": "value"}'
231-
232- def test_safe_serialize_still_handles_floats (self ):
233- """safe_serialize should JSON-encode floats."""
234- result = safe_serialize (0.01 )
235- assert result == "0.01"
236-
237- def test_safe_serialize_still_handles_booleans (self ):
238- """safe_serialize should JSON-encode booleans."""
239- assert safe_serialize (True ) == "true"
240- assert safe_serialize (False ) == "false"
210+ @pytest .mark .parametrize ("input_val,expected" , [
211+ ("hello" , "hello" ),
212+ (42 , "42" ),
213+ ({"key" : "value" }, '{"key": "value"}' ),
214+ (0.01 , "0.01" ),
215+ (True , "true" ),
216+ (False , "false" ),
217+ ])
218+ def test_safe_serialize_handles_normal_types (self , input_val , expected ):
219+ """Verify safe_serialize correctly serializes normal (non-PipelineVariable) types."""
220+ result = safe_serialize (input_val )
221+ assert result == expected
241222
242223
243224class TestModelTrainerHyperparametersWithPipelineVariables :
244- """Tests that ModelTrainer accepts PipelineVariable objects in hyperparameters."""
225+ """Verify that ModelTrainer accepts PipelineVariable objects in hyperparameters.
226+
227+ See: https://github.com/aws/sagemaker-python-sdk/issues/5504
228+ """
245229
246230 def test_hyperparameters_accept_pipeline_variable_values (self ):
247- """ModelTrainer should accept PipelineVariable objects as hyperparameter values."""
231+ """Verify ModelTrainer accepts PipelineVariable objects as hyperparameter values."""
248232 max_depth = ParameterInteger (name = "MaxDepth" , default_value = 5 )
249233 learning_rate = ParameterFloat (name = "LearningRate" , default_value = 0.01 )
250234 optimizer = ParameterString (name = "Optimizer" , default_value = "adam" )
@@ -267,8 +251,10 @@ def test_hyperparameters_accept_pipeline_variable_values(self):
267251 assert trainer .hyperparameters ["optimizer" ] is optimizer
268252 assert trainer .hyperparameters ["static_param" ] == 10
269253
270- def test_create_training_job_args_with_pipeline_variable_hyperparameters (self ):
271- """_create_training_job_args should preserve PipelineVariable in hyper_parameters."""
254+ def test_create_training_job_args_with_pipeline_variable_hyperparameters (
255+ self , modules_session
256+ ):
257+ """Verify _create_training_job_args preserves PipelineVariable in hyper_parameters."""
272258 max_depth = ParameterInteger (name = "MaxDepth" , default_value = 5 )
273259 learning_rate = ParameterFloat (name = "LearningRate" , default_value = 0.01 )
274260
@@ -278,6 +264,7 @@ def test_create_training_job_args_with_pipeline_variable_hyperparameters(self):
278264 compute = DEFAULT_COMPUTE ,
279265 stopping_condition = DEFAULT_STOPPING ,
280266 output_data_config = DEFAULT_OUTPUT ,
267+ sagemaker_session = modules_session ,
281268 hyperparameters = {
282269 "max_depth" : max_depth ,
283270 "learning_rate" : learning_rate ,
0 commit comments