Skip to content

Commit 471ee99

Browse files
committed
fix: address review comments (iteration #2)
1 parent a996ac7 commit 471ee99

File tree

3 files changed

+139
-89
lines changed

3 files changed

+139
-89
lines changed

sagemaker-train/src/sagemaker/train/model_trainer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -411,11 +411,13 @@ def __del__(self):
411411

412412
def _validate_training_image_and_algorithm_name(
413413
self,
414-
training_image: "str | PipelineVariable | None",
415-
algorithm_name: "str | PipelineVariable | None",
414+
training_image: Optional[StrPipeVar],
415+
algorithm_name: Optional[StrPipeVar],
416416
):
417417
"""Validate that only one of 'training_image' or 'algorithm_name' is provided."""
418-
# PipelineVariables are truthy for validation purposes
418+
# PipelineVariable objects do not support standard boolean coercion
419+
# (__bool__ raises TypeError), so we use isinstance checks to detect
420+
# them as truthy values during validation.
419421
has_image = isinstance(training_image, PipelineVariable) or bool(training_image)
420422
has_algo = isinstance(algorithm_name, PipelineVariable) or bool(algorithm_name)
421423
if not has_image and not has_algo:
@@ -552,7 +554,10 @@ def model_post_init(self, __context: Any):
552554

553555
if self.training_image:
554556
if isinstance(self.training_image, PipelineVariable):
555-
logger.info("Training image URI: (PipelineVariable - resolved at pipeline execution)")
557+
logger.info(
558+
"Training image URI: "
559+
"(PipelineVariable - resolved at pipeline execution)"
560+
)
556561
else:
557562
logger.info(f"Training image URI: {self.training_image}")
558563

sagemaker-train/src/sagemaker/train/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from datetime import datetime
2424
from typing import Literal, Any
2525

26+
from typing import Union
27+
2628
from sagemaker.core.helper.session_helper import Session
2729
from sagemaker.core.shapes import Unassigned
2830
from sagemaker.train import logger
@@ -144,7 +146,7 @@ def _get_unique_name(base, max_length=63):
144146
return unique_name
145147

146148

147-
def _get_repo_name_from_image(image: "str | PipelineVariable") -> str:
149+
def _get_repo_name_from_image(image: Union[str, PipelineVariable]) -> str:
148150
"""Get the repository name from the image URI.
149151
150152
Example:
@@ -154,7 +156,7 @@ def _get_repo_name_from_image(image: "str | PipelineVariable") -> str:
154156
```
155157
156158
Args:
157-
image (str or PipelineVariable): The image URI
159+
image (str or PipelineVariable): The image URI.
158160
159161
Returns:
160162
str: The repository name

sagemaker-train/tst/unit/sagemaker/train/test_model_trainer_pipeline_variable.py

Lines changed: 126 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,26 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
"""Tests for PipelineVariable support in ModelTrainer."""
14-
from __future__ import absolute_import
14+
from __future__ import annotations
1515

1616
import pytest
1717
from unittest.mock import MagicMock, patch
1818

19-
from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger
19+
from sagemaker.core.workflow.parameters import (
20+
ParameterString,
21+
ParameterInteger,
22+
)
2023
from sagemaker.core.helper.pipeline_variable import PipelineVariable
21-
from sagemaker.train.utils import safe_serialize, _get_repo_name_from_image, _PIPELINE_VARIABLE_IMAGE_PLACEHOLDER
24+
from sagemaker.train.utils import (
25+
safe_serialize,
26+
_get_repo_name_from_image,
27+
_PIPELINE_VARIABLE_IMAGE_PLACEHOLDER,
28+
)
29+
30+
_TEST_IMAGE_URI = (
31+
"683313688378.dkr.ecr.us-east-1.amazonaws.com/"
32+
"sagemaker-xgboost:1.0-1-cpu-py3"
33+
)
2234

2335

2436
class TestSafeSerializeWithPipelineVariable:
@@ -63,13 +75,14 @@ class TestGetRepoNameFromImage:
6375

6476
def test_get_repo_name_from_image_string(self):
6577
"""Test that a normal image URI returns the repo name."""
66-
image = "683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3"
67-
result = _get_repo_name_from_image(image)
78+
result = _get_repo_name_from_image(_TEST_IMAGE_URI)
6879
assert result == "sagemaker-xgboost"
6980

7081
def test_get_repo_name_from_image_pipeline_variable(self):
7182
"""Test that a PipelineVariable returns the placeholder constant."""
72-
param = ParameterString(name="TrainingImage", default_value="some-image")
83+
param = ParameterString(
84+
name="TrainingImage", default_value="some-image"
85+
)
7386
result = _get_repo_name_from_image(param)
7487
assert result == _PIPELINE_VARIABLE_IMAGE_PLACEHOLDER
7588

@@ -80,159 +93,189 @@ def test_get_repo_name_from_image_simple_string(self):
8093

8194
def test_get_repo_name_from_image_with_digest(self):
8295
"""Test with an image URI containing a digest."""
83-
image = "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-repo@sha256:abc123"
96+
image = (
97+
"123456789012.dkr.ecr.us-west-2.amazonaws.com/"
98+
"my-repo@sha256:abc123"
99+
)
84100
result = _get_repo_name_from_image(image)
85101
assert result == "my-repo"
86102

87103

88-
class TestModelTrainerValidationWithPipelineVariable:
89-
"""Tests for ModelTrainer validation with PipelineVariable objects."""
104+
@pytest.fixture
105+
def mock_session():
106+
"""Create a mock SageMaker session."""
107+
session = MagicMock()
108+
session.boto_region_name = "us-east-1"
109+
session.default_bucket.return_value = "my-bucket"
110+
session.default_bucket_prefix = None
111+
return session
90112

91-
@patch("sagemaker.train.model_trainer.TrainDefaults")
92-
def test_training_image_accepts_parameter_string(self, mock_defaults):
93-
"""Test that training_image accepts ParameterString."""
94-
from sagemaker.train.model_trainer import ModelTrainer
95-
from sagemaker.train.configs import Compute
96113

97-
mock_session = MagicMock()
98-
mock_session.boto_region_name = "us-east-1"
99-
mock_session.default_bucket.return_value = "my-bucket"
100-
mock_session.default_bucket_prefix = None
114+
@pytest.fixture
115+
def mock_train_defaults():
116+
"""Patch TrainDefaults for ModelTrainer construction."""
117+
with patch("sagemaker.train.model_trainer.TrainDefaults") as mock_defaults:
118+
from sagemaker.train.configs import Compute
101119

102-
mock_defaults.get_sagemaker_session.return_value = mock_session
103-
mock_defaults.get_role.return_value = "arn:aws:iam::123456789012:role/SageMakerRole"
120+
mock_defaults.get_sagemaker_session.return_value = MagicMock()
121+
mock_defaults.get_role.return_value = (
122+
"arn:aws:iam::123456789012:role/SageMakerRole"
123+
)
104124
mock_defaults.get_base_job_name.return_value = "test-job"
105125
mock_defaults.get_compute.return_value = Compute(
106126
instance_type="ml.m5.xlarge", instance_count=1
107127
)
108128
mock_defaults.get_stopping_condition.return_value = MagicMock()
109129
mock_defaults.get_output_data_config.return_value = MagicMock()
130+
yield mock_defaults
110131

111-
param = ParameterString(name="TrainingImage", default_value="some-image-uri")
132+
133+
class TestModelTrainerValidationWithPipelineVariable:
134+
"""Tests for ModelTrainer validation with PipelineVariable objects."""
135+
136+
def test_training_image_accepts_parameter_string(
137+
self, mock_session, mock_train_defaults
138+
):
139+
"""Test that training_image accepts ParameterString."""
140+
from sagemaker.train.model_trainer import ModelTrainer
141+
from sagemaker.train.configs import Compute
142+
143+
param = ParameterString(
144+
name="TrainingImage", default_value="some-image-uri"
145+
)
112146

113147
# Should not raise
114148
trainer = ModelTrainer(
115149
training_image=param,
116-
compute=Compute(instance_type="ml.m5.xlarge", instance_count=1),
150+
compute=Compute(
151+
instance_type="ml.m5.xlarge", instance_count=1
152+
),
117153
sagemaker_session=mock_session,
118154
role="arn:aws:iam::123456789012:role/SageMakerRole",
119155
)
120156
assert trainer.training_image is param
121157

122-
@patch("sagemaker.train.model_trainer.TrainDefaults")
123-
def test_algorithm_name_accepts_parameter_string(self, mock_defaults):
158+
def test_algorithm_name_accepts_parameter_string(
159+
self, mock_session, mock_train_defaults
160+
):
124161
"""Test that algorithm_name accepts ParameterString."""
125162
from sagemaker.train.model_trainer import ModelTrainer
126163
from sagemaker.train.configs import Compute
127164

128-
mock_session = MagicMock()
129-
mock_session.boto_region_name = "us-east-1"
130-
mock_session.default_bucket.return_value = "my-bucket"
131-
mock_session.default_bucket_prefix = None
132-
133-
mock_defaults.get_sagemaker_session.return_value = mock_session
134-
mock_defaults.get_role.return_value = "arn:aws:iam::123456789012:role/SageMakerRole"
135-
mock_defaults.get_base_job_name.return_value = "test-job"
136-
mock_defaults.get_compute.return_value = Compute(
137-
instance_type="ml.m5.xlarge", instance_count=1
165+
param = ParameterString(
166+
name="AlgorithmName", default_value="some-algo"
138167
)
139-
mock_defaults.get_stopping_condition.return_value = MagicMock()
140-
mock_defaults.get_output_data_config.return_value = MagicMock()
141-
142-
param = ParameterString(name="AlgorithmName", default_value="some-algo")
143168

144169
# Should not raise
145170
trainer = ModelTrainer(
146171
algorithm_name=param,
147-
compute=Compute(instance_type="ml.m5.xlarge", instance_count=1),
172+
compute=Compute(
173+
instance_type="ml.m5.xlarge", instance_count=1
174+
),
148175
sagemaker_session=mock_session,
149176
role="arn:aws:iam::123456789012:role/SageMakerRole",
150177
)
151178
assert trainer.algorithm_name is param
152179

153-
@patch("sagemaker.train.model_trainer.TrainDefaults")
154-
def test_environment_values_accept_parameter_string(self, mock_defaults):
180+
def test_environment_values_accept_parameter_string(
181+
self, mock_session, mock_train_defaults
182+
):
155183
"""Test that environment dict values accept ParameterString."""
156184
from sagemaker.train.model_trainer import ModelTrainer
157185
from sagemaker.train.configs import Compute
158186

159-
mock_session = MagicMock()
160-
mock_session.boto_region_name = "us-east-1"
161-
mock_session.default_bucket.return_value = "my-bucket"
162-
mock_session.default_bucket_prefix = None
163-
164-
mock_defaults.get_sagemaker_session.return_value = mock_session
165-
mock_defaults.get_role.return_value = "arn:aws:iam::123456789012:role/SageMakerRole"
166-
mock_defaults.get_base_job_name.return_value = "test-job"
167-
mock_defaults.get_compute.return_value = Compute(
168-
instance_type="ml.m5.xlarge", instance_count=1
187+
env_param = ParameterString(
188+
name="EnvValue", default_value="val"
169189
)
170-
mock_defaults.get_stopping_condition.return_value = MagicMock()
171-
mock_defaults.get_output_data_config.return_value = MagicMock()
172-
173-
env_param = ParameterString(name="EnvValue", default_value="val")
174190

175191
# Should not raise
176192
trainer = ModelTrainer(
177-
training_image="683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3",
178-
compute=Compute(instance_type="ml.m5.xlarge", instance_count=1),
193+
training_image=_TEST_IMAGE_URI,
194+
compute=Compute(
195+
instance_type="ml.m5.xlarge", instance_count=1
196+
),
179197
sagemaker_session=mock_session,
180198
role="arn:aws:iam::123456789012:role/SageMakerRole",
181199
environment={"MY_VAR": env_param},
182200
)
183201
assert trainer.environment["MY_VAR"] is env_param
184202

185-
@patch("sagemaker.train.model_trainer.TrainDefaults")
186-
def test_plain_string_values_still_work(self, mock_defaults):
203+
def test_plain_string_values_still_work(
204+
self, mock_session, mock_train_defaults
205+
):
187206
"""Regression test: plain string values continue to work."""
188207
from sagemaker.train.model_trainer import ModelTrainer
189208
from sagemaker.train.configs import Compute
190209

191-
mock_session = MagicMock()
192-
mock_session.boto_region_name = "us-east-1"
193-
mock_session.default_bucket.return_value = "my-bucket"
194-
mock_session.default_bucket_prefix = None
195-
196-
mock_defaults.get_sagemaker_session.return_value = mock_session
197-
mock_defaults.get_role.return_value = "arn:aws:iam::123456789012:role/SageMakerRole"
198-
mock_defaults.get_base_job_name.return_value = "test-job"
199-
mock_defaults.get_compute.return_value = Compute(
200-
instance_type="ml.m5.xlarge", instance_count=1
201-
)
202-
mock_defaults.get_stopping_condition.return_value = MagicMock()
203-
mock_defaults.get_output_data_config.return_value = MagicMock()
204-
205210
# Should not raise
206211
trainer = ModelTrainer(
207-
training_image="683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3",
208-
compute=Compute(instance_type="ml.m5.xlarge", instance_count=1),
212+
training_image=_TEST_IMAGE_URI,
213+
compute=Compute(
214+
instance_type="ml.m5.xlarge", instance_count=1
215+
),
209216
sagemaker_session=mock_session,
210217
role="arn:aws:iam::123456789012:role/SageMakerRole",
211218
)
212-
assert trainer.training_image == "683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3"
219+
assert trainer.training_image == _TEST_IMAGE_URI
220+
221+
def test_validation_accepts_pipeline_variable_image_none_algo(self):
222+
"""Test validation accepts PipelineVariable image with None algorithm."""
223+
from sagemaker.train.model_trainer import ModelTrainer
224+
225+
trainer = ModelTrainer.__new__(ModelTrainer)
226+
param = ParameterString(
227+
name="Image", default_value="img"
228+
)
229+
# Should not raise
230+
trainer._validate_training_image_and_algorithm_name(
231+
param, None
232+
)
233+
234+
def test_validation_accepts_none_image_pipeline_variable_algo(self):
235+
"""Test validation accepts None image with PipelineVariable algorithm."""
236+
from sagemaker.train.model_trainer import ModelTrainer
237+
238+
trainer = ModelTrainer.__new__(ModelTrainer)
239+
param = ParameterString(
240+
name="Algo", default_value="algo"
241+
)
242+
# Should not raise
243+
trainer._validate_training_image_and_algorithm_name(
244+
None, param
245+
)
213246

214247
def test_validation_rejects_no_image_or_algorithm(self):
215-
"""Test that validation rejects when neither training_image nor algorithm_name is provided."""
248+
"""Test that validation rejects when neither is provided."""
216249
from sagemaker.train.model_trainer import ModelTrainer
217250

218251
trainer = ModelTrainer.__new__(ModelTrainer)
219252
with pytest.raises(ValueError, match="Atleast one of"):
220-
trainer._validate_training_image_and_algorithm_name(None, None)
253+
trainer._validate_training_image_and_algorithm_name(
254+
None, None
255+
)
221256

222257
def test_validation_rejects_both_image_and_algorithm(self):
223-
"""Test that validation rejects when both training_image and algorithm_name are provided."""
258+
"""Test that validation rejects when both are provided."""
224259
from sagemaker.train.model_trainer import ModelTrainer
225260

226261
trainer = ModelTrainer.__new__(ModelTrainer)
227262
with pytest.raises(ValueError, match="Only one of"):
228-
trainer._validate_training_image_and_algorithm_name("image", "algo")
263+
trainer._validate_training_image_and_algorithm_name(
264+
"image", "algo"
265+
)
229266

230267
def test_validation_rejects_both_pipeline_variables(self):
231268
"""Test that validation rejects when both are PipelineVariables."""
232269
from sagemaker.train.model_trainer import ModelTrainer
233270

234271
trainer = ModelTrainer.__new__(ModelTrainer)
235-
img_param = ParameterString(name="Image", default_value="img")
236-
algo_param = ParameterString(name="Algo", default_value="algo")
272+
img_param = ParameterString(
273+
name="Image", default_value="img"
274+
)
275+
algo_param = ParameterString(
276+
name="Algo", default_value="algo"
277+
)
237278
with pytest.raises(ValueError, match="Only one of"):
238-
trainer._validate_training_image_and_algorithm_name(img_param, algo_param)
279+
trainer._validate_training_image_and_algorithm_name(
280+
img_param, algo_param
281+
)

0 commit comments

Comments
 (0)