|
26 | 26 |
|
27 | 27 | from sagemaker.core.helper.session_helper import Session |
28 | 28 | from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar |
29 | | -from sagemaker.core.workflow.parameters import ParameterString |
| 29 | +from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger, ParameterFloat |
30 | 30 | from sagemaker.train.model_trainer import ModelTrainer, Mode |
31 | 31 | from sagemaker.train.configs import ( |
32 | 32 | Compute, |
33 | 33 | StoppingCondition, |
34 | 34 | OutputDataConfig, |
35 | 35 | ) |
36 | 36 | from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE |
| 37 | +from sagemaker.train.utils import safe_serialize |
37 | 38 |
|
38 | 39 |
|
39 | 40 | DEFAULT_IMAGE = "000000000000.dkr.ecr.us-west-2.amazonaws.com/dummy-image:latest" |
@@ -117,6 +118,120 @@ def test_environment_values_accept_parameter_string(self): |
117 | 118 | assert trainer.environment["STATIC_VAR"] == "hello" |
118 | 119 |
|
119 | 120 |
|
| 121 | +class TestSafeSerializePipelineVariable: |
| 122 | + """Test that safe_serialize correctly handles PipelineVariable objects (GH#5504).""" |
| 123 | + |
| 124 | + def test_safe_serialize_returns_pipeline_variable_as_is(self): |
| 125 | + """safe_serialize should return PipelineVariable objects without JSON serialization.""" |
| 126 | + param = ParameterInteger(name="MaxDepth", default_value=5) |
| 127 | + result = safe_serialize(param) |
| 128 | + assert result is param |
| 129 | + |
| 130 | + def test_safe_serialize_returns_parameter_string_as_is(self): |
| 131 | + """safe_serialize should return ParameterString objects without JSON serialization.""" |
| 132 | + param = ParameterString(name="Algorithm", default_value="xgboost") |
| 133 | + result = safe_serialize(param) |
| 134 | + assert result is param |
| 135 | + |
| 136 | + def test_safe_serialize_returns_parameter_float_as_is(self): |
| 137 | + """safe_serialize should return ParameterFloat objects without JSON serialization.""" |
| 138 | + param = ParameterFloat(name="LearningRate", default_value=0.01) |
| 139 | + result = safe_serialize(param) |
| 140 | + assert result is param |
| 141 | + |
| 142 | + def test_safe_serialize_still_handles_plain_string(self): |
| 143 | + """safe_serialize should return plain strings as-is.""" |
| 144 | + result = safe_serialize("hello") |
| 145 | + assert result == "hello" |
| 146 | + |
| 147 | + def test_safe_serialize_still_handles_int(self): |
| 148 | + """safe_serialize should JSON-encode integers.""" |
| 149 | + result = safe_serialize(42) |
| 150 | + assert result == "42" |
| 151 | + |
| 152 | + def test_safe_serialize_still_handles_dict(self): |
| 153 | + """safe_serialize should JSON-encode dicts.""" |
| 154 | + result = safe_serialize({"key": "value"}) |
| 155 | + assert result == '{"key": "value"}' |
| 156 | + |
| 157 | + |
| 158 | +class TestModelTrainerHyperparametersPipelineVariable: |
| 159 | + """Test that ModelTrainer hyperparameters accept PipelineVariable objects (GH#5504).""" |
| 160 | + |
| 161 | + def test_hyperparameters_accept_parameter_integer_via_safe_serialize(self): |
| 162 | + """ModelTrainer hyperparameters should accept ParameterInteger (GH#5504). |
| 163 | +
|
| 164 | + This is the exact bug scenario: ParameterInteger in hyperparameters |
| 165 | + caused TypeError in safe_serialize before the fix. |
| 166 | + """ |
| 167 | + max_depth = ParameterInteger(name="MaxDepth", default_value=5) |
| 168 | + trainer = ModelTrainer( |
| 169 | + training_image=DEFAULT_IMAGE, |
| 170 | + role=DEFAULT_ROLE, |
| 171 | + compute=DEFAULT_COMPUTE, |
| 172 | + stopping_condition=DEFAULT_STOPPING, |
| 173 | + output_data_config=DEFAULT_OUTPUT, |
| 174 | + hyperparameters={"max_depth": max_depth}, |
| 175 | + ) |
| 176 | + assert trainer.hyperparameters["max_depth"] is max_depth |
| 177 | + |
| 178 | + def test_hyperparameters_accept_parameter_string_via_safe_serialize(self): |
| 179 | + """ModelTrainer hyperparameters should accept ParameterString (GH#5504).""" |
| 180 | + objective = ParameterString(name="Objective", default_value="reg:squarederror") |
| 181 | + trainer = ModelTrainer( |
| 182 | + training_image=DEFAULT_IMAGE, |
| 183 | + role=DEFAULT_ROLE, |
| 184 | + compute=DEFAULT_COMPUTE, |
| 185 | + stopping_condition=DEFAULT_STOPPING, |
| 186 | + output_data_config=DEFAULT_OUTPUT, |
| 187 | + hyperparameters={"objective": objective}, |
| 188 | + ) |
| 189 | + assert trainer.hyperparameters["objective"] is objective |
| 190 | + |
| 191 | + def test_hyperparameters_accept_mixed_pipeline_and_plain_values(self): |
| 192 | + """ModelTrainer hyperparameters should accept a mix of PipelineVariable and plain values.""" |
| 193 | + max_depth = ParameterInteger(name="MaxDepth", default_value=5) |
| 194 | + trainer = ModelTrainer( |
| 195 | + training_image=DEFAULT_IMAGE, |
| 196 | + role=DEFAULT_ROLE, |
| 197 | + compute=DEFAULT_COMPUTE, |
| 198 | + stopping_condition=DEFAULT_STOPPING, |
| 199 | + output_data_config=DEFAULT_OUTPUT, |
| 200 | + hyperparameters={ |
| 201 | + "max_depth": max_depth, |
| 202 | + "eta": 0.1, |
| 203 | + "objective": "reg:squarederror", |
| 204 | + }, |
| 205 | + ) |
| 206 | + assert trainer.hyperparameters["max_depth"] is max_depth |
| 207 | + assert trainer.hyperparameters["eta"] == 0.1 |
| 208 | + assert trainer.hyperparameters["objective"] == "reg:squarederror" |
| 209 | + |
| 210 | + @patch("sagemaker.train.model_trainer._get_unique_name", return_value="test-job-20240101") |
| 211 | + def test_create_training_job_args_preserves_pipeline_variable_in_hyperparameters( |
| 212 | + self, mock_unique_name |
| 213 | + ): |
| 214 | + """_create_training_job_args should preserve PipelineVariable in hyper_parameters dict. |
| 215 | +
|
| 216 | + When safe_serialize is called on a PipelineVariable, it should return the |
| 217 | + PipelineVariable object as-is, not attempt JSON serialization. |
| 218 | + """ |
| 219 | + max_depth = ParameterInteger(name="MaxDepth", default_value=5) |
| 220 | + trainer = ModelTrainer( |
| 221 | + training_image=DEFAULT_IMAGE, |
| 222 | + role=DEFAULT_ROLE, |
| 223 | + compute=DEFAULT_COMPUTE, |
| 224 | + stopping_condition=DEFAULT_STOPPING, |
| 225 | + output_data_config=DEFAULT_OUTPUT, |
| 226 | + hyperparameters={"max_depth": max_depth, "eta": 0.1}, |
| 227 | + ) |
| 228 | + args = trainer._create_training_job_args() |
| 229 | + # PipelineVariable should be preserved as-is by safe_serialize |
| 230 | + assert args["hyper_parameters"]["max_depth"] is max_depth |
| 231 | + # Plain values should be JSON-serialized to strings |
| 232 | + assert args["hyper_parameters"]["eta"] == "0.1" |
| 233 | + |
| 234 | + |
120 | 235 | class TestModelTrainerRealValuesStillWork: |
121 | 236 | """Regression tests: verify that passing real values still works after the change.""" |
122 | 237 |
|
|
0 commit comments