Skip to content

Commit 95a19c5

Browse files
committed
fix: [Bug] Pipeline parameters (ParameterInteger, ParameterString) fail in ModelTrain (5504)
1 parent ee420cc commit 95a19c5

File tree

3 files changed

+160
-2
lines changed

3 files changed

+160
-2
lines changed

sagemaker-train/src/sagemaker/train/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,16 @@ def safe_serialize(data):
192192
try:
193193
return json.dumps(data)
194194
except TypeError:
195-
return str(data)
195+
try:
196+
return str(data)
197+
except TypeError:
198+
# PipelineVariable.__str__ raises TypeError by design.
199+
# If the isinstance check above didn't catch it (e.g. import
200+
# path mismatch), fall back to returning the object directly
201+
# when it looks like a PipelineVariable (has an ``expr`` property).
202+
if hasattr(data, "expr"):
203+
return data
204+
raise
196205

197206

198207
def _run_clone_command_silent(repo_url, dest_dir):

sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,14 @@
2626

2727
from sagemaker.core.helper.session_helper import Session
2828
from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar
29-
from sagemaker.core.workflow.parameters import ParameterString
29+
from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger, ParameterFloat
3030
from sagemaker.train.model_trainer import ModelTrainer, Mode
3131
from sagemaker.train.configs import (
3232
Compute,
3333
StoppingCondition,
3434
OutputDataConfig,
3535
)
36+
from sagemaker.core.workflow.pipeline_context import PipelineSession
3637
from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE
3738

3839

@@ -176,3 +177,61 @@ def test_training_image_rejects_invalid_type(self):
176177
stopping_condition=DEFAULT_STOPPING,
177178
output_data_config=DEFAULT_OUTPUT,
178179
)
180+
181+
182+
class TestModelTrainerHyperparametersPipelineVariable:
183+
"""Test that PipelineVariable objects in hyperparameters survive safe_serialize."""
184+
185+
def test_hyperparameters_with_pipeline_variable_integer(self):
186+
"""ParameterInteger in hyperparameters should be passed through as-is."""
187+
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
188+
trainer = ModelTrainer(
189+
training_image=DEFAULT_IMAGE,
190+
role=DEFAULT_ROLE,
191+
compute=DEFAULT_COMPUTE,
192+
stopping_condition=DEFAULT_STOPPING,
193+
output_data_config=DEFAULT_OUTPUT,
194+
hyperparameters={"max_depth": max_depth},
195+
)
196+
# safe_serialize should return the PipelineVariable object directly
197+
from sagemaker.train.utils import safe_serialize
198+
result = safe_serialize(max_depth)
199+
assert result is max_depth
200+
201+
def test_hyperparameters_with_pipeline_variable_string(self):
202+
"""ParameterString in hyperparameters should be passed through as-is."""
203+
optimizer = ParameterString(name="Optimizer", default_value="sgd")
204+
trainer = ModelTrainer(
205+
training_image=DEFAULT_IMAGE,
206+
role=DEFAULT_ROLE,
207+
compute=DEFAULT_COMPUTE,
208+
stopping_condition=DEFAULT_STOPPING,
209+
output_data_config=DEFAULT_OUTPUT,
210+
hyperparameters={"optimizer": optimizer},
211+
)
212+
from sagemaker.train.utils import safe_serialize
213+
result = safe_serialize(optimizer)
214+
assert result is optimizer
215+
216+
def test_hyperparameters_with_mixed_pipeline_and_regular_values(self):
217+
"""Mixed PipelineVariable and regular values should both serialize correctly."""
218+
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
219+
trainer = ModelTrainer(
220+
training_image=DEFAULT_IMAGE,
221+
role=DEFAULT_ROLE,
222+
compute=DEFAULT_COMPUTE,
223+
stopping_condition=DEFAULT_STOPPING,
224+
output_data_config=DEFAULT_OUTPUT,
225+
hyperparameters={
226+
"max_depth": max_depth,
227+
"eta": 0.1,
228+
"objective": "binary:logistic",
229+
},
230+
)
231+
from sagemaker.train.utils import safe_serialize
232+
# PipelineVariable should be returned as-is
233+
assert safe_serialize(max_depth) is max_depth
234+
# Float should be JSON-serialized
235+
assert safe_serialize(0.1) == "0.1"
236+
# String should be returned as-is
237+
assert safe_serialize("binary:logistic") == "binary:logistic"
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Unit tests for sagemaker.train.utils – specifically safe_serialize."""
14+
from __future__ import absolute_import
15+
16+
import json
17+
18+
from sagemaker.train.utils import safe_serialize
19+
from sagemaker.core.workflow.parameters import (
20+
ParameterInteger,
21+
ParameterString,
22+
ParameterFloat,
23+
)
24+
25+
26+
# ---------------------------------------------------------------------------
27+
# PipelineVariable inputs – should be returned as-is (identity)
28+
# ---------------------------------------------------------------------------
29+
30+
def test_safe_serialize_with_pipeline_variable_integer_returns_object_directly():
31+
param = ParameterInteger(name="MaxDepth", default_value=5)
32+
result = safe_serialize(param)
33+
assert result is param
34+
35+
36+
def test_safe_serialize_with_pipeline_variable_string_returns_object_directly():
37+
param = ParameterString(name="Optimizer", default_value="sgd")
38+
result = safe_serialize(param)
39+
assert result is param
40+
41+
42+
def test_safe_serialize_with_pipeline_variable_float_returns_object_directly():
43+
param = ParameterFloat(name="LearningRate", default_value=0.01)
44+
result = safe_serialize(param)
45+
assert result is param
46+
47+
48+
# ---------------------------------------------------------------------------
49+
# Regular / primitive inputs
50+
# ---------------------------------------------------------------------------
51+
52+
def test_safe_serialize_with_string_returns_string_as_is():
53+
assert safe_serialize("hello") == "hello"
54+
assert safe_serialize("12345") == "12345"
55+
56+
57+
def test_safe_serialize_with_int_returns_json_string():
58+
assert safe_serialize(5) == "5"
59+
assert safe_serialize(0) == "0"
60+
61+
62+
def test_safe_serialize_with_dict_returns_json_string():
63+
data = {"key": "value", "num": 1}
64+
assert safe_serialize(data) == json.dumps(data)
65+
66+
67+
def test_safe_serialize_with_bool_returns_json_string():
68+
assert safe_serialize(True) == "true"
69+
assert safe_serialize(False) == "false"
70+
71+
72+
def test_safe_serialize_with_custom_object_returns_str():
73+
class CustomObject:
74+
def __str__(self):
75+
return "CustomObject"
76+
77+
obj = CustomObject()
78+
assert safe_serialize(obj) == "CustomObject"
79+
80+
81+
def test_safe_serialize_with_none_returns_json_null():
82+
assert safe_serialize(None) == "null"
83+
84+
85+
def test_safe_serialize_with_list_returns_json_string():
86+
assert safe_serialize([1, 2, 3]) == "[1, 2, 3]"
87+
88+
89+
def test_safe_serialize_with_empty_string():
90+
assert safe_serialize("") == ""

0 commit comments

Comments
 (0)