diff --git a/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py b/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py index 3fd34fa47b..24f7d2184e 100644 --- a/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py +++ b/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py @@ -26,7 +26,7 @@ from sagemaker.core.helper.session_helper import Session from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar -from sagemaker.core.workflow.parameters import ParameterString +from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger from sagemaker.train.model_trainer import ModelTrainer, Mode from sagemaker.train.configs import ( Compute, @@ -116,6 +116,123 @@ def test_environment_values_accept_parameter_string(self): assert trainer.environment["DATASET_VERSION"] is param assert trainer.environment["STATIC_VAR"] == "hello" + def test_hyperparameters_accept_parameter_integer(self): + """ModelTrainer.hyperparameters should accept ParameterInteger values (GH#5504).""" + param = ParameterInteger(name="MaxDepth", default_value=5) + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + hyperparameters={"max_depth": param}, + base_job_name="pipeline-test-job", + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE, + stopping_condition=DEFAULT_STOPPING, + output_data_config=DEFAULT_OUTPUT, + ) + assert trainer.hyperparameters["max_depth"] is param + + def test_hyperparameters_accept_parameter_string(self): + """ModelTrainer.hyperparameters should accept ParameterString values (GH#5504).""" + param = ParameterString(name="Algorithm", default_value="xgboost") + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + hyperparameters={"algorithm": param}, + base_job_name="pipeline-test-job", + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE, + stopping_condition=DEFAULT_STOPPING, + output_data_config=DEFAULT_OUTPUT, + ) + assert trainer.hyperparameters["algorithm"] is param + + def test_hyperparameters_accept_mixed_pipeline_and_plain_values(self): + """ModelTrainer.hyperparameters should accept a mix of PipelineVariable and plain values. + + Regression test for GH#5504. + """ + param_int = ParameterInteger(name="MaxDepth", default_value=5) + param_str = ParameterString(name="Objective", default_value="reg:squarederror") + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + hyperparameters={ + "max_depth": param_int, + "objective": param_str, + "eta": 0.1, + "num_round": "100", + }, + base_job_name="pipeline-test-job", + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE, + stopping_condition=DEFAULT_STOPPING, + output_data_config=DEFAULT_OUTPUT, + ) + assert trainer.hyperparameters["max_depth"] is param_int + assert trainer.hyperparameters["objective"] is param_str + assert trainer.hyperparameters["eta"] == 0.1 + assert trainer.hyperparameters["num_round"] == "100" + + +class TestSafeSerializePipelineVariable: + """Test that safe_serialize correctly preserves PipelineVariable objects (GH#5504).""" + + def test_safe_serialize_preserves_parameter_integer(self): + """safe_serialize should return PipelineVariable as-is, not stringify it.""" + from sagemaker.train.utils import safe_serialize + + param = ParameterInteger(name="MaxDepth", default_value=5) + result = safe_serialize(param) + assert result is param + assert isinstance(result, PipelineVariable) + + def test_safe_serialize_preserves_parameter_string(self): + """safe_serialize should return ParameterString as-is.""" + from sagemaker.train.utils import safe_serialize + + param = ParameterString(name="Objective", default_value="reg:squarederror") + result = safe_serialize(param) + assert result is param + assert isinstance(result, PipelineVariable) + + def test_safe_serialize_still_serializes_plain_values(self): + """safe_serialize should still JSON-serialize plain values.""" + from sagemaker.train.utils import safe_serialize + + assert safe_serialize(42) == "42" + assert safe_serialize("hello") == '"hello"' + assert safe_serialize(0.1) == "0.1" + + def test_create_training_job_args_preserves_pipeline_hyperparameters( + self, + ): + """_create_training_job_args should preserve PipelineVariable in hyperparameters. + + Regression test for GH#5504. + """ + param_int = ParameterInteger(name="MaxDepth", default_value=5) + param_str = ParameterString( + name="Objective", default_value="reg:squarederror" + ) + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + hyperparameters={ + "max_depth": param_int, + "objective": param_str, + "eta": 0.1, + "num_round": "100", + }, + base_job_name="pipeline-test-job", + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE, + stopping_condition=DEFAULT_STOPPING, + output_data_config=DEFAULT_OUTPUT, + ) + args = trainer._create_training_job_args() + hp = args["hyperparameters"] + assert hp["max_depth"] is param_int + assert hp["objective"] is param_str + # Plain values should be JSON-serialized strings + assert hp["eta"] == "0.1" + assert hp["num_round"] == '"100"' + class TestModelTrainerRealValuesStillWork: """Regression tests: verify that passing real values still works after the change."""