Skip to content

Commit b05d156

Browse files
committed
fix: address review comments (iteration #1)
1 parent aae602b commit b05d156

File tree

1 file changed

+45
-58
lines changed

1 file changed

+45
-58
lines changed

sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py

Lines changed: 45 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,16 @@
1212
# language governing permissions and limitations under the License.
1313
"""Tests for PipelineVariable support in ModelTrainer.
1414
15-
Verifies that ModelTrainer fields accept PipelineVariable objects
15+
Verify that ModelTrainer fields accept PipelineVariable objects
1616
(e.g., ParameterString) in addition to their concrete types, following
1717
the existing V3 pattern established by SourceCode and OutputDataConfig.
1818
19-
Also verifies that safe_serialize correctly handles PipelineVariable objects
19+
Also verify that safe_serialize correctly handles PipelineVariable objects
2020
in hyperparameters (returning them as-is instead of attempting json.dumps),
2121
and that _create_training_job_args preserves PipelineVariable objects through
2222
the serialization pipeline.
23+
24+
See: https://github.com/aws/sagemaker-python-sdk/issues/5504
2325
"""
2426
from __future__ import absolute_import
2527

@@ -56,7 +58,7 @@
5658
)
5759

5860

59-
@pytest.fixture(scope="module", autouse=True)
61+
@pytest.fixture(scope="module")
6062
def modules_session():
6163
with patch("sagemaker.train.Session", spec=Session) as session_mock:
6264
session_instance = session_mock.return_value
@@ -72,7 +74,7 @@ class TestModelTrainerPipelineVariableAcceptance:
7274
"""Test that ModelTrainer fields accept PipelineVariable objects."""
7375

