Skip to content

Commit 17bcbac

Browse files
committed
fix: address review comments (iteration #1)
1 parent 3a78790 commit 17bcbac

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def test_safe_serialize_still_handles_dict(self):
158158
class TestModelTrainerHyperparametersPipelineVariable:
159159
"""Test that ModelTrainer hyperparameters accept PipelineVariable objects (GH#5504)."""
160160

161-
def test_hyperparameters_accept_parameter_integer_via_safe_serialize(self):
161+
def test_hyperparameters_accept_parameter_integer(self):
162162
"""ModelTrainer hyperparameters should accept ParameterInteger (GH#5504).
163163
164164
This is the exact bug scenario: ParameterInteger in hyperparameters
@@ -175,7 +175,7 @@ def test_hyperparameters_accept_parameter_integer_via_safe_serialize(self):
175175
)
176176
assert trainer.hyperparameters["max_depth"] is max_depth
177177

178-
def test_hyperparameters_accept_parameter_string_via_safe_serialize(self):
178+
def test_hyperparameters_accept_parameter_string(self):
179179
"""ModelTrainer hyperparameters should accept ParameterString (GH#5504)."""
180180
objective = ParameterString(name="Objective", default_value="reg:squarederror")
181181
trainer = ModelTrainer(
@@ -228,6 +228,7 @@ def test_create_training_job_args_preserves_pipeline_variable_in_hyperparameters
228228
args = trainer._create_training_job_args()
229229
# PipelineVariable should be preserved as-is by safe_serialize
230230
assert args["hyper_parameters"]["max_depth"] is max_depth
231+
assert isinstance(args["hyper_parameters"]["max_depth"], PipelineVariable)
231232
# Plain values should be JSON-serialized to strings
232233
assert args["hyper_parameters"]["eta"] == "0.1"
233234

0 commit comments

Comments
 (0)