From 41eb3f5332ebd8fb180f89463f4f505f4f66b2cf Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Tue, 7 Apr 2026 14:16:41 -0400 Subject: [PATCH] fix: [Bug] Pipeline parameters (ParameterInteger, ParameterString) fail in ModelTrain (5504) --- .../test_model_trainer_pipeline_variable.py | 92 ++++++++++++++++++- 1 file changed, 91 insertions(+), 1 deletion(-) diff --git a/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py b/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py index 3fd34fa47b..465650b310 100644 --- a/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py +++ b/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py @@ -26,8 +26,9 @@ from sagemaker.core.helper.session_helper import Session from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar -from sagemaker.core.workflow.parameters import ParameterString +from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger, ParameterFloat from sagemaker.train.model_trainer import ModelTrainer, Mode +from sagemaker.train.utils import safe_serialize from sagemaker.train.configs import ( Compute, StoppingCondition, @@ -116,6 +117,54 @@ def test_environment_values_accept_parameter_string(self): assert trainer.environment["DATASET_VERSION"] is param assert trainer.environment["STATIC_VAR"] == "hello" + def test_hyperparameters_accept_parameter_integer(self): + """ModelTrainer.hyperparameters should accept ParameterInteger values (GH#5504).""" + param = ParameterInteger(name="MaxDepth", default_value=5) + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + hyperparameters={"max_depth": param}, + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE, + stopping_condition=DEFAULT_STOPPING, + output_data_config=DEFAULT_OUTPUT, + ) + assert trainer.hyperparameters["max_depth"] is param + + def test_hyperparameters_accept_parameter_string(self): + """ModelTrainer.hyperparameters should accept ParameterString values (GH#5504).""" + param = ParameterString(name="LearningRate", default_value="0.01") + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + hyperparameters={"learning_rate": param}, + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE, + stopping_condition=DEFAULT_STOPPING, + output_data_config=DEFAULT_OUTPUT, + ) + assert trainer.hyperparameters["learning_rate"] is param + + def test_hyperparameters_accept_mixed_pipeline_and_static_values(self): + """ModelTrainer.hyperparameters should accept a mix of PipelineVariable and static values.""" + param_int = ParameterInteger(name="MaxDepth", default_value=5) + param_str = ParameterString(name="Objective", default_value="reg:squarederror") + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + hyperparameters={ + "max_depth": param_int, + "objective": param_str, + "eta": 0.1, + "num_round": 100, + }, + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE, + stopping_condition=DEFAULT_STOPPING, + output_data_config=DEFAULT_OUTPUT, + ) + assert trainer.hyperparameters["max_depth"] is param_int + assert trainer.hyperparameters["objective"] is param_str + assert trainer.hyperparameters["eta"] == 0.1 + assert trainer.hyperparameters["num_round"] == 100 + class TestModelTrainerRealValuesStillWork: """Regression tests: verify that passing real values still works after the change.""" @@ -176,3 +225,44 @@ def test_training_image_rejects_invalid_type(self): stopping_condition=DEFAULT_STOPPING, output_data_config=DEFAULT_OUTPUT, ) + + +class TestSafeSerializePipelineVariable: + """Test that safe_serialize correctly handles PipelineVariable objects (GH#5504). + + The bug was that safe_serialize would call json.dumps() on PipelineVariable objects, + causing TypeError. The fix returns PipelineVariable objects as-is. + """ + + def test_safe_serialize_with_parameter_integer(self): + """safe_serialize should return ParameterInteger as-is, not attempt JSON serialization.""" + param = ParameterInteger(name="MaxDepth", default_value=5) + result = safe_serialize(param) + assert result is param + + def test_safe_serialize_with_parameter_string(self): + """safe_serialize should return ParameterString as-is, not attempt JSON serialization.""" + param = ParameterString(name="Objective", default_value="reg:squarederror") + result = safe_serialize(param) + assert result is param + + def test_safe_serialize_with_parameter_float(self): + """safe_serialize should return ParameterFloat as-is, not attempt JSON serialization.""" + param = ParameterFloat(name="LearningRate", default_value=0.01) + result = safe_serialize(param) + assert result is param + + def test_safe_serialize_with_plain_string(self): + """safe_serialize should return plain strings unchanged.""" + result = safe_serialize("hello") + assert result == "hello" + + def test_safe_serialize_with_int(self): + """safe_serialize should JSON-dump integers to their string representation.""" + result = safe_serialize(5) + assert result == "5" + + def test_safe_serialize_with_dict(self): + """safe_serialize should JSON-dump dicts to their string representation.""" + result = safe_serialize({"key": "value"}) + assert result == '{"key": "value"}'