Skip to content

Commit 109214a

Browse files
committed
fix: address review comments (iteration #1)
1 parent c517f24 commit 109214a

File tree

4 files changed

+46
-4
lines changed

4 files changed

+46
-4
lines changed

sagemaker-core/src/sagemaker/core/modules/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@
2424

2525
from sagemaker.core.shapes import Unassigned
2626
from sagemaker.core.modules import logger
27-
from sagemaker.core.helper.pipeline_variable import PipelineVariable
27+
28+
try:
29+
from sagemaker.core.helper.pipeline_variable import PipelineVariable
30+
except ImportError:
31+
PipelineVariable = None
2832

2933

3034
def _is_valid_s3_uri(path: str, path_type: Literal["File", "Directory", "Any"] = "Any") -> bool:
@@ -145,7 +149,7 @@ def safe_serialize(data):
145149
"""
146150
if isinstance(data, str):
147151
return data
148-
elif isinstance(data, PipelineVariable):
152+
elif PipelineVariable is not None and isinstance(data, PipelineVariable):
149153
return data
150154
try:
151155
return json.dumps(data)

sagemaker-core/tests/unit/core/modules/test_utils_safe_serialize.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
1919
See: https://github.com/aws/sagemaker-python-sdk/issues/5504
2020
"""
21-
from __future__ import absolute_import
21+
from __future__ import annotations
2222

2323
import pytest
2424

@@ -51,6 +51,12 @@ def test_safe_serialize_does_not_call_str_on_pipeline_variable(self):
5151
result = safe_serialize(param)
5252
assert result is param
5353

54+
def test_pipeline_variable_str_raises_type_error(self):
55+
"""Confirm PipelineVariable.__str__ raises TypeError (the root cause of the bug)."""
56+
param = ParameterInteger(name="TestParam", default_value=10)
57+
with pytest.raises(TypeError):
58+
str(param)
59+
5460

5561
class TestSafeSerializeBasicTypes:
5662
"""Regression tests: verify basic types still work after PipelineVariable support."""

sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,32 @@ def test_hyperparameters_with_parameter_string(self):
210210
args = trainer._create_training_job_args()
211211
assert args["hyper_parameters"]["algorithm"] is algo
212212

213+
def test_hyperparameters_with_parameter_integer_does_not_raise(self):
214+
"""Verify ParameterInteger in hyperparameters does NOT raise TypeError.
215+
216+
This test documents the exact bug scenario from GH#5504: safe_serialize
217+
would fall back to str(data) for PipelineVariable objects, but
218+
PipelineVariable.__str__ intentionally raises TypeError.
219+
"""
220+
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
221+
trainer = ModelTrainer(
222+
training_image=DEFAULT_IMAGE,
223+
role=DEFAULT_ROLE,
224+
compute=DEFAULT_COMPUTE,
225+
stopping_condition=DEFAULT_STOPPING,
226+
output_data_config=DEFAULT_OUTPUT,
227+
hyperparameters={"max_depth": max_depth},
228+
)
229+
# This call would have raised TypeError before the fix
230+
try:
231+
args = trainer._create_training_job_args()
232+
except TypeError:
233+
pytest.fail(
234+
"safe_serialize raised TypeError on PipelineVariable - "
235+
"this is the bug described in GH#5504"
236+
)
237+
assert args["hyper_parameters"]["max_depth"] is max_depth
238+
213239
def test_hyperparameters_with_mixed_pipeline_and_static_values(self):
214240
"""Mixed PipelineVariable and static values should both be handled correctly."""
215241
max_depth = ParameterInteger(name="MaxDepth", default_value=5)

sagemaker-train/tests/unit/train/test_safe_serialize.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
2020
See: https://github.com/aws/sagemaker-python-sdk/issues/5504
2121
"""
22-
from __future__ import absolute_import
22+
from __future__ import annotations
2323

2424
import pytest
2525

@@ -52,6 +52,12 @@ def test_safe_serialize_does_not_call_str_on_pipeline_variable(self):
5252
result = safe_serialize(param)
5353
assert result is param
5454

55+
def test_pipeline_variable_str_raises_type_error(self):
56+
"""Confirm PipelineVariable.__str__ raises TypeError (the root cause of the bug)."""
57+
param = ParameterInteger(name="TestParam", default_value=10)
58+
with pytest.raises(TypeError):
59+
str(param)
60+
5561

5662
class TestSafeSerializeBasicTypes:
5763
"""Regression tests: verify basic types still work after PipelineVariable support."""

0 commit comments

Comments
 (0)