Skip to content

Commit a996ac7

Browse files
committed
fix: address review comments (iteration #1)
1 parent 9450aee commit a996ac7

File tree

3 files changed

+249
-9
lines changed

3 files changed

+249
-9
lines changed

sagemaker-train/src/sagemaker/train/model_trainer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@
116116
from sagemaker.core.jumpstart.utils import get_eula_url
117117
from sagemaker.train.defaults import TrainDefaults, JumpStartTrainDefaults
118118
from sagemaker.core.workflow.pipeline_context import PipelineSession, runnable_by_pipeline
119-
from sagemaker.core.helper.pipeline_variable import StrPipeVar
119+
from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar
120120

121121
from sagemaker.train.local.local_container import _LocalContainer
122122

@@ -410,13 +410,14 @@ def __del__(self):
410410
self._temp_code_dir.cleanup()
411411

412412
def _validate_training_image_and_algorithm_name(
413-
self, training_image, algorithm_name
413+
self,
414+
training_image: "str | PipelineVariable | None",
415+
algorithm_name: "str | PipelineVariable | None",
414416
):
415417
"""Validate that only one of 'training_image' or 'algorithm_name' is provided."""
416-
from sagemaker.core.helper.pipeline_variable import PipelineVariable as _PV
417418
# PipelineVariables are truthy for validation purposes
418-
has_image = isinstance(training_image, _PV) or bool(training_image)
419-
has_algo = isinstance(algorithm_name, _PV) or bool(algorithm_name)
419+
has_image = isinstance(training_image, PipelineVariable) or bool(training_image)
420+
has_algo = isinstance(algorithm_name, PipelineVariable) or bool(algorithm_name)
420421
if not has_image and not has_algo:
421422
raise ValueError(
422423
"Atleast one of 'training_image' or 'algorithm_name' must be provided.",
@@ -550,7 +551,6 @@ def model_post_init(self, __context: Any):
550551
)
551552

552553
if self.training_image:
553-
from sagemaker.core.helper.pipeline_variable import PipelineVariable
554554
if isinstance(self.training_image, PipelineVariable):
555555
logger.info("Training image URI: (PipelineVariable - resolved at pipeline execution)")
556556
else:

sagemaker-train/src/sagemaker/train/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from sagemaker.train import logger
2929
from sagemaker.core.workflow.parameters import PipelineVariable
3030

31+
_PIPELINE_VARIABLE_IMAGE_PLACEHOLDER = "pipeline-variable-image"
32+
3133

3234
def _default_bucket_and_prefix(session: Session) -> str:
3335
"""Helper function to get the bucket name with the corresponding prefix if applicable
@@ -142,7 +144,7 @@ def _get_unique_name(base, max_length=63):
142144
return unique_name
143145

144146

145-
def _get_repo_name_from_image(image) -> str:
147+
def _get_repo_name_from_image(image: "str | PipelineVariable") -> str:
146148
"""Get the repository name from the image URI.
147149
148150
Example:
@@ -152,13 +154,13 @@ def _get_repo_name_from_image(image) -> str:
152154
```
153155
154156
Args:
155-
image: The image URI (str or PipelineVariable)
157+
image (str or PipelineVariable): The image URI
156158
157159
Returns:
158160
str: The repository name
159161
"""
160162
if isinstance(image, PipelineVariable):
161-
return "pipeline-variable-image"
163+
return _PIPELINE_VARIABLE_IMAGE_PLACEHOLDER
162164
return image.split("/")[-1].split(":")[0].split("@")[0]
163165

164166

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Tests for PipelineVariable support in ModelTrainer."""
14+
from __future__ import absolute_import
15+
16+
import pytest
17+
from unittest.mock import MagicMock, patch
18+
19+
from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger
20+
from sagemaker.core.helper.pipeline_variable import PipelineVariable
21+
from sagemaker.train.utils import safe_serialize, _get_repo_name_from_image, _PIPELINE_VARIABLE_IMAGE_PLACEHOLDER
22+
23+
24+
class TestSafeSerializeWithPipelineVariable:
25+
"""Tests for safe_serialize handling of PipelineVariable objects."""
26+
27+
def test_safe_serialize_string(self):
28+
"""Test that plain strings are returned as-is."""
29+
assert safe_serialize("hello") == "hello"
30+
31+
def test_safe_serialize_int(self):
32+
"""Test that integers are JSON-serialized."""
33+
assert safe_serialize(5) == "5"
34+
35+
def test_safe_serialize_float(self):
36+
"""Test that floats are JSON-serialized."""
37+
assert safe_serialize(3.14) == "3.14"
38+
39+
def test_safe_serialize_dict(self):
40+
"""Test that dicts are JSON-serialized."""
41+
result = safe_serialize({"key": "value"})
42+
assert result == '{"key": "value"}'
43+
44+
def test_safe_serialize_pipeline_variable_parameter_string(self):
45+
"""Test that ParameterString is returned as the PipelineVariable object itself."""
46+
param = ParameterString(name="MyParam", default_value="test")
47+
result = safe_serialize(param)
48+
# Should return the PipelineVariable object, not raise TypeError
49+
assert isinstance(result, PipelineVariable)
50+
assert result is param
51+
52+
def test_safe_serialize_pipeline_variable_parameter_integer(self):
53+
"""Test that ParameterInteger is returned as the PipelineVariable object itself."""
54+
param = ParameterInteger(name="MaxDepth", default_value=5)
55+
result = safe_serialize(param)
56+
# Should return the PipelineVariable object, not raise TypeError
57+
assert isinstance(result, PipelineVariable)
58+
assert result is param
59+
60+
61+
class TestGetRepoNameFromImage:
62+
"""Tests for _get_repo_name_from_image handling of PipelineVariable objects."""
63+
64+
def test_get_repo_name_from_image_string(self):
65+
"""Test that a normal image URI returns the repo name."""
66+
image = "683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3"
67+
result = _get_repo_name_from_image(image)
68+
assert result == "sagemaker-xgboost"
69+
70+
def test_get_repo_name_from_image_pipeline_variable(self):
71+
"""Test that a PipelineVariable returns the placeholder constant."""
72+
param = ParameterString(name="TrainingImage", default_value="some-image")
73+
result = _get_repo_name_from_image(param)
74+
assert result == _PIPELINE_VARIABLE_IMAGE_PLACEHOLDER
75+
76+
def test_get_repo_name_from_image_simple_string(self):
77+
"""Test with a simple image name."""
78+
result = _get_repo_name_from_image("my-repo:latest")
79+
assert result == "my-repo"
80+
81+
def test_get_repo_name_from_image_with_digest(self):
82+
"""Test with an image URI containing a digest."""
83+
image = "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-repo@sha256:abc123"
84+
result = _get_repo_name_from_image(image)
85+
assert result == "my-repo"
86+
87+
88+
class TestModelTrainerValidationWithPipelineVariable:
89+
"""Tests for ModelTrainer validation with PipelineVariable objects."""
90+
91+
@patch("sagemaker.train.model_trainer.TrainDefaults")
92+
def test_training_image_accepts_parameter_string(self, mock_defaults):
93+
"""Test that training_image accepts ParameterString."""
94+
from sagemaker.train.model_trainer import ModelTrainer
95+
from sagemaker.train.configs import Compute
96+
97+
mock_session = MagicMock()
98+
mock_session.boto_region_name = "us-east-1"
99+
mock_session.default_bucket.return_value = "my-bucket"
100+
mock_session.default_bucket_prefix = None
101+
102+
mock_defaults.get_sagemaker_session.return_value = mock_session
103+
mock_defaults.get_role.return_value = "arn:aws:iam::123456789012:role/SageMakerRole"
104+
mock_defaults.get_base_job_name.return_value = "test-job"
105+
mock_defaults.get_compute.return_value = Compute(
106+
instance_type="ml.m5.xlarge", instance_count=1
107+
)
108+
mock_defaults.get_stopping_condition.return_value = MagicMock()
109+
mock_defaults.get_output_data_config.return_value = MagicMock()
110+
111+
param = ParameterString(name="TrainingImage", default_value="some-image-uri")
112+
113+
# Should not raise
114+
trainer = ModelTrainer(
115+
training_image=param,
116+
compute=Compute(instance_type="ml.m5.xlarge", instance_count=1),
117+
sagemaker_session=mock_session,
118+
role="arn:aws:iam::123456789012:role/SageMakerRole",
119+
)
120+
assert trainer.training_image is param
121+
122+
@patch("sagemaker.train.model_trainer.TrainDefaults")
123+
def test_algorithm_name_accepts_parameter_string(self, mock_defaults):
124+
"""Test that algorithm_name accepts ParameterString."""
125+
from sagemaker.train.model_trainer import ModelTrainer
126+
from sagemaker.train.configs import Compute
127+
128+
mock_session = MagicMock()
129+
mock_session.boto_region_name = "us-east-1"
130+
mock_session.default_bucket.return_value = "my-bucket"
131+
mock_session.default_bucket_prefix = None
132+
133+
mock_defaults.get_sagemaker_session.return_value = mock_session
134+
mock_defaults.get_role.return_value = "arn:aws:iam::123456789012:role/SageMakerRole"
135+
mock_defaults.get_base_job_name.return_value = "test-job"
136+
mock_defaults.get_compute.return_value = Compute(
137+
instance_type="ml.m5.xlarge", instance_count=1
138+
)
139+
mock_defaults.get_stopping_condition.return_value = MagicMock()
140+
mock_defaults.get_output_data_config.return_value = MagicMock()
141+
142+
param = ParameterString(name="AlgorithmName", default_value="some-algo")
143+
144+
# Should not raise
145+
trainer = ModelTrainer(
146+
algorithm_name=param,
147+
compute=Compute(instance_type="ml.m5.xlarge", instance_count=1),
148+
sagemaker_session=mock_session,
149+
role="arn:aws:iam::123456789012:role/SageMakerRole",
150+
)
151+
assert trainer.algorithm_name is param
152+
153+
@patch("sagemaker.train.model_trainer.TrainDefaults")
154+
def test_environment_values_accept_parameter_string(self, mock_defaults):
155+
"""Test that environment dict values accept ParameterString."""
156+
from sagemaker.train.model_trainer import ModelTrainer
157+
from sagemaker.train.configs import Compute
158+
159+
mock_session = MagicMock()
160+
mock_session.boto_region_name = "us-east-1"
161+
mock_session.default_bucket.return_value = "my-bucket"
162+
mock_session.default_bucket_prefix = None
163+
164+
mock_defaults.get_sagemaker_session.return_value = mock_session
165+
mock_defaults.get_role.return_value = "arn:aws:iam::123456789012:role/SageMakerRole"
166+
mock_defaults.get_base_job_name.return_value = "test-job"
167+
mock_defaults.get_compute.return_value = Compute(
168+
instance_type="ml.m5.xlarge", instance_count=1
169+
)
170+
mock_defaults.get_stopping_condition.return_value = MagicMock()
171+
mock_defaults.get_output_data_config.return_value = MagicMock()
172+
173+
env_param = ParameterString(name="EnvValue", default_value="val")
174+
175+
# Should not raise
176+
trainer = ModelTrainer(
177+
training_image="683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3",
178+
compute=Compute(instance_type="ml.m5.xlarge", instance_count=1),
179+
sagemaker_session=mock_session,
180+
role="arn:aws:iam::123456789012:role/SageMakerRole",
181+
environment={"MY_VAR": env_param},
182+
)
183+
assert trainer.environment["MY_VAR"] is env_param
184+
185+
@patch("sagemaker.train.model_trainer.TrainDefaults")
186+
def test_plain_string_values_still_work(self, mock_defaults):
187+
"""Regression test: plain string values continue to work."""
188+
from sagemaker.train.model_trainer import ModelTrainer
189+
from sagemaker.train.configs import Compute
190+
191+
mock_session = MagicMock()
192+
mock_session.boto_region_name = "us-east-1"
193+
mock_session.default_bucket.return_value = "my-bucket"
194+
mock_session.default_bucket_prefix = None
195+
196+
mock_defaults.get_sagemaker_session.return_value = mock_session
197+
mock_defaults.get_role.return_value = "arn:aws:iam::123456789012:role/SageMakerRole"
198+
mock_defaults.get_base_job_name.return_value = "test-job"
199+
mock_defaults.get_compute.return_value = Compute(
200+
instance_type="ml.m5.xlarge", instance_count=1
201+
)
202+
mock_defaults.get_stopping_condition.return_value = MagicMock()
203+
mock_defaults.get_output_data_config.return_value = MagicMock()
204+
205+
# Should not raise
206+
trainer = ModelTrainer(
207+
training_image="683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3",
208+
compute=Compute(instance_type="ml.m5.xlarge", instance_count=1),
209+
sagemaker_session=mock_session,
210+
role="arn:aws:iam::123456789012:role/SageMakerRole",
211+
)
212+
assert trainer.training_image == "683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3"
213+
214+
def test_validation_rejects_no_image_or_algorithm(self):
215+
"""Test that validation rejects when neither training_image nor algorithm_name is provided."""
216+
from sagemaker.train.model_trainer import ModelTrainer
217+
218+
trainer = ModelTrainer.__new__(ModelTrainer)
219+
with pytest.raises(ValueError, match="Atleast one of"):
220+
trainer._validate_training_image_and_algorithm_name(None, None)
221+
222+
def test_validation_rejects_both_image_and_algorithm(self):
223+
"""Test that validation rejects when both training_image and algorithm_name are provided."""
224+
from sagemaker.train.model_trainer import ModelTrainer
225+
226+
trainer = ModelTrainer.__new__(ModelTrainer)
227+
with pytest.raises(ValueError, match="Only one of"):
228+
trainer._validate_training_image_and_algorithm_name("image", "algo")
229+
230+
def test_validation_rejects_both_pipeline_variables(self):
231+
"""Test that validation rejects when both are PipelineVariables."""
232+
from sagemaker.train.model_trainer import ModelTrainer
233+
234+
trainer = ModelTrainer.__new__(ModelTrainer)
235+
img_param = ParameterString(name="Image", default_value="img")
236+
algo_param = ParameterString(name="Algo", default_value="algo")
237+
with pytest.raises(ValueError, match="Only one of"):
238+
trainer._validate_training_image_and_algorithm_name(img_param, algo_param)

0 commit comments

Comments
 (0)