Skip to content

Commit c517f24

Browse files
committed
fix: [Bug] Pipeline parameters (ParameterInteger, ParameterString) fail in ModelTrain (5504)
1 parent 272fdbf commit c517f24

File tree

4 files changed

+250
-3
lines changed

4 files changed

+250
-3
lines changed

sagemaker-core/src/sagemaker/core/modules/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from sagemaker.core.shapes import Unassigned
2626
from sagemaker.core.modules import logger
27+
from sagemaker.core.helper.pipeline_variable import PipelineVariable
2728

2829

2930
def _is_valid_s3_uri(path: str, path_type: Literal["File", "Directory", "Any"] = "Any") -> bool:
@@ -129,9 +130,11 @@ def safe_serialize(data):
129130
130131
This function handles the following cases:
131132
1. If `data` is a string, it returns the string as-is without wrapping in quotes.
132-
2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns
133+
2. If `data` is of type `PipelineVariable`, it returns the PipelineVariable object
134+
as-is for pipeline serialization.
135+
3. If `data` is serializable (e.g., a dictionary, list, int, float), it returns
133136
the JSON-encoded string using `json.dumps()`.
134-
3. If `data` cannot be serialized (e.g., a custom object), it returns the string
137+
4. If `data` cannot be serialized (e.g., a custom object), it returns the string
135138
representation of the data using `str(data)`.
136139
137140
Args:
@@ -142,6 +145,8 @@ def safe_serialize(data):
142145
"""
143146
if isinstance(data, str):
144147
return data
148+
elif isinstance(data, PipelineVariable):
149+
return data
145150
try:
146151
return json.dumps(data)
147152
except TypeError:
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Tests for safe_serialize in sagemaker.core.modules.utils with PipelineVariable support.
14+
15+
Verifies that safe_serialize correctly handles PipelineVariable objects
16+
(e.g., ParameterInteger, ParameterString) by returning them as-is rather
17+
than attempting str() conversion which would raise TypeError.
18+
19+
See: https://github.com/aws/sagemaker-python-sdk/issues/5504
20+
"""
21+
from __future__ import absolute_import
22+
23+
import pytest
24+
25+
from sagemaker.core.modules.utils import safe_serialize
26+
from sagemaker.core.helper.pipeline_variable import PipelineVariable
27+
from sagemaker.core.workflow.parameters import ParameterInteger, ParameterString
28+
29+
30+
class TestSafeSerializeWithPipelineVariables:
31+
"""Test safe_serialize handles PipelineVariable objects correctly."""
32+
33+
def test_safe_serialize_with_parameter_integer(self):
34+
"""ParameterInteger should be returned as-is (identity preserved)."""
35+
param = ParameterInteger(name="MaxDepth", default_value=5)
36+
result = safe_serialize(param)
37+
assert result is param
38+
assert isinstance(result, PipelineVariable)
39+
40+
def test_safe_serialize_with_parameter_string(self):
41+
"""ParameterString should be returned as-is (identity preserved)."""
42+
param = ParameterString(name="Algorithm", default_value="xgboost")
43+
result = safe_serialize(param)
44+
assert result is param
45+
assert isinstance(result, PipelineVariable)
46+
47+
def test_safe_serialize_does_not_call_str_on_pipeline_variable(self):
48+
"""Verify that PipelineVariable.__str__ is never invoked (would raise TypeError)."""
49+
param = ParameterInteger(name="TestParam", default_value=10)
50+
# This should NOT raise TypeError
51+
result = safe_serialize(param)
52+
assert result is param
53+
54+
55+
class TestSafeSerializeBasicTypes:
56+
"""Regression tests: verify basic types still work after PipelineVariable support."""
57+
58+
def test_safe_serialize_with_string(self):
59+
"""Strings should be returned as-is without JSON wrapping."""
60+
assert safe_serialize("hello") == "hello"
61+
62+
def test_safe_serialize_with_int(self):
63+
"""Integers should be JSON-serialized to string."""
64+
assert safe_serialize(42) == "42"
65+
66+
def test_safe_serialize_with_dict(self):
67+
"""Dicts should be JSON-serialized."""
68+
result = safe_serialize({"key": "val"})
69+
assert result == '{"key": "val"}'
70+
71+
def test_safe_serialize_with_bool(self):
72+
"""Booleans should be JSON-serialized."""
73+
assert safe_serialize(True) == "true"
74+
assert safe_serialize(False) == "false"
75+
76+
def test_safe_serialize_with_none(self):
77+
"""None should be JSON-serialized to 'null'."""
78+
assert safe_serialize(None) == "null"
79+
80+
def test_safe_serialize_with_custom_object(self):
81+
"""Custom objects should fall back to str()."""
82+
83+
class CustomObj:
84+
def __str__(self):
85+
return "custom"
86+
87+
assert safe_serialize(CustomObj()) == "custom"

sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from sagemaker.core.helper.session_helper import Session
2828
from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar
29-
from sagemaker.core.workflow.parameters import ParameterString
29+
from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger
3030
from sagemaker.train.model_trainer import ModelTrainer, Mode
3131
from sagemaker.train.configs import (
3232
Compute,
@@ -176,3 +176,61 @@ def test_training_image_rejects_invalid_type(self):
176176
stopping_condition=DEFAULT_STOPPING,
177177
output_data_config=DEFAULT_OUTPUT,
178178
)
179+
180+
181+
class TestModelTrainerPipelineVariableHyperparameters:
182+
"""Test that PipelineVariable objects work correctly in ModelTrainer hyperparameters."""
183+
184+
def test_hyperparameters_with_parameter_integer(self):
185+
"""ParameterInteger in hyperparameters should be preserved through _create_training_job_args."""
186+
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
187+
trainer = ModelTrainer(
188+
training_image=DEFAULT_IMAGE,
189+
role=DEFAULT_ROLE,
190+
compute=DEFAULT_COMPUTE,
191+
stopping_condition=DEFAULT_STOPPING,
192+
output_data_config=DEFAULT_OUTPUT,
193+
hyperparameters={"max_depth": max_depth},
194+
)
195+
args = trainer._create_training_job_args()
196+
# PipelineVariable should be preserved as-is, not stringified
197+
assert args["hyper_parameters"]["max_depth"] is max_depth
198+
199+
def test_hyperparameters_with_parameter_string(self):
200+
"""ParameterString in hyperparameters should be preserved through _create_training_job_args."""
201+
algo = ParameterString(name="Algorithm", default_value="xgboost")
202+
trainer = ModelTrainer(
203+
training_image=DEFAULT_IMAGE,
204+
role=DEFAULT_ROLE,
205+
compute=DEFAULT_COMPUTE,
206+
stopping_condition=DEFAULT_STOPPING,
207+
output_data_config=DEFAULT_OUTPUT,
208+
hyperparameters={"algorithm": algo},
209+
)
210+
args = trainer._create_training_job_args()
211+
assert args["hyper_parameters"]["algorithm"] is algo
212+
213+
def test_hyperparameters_with_mixed_pipeline_and_static_values(self):
214+
"""Mixed PipelineVariable and static values should both be handled correctly."""
215+
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
216+
trainer = ModelTrainer(
217+
training_image=DEFAULT_IMAGE,
218+
role=DEFAULT_ROLE,
219+
compute=DEFAULT_COMPUTE,
220+
stopping_condition=DEFAULT_STOPPING,
221+
output_data_config=DEFAULT_OUTPUT,
222+
hyperparameters={
223+
"max_depth": max_depth,
224+
"eta": 0.1,
225+
"objective": "binary:logistic",
226+
"num_round": 100,
227+
},
228+
)
229+
args = trainer._create_training_job_args()
230+
hp = args["hyper_parameters"]
231+
# PipelineVariable preserved as-is
232+
assert hp["max_depth"] is max_depth
233+
# Static values serialized to strings
234+
assert hp["eta"] == "0.1"
235+
assert hp["objective"] == "binary:logistic"
236+
assert hp["num_round"] == "100"
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Tests for safe_serialize with PipelineVariable support.
14+
15+
Verifies that safe_serialize in sagemaker.train.utils correctly handles
16+
PipelineVariable objects (e.g., ParameterInteger, ParameterString) by
17+
returning them as-is rather than attempting str() conversion which would
18+
raise TypeError.
19+
20+
See: https://github.com/aws/sagemaker-python-sdk/issues/5504
21+
"""
22+
from __future__ import absolute_import
23+
24+
import pytest
25+
26+
from sagemaker.train.utils import safe_serialize
27+
from sagemaker.core.helper.pipeline_variable import PipelineVariable
28+
from sagemaker.core.workflow.parameters import ParameterInteger, ParameterString
29+
30+
31+
class TestSafeSerializeWithPipelineVariables:
32+
"""Test safe_serialize handles PipelineVariable objects correctly."""
33+
34+
def test_safe_serialize_with_parameter_integer(self):
35+
"""ParameterInteger should be returned as-is (identity preserved)."""
36+
param = ParameterInteger(name="MaxDepth", default_value=5)
37+
result = safe_serialize(param)
38+
assert result is param
39+
assert isinstance(result, PipelineVariable)
40+
41+
def test_safe_serialize_with_parameter_string(self):
42+
"""ParameterString should be returned as-is (identity preserved)."""
43+
param = ParameterString(name="Algorithm", default_value="xgboost")
44+
result = safe_serialize(param)
45+
assert result is param
46+
assert isinstance(result, PipelineVariable)
47+
48+
def test_safe_serialize_does_not_call_str_on_pipeline_variable(self):
49+
"""Verify that PipelineVariable.__str__ is never invoked (would raise TypeError)."""
50+
param = ParameterInteger(name="TestParam", default_value=10)
51+
# This should NOT raise TypeError
52+
result = safe_serialize(param)
53+
assert result is param
54+
55+
56+
class TestSafeSerializeBasicTypes:
57+
"""Regression tests: verify basic types still work after PipelineVariable support."""
58+
59+
def test_safe_serialize_with_string(self):
60+
"""Strings should be returned as-is without JSON wrapping."""
61+
assert safe_serialize("hello") == "hello"
62+
assert safe_serialize("12345") == "12345"
63+
64+
def test_safe_serialize_with_int(self):
65+
"""Integers should be JSON-serialized to string."""
66+
assert safe_serialize(42) == "42"
67+
68+
def test_safe_serialize_with_float(self):
69+
"""Floats should be JSON-serialized to string."""
70+
assert safe_serialize(3.14) == "3.14"
71+
72+
def test_safe_serialize_with_dict(self):
73+
"""Dicts should be JSON-serialized."""
74+
result = safe_serialize({"key": "val"})
75+
assert result == '{"key": "val"}'
76+
77+
def test_safe_serialize_with_bool(self):
78+
"""Booleans should be JSON-serialized."""
79+
assert safe_serialize(True) == "true"
80+
assert safe_serialize(False) == "false"
81+
82+
def test_safe_serialize_with_none(self):
83+
"""None should be JSON-serialized to 'null'."""
84+
assert safe_serialize(None) == "null"
85+
86+
def test_safe_serialize_with_list(self):
87+
"""Lists should be JSON-serialized."""
88+
assert safe_serialize([1, 2, 3]) == "[1, 2, 3]"
89+
90+
def test_safe_serialize_with_custom_object(self):
91+
"""Custom objects should fall back to str()."""
92+
93+
class CustomObj:
94+
def __str__(self):
95+
return "custom"
96+
97+
assert safe_serialize(CustomObj()) == "custom"

0 commit comments

Comments
 (0)