From 3a7879079c8fff6d955fff93d40cd90146b7eaa1 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Mon, 30 Mar 2026 17:10:36 -0400 Subject: [PATCH 1/2] fix: [Bug] Pipeline parameters (ParameterInteger, ParameterString) fail in ModelTrain (5504) --- .../test_model_trainer_pipeline_variable.py | 117 +++++++++++++++++- 1 file changed, 116 insertions(+), 1 deletion(-) diff --git a/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py b/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py index 3fd34fa47b..e18630adb3 100644 --- a/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py +++ b/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py @@ -26,7 +26,7 @@ from sagemaker.core.helper.session_helper import Session from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar -from sagemaker.core.workflow.parameters import ParameterString +from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger, ParameterFloat from sagemaker.train.model_trainer import ModelTrainer, Mode from sagemaker.train.configs import ( Compute, @@ -34,6 +34,7 @@ OutputDataConfig, ) from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE +from sagemaker.train.utils import safe_serialize DEFAULT_IMAGE = "000000000000.dkr.ecr.us-west-2.amazonaws.com/dummy-image:latest" @@ -117,6 +118,120 @@ def test_environment_values_accept_parameter_string(self): assert trainer.environment["STATIC_VAR"] == "hello" +class TestSafeSerializePipelineVariable: + """Test that safe_serialize correctly handles PipelineVariable objects (GH#5504).""" + + def test_safe_serialize_returns_pipeline_variable_as_is(self): + """safe_serialize should return PipelineVariable objects without JSON serialization.""" + param = ParameterInteger(name="MaxDepth", default_value=5) + result = safe_serialize(param) + assert result is param + + def test_safe_serialize_returns_parameter_string_as_is(self): + """safe_serialize should return ParameterString objects without JSON serialization.""" + param = ParameterString(name="Algorithm", default_value="xgboost") + result = safe_serialize(param) + assert result is param + + def test_safe_serialize_returns_parameter_float_as_is(self): + """safe_serialize should return ParameterFloat objects without JSON serialization.""" + param = ParameterFloat(name="LearningRate", default_value=0.01) + result = safe_serialize(param) + assert result is param + + def test_safe_serialize_still_handles_plain_string(self): + """safe_serialize should return plain strings as-is.""" + result = safe_serialize("hello") + assert result == "hello" + + def test_safe_serialize_still_handles_int(self): + """safe_serialize should JSON-encode integers.""" + result = safe_serialize(42) + assert result == "42" + + def test_safe_serialize_still_handles_dict(self): + """safe_serialize should JSON-encode dicts.""" + result = safe_serialize({"key": "value"}) + assert result == '{"key": "value"}' + + +class TestModelTrainerHyperparametersPipelineVariable: + """Test that ModelTrainer hyperparameters accept PipelineVariable objects (GH#5504).""" + + def test_hyperparameters_accept_parameter_integer_via_safe_serialize(self): + """ModelTrainer hyperparameters should accept ParameterInteger (GH#5504). + + This is the exact bug scenario: ParameterInteger in hyperparameters + caused TypeError in safe_serialize before the fix. + """ + max_depth = ParameterInteger(name="MaxDepth", default_value=5) + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE, + stopping_condition=DEFAULT_STOPPING, + output_data_config=DEFAULT_OUTPUT, + hyperparameters={"max_depth": max_depth}, + ) + assert trainer.hyperparameters["max_depth"] is max_depth + + def test_hyperparameters_accept_parameter_string_via_safe_serialize(self): + """ModelTrainer hyperparameters should accept ParameterString (GH#5504).""" + objective = ParameterString(name="Objective", default_value="reg:squarederror") + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE, + stopping_condition=DEFAULT_STOPPING, + output_data_config=DEFAULT_OUTPUT, + hyperparameters={"objective": objective}, + ) + assert trainer.hyperparameters["objective"] is objective + + def test_hyperparameters_accept_mixed_pipeline_and_plain_values(self): + """ModelTrainer hyperparameters should accept a mix of PipelineVariable and plain values.""" + max_depth = ParameterInteger(name="MaxDepth", default_value=5) + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE, + stopping_condition=DEFAULT_STOPPING, + output_data_config=DEFAULT_OUTPUT, + hyperparameters={ + "max_depth": max_depth, + "eta": 0.1, + "objective": "reg:squarederror", + }, + ) + assert trainer.hyperparameters["max_depth"] is max_depth + assert trainer.hyperparameters["eta"] == 0.1 + assert trainer.hyperparameters["objective"] == "reg:squarederror" + + @patch("sagemaker.train.model_trainer._get_unique_name", return_value="test-job-20240101") + def test_create_training_job_args_preserves_pipeline_variable_in_hyperparameters( + self, mock_unique_name + ): + """_create_training_job_args should preserve PipelineVariable in hyper_parameters dict. + + When safe_serialize is called on a PipelineVariable, it should return the + PipelineVariable object as-is, not attempt JSON serialization. + """ + max_depth = ParameterInteger(name="MaxDepth", default_value=5) + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE, + stopping_condition=DEFAULT_STOPPING, + output_data_config=DEFAULT_OUTPUT, + hyperparameters={"max_depth": max_depth, "eta": 0.1}, + ) + args = trainer._create_training_job_args() + # PipelineVariable should be preserved as-is by safe_serialize + assert args["hyper_parameters"]["max_depth"] is max_depth + # Plain values should be JSON-serialized to strings + assert args["hyper_parameters"]["eta"] == "0.1" + + class TestModelTrainerRealValuesStillWork: """Regression tests: verify that passing real values still works after the change.""" From 17bcbacb7080832b9af1cfc771472c5239687ec0 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Mon, 30 Mar 2026 17:14:45 -0400 Subject: [PATCH 2/2] fix: address review comments (iteration #1) --- .../tests/unit/train/test_model_trainer_pipeline_variable.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py b/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py index e18630adb3..3f830b5e6d 100644 --- a/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py +++ b/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py @@ -158,7 +158,7 @@ def test_safe_serialize_still_handles_dict(self): class TestModelTrainerHyperparametersPipelineVariable: """Test that ModelTrainer hyperparameters accept PipelineVariable objects (GH#5504).""" - def test_hyperparameters_accept_parameter_integer_via_safe_serialize(self): + def test_hyperparameters_accept_parameter_integer(self): """ModelTrainer hyperparameters should accept ParameterInteger (GH#5504). This is the exact bug scenario: ParameterInteger in hyperparameters @@ -175,7 +175,7 @@ def test_hyperparameters_accept_parameter_integer_via_safe_serialize(self): ) assert trainer.hyperparameters["max_depth"] is max_depth - def test_hyperparameters_accept_parameter_string_via_safe_serialize(self): + def test_hyperparameters_accept_parameter_string(self): """ModelTrainer hyperparameters should accept ParameterString (GH#5504).""" objective = ParameterString(name="Objective", default_value="reg:squarederror") trainer = ModelTrainer( @@ -228,6 +228,7 @@ def test_create_training_job_args_preserves_pipeline_variable_in_hyperparameters args = trainer._create_training_job_args() # PipelineVariable should be preserved as-is by safe_serialize assert args["hyper_parameters"]["max_depth"] is max_depth + assert isinstance(args["hyper_parameters"]["max_depth"], PipelineVariable) # Plain values should be JSON-serialized to strings assert args["hyper_parameters"]["eta"] == "0.1"