|
26 | 26 |
|
27 | 27 | from sagemaker.core.helper.session_helper import Session |
28 | 28 | from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar |
29 | | -from sagemaker.core.workflow.parameters import ParameterString |
| 29 | +from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger, ParameterFloat |
30 | 30 | from sagemaker.train.model_trainer import ModelTrainer, Mode |
31 | 31 | from sagemaker.train.configs import ( |
32 | 32 | Compute, |
33 | 33 | StoppingCondition, |
34 | 34 | OutputDataConfig, |
35 | 35 | ) |
| 36 | +from sagemaker.core.workflow.pipeline_context import PipelineSession |
36 | 37 | from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE |
37 | 38 |
|
38 | 39 |
|
@@ -176,3 +177,61 @@ def test_training_image_rejects_invalid_type(self): |
176 | 177 | stopping_condition=DEFAULT_STOPPING, |
177 | 178 | output_data_config=DEFAULT_OUTPUT, |
178 | 179 | ) |
| 180 | + |
| 181 | + |
| 182 | +class TestModelTrainerHyperparametersPipelineVariable: |
| 183 | + """Test that PipelineVariable objects in hyperparameters survive safe_serialize.""" |
| 184 | + |
| 185 | + def test_hyperparameters_with_pipeline_variable_integer(self): |
| 186 | + """ParameterInteger in hyperparameters should be passed through as-is.""" |
| 187 | + max_depth = ParameterInteger(name="MaxDepth", default_value=5) |
| 188 | + trainer = ModelTrainer( |
| 189 | + training_image=DEFAULT_IMAGE, |
| 190 | + role=DEFAULT_ROLE, |
| 191 | + compute=DEFAULT_COMPUTE, |
| 192 | + stopping_condition=DEFAULT_STOPPING, |
| 193 | + output_data_config=DEFAULT_OUTPUT, |
| 194 | + hyperparameters={"max_depth": max_depth}, |
| 195 | + ) |
| 196 | + # safe_serialize should return the PipelineVariable object directly |
| 197 | + from sagemaker.train.utils import safe_serialize |
| 198 | + result = safe_serialize(max_depth) |
| 199 | + assert result is max_depth |
| 200 | + |
| 201 | + def test_hyperparameters_with_pipeline_variable_string(self): |
| 202 | + """ParameterString in hyperparameters should be passed through as-is.""" |
| 203 | + optimizer = ParameterString(name="Optimizer", default_value="sgd") |
| 204 | + trainer = ModelTrainer( |
| 205 | + training_image=DEFAULT_IMAGE, |
| 206 | + role=DEFAULT_ROLE, |
| 207 | + compute=DEFAULT_COMPUTE, |
| 208 | + stopping_condition=DEFAULT_STOPPING, |
| 209 | + output_data_config=DEFAULT_OUTPUT, |
| 210 | + hyperparameters={"optimizer": optimizer}, |
| 211 | + ) |
| 212 | + from sagemaker.train.utils import safe_serialize |
| 213 | + result = safe_serialize(optimizer) |
| 214 | + assert result is optimizer |
| 215 | + |
| 216 | + def test_hyperparameters_with_mixed_pipeline_and_regular_values(self): |
| 217 | + """Mixed PipelineVariable and regular values should both serialize correctly.""" |
| 218 | + max_depth = ParameterInteger(name="MaxDepth", default_value=5) |
| 219 | + trainer = ModelTrainer( |
| 220 | + training_image=DEFAULT_IMAGE, |
| 221 | + role=DEFAULT_ROLE, |
| 222 | + compute=DEFAULT_COMPUTE, |
| 223 | + stopping_condition=DEFAULT_STOPPING, |
| 224 | + output_data_config=DEFAULT_OUTPUT, |
| 225 | + hyperparameters={ |
| 226 | + "max_depth": max_depth, |
| 227 | + "eta": 0.1, |
| 228 | + "objective": "binary:logistic", |
| 229 | + }, |
| 230 | + ) |
| 231 | + from sagemaker.train.utils import safe_serialize |
| 232 | + # PipelineVariable should be returned as-is |
| 233 | + assert safe_serialize(max_depth) is max_depth |
| 234 | + # Float should be JSON-serialized |
| 235 | + assert safe_serialize(0.1) == "0.1" |
| 236 | + # String should be returned as-is |
| 237 | + assert safe_serialize("binary:logistic") == "binary:logistic" |
0 commit comments