7476
def test_training_image_accepts_parameter_string(self):
75-
"""ModelTrainer.training_image should accept ParameterString (GH#5524)."""
77+
"""Verify ModelTrainer.training_image accepts ParameterString (GH#5504)."""
7678
param = ParameterString(name="TrainingImage", default_value=DEFAULT_IMAGE)
7779
trainer = ModelTrainer(
7880
training_image=param,
@@ -85,7 +87,7 @@ def test_training_image_accepts_parameter_string(self):
8587
assert trainer.training_image is param
8688

8789
def test_algorithm_name_accepts_parameter_string(self):
88-
"""ModelTrainer.algorithm_name should accept ParameterString."""
90+
"""Verify ModelTrainer.algorithm_name accepts ParameterString."""
8991
param = ParameterString(name="AlgorithmName", default_value="my-algo-arn")
9092
trainer = ModelTrainer(
9193
algorithm_name=param,
@@ -98,7 +100,7 @@ def test_algorithm_name_accepts_parameter_string(self):
98100
assert trainer.algorithm_name is param
99101

100102
def test_training_input_mode_accepts_parameter_string(self):
101-
"""ModelTrainer.training_input_mode should accept ParameterString."""
103+
"""Verify ModelTrainer.training_input_mode accepts ParameterString."""
102104
param = ParameterString(name="InputMode", default_value="File")
103105
trainer = ModelTrainer(
104106
training_image=DEFAULT_IMAGE,
@@ -111,7 +113,7 @@ def test_training_input_mode_accepts_parameter_string(self):
111113
assert trainer.training_input_mode is param
112114

113115
def test_environment_values_accept_parameter_string(self):
114-
"""ModelTrainer.environment dict values should accept ParameterString."""
116+
"""Verify ModelTrainer.environment dict values accept ParameterString."""
115117
param = ParameterString(name="DatasetVersion", default_value="v1")
116118
trainer = ModelTrainer(
117119
training_image=DEFAULT_IMAGE,
@@ -129,7 +131,7 @@ class TestModelTrainerRealValuesStillWork:
129131
"""Regression tests: verify that passing real values still works after the change."""
130132

131133
def test_training_image_accepts_real_string(self):
132-
"""ModelTrainer.training_image should still accept a plain string."""
134+
"""Verify ModelTrainer.training_image still accepts a plain string."""
133135
trainer = ModelTrainer(
134136
training_image=DEFAULT_IMAGE,
135137
role=DEFAULT_ROLE,
@@ -140,7 +142,7 @@ def test_training_image_accepts_real_string(self):
140142
assert trainer.training_image == DEFAULT_IMAGE
141143

142144
def test_algorithm_name_accepts_real_string(self):
143-
"""ModelTrainer.algorithm_name should still accept a plain string."""
145+
"""Verify ModelTrainer.algorithm_name still accepts a plain string."""
144146
trainer = ModelTrainer(
145147
algorithm_name="arn:aws:sagemaker:us-west-2:000000000000:algorithm/my-algo",
146148
role=DEFAULT_ROLE,
@@ -151,7 +153,7 @@ def test_algorithm_name_accepts_real_string(self):
151153
assert trainer.algorithm_name == "arn:aws:sagemaker:us-west-2:000000000000:algorithm/my-algo"
152154

153155
def test_training_input_mode_accepts_real_string(self):
154-
"""ModelTrainer.training_input_mode should still accept a plain string."""
156+
"""Verify ModelTrainer.training_input_mode still accepts a plain string."""
155157
trainer = ModelTrainer(
156158
training_image=DEFAULT_IMAGE,
157159
training_input_mode="Pipe",
@@ -163,7 +165,7 @@ def test_training_input_mode_accepts_real_string(self):
163165
assert trainer.training_input_mode == "Pipe"
164166

165167
def test_environment_accepts_real_string_values(self):
166-
"""ModelTrainer.environment should still accept plain string values."""
168+
"""Verify ModelTrainer.environment still accepts plain string values."""
167169
trainer = ModelTrainer(
168170
training_image=DEFAULT_IMAGE,
169171
environment={"KEY1": "value1", "KEY2": "value2"},
@@ -175,7 +177,7 @@ def test_environment_accepts_real_string_values(self):
175177
assert trainer.environment == {"KEY1": "value1", "KEY2": "value2"}
176178

177179
def test_training_image_rejects_invalid_type(self):
178-
"""ModelTrainer.training_image should still reject invalid types (e.g., int)."""
180+
"""Verify ModelTrainer.training_image still rejects invalid types (e.g., int)."""
179181
with pytest.raises(ValidationError):
180182
ModelTrainer(
181183
training_image=12345,
@@ -187,64 +189,46 @@ def test_training_image_rejects_invalid_type(self):
187189

188190

189191
class TestSafeSerializeWithPipelineVariables:
190-
"""Tests that safe_serialize handles PipelineVariable objects correctly.
192+
"""Verify that safe_serialize handles PipelineVariable objects correctly.
191193
192194
The safe_serialize function must return PipelineVariable objects as-is
193195
instead of attempting json.dumps(), which would raise TypeError.
196+
See: https://github.com/aws/sagemaker-python-sdk/issues/5504
194197
"""
195198

196-
def test_safe_serialize_with_parameter_integer_returns_pipeline_variable(self):
197-
"""safe_serialize should return ParameterInteger as-is."""
198-
param = ParameterInteger(name="MaxDepth", default_value=5)
199-
result = safe_serialize(param)
200-
assert result is param
201-
assert isinstance(result, PipelineVariable)
202-
203-
def test_safe_serialize_with_parameter_string_returns_pipeline_variable(self):
204-
"""safe_serialize should return ParameterString as-is."""
205-
param = ParameterString(name="Optimizer", default_value="adam")
199+
@pytest.mark.parametrize("param", [
200+
ParameterInteger(name="MaxDepth", default_value=5),
201+
ParameterString(name="Optimizer", default_value="adam"),
202+
ParameterFloat(name="LearningRate", default_value=0.01),
203+
])
204+
def test_safe_serialize_returns_pipeline_variable_as_is(self, param):
205+
"""Verify safe_serialize returns PipelineVariable objects as-is."""
206206
result = safe_serialize(param)
207207
assert result is param
208208
assert isinstance(result, PipelineVariable)
209209

210-
def test_safe_serialize_with_parameter_float_returns_pipeline_variable(self):
211-
"""safe_serialize should return ParameterFloat as-is."""
212-
param = ParameterFloat(name="LearningRate", default_value=0.01)
213-
result = safe_serialize(param)
214-
assert result is param
215-
assert isinstance(result, PipelineVariable)
216-
217-
def test_safe_serialize_still_handles_strings(self):
218-
"""safe_serialize should return plain strings as-is (no quotes wrapping)."""
219-
result = safe_serialize("hello")
220-
assert result == "hello"
221-
222-
def test_safe_serialize_still_handles_integers(self):
223-
"""safe_serialize should JSON-encode integers."""
224-
result = safe_serialize(42)
225-
assert result == "42"
226-
227-
def test_safe_serialize_still_handles_dicts(self):
228-
"""safe_serialize should JSON-encode dicts."""
229-
result = safe_serialize({"key": "value"})
230-
assert result == '{"key": "value"}'
231-
232-
def test_safe_serialize_still_handles_floats(self):
233-
"""safe_serialize should JSON-encode floats."""
234-
result = safe_serialize(0.01)
235-
assert result == "0.01"
236-
237-
def test_safe_serialize_still_handles_booleans(self):
238-
"""safe_serialize should JSON-encode booleans."""
239-
assert safe_serialize(True) == "true"
240-
assert safe_serialize(False) == "false"
210+
@pytest.mark.parametrize("input_val,expected", [
211+
("hello", "hello"),
212+
(42, "42"),
213+
({"key": "value"}, '{"key": "value"}'),
214+
(0.01, "0.01"),
215+
(True, "true"),
216+
(False, "false"),
217+
])
218+
def test_safe_serialize_handles_normal_types(self, input_val, expected):
219+
"""Verify safe_serialize correctly serializes normal (non-PipelineVariable) types."""
220+
result = safe_serialize(input_val)
221+
assert result == expected
241222

242223

243224
class TestModelTrainerHyperparametersWithPipelineVariables:
244-
"""Tests that ModelTrainer accepts PipelineVariable objects in hyperparameters."""
225+
"""Verify that ModelTrainer accepts PipelineVariable objects in hyperparameters.
226+
227+
See: https://github.com/aws/sagemaker-python-sdk/issues/5504
228+
"""
245229

246230
def test_hyperparameters_accept_pipeline_variable_values(self):
247-
"""ModelTrainer should accept PipelineVariable objects as hyperparameter values."""
231+
"""Verify ModelTrainer accepts PipelineVariable objects as hyperparameter values."""
248232
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
249233
learning_rate = ParameterFloat(name="LearningRate", default_value=0.01)
250234
optimizer = ParameterString(name="Optimizer", default_value="adam")
@@ -267,8 +251,10 @@ def test_hyperparameters_accept_pipeline_variable_values(self):
267251
assert trainer.hyperparameters["optimizer"] is optimizer
268252
assert trainer.hyperparameters["static_param"] == 10
269253

270-
def test_create_training_job_args_with_pipeline_variable_hyperparameters(self):
271-
"""_create_training_job_args should preserve PipelineVariable in hyper_parameters."""
254+
def test_create_training_job_args_with_pipeline_variable_hyperparameters(
255+
self, modules_session
256+
):
257+
"""Verify _create_training_job_args preserves PipelineVariable in hyper_parameters."""
272258
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
273259
learning_rate = ParameterFloat(name="LearningRate", default_value=0.01)
274260

@@ -278,6 +264,7 @@ def test_create_training_job_args_with_pipeline_variable_hyperparameters(self):
278264
compute=DEFAULT_COMPUTE,
279265
stopping_condition=DEFAULT_STOPPING,
280266
output_data_config=DEFAULT_OUTPUT,
267+
sagemaker_session=modules_session,
281268
hyperparameters={
282269
"max_depth": max_depth,
283270
"learning_rate": learning_rate,

0 commit comments

Comments
 (0)