Skip to content

Commit e3b2d82

Browse files
committed
fix: address review comments (iteration #3)
1 parent 0d1195d commit e3b2d82

File tree

3 files changed

+11
-25
lines changed

3 files changed

+11
-25
lines changed

sagemaker-core/src/sagemaker/core/modules/utils.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,7 @@
2424

2525
from sagemaker.core.shapes import Unassigned
2626
from sagemaker.core.modules import logger
27-
28-
try:
29-
from sagemaker.core.helper.pipeline_variable import PipelineVariable
30-
except ImportError:
31-
PipelineVariable = None
27+
from sagemaker.core.helper.pipeline_variable import PipelineVariable
3228

3329

3430
def _is_valid_s3_uri(path: str, path_type: Literal["File", "Directory", "Any"] = "Any") -> bool:
@@ -145,11 +141,12 @@ def safe_serialize(data):
145141
data (Any): The data to serialize.
146142
147143
Returns:
148-
str: The serialized JSON-compatible string or the string representation of the input.
144+
str | PipelineVariable: The serialized JSON-compatible string, the string
145+
representation of the input, or the PipelineVariable object as-is.
149146
"""
150147
if isinstance(data, str):
151148
return data
152-
elif PipelineVariable is not None and isinstance(data, PipelineVariable):
149+
elif isinstance(data, PipelineVariable):
153150
return data
154151
try:
155152
return json.dumps(data)

sagemaker-core/tests/unit/core/modules/test_utils_safe_serialize.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,27 +30,16 @@
3030
class TestSafeSerializeWithPipelineVariables:
3131
"""Test safe_serialize handles PipelineVariable objects correctly."""
3232

33-
def test_safe_serialize_with_parameter_integer(self):
34-
"""ParameterInteger should be returned as-is (identity preserved)."""
35-
param = ParameterInteger(name="MaxDepth", default_value=5)
33+
@pytest.mark.parametrize("param", [
34+
ParameterInteger(name="MaxDepth", default_value=5),
35+
ParameterString(name="Algorithm", default_value="xgboost"),
36+
])
37+
def test_safe_serialize_returns_pipeline_variable_as_is(self, param):
38+
"""PipelineVariable objects should be returned as-is (identity preserved)."""
3639
result = safe_serialize(param)
3740
assert result is param
3841
assert isinstance(result, PipelineVariable)
3942

40-
def test_safe_serialize_with_parameter_string(self):
41-
"""ParameterString should be returned as-is (identity preserved)."""
42-
param = ParameterString(name="Algorithm", default_value="xgboost")
43-
result = safe_serialize(param)
44-
assert result is param
45-
assert isinstance(result, PipelineVariable)
46-
47-
def test_safe_serialize_does_not_call_str_on_pipeline_variable(self):
48-
"""Verify that PipelineVariable.__str__ is never invoked (would raise TypeError)."""
49-
param = ParameterInteger(name="TestParam", default_value=10)
50-
# This should NOT raise TypeError
51-
result = safe_serialize(param)
52-
assert result is param
53-
5443
def test_pipeline_variable_str_raises_type_error(self):
5544
"""Confirm PipelineVariable.__str__ raises TypeError (the root cause of the bug)."""
5645
param = ParameterInteger(name="TestParam", default_value=10)

sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def test_hyperparameters_with_parameter_integer(self):
204204
assert args["hyper_parameters"]["max_depth"] is max_depth
205205

206206
def test_hyperparameters_with_parameter_string(self):
207-
"""ParameterString in hyperparameters should be preserved through _create_training_job_args."""
207+
"""ParameterString in hyperparameters should be preserved."""
208208
algo = ParameterString(name="Algorithm", default_value="xgboost")
209209
trainer = ModelTrainer(
210210
training_image=DEFAULT_IMAGE,

0 commit comments

Comments
 (0)