1111# ANY KIND, either express or implied. See the License for the specific
1212# language governing permissions and limitations under the License.
1313"""Tests for PipelineVariable support in ModelTrainer."""
14- from __future__ import absolute_import
14+ from __future__ import annotations
1515
1616import pytest
1717from unittest .mock import MagicMock , patch
1818
19- from sagemaker .core .workflow .parameters import ParameterString , ParameterInteger
19+ from sagemaker .core .workflow .parameters import (
20+ ParameterString ,
21+ ParameterInteger ,
22+ )
2023from sagemaker .core .helper .pipeline_variable import PipelineVariable
21- from sagemaker .train .utils import safe_serialize , _get_repo_name_from_image , _PIPELINE_VARIABLE_IMAGE_PLACEHOLDER
24+ from sagemaker .train .utils import (
25+ safe_serialize ,
26+ _get_repo_name_from_image ,
27+ _PIPELINE_VARIABLE_IMAGE_PLACEHOLDER ,
28+ )
29+
30+ _TEST_IMAGE_URI = (
31+ "683313688378.dkr.ecr.us-east-1.amazonaws.com/"
32+ "sagemaker-xgboost:1.0-1-cpu-py3"
33+ )
2234
2335
2436class TestSafeSerializeWithPipelineVariable :
@@ -63,13 +75,14 @@ class TestGetRepoNameFromImage:
6375
6476 def test_get_repo_name_from_image_string (self ):
6577 """Test that a normal image URI returns the repo name."""
66- image = "683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3"
67- result = _get_repo_name_from_image (image )
78+ result = _get_repo_name_from_image (_TEST_IMAGE_URI )
6879 assert result == "sagemaker-xgboost"
6980
7081 def test_get_repo_name_from_image_pipeline_variable (self ):
7182 """Test that a PipelineVariable returns the placeholder constant."""
72- param = ParameterString (name = "TrainingImage" , default_value = "some-image" )
83+ param = ParameterString (
84+ name = "TrainingImage" , default_value = "some-image"
85+ )
7386 result = _get_repo_name_from_image (param )
7487 assert result == _PIPELINE_VARIABLE_IMAGE_PLACEHOLDER
7588
@@ -80,159 +93,189 @@ def test_get_repo_name_from_image_simple_string(self):
8093
8194 def test_get_repo_name_from_image_with_digest (self ):
8295 """Test with an image URI containing a digest."""
83- image = "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-repo@sha256:abc123"
96+ image = (
97+ "123456789012.dkr.ecr.us-west-2.amazonaws.com/"
98+ "my-repo@sha256:abc123"
99+ )
84100 result = _get_repo_name_from_image (image )
85101 assert result == "my-repo"
86102
87103
88- class TestModelTrainerValidationWithPipelineVariable :
89- """Tests for ModelTrainer validation with PipelineVariable objects."""
104+ @pytest .fixture
105+ def mock_session ():
106+ """Create a mock SageMaker session."""
107+ session = MagicMock ()
108+ session .boto_region_name = "us-east-1"
109+ session .default_bucket .return_value = "my-bucket"
110+ session .default_bucket_prefix = None
111+ return session
90112
91- @patch ("sagemaker.train.model_trainer.TrainDefaults" )
92- def test_training_image_accepts_parameter_string (self , mock_defaults ):
93- """Test that training_image accepts ParameterString."""
94- from sagemaker .train .model_trainer import ModelTrainer
95- from sagemaker .train .configs import Compute
96113
97- mock_session = MagicMock ()
98- mock_session .boto_region_name = "us-east-1"
99- mock_session .default_bucket .return_value = "my-bucket"
100- mock_session .default_bucket_prefix = None
114+ @pytest .fixture
115+ def mock_train_defaults ():
116+ """Patch TrainDefaults for ModelTrainer construction."""
117+ with patch ("sagemaker.train.model_trainer.TrainDefaults" ) as mock_defaults :
118+ from sagemaker .train .configs import Compute
101119
102- mock_defaults .get_sagemaker_session .return_value = mock_session
103- mock_defaults .get_role .return_value = "arn:aws:iam::123456789012:role/SageMakerRole"
120+ mock_defaults .get_sagemaker_session .return_value = MagicMock ()
121+ mock_defaults .get_role .return_value = (
122+ "arn:aws:iam::123456789012:role/SageMakerRole"
123+ )
104124 mock_defaults .get_base_job_name .return_value = "test-job"
105125 mock_defaults .get_compute .return_value = Compute (
106126 instance_type = "ml.m5.xlarge" , instance_count = 1
107127 )
108128 mock_defaults .get_stopping_condition .return_value = MagicMock ()
109129 mock_defaults .get_output_data_config .return_value = MagicMock ()
130+ yield mock_defaults
110131
111- param = ParameterString (name = "TrainingImage" , default_value = "some-image-uri" )
132+
133+ class TestModelTrainerValidationWithPipelineVariable :
134+ """Tests for ModelTrainer validation with PipelineVariable objects."""
135+
136+ def test_training_image_accepts_parameter_string (
137+ self , mock_session , mock_train_defaults
138+ ):
139+ """Test that training_image accepts ParameterString."""
140+ from sagemaker .train .model_trainer import ModelTrainer
141+ from sagemaker .train .configs import Compute
142+
143+ param = ParameterString (
144+ name = "TrainingImage" , default_value = "some-image-uri"
145+ )
112146
113147 # Should not raise
114148 trainer = ModelTrainer (
115149 training_image = param ,
116- compute = Compute (instance_type = "ml.m5.xlarge" , instance_count = 1 ),
150+ compute = Compute (
151+ instance_type = "ml.m5.xlarge" , instance_count = 1
152+ ),
117153 sagemaker_session = mock_session ,
118154 role = "arn:aws:iam::123456789012:role/SageMakerRole" ,
119155 )
120156 assert trainer .training_image is param
121157
122- @patch ("sagemaker.train.model_trainer.TrainDefaults" )
123- def test_algorithm_name_accepts_parameter_string (self , mock_defaults ):
158+ def test_algorithm_name_accepts_parameter_string (
159+ self , mock_session , mock_train_defaults
160+ ):
124161 """Test that algorithm_name accepts ParameterString."""
125162 from sagemaker .train .model_trainer import ModelTrainer
126163 from sagemaker .train .configs import Compute
127164
128- mock_session = MagicMock ()
129- mock_session .boto_region_name = "us-east-1"
130- mock_session .default_bucket .return_value = "my-bucket"
131- mock_session .default_bucket_prefix = None
132-
133- mock_defaults .get_sagemaker_session .return_value = mock_session
134- mock_defaults .get_role .return_value = "arn:aws:iam::123456789012:role/SageMakerRole"
135- mock_defaults .get_base_job_name .return_value = "test-job"
136- mock_defaults .get_compute .return_value = Compute (
137- instance_type = "ml.m5.xlarge" , instance_count = 1
165+ param = ParameterString (
166+ name = "AlgorithmName" , default_value = "some-algo"
138167 )
139- mock_defaults .get_stopping_condition .return_value = MagicMock ()
140- mock_defaults .get_output_data_config .return_value = MagicMock ()
141-
142- param = ParameterString (name = "AlgorithmName" , default_value = "some-algo" )
143168
144169 # Should not raise
145170 trainer = ModelTrainer (
146171 algorithm_name = param ,
147- compute = Compute (instance_type = "ml.m5.xlarge" , instance_count = 1 ),
172+ compute = Compute (
173+ instance_type = "ml.m5.xlarge" , instance_count = 1
174+ ),
148175 sagemaker_session = mock_session ,
149176 role = "arn:aws:iam::123456789012:role/SageMakerRole" ,
150177 )
151178 assert trainer .algorithm_name is param
152179
153- @patch ("sagemaker.train.model_trainer.TrainDefaults" )
154- def test_environment_values_accept_parameter_string (self , mock_defaults ):
180+ def test_environment_values_accept_parameter_string (
181+ self , mock_session , mock_train_defaults
182+ ):
155183 """Test that environment dict values accept ParameterString."""
156184 from sagemaker .train .model_trainer import ModelTrainer
157185 from sagemaker .train .configs import Compute
158186
159- mock_session = MagicMock ()
160- mock_session .boto_region_name = "us-east-1"
161- mock_session .default_bucket .return_value = "my-bucket"
162- mock_session .default_bucket_prefix = None
163-
164- mock_defaults .get_sagemaker_session .return_value = mock_session
165- mock_defaults .get_role .return_value = "arn:aws:iam::123456789012:role/SageMakerRole"
166- mock_defaults .get_base_job_name .return_value = "test-job"
167- mock_defaults .get_compute .return_value = Compute (
168- instance_type = "ml.m5.xlarge" , instance_count = 1
187+ env_param = ParameterString (
188+ name = "EnvValue" , default_value = "val"
169189 )
170- mock_defaults .get_stopping_condition .return_value = MagicMock ()
171- mock_defaults .get_output_data_config .return_value = MagicMock ()
172-
173- env_param = ParameterString (name = "EnvValue" , default_value = "val" )
174190
175191 # Should not raise
176192 trainer = ModelTrainer (
177- training_image = "683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3" ,
178- compute = Compute (instance_type = "ml.m5.xlarge" , instance_count = 1 ),
193+ training_image = _TEST_IMAGE_URI ,
194+ compute = Compute (
195+ instance_type = "ml.m5.xlarge" , instance_count = 1
196+ ),
179197 sagemaker_session = mock_session ,
180198 role = "arn:aws:iam::123456789012:role/SageMakerRole" ,
181199 environment = {"MY_VAR" : env_param },
182200 )
183201 assert trainer .environment ["MY_VAR" ] is env_param
184202
185- @patch ("sagemaker.train.model_trainer.TrainDefaults" )
186- def test_plain_string_values_still_work (self , mock_defaults ):
203+ def test_plain_string_values_still_work (
204+ self , mock_session , mock_train_defaults
205+ ):
187206 """Regression test: plain string values continue to work."""
188207 from sagemaker .train .model_trainer import ModelTrainer
189208 from sagemaker .train .configs import Compute
190209
191- mock_session = MagicMock ()
192- mock_session .boto_region_name = "us-east-1"
193- mock_session .default_bucket .return_value = "my-bucket"
194- mock_session .default_bucket_prefix = None
195-
196- mock_defaults .get_sagemaker_session .return_value = mock_session
197- mock_defaults .get_role .return_value = "arn:aws:iam::123456789012:role/SageMakerRole"
198- mock_defaults .get_base_job_name .return_value = "test-job"
199- mock_defaults .get_compute .return_value = Compute (
200- instance_type = "ml.m5.xlarge" , instance_count = 1
201- )
202- mock_defaults .get_stopping_condition .return_value = MagicMock ()
203- mock_defaults .get_output_data_config .return_value = MagicMock ()
204-
205210 # Should not raise
206211 trainer = ModelTrainer (
207- training_image = "683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3" ,
208- compute = Compute (instance_type = "ml.m5.xlarge" , instance_count = 1 ),
212+ training_image = _TEST_IMAGE_URI ,
213+ compute = Compute (
214+ instance_type = "ml.m5.xlarge" , instance_count = 1
215+ ),
209216 sagemaker_session = mock_session ,
210217 role = "arn:aws:iam::123456789012:role/SageMakerRole" ,
211218 )
212- assert trainer .training_image == "683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3"
219+ assert trainer .training_image == _TEST_IMAGE_URI
220+
221+ def test_validation_accepts_pipeline_variable_image_none_algo (self ):
222+ """Test validation accepts PipelineVariable image with None algorithm."""
223+ from sagemaker .train .model_trainer import ModelTrainer
224+
225+ trainer = ModelTrainer .__new__ (ModelTrainer )
226+ param = ParameterString (
227+ name = "Image" , default_value = "img"
228+ )
229+ # Should not raise
230+ trainer ._validate_training_image_and_algorithm_name (
231+ param , None
232+ )
233+
234+ def test_validation_accepts_none_image_pipeline_variable_algo (self ):
235+ """Test validation accepts None image with PipelineVariable algorithm."""
236+ from sagemaker .train .model_trainer import ModelTrainer
237+
238+ trainer = ModelTrainer .__new__ (ModelTrainer )
239+ param = ParameterString (
240+ name = "Algo" , default_value = "algo"
241+ )
242+ # Should not raise
243+ trainer ._validate_training_image_and_algorithm_name (
244+ None , param
245+ )
213246
214247 def test_validation_rejects_no_image_or_algorithm (self ):
215- """Test that validation rejects when neither training_image nor algorithm_name is provided."""
248+ """Test that validation rejects when neither is provided."""
216249 from sagemaker .train .model_trainer import ModelTrainer
217250
218251 trainer = ModelTrainer .__new__ (ModelTrainer )
219252 with pytest .raises (ValueError , match = "Atleast one of" ):
220- trainer ._validate_training_image_and_algorithm_name (None , None )
253+ trainer ._validate_training_image_and_algorithm_name (
254+ None , None
255+ )
221256
222257 def test_validation_rejects_both_image_and_algorithm (self ):
223- """Test that validation rejects when both training_image and algorithm_name are provided."""
258+ """Test that validation rejects when both are provided."""
224259 from sagemaker .train .model_trainer import ModelTrainer
225260
226261 trainer = ModelTrainer .__new__ (ModelTrainer )
227262 with pytest .raises (ValueError , match = "Only one of" ):
228- trainer ._validate_training_image_and_algorithm_name ("image" , "algo" )
263+ trainer ._validate_training_image_and_algorithm_name (
264+ "image" , "algo"
265+ )
229266
230267 def test_validation_rejects_both_pipeline_variables (self ):
231268 """Test that validation rejects when both are PipelineVariables."""
232269 from sagemaker .train .model_trainer import ModelTrainer
233270
234271 trainer = ModelTrainer .__new__ (ModelTrainer )
235- img_param = ParameterString (name = "Image" , default_value = "img" )
236- algo_param = ParameterString (name = "Algo" , default_value = "algo" )
272+ img_param = ParameterString (
273+ name = "Image" , default_value = "img"
274+ )
275+ algo_param = ParameterString (
276+ name = "Algo" , default_value = "algo"
277+ )
237278 with pytest .raises (ValueError , match = "Only one of" ):
238- trainer ._validate_training_image_and_algorithm_name (img_param , algo_param )
279+ trainer ._validate_training_image_and_algorithm_name (
280+ img_param , algo_param
281+ )
0 commit comments