|
26 | 26 |
|
27 | 27 | from sagemaker.core.helper.session_helper import Session |
28 | 28 | from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar |
29 | | -from sagemaker.core.workflow.parameters import ParameterString |
| 29 | +from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger, ParameterFloat |
30 | 30 | from sagemaker.train.model_trainer import ModelTrainer, Mode |
| 31 | +from sagemaker.train.utils import safe_serialize |
31 | 32 | from sagemaker.train.configs import ( |
32 | 33 | Compute, |
33 | 34 | StoppingCondition, |
@@ -116,6 +117,54 @@ def test_environment_values_accept_parameter_string(self): |
116 | 117 | assert trainer.environment["DATASET_VERSION"] is param |
117 | 118 | assert trainer.environment["STATIC_VAR"] == "hello" |
118 | 119 |
|
| 120 | + def test_hyperparameters_accept_parameter_integer(self): |
| 121 | + """ModelTrainer.hyperparameters should accept ParameterInteger values (GH#5504).""" |
| 122 | + param = ParameterInteger(name="MaxDepth", default_value=5) |
| 123 | + trainer = ModelTrainer( |
| 124 | + training_image=DEFAULT_IMAGE, |
| 125 | + hyperparameters={"max_depth": param}, |
| 126 | + role=DEFAULT_ROLE, |
| 127 | + compute=DEFAULT_COMPUTE, |
| 128 | + stopping_condition=DEFAULT_STOPPING, |
| 129 | + output_data_config=DEFAULT_OUTPUT, |
| 130 | + ) |
| 131 | + assert trainer.hyperparameters["max_depth"] is param |
| 132 | + |
| 133 | + def test_hyperparameters_accept_parameter_string(self): |
| 134 | + """ModelTrainer.hyperparameters should accept ParameterString values (GH#5504).""" |
| 135 | + param = ParameterString(name="LearningRate", default_value="0.01") |
| 136 | + trainer = ModelTrainer( |
| 137 | + training_image=DEFAULT_IMAGE, |
| 138 | + hyperparameters={"learning_rate": param}, |
| 139 | + role=DEFAULT_ROLE, |
| 140 | + compute=DEFAULT_COMPUTE, |
| 141 | + stopping_condition=DEFAULT_STOPPING, |
| 142 | + output_data_config=DEFAULT_OUTPUT, |
| 143 | + ) |
| 144 | + assert trainer.hyperparameters["learning_rate"] is param |
| 145 | + |
| 146 | + def test_hyperparameters_accept_mixed_pipeline_and_static_values(self): |
| 147 | + """ModelTrainer.hyperparameters should accept a mix of PipelineVariable and static values.""" |
| 148 | + param_int = ParameterInteger(name="MaxDepth", default_value=5) |
| 149 | + param_str = ParameterString(name="Objective", default_value="reg:squarederror") |
| 150 | + trainer = ModelTrainer( |
| 151 | + training_image=DEFAULT_IMAGE, |
| 152 | + hyperparameters={ |
| 153 | + "max_depth": param_int, |
| 154 | + "objective": param_str, |
| 155 | + "eta": 0.1, |
| 156 | + "num_round": 100, |
| 157 | + }, |
| 158 | + role=DEFAULT_ROLE, |
| 159 | + compute=DEFAULT_COMPUTE, |
| 160 | + stopping_condition=DEFAULT_STOPPING, |
| 161 | + output_data_config=DEFAULT_OUTPUT, |
| 162 | + ) |
| 163 | + assert trainer.hyperparameters["max_depth"] is param_int |
| 164 | + assert trainer.hyperparameters["objective"] is param_str |
| 165 | + assert trainer.hyperparameters["eta"] == 0.1 |
| 166 | + assert trainer.hyperparameters["num_round"] == 100 |
| 167 | + |
119 | 168 |
|
120 | 169 | class TestModelTrainerRealValuesStillWork: |
121 | 170 | """Regression tests: verify that passing real values still works after the change.""" |
@@ -176,3 +225,44 @@ def test_training_image_rejects_invalid_type(self): |
176 | 225 | stopping_condition=DEFAULT_STOPPING, |
177 | 226 | output_data_config=DEFAULT_OUTPUT, |
178 | 227 | ) |
| 228 | + |
| 229 | + |
| 230 | +class TestSafeSerializePipelineVariable: |
| 231 | + """Test that safe_serialize correctly handles PipelineVariable objects (GH#5504). |
| 232 | +
|
| 233 | + The bug was that safe_serialize would call json.dumps() on PipelineVariable objects, |
| 234 | + causing TypeError. The fix returns PipelineVariable objects as-is. |
| 235 | + """ |
| 236 | + |
| 237 | + def test_safe_serialize_with_parameter_integer(self): |
| 238 | + """safe_serialize should return ParameterInteger as-is, not attempt JSON serialization.""" |
| 239 | + param = ParameterInteger(name="MaxDepth", default_value=5) |
| 240 | + result = safe_serialize(param) |
| 241 | + assert result is param |
| 242 | + |
| 243 | + def test_safe_serialize_with_parameter_string(self): |
| 244 | + """safe_serialize should return ParameterString as-is, not attempt JSON serialization.""" |
| 245 | + param = ParameterString(name="Objective", default_value="reg:squarederror") |
| 246 | + result = safe_serialize(param) |
| 247 | + assert result is param |
| 248 | + |
| 249 | + def test_safe_serialize_with_parameter_float(self): |
| 250 | + """safe_serialize should return ParameterFloat as-is, not attempt JSON serialization.""" |
| 251 | + param = ParameterFloat(name="LearningRate", default_value=0.01) |
| 252 | + result = safe_serialize(param) |
| 253 | + assert result is param |
| 254 | + |
| 255 | + def test_safe_serialize_with_plain_string(self): |
| 256 | + """safe_serialize should return plain strings unchanged.""" |
| 257 | + result = safe_serialize("hello") |
| 258 | + assert result == "hello" |
| 259 | + |
| 260 | + def test_safe_serialize_with_int(self): |
| 261 | + """safe_serialize should JSON-dump integers to their string representation.""" |
| 262 | + result = safe_serialize(5) |
| 263 | + assert result == "5" |
| 264 | + |
| 265 | + def test_safe_serialize_with_dict(self): |
| 266 | + """safe_serialize should JSON-dump dicts to their string representation.""" |
| 267 | + result = safe_serialize({"key": "value"}) |
| 268 | + assert result == '{"key": "value"}' |
0 commit comments