|
10 | 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
11 | 11 | # ANY KIND, either express or implied. See the License for the specific |
12 | 12 | # language governing permissions and limitations under the License. |
13 | | -"""Tests for PipelineVariable support in ModelTrainer (GH#5524). |
| 13 | +"""Tests for PipelineVariable support in ModelTrainer. |
14 | 14 |
|
15 | 15 | Verifies that ModelTrainer fields accept PipelineVariable objects |
16 | 16 | (e.g., ParameterString) in addition to their concrete types, following |
17 | 17 | the existing V3 pattern established by SourceCode and OutputDataConfig. |
18 | 18 |
|
19 | | -See: https://github.com/aws/sagemaker-python-sdk/issues/5524 |
| 19 | +Also verifies that safe_serialize correctly handles PipelineVariable objects |
| 20 | +in hyperparameters (returning them as-is instead of attempting json.dumps), |
| 21 | +and that _create_training_job_args preserves PipelineVariable objects through |
| 22 | +the serialization pipeline. |
20 | 23 | """ |
21 | 24 | from __future__ import absolute_import |
22 | 25 |
|
|
26 | 29 |
|
27 | 30 | from sagemaker.core.helper.session_helper import Session |
28 | 31 | from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar |
29 | | -from sagemaker.core.workflow.parameters import ParameterString |
| 32 | +from sagemaker.core.workflow.parameters import ( |
| 33 | + ParameterString, |
| 34 | + ParameterInteger, |
| 35 | + ParameterFloat, |
| 36 | +) |
30 | 37 | from sagemaker.train.model_trainer import ModelTrainer, Mode |
31 | 38 | from sagemaker.train.configs import ( |
32 | 39 | Compute, |
33 | 40 | StoppingCondition, |
34 | 41 | OutputDataConfig, |
35 | 42 | ) |
| 43 | +from sagemaker.train.utils import safe_serialize |
36 | 44 | from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE |
37 | 45 |
|
38 | 46 |
|
@@ -176,3 +184,114 @@ def test_training_image_rejects_invalid_type(self): |
176 | 184 | stopping_condition=DEFAULT_STOPPING, |
177 | 185 | output_data_config=DEFAULT_OUTPUT, |
178 | 186 | ) |
| 187 | + |
| 188 | + |
| 189 | +class TestSafeSerializeWithPipelineVariables: |
| 190 | + """Tests that safe_serialize handles PipelineVariable objects correctly. |
| 191 | +
|
| 192 | + The safe_serialize function must return PipelineVariable objects as-is |
| 193 | + instead of attempting json.dumps(), which would raise TypeError. |
| 194 | + """ |
| 195 | + |
| 196 | + def test_safe_serialize_with_parameter_integer_returns_pipeline_variable(self): |
| 197 | + """safe_serialize should return ParameterInteger as-is.""" |
| 198 | + param = ParameterInteger(name="MaxDepth", default_value=5) |
| 199 | + result = safe_serialize(param) |
| 200 | + assert result is param |
| 201 | + assert isinstance(result, PipelineVariable) |
| 202 | + |
| 203 | + def test_safe_serialize_with_parameter_string_returns_pipeline_variable(self): |
| 204 | + """safe_serialize should return ParameterString as-is.""" |
| 205 | + param = ParameterString(name="Optimizer", default_value="adam") |
| 206 | + result = safe_serialize(param) |
| 207 | + assert result is param |
| 208 | + assert isinstance(result, PipelineVariable) |
| 209 | + |
| 210 | + def test_safe_serialize_with_parameter_float_returns_pipeline_variable(self): |
| 211 | + """safe_serialize should return ParameterFloat as-is.""" |
| 212 | + param = ParameterFloat(name="LearningRate", default_value=0.01) |
| 213 | + result = safe_serialize(param) |
| 214 | + assert result is param |
| 215 | + assert isinstance(result, PipelineVariable) |
| 216 | + |
| 217 | + def test_safe_serialize_still_handles_strings(self): |
| 218 | + """safe_serialize should return plain strings as-is (no quotes wrapping).""" |
| 219 | + result = safe_serialize("hello") |
| 220 | + assert result == "hello" |
| 221 | + |
| 222 | + def test_safe_serialize_still_handles_integers(self): |
| 223 | + """safe_serialize should JSON-encode integers.""" |
| 224 | + result = safe_serialize(42) |
| 225 | + assert result == "42" |
| 226 | + |
| 227 | + def test_safe_serialize_still_handles_dicts(self): |
| 228 | + """safe_serialize should JSON-encode dicts.""" |
| 229 | + result = safe_serialize({"key": "value"}) |
| 230 | + assert result == '{"key": "value"}' |
| 231 | + |
| 232 | + def test_safe_serialize_still_handles_floats(self): |
| 233 | + """safe_serialize should JSON-encode floats.""" |
| 234 | + result = safe_serialize(0.01) |
| 235 | + assert result == "0.01" |
| 236 | + |
| 237 | + def test_safe_serialize_still_handles_booleans(self): |
| 238 | + """safe_serialize should JSON-encode booleans.""" |
| 239 | + assert safe_serialize(True) == "true" |
| 240 | + assert safe_serialize(False) == "false" |
| 241 | + |
| 242 | + |
| 243 | +class TestModelTrainerHyperparametersWithPipelineVariables: |
| 244 | + """Tests that ModelTrainer accepts PipelineVariable objects in hyperparameters.""" |
| 245 | + |
| 246 | + def test_hyperparameters_accept_pipeline_variable_values(self): |
| 247 | + """ModelTrainer should accept PipelineVariable objects as hyperparameter values.""" |
| 248 | + max_depth = ParameterInteger(name="MaxDepth", default_value=5) |
| 249 | + learning_rate = ParameterFloat(name="LearningRate", default_value=0.01) |
| 250 | + optimizer = ParameterString(name="Optimizer", default_value="adam") |
| 251 | + |
| 252 | + trainer = ModelTrainer( |
| 253 | + training_image=DEFAULT_IMAGE, |
| 254 | + role=DEFAULT_ROLE, |
| 255 | + compute=DEFAULT_COMPUTE, |
| 256 | + stopping_condition=DEFAULT_STOPPING, |
| 257 | + output_data_config=DEFAULT_OUTPUT, |
| 258 | + hyperparameters={ |
| 259 | + "max_depth": max_depth, |
| 260 | + "learning_rate": learning_rate, |
| 261 | + "optimizer": optimizer, |
| 262 | + "static_param": 10, |
| 263 | + }, |
| 264 | + ) |
| 265 | + assert trainer.hyperparameters["max_depth"] is max_depth |
| 266 | + assert trainer.hyperparameters["learning_rate"] is learning_rate |
| 267 | + assert trainer.hyperparameters["optimizer"] is optimizer |
| 268 | + assert trainer.hyperparameters["static_param"] == 10 |
| 269 | + |
| 270 | + def test_create_training_job_args_with_pipeline_variable_hyperparameters(self): |
| 271 | + """_create_training_job_args should preserve PipelineVariable in hyper_parameters.""" |
| 272 | + max_depth = ParameterInteger(name="MaxDepth", default_value=5) |
| 273 | + learning_rate = ParameterFloat(name="LearningRate", default_value=0.01) |
| 274 | + |
| 275 | + trainer = ModelTrainer( |
| 276 | + training_image=DEFAULT_IMAGE, |
| 277 | + role=DEFAULT_ROLE, |
| 278 | + compute=DEFAULT_COMPUTE, |
| 279 | + stopping_condition=DEFAULT_STOPPING, |
| 280 | + output_data_config=DEFAULT_OUTPUT, |
| 281 | + hyperparameters={ |
| 282 | + "max_depth": max_depth, |
| 283 | + "learning_rate": learning_rate, |
| 284 | + "epochs": 10, |
| 285 | + "verbose": "true", |
| 286 | + }, |
| 287 | + ) |
| 288 | + |
| 289 | + training_args = trainer._create_training_job_args() |
| 290 | + hyper_params = training_args["hyper_parameters"] |
| 291 | + |
| 292 | + # PipelineVariable objects should be preserved as-is by safe_serialize |
| 293 | + assert hyper_params["max_depth"] is max_depth |
| 294 | + assert hyper_params["learning_rate"] is learning_rate |
| 295 | + # Regular values should be serialized to strings |
| 296 | + assert hyper_params["epochs"] == "10" |
| 297 | + assert hyper_params["verbose"] == "true" |
0 commit comments