Skip to content

Commit 0d1195d

Browse files
committed
fix: address review comments (iteration #2)
1 parent 109214a commit 0d1195d

File tree

2 files changed

+14
-44
lines changed

2 files changed

+14
-44
lines changed

sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,13 @@ class TestModelTrainerPipelineVariableHyperparameters:
182182
"""Test that PipelineVariable objects work correctly in ModelTrainer hyperparameters."""
183183

184184
def test_hyperparameters_with_parameter_integer(self):
185-
"""ParameterInteger in hyperparameters should be preserved through _create_training_job_args."""
185+
"""ParameterInteger in hyperparameters should be preserved through _create_training_job_args.
186+
187+
This test documents the exact bug scenario from GH#5504: safe_serialize
188+
would fall back to str(data) for PipelineVariable objects, but
189+
PipelineVariable.__str__ intentionally raises TypeError.
190+
Before the fix, this call would have raised TypeError.
191+
"""
186192
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
187193
trainer = ModelTrainer(
188194
training_image=DEFAULT_IMAGE,
@@ -192,6 +198,7 @@ def test_hyperparameters_with_parameter_integer(self):
192198
output_data_config=DEFAULT_OUTPUT,
193199
hyperparameters={"max_depth": max_depth},
194200
)
201+
# This call would have raised TypeError before the fix (GH#5504)
195202
args = trainer._create_training_job_args()
196203
# PipelineVariable should be preserved as-is, not stringified
197204
assert args["hyper_parameters"]["max_depth"] is max_depth
@@ -210,32 +217,6 @@ def test_hyperparameters_with_parameter_string(self):
210217
args = trainer._create_training_job_args()
211218
assert args["hyper_parameters"]["algorithm"] is algo
212219

213-
def test_hyperparameters_with_parameter_integer_does_not_raise(self):
214-
"""Verify ParameterInteger in hyperparameters does NOT raise TypeError.
215-
216-
This test documents the exact bug scenario from GH#5504: safe_serialize
217-
would fall back to str(data) for PipelineVariable objects, but
218-
PipelineVariable.__str__ intentionally raises TypeError.
219-
"""
220-
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
221-
trainer = ModelTrainer(
222-
training_image=DEFAULT_IMAGE,
223-
role=DEFAULT_ROLE,
224-
compute=DEFAULT_COMPUTE,
225-
stopping_condition=DEFAULT_STOPPING,
226-
output_data_config=DEFAULT_OUTPUT,
227-
hyperparameters={"max_depth": max_depth},
228-
)
229-
# This call would have raised TypeError before the fix
230-
try:
231-
args = trainer._create_training_job_args()
232-
except TypeError:
233-
pytest.fail(
234-
"safe_serialize raised TypeError on PipelineVariable - "
235-
"this is the bug described in GH#5504"
236-
)
237-
assert args["hyper_parameters"]["max_depth"] is max_depth
238-
239220
def test_hyperparameters_with_mixed_pipeline_and_static_values(self):
240221
"""Mixed PipelineVariable and static values should both be handled correctly."""
241222
max_depth = ParameterInteger(name="MaxDepth", default_value=5)

sagemaker-train/tests/unit/train/test_safe_serialize.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,27 +31,16 @@
3131
class TestSafeSerializeWithPipelineVariables:
3232
"""Test safe_serialize handles PipelineVariable objects correctly."""
3333

34-
def test_safe_serialize_with_parameter_integer(self):
35-
"""ParameterInteger should be returned as-is (identity preserved)."""
36-
param = ParameterInteger(name="MaxDepth", default_value=5)
34+
@pytest.mark.parametrize("param", [
35+
ParameterInteger(name="MaxDepth", default_value=5),
36+
ParameterString(name="Algorithm", default_value="xgboost"),
37+
])
38+
def test_safe_serialize_returns_pipeline_variable_as_is(self, param):
39+
"""PipelineVariable objects should be returned as-is (identity preserved)."""
3740
result = safe_serialize(param)
3841
assert result is param
3942
assert isinstance(result, PipelineVariable)
4043

41-
def test_safe_serialize_with_parameter_string(self):
42-
"""ParameterString should be returned as-is (identity preserved)."""
43-
param = ParameterString(name="Algorithm", default_value="xgboost")
44-
result = safe_serialize(param)
45-
assert result is param
46-
assert isinstance(result, PipelineVariable)
47-
48-
def test_safe_serialize_does_not_call_str_on_pipeline_variable(self):
49-
"""Verify that PipelineVariable.__str__ is never invoked (would raise TypeError)."""
50-
param = ParameterInteger(name="TestParam", default_value=10)
51-
# This should NOT raise TypeError
52-
result = safe_serialize(param)
53-
assert result is param
54-
5544
def test_pipeline_variable_str_raises_type_error(self):
5645
"""Confirm PipelineVariable.__str__ raises TypeError (the root cause of the bug)."""
5746
param = ParameterInteger(name="TestParam", default_value=10)

0 commit comments

Comments
 (0)