|
26 | 26 |
|
27 | 27 | from sagemaker.core.helper.session_helper import Session |
28 | 28 | from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar |
29 | | -from sagemaker.core.workflow.parameters import ParameterString |
| 29 | +from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger |
30 | 30 | from sagemaker.train.model_trainer import ModelTrainer, Mode |
31 | 31 | from sagemaker.train.configs import ( |
32 | 32 | Compute, |
33 | 33 | StoppingCondition, |
34 | 34 | OutputDataConfig, |
35 | 35 | ) |
36 | | -from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE, TrainDefaults |
| 36 | +from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE |
| 37 | +from sagemaker.train.utils import _get_repo_name_from_image, safe_serialize |
37 | 38 |
|
38 | 39 |
|
39 | 40 | DEFAULT_IMAGE = "000000000000.dkr.ecr.us-west-2.amazonaws.com/dummy-image:latest" |
@@ -176,3 +177,116 @@ def test_training_image_rejects_invalid_type(self): |
176 | 177 | stopping_condition=DEFAULT_STOPPING, |
177 | 178 | output_data_config=DEFAULT_OUTPUT, |
178 | 179 | ) |
| 180 | + |
| 181 | + |
| 182 | +class TestValidateTrainingImageAndAlgorithmName: |
| 183 | + """Tests for _validate_training_image_and_algorithm_name with PipelineVariable.""" |
| 184 | + |
| 185 | + def test_pipeline_variable_training_image_passes_validation(self): |
| 186 | + """PipelineVariable as training_image should pass validation.""" |
| 187 | + param = ParameterString(name="TrainingImage", default_value=DEFAULT_IMAGE) |
| 188 | + trainer = ModelTrainer( |
| 189 | + training_image=param, |
| 190 | + base_job_name="pipeline-test-job", |
| 191 | + role=DEFAULT_ROLE, |
| 192 | + compute=DEFAULT_COMPUTE, |
| 193 | + stopping_condition=DEFAULT_STOPPING, |
| 194 | + output_data_config=DEFAULT_OUTPUT, |
| 195 | + ) |
| 196 | + assert trainer.training_image is param |
| 197 | + |
| 198 | + def test_pipeline_variable_algorithm_name_passes_validation(self): |
| 199 | + """PipelineVariable as algorithm_name should pass validation.""" |
| 200 | + param = ParameterString(name="AlgoName", default_value="my-algo") |
| 201 | + trainer = ModelTrainer( |
| 202 | + algorithm_name=param, |
| 203 | + base_job_name="pipeline-test-job", |
| 204 | + role=DEFAULT_ROLE, |
| 205 | + compute=DEFAULT_COMPUTE, |
| 206 | + stopping_condition=DEFAULT_STOPPING, |
| 207 | + output_data_config=DEFAULT_OUTPUT, |
| 208 | + ) |
| 209 | + assert trainer.algorithm_name is param |
| 210 | + |
| 211 | + def test_both_pipeline_variables_raises_value_error(self): |
| 212 | + """Both training_image and algorithm_name as PipelineVariable should raise ValueError.""" |
| 213 | + image_param = ParameterString(name="TrainingImage", default_value=DEFAULT_IMAGE) |
| 214 | + algo_param = ParameterString(name="AlgoName", default_value="my-algo") |
| 215 | + with pytest.raises(ValueError, match="Only one of"): |
| 216 | + ModelTrainer( |
| 217 | + training_image=image_param, |
| 218 | + algorithm_name=algo_param, |
| 219 | + base_job_name="pipeline-test-job", |
| 220 | + role=DEFAULT_ROLE, |
| 221 | + compute=DEFAULT_COMPUTE, |
| 222 | + stopping_condition=DEFAULT_STOPPING, |
| 223 | + output_data_config=DEFAULT_OUTPUT, |
| 224 | + ) |
| 225 | + |
| 226 | + def test_neither_provided_raises_value_error(self): |
| 227 | + """Neither training_image nor algorithm_name should raise ValueError.""" |
| 228 | + with pytest.raises(ValueError, match="Atleast one of"): |
| 229 | + ModelTrainer( |
| 230 | + training_image=None, |
| 231 | + algorithm_name=None, |
| 232 | + base_job_name="pipeline-test-job", |
| 233 | + role=DEFAULT_ROLE, |
| 234 | + compute=DEFAULT_COMPUTE, |
| 235 | + stopping_condition=DEFAULT_STOPPING, |
| 236 | + output_data_config=DEFAULT_OUTPUT, |
| 237 | + ) |
| 238 | + |
| 239 | + |
| 240 | +class TestGetRepoNameFromImage: |
| 241 | + """Tests for _get_repo_name_from_image with PipelineVariable.""" |
| 242 | + |
| 243 | + def test_returns_none_for_pipeline_variable(self): |
| 244 | + """_get_repo_name_from_image should return None for PipelineVariable.""" |
| 245 | + param = ParameterString(name="TrainingImage", default_value=DEFAULT_IMAGE) |
| 246 | + result = _get_repo_name_from_image(param) |
| 247 | + assert result is None |
| 248 | + |
| 249 | + def test_returns_repo_name_for_string(self): |
| 250 | + """_get_repo_name_from_image should return repo name for a normal string.""" |
| 251 | + result = _get_repo_name_from_image( |
| 252 | + "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-repo:latest" |
| 253 | + ) |
| 254 | + assert result == "my-repo" |
| 255 | + |
| 256 | + def test_returns_repo_name_without_tag(self): |
| 257 | + """_get_repo_name_from_image should handle image URIs without tags.""" |
| 258 | + result = _get_repo_name_from_image( |
| 259 | + "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-repo" |
| 260 | + ) |
| 261 | + assert result == "my-repo" |
| 262 | + |
| 263 | + |
| 264 | +class TestSafeSerialize: |
| 265 | + """Tests for safe_serialize with PipelineVariable.""" |
| 266 | + |
| 267 | + def test_safe_serialize_pipeline_variable_returns_variable(self): |
| 268 | + """safe_serialize should return the PipelineVariable object as-is.""" |
| 269 | + param = ParameterInteger(name="MaxDepth", default_value=5) |
| 270 | + result = safe_serialize(param) |
| 271 | + assert result is param |
| 272 | + |
| 273 | + def test_safe_serialize_string_returns_string(self): |
| 274 | + """safe_serialize should return strings as-is.""" |
| 275 | + result = safe_serialize("hello") |
| 276 | + assert result == "hello" |
| 277 | + |
| 278 | + def test_safe_serialize_int_returns_json(self): |
| 279 | + """safe_serialize should JSON-encode integers.""" |
| 280 | + result = safe_serialize(5) |
| 281 | + assert result == "5" |
| 282 | + |
| 283 | + def test_safe_serialize_dict_returns_json(self): |
| 284 | + """safe_serialize should JSON-encode dicts.""" |
| 285 | + result = safe_serialize({"key": "value"}) |
| 286 | + assert result == '{"key": "value"}' |
| 287 | + |
| 288 | + def test_safe_serialize_parameter_string_returns_variable(self): |
| 289 | + """safe_serialize should return ParameterString as-is.""" |
| 290 | + param = ParameterString(name="MyParam", default_value="val") |
| 291 | + result = safe_serialize(param) |
| 292 | + assert result is param |
0 commit comments