Skip to content

Commit a2f9e47

Browse files
committed
fix: address review comments (iteration #1)
1 parent 04d768d commit a2f9e47

File tree

3 files changed

+124
-10
lines changed

3 files changed

+124
-10
lines changed

sagemaker-train/src/sagemaker/train/model_trainer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@
116116
from sagemaker.core.jumpstart.utils import get_eula_url
117117
from sagemaker.train.defaults import TrainDefaults, JumpStartTrainDefaults
118118
from sagemaker.core.workflow.pipeline_context import PipelineSession, runnable_by_pipeline
119-
from sagemaker.core.helper.pipeline_variable import StrPipeVar
119+
from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar
120120

121121
from sagemaker.train.local.local_container import _LocalContainer
122122

@@ -410,7 +410,9 @@ def __del__(self):
410410
self._temp_code_dir.cleanup()
411411

412412
def _validate_training_image_and_algorithm_name(
413-
self, training_image, algorithm_name
413+
self,
414+
training_image: Union[str, PipelineVariable, None],
415+
algorithm_name: Union[str, PipelineVariable, None],
414416
):
415417
"""Validate that only one of 'training_image' or 'algorithm_name' is provided."""
416418
has_image = training_image is not None

sagemaker-train/src/sagemaker/train/utils.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
from typing import Literal, Any
2525

2626
from sagemaker.core.helper.session_helper import Session
27+
from sagemaker.core.helper.pipeline_variable import PipelineVariable
2728
from sagemaker.core.shapes import Unassigned
2829
from sagemaker.train import logger
29-
from sagemaker.core.workflow.parameters import PipelineVariable
3030

3131

3232
def _default_bucket_and_prefix(session: Session) -> str:
@@ -142,7 +142,7 @@ def _get_unique_name(base, max_length=63):
142142
return unique_name
143143

144144

145-
def _get_repo_name_from_image(image) -> str:
145+
def _get_repo_name_from_image(image: "str | PipelineVariable") -> "str | None":
146146
"""Get the repository name from the image URI.
147147
148148
Example:
@@ -152,13 +152,11 @@ def _get_repo_name_from_image(image) -> str:
152152
```
153153
154154
Args:
155-
image: The image URI (str or PipelineVariable)
155+
image (str | PipelineVariable): The image URI
156156
157157
Returns:
158-
str: The repository name, or None if image is a PipelineVariable
158+
str | None: The repository name, or None if image is a PipelineVariable
159159
"""
160-
from sagemaker.core.helper.pipeline_variable import PipelineVariable
161-
162160
if isinstance(image, PipelineVariable):
163161
return None
164162
return image.split("/")[-1].split(":")[0].split("@")[0]

sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py

Lines changed: 116 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,15 @@
2626

2727
from sagemaker.core.helper.session_helper import Session
2828
from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar
29-
from sagemaker.core.workflow.parameters import ParameterString
29+
from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger
3030
from sagemaker.train.model_trainer import ModelTrainer, Mode
3131
from sagemaker.train.configs import (
3232
Compute,
3333
StoppingCondition,
3434
OutputDataConfig,
3535
)
36-
from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE, TrainDefaults
36+
from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE
37+
from sagemaker.train.utils import _get_repo_name_from_image, safe_serialize
3738

3839

3940
DEFAULT_IMAGE = "000000000000.dkr.ecr.us-west-2.amazonaws.com/dummy-image:latest"
@@ -176,3 +177,116 @@ def test_training_image_rejects_invalid_type(self):
176177
stopping_condition=DEFAULT_STOPPING,
177178
output_data_config=DEFAULT_OUTPUT,
178179
)
180+
181+
182+
class TestValidateTrainingImageAndAlgorithmName:
183+
"""Tests for _validate_training_image_and_algorithm_name with PipelineVariable."""
184+
185+
def test_pipeline_variable_training_image_passes_validation(self):
186+
"""PipelineVariable as training_image should pass validation."""
187+
param = ParameterString(name="TrainingImage", default_value=DEFAULT_IMAGE)
188+
trainer = ModelTrainer(
189+
training_image=param,
190+
base_job_name="pipeline-test-job",
191+
role=DEFAULT_ROLE,
192+
compute=DEFAULT_COMPUTE,
193+
stopping_condition=DEFAULT_STOPPING,
194+
output_data_config=DEFAULT_OUTPUT,
195+
)
196+
assert trainer.training_image is param
197+
198+
def test_pipeline_variable_algorithm_name_passes_validation(self):
199+
"""PipelineVariable as algorithm_name should pass validation."""
200+
param = ParameterString(name="AlgoName", default_value="my-algo")
201+
trainer = ModelTrainer(
202+
algorithm_name=param,
203+
base_job_name="pipeline-test-job",
204+
role=DEFAULT_ROLE,
205+
compute=DEFAULT_COMPUTE,
206+
stopping_condition=DEFAULT_STOPPING,
207+
output_data_config=DEFAULT_OUTPUT,
208+
)
209+
assert trainer.algorithm_name is param
210+
211+
def test_both_pipeline_variables_raises_value_error(self):
212+
"""Both training_image and algorithm_name as PipelineVariable should raise ValueError."""
213+
image_param = ParameterString(name="TrainingImage", default_value=DEFAULT_IMAGE)
214+
algo_param = ParameterString(name="AlgoName", default_value="my-algo")
215+
with pytest.raises(ValueError, match="Only one of"):
216+
ModelTrainer(
217+
training_image=image_param,
218+
algorithm_name=algo_param,
219+
base_job_name="pipeline-test-job",
220+
role=DEFAULT_ROLE,
221+
compute=DEFAULT_COMPUTE,
222+
stopping_condition=DEFAULT_STOPPING,
223+
output_data_config=DEFAULT_OUTPUT,
224+
)
225+
226+
def test_neither_provided_raises_value_error(self):
227+
"""Neither training_image nor algorithm_name should raise ValueError."""
228+
with pytest.raises(ValueError, match="Atleast one of"):
229+
ModelTrainer(
230+
training_image=None,
231+
algorithm_name=None,
232+
base_job_name="pipeline-test-job",
233+
role=DEFAULT_ROLE,
234+
compute=DEFAULT_COMPUTE,
235+
stopping_condition=DEFAULT_STOPPING,
236+
output_data_config=DEFAULT_OUTPUT,
237+
)
238+
239+
240+
class TestGetRepoNameFromImage:
241+
"""Tests for _get_repo_name_from_image with PipelineVariable."""
242+
243+
def test_returns_none_for_pipeline_variable(self):
244+
"""_get_repo_name_from_image should return None for PipelineVariable."""
245+
param = ParameterString(name="TrainingImage", default_value=DEFAULT_IMAGE)
246+
result = _get_repo_name_from_image(param)
247+
assert result is None
248+
249+
def test_returns_repo_name_for_string(self):
250+
"""_get_repo_name_from_image should return repo name for a normal string."""
251+
result = _get_repo_name_from_image(
252+
"123456789012.dkr.ecr.us-west-2.amazonaws.com/my-repo:latest"
253+
)
254+
assert result == "my-repo"
255+
256+
def test_returns_repo_name_without_tag(self):
257+
"""_get_repo_name_from_image should handle image URIs without tags."""
258+
result = _get_repo_name_from_image(
259+
"123456789012.dkr.ecr.us-west-2.amazonaws.com/my-repo"
260+
)
261+
assert result == "my-repo"
262+
263+
264+
class TestSafeSerialize:
265+
"""Tests for safe_serialize with PipelineVariable."""
266+
267+
def test_safe_serialize_pipeline_variable_returns_variable(self):
268+
"""safe_serialize should return the PipelineVariable object as-is."""
269+
param = ParameterInteger(name="MaxDepth", default_value=5)
270+
result = safe_serialize(param)
271+
assert result is param
272+
273+
def test_safe_serialize_string_returns_string(self):
274+
"""safe_serialize should return strings as-is."""
275+
result = safe_serialize("hello")
276+
assert result == "hello"
277+
278+
def test_safe_serialize_int_returns_json(self):
279+
"""safe_serialize should JSON-encode integers."""
280+
result = safe_serialize(5)
281+
assert result == "5"
282+
283+
def test_safe_serialize_dict_returns_json(self):
284+
"""safe_serialize should JSON-encode dicts."""
285+
result = safe_serialize({"key": "value"})
286+
assert result == '{"key": "value"}'
287+
288+
def test_safe_serialize_parameter_string_returns_variable(self):
289+
"""safe_serialize should return ParameterString as-is."""
290+
param = ParameterString(name="MyParam", default_value="val")
291+
result = safe_serialize(param)
292+
assert result is param

0 commit comments

Comments
 (0)