Skip to content

Commit aae602b

Browse files
committed
fix: [Bug] Pipeline parameters (ParameterInteger, ParameterString) fail in ModelTrain (5504)
1 parent ee420cc commit aae602b

File tree

1 file changed

+122
-3
lines changed

1 file changed

+122
-3
lines changed

sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py

Lines changed: 122 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,16 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Tests for PipelineVariable support in ModelTrainer (GH#5524).
13+
"""Tests for PipelineVariable support in ModelTrainer.
1414
1515
Verifies that ModelTrainer fields accept PipelineVariable objects
1616
(e.g., ParameterString) in addition to their concrete types, following
1717
the existing V3 pattern established by SourceCode and OutputDataConfig.
1818
19-
See: https://github.com/aws/sagemaker-python-sdk/issues/5524
19+
Also verifies that safe_serialize correctly handles PipelineVariable objects
20+
in hyperparameters (returning them as-is instead of attempting json.dumps),
21+
and that _create_training_job_args preserves PipelineVariable objects through
22+
the serialization pipeline.
2023
"""
2124
from __future__ import absolute_import
2225

@@ -26,13 +29,18 @@
2629

2730
from sagemaker.core.helper.session_helper import Session
2831
from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar
29-
from sagemaker.core.workflow.parameters import ParameterString
32+
from sagemaker.core.workflow.parameters import (
33+
ParameterString,
34+
ParameterInteger,
35+
ParameterFloat,
36+
)
3037
from sagemaker.train.model_trainer import ModelTrainer, Mode
3138
from sagemaker.train.configs import (
3239
Compute,
3340
StoppingCondition,
3441
OutputDataConfig,
3542
)
43+
from sagemaker.train.utils import safe_serialize
3644
from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE
3745

3846

@@ -176,3 +184,114 @@ def test_training_image_rejects_invalid_type(self):
176184
stopping_condition=DEFAULT_STOPPING,
177185
output_data_config=DEFAULT_OUTPUT,
178186
)
187+
188+
189+
class TestSafeSerializeWithPipelineVariables:
190+
"""Tests that safe_serialize handles PipelineVariable objects correctly.
191+
192+
The safe_serialize function must return PipelineVariable objects as-is
193+
instead of attempting json.dumps(), which would raise TypeError.
194+
"""
195+
196+
def test_safe_serialize_with_parameter_integer_returns_pipeline_variable(self):
197+
"""safe_serialize should return ParameterInteger as-is."""
198+
param = ParameterInteger(name="MaxDepth", default_value=5)
199+
result = safe_serialize(param)
200+
assert result is param
201+
assert isinstance(result, PipelineVariable)
202+
203+
def test_safe_serialize_with_parameter_string_returns_pipeline_variable(self):
204+
"""safe_serialize should return ParameterString as-is."""
205+
param = ParameterString(name="Optimizer", default_value="adam")
206+
result = safe_serialize(param)
207+
assert result is param
208+
assert isinstance(result, PipelineVariable)
209+
210+
def test_safe_serialize_with_parameter_float_returns_pipeline_variable(self):
211+
"""safe_serialize should return ParameterFloat as-is."""
212+
param = ParameterFloat(name="LearningRate", default_value=0.01)
213+
result = safe_serialize(param)
214+
assert result is param
215+
assert isinstance(result, PipelineVariable)
216+
217+
def test_safe_serialize_still_handles_strings(self):
218+
"""safe_serialize should return plain strings as-is (no quotes wrapping)."""
219+
result = safe_serialize("hello")
220+
assert result == "hello"
221+
222+
def test_safe_serialize_still_handles_integers(self):
223+
"""safe_serialize should JSON-encode integers."""
224+
result = safe_serialize(42)
225+
assert result == "42"
226+
227+
def test_safe_serialize_still_handles_dicts(self):
228+
"""safe_serialize should JSON-encode dicts."""
229+
result = safe_serialize({"key": "value"})
230+
assert result == '{"key": "value"}'
231+
232+
def test_safe_serialize_still_handles_floats(self):
233+
"""safe_serialize should JSON-encode floats."""
234+
result = safe_serialize(0.01)
235+
assert result == "0.01"
236+
237+
def test_safe_serialize_still_handles_booleans(self):
238+
"""safe_serialize should JSON-encode booleans."""
239+
assert safe_serialize(True) == "true"
240+
assert safe_serialize(False) == "false"
241+
242+
243+
class TestModelTrainerHyperparametersWithPipelineVariables:
244+
"""Tests that ModelTrainer accepts PipelineVariable objects in hyperparameters."""
245+
246+
def test_hyperparameters_accept_pipeline_variable_values(self):
247+
"""ModelTrainer should accept PipelineVariable objects as hyperparameter values."""
248+
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
249+
learning_rate = ParameterFloat(name="LearningRate", default_value=0.01)
250+
optimizer = ParameterString(name="Optimizer", default_value="adam")
251+
252+
trainer = ModelTrainer(
253+
training_image=DEFAULT_IMAGE,
254+
role=DEFAULT_ROLE,
255+
compute=DEFAULT_COMPUTE,
256+
stopping_condition=DEFAULT_STOPPING,
257+
output_data_config=DEFAULT_OUTPUT,
258+
hyperparameters={
259+
"max_depth": max_depth,
260+
"learning_rate": learning_rate,
261+
"optimizer": optimizer,
262+
"static_param": 10,
263+
},
264+
)
265+
assert trainer.hyperparameters["max_depth"] is max_depth
266+
assert trainer.hyperparameters["learning_rate"] is learning_rate
267+
assert trainer.hyperparameters["optimizer"] is optimizer
268+
assert trainer.hyperparameters["static_param"] == 10
269+
270+
def test_create_training_job_args_with_pipeline_variable_hyperparameters(self):
271+
"""_create_training_job_args should preserve PipelineVariable in hyper_parameters."""
272+
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
273+
learning_rate = ParameterFloat(name="LearningRate", default_value=0.01)
274+
275+
trainer = ModelTrainer(
276+
training_image=DEFAULT_IMAGE,
277+
role=DEFAULT_ROLE,
278+
compute=DEFAULT_COMPUTE,
279+
stopping_condition=DEFAULT_STOPPING,
280+
output_data_config=DEFAULT_OUTPUT,
281+
hyperparameters={
282+
"max_depth": max_depth,
283+
"learning_rate": learning_rate,
284+
"epochs": 10,
285+
"verbose": "true",
286+
},
287+
)
288+
289+
training_args = trainer._create_training_job_args()
290+
hyper_params = training_args["hyper_parameters"]
291+
292+
# PipelineVariable objects should be preserved as-is by safe_serialize
293+
assert hyper_params["max_depth"] is max_depth
294+
assert hyper_params["learning_rate"] is learning_rate
295+
# Regular values should be serialized to strings
296+
assert hyper_params["epochs"] == "10"
297+
assert hyper_params["verbose"] == "true"

0 commit comments

Comments
 (0)