Skip to content

Commit 00f479e

Browse files
committed
fix: address review comments (iteration #1)
1 parent 95a19c5 commit 00f479e

File tree

3 files changed

+36
-36
lines changed

3 files changed

+36
-36
lines changed

sagemaker-train/src/sagemaker/train/utils.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -168,22 +168,26 @@ def convert_unassigned_to_none(instance) -> Any:
168168
return instance
169169

170170

171-
def safe_serialize(data):
171+
def safe_serialize(data) -> "str | PipelineVariable":
172172
"""Serialize the data without wrapping strings in quotes.
173173
174174
This function handles the following cases:
175-
1. If `data` is a string, it returns the string as-is without wrapping in quotes.
176-
2. If `data` is of type `PipelineVariable`, it returns the json representation of the PipelineVariable
177-
3. If `data` is serializable (e.g., a dictionary, list, int, float), it returns
178-
the JSON-encoded string using `json.dumps()`.
179-
4. If `data` cannot be serialized (e.g., a custom object), it returns the string
180-
representation of the data using `str(data)`.
175+
1. If ``data`` is a string, it returns the string as-is without wrapping in quotes.
176+
2. If ``data`` is of type :class:`~sagemaker.core.workflow.parameters.PipelineVariable`,
177+
it returns the object directly so that pipeline serialization can handle it
178+
downstream. Callers should be aware that the return value may be a
179+
``PipelineVariable`` rather than a plain ``str``.
180+
3. If ``data`` is serializable (e.g., a dictionary, list, int, float), it returns
181+
the JSON-encoded string using ``json.dumps()``.
182+
4. If ``data`` cannot be serialized (e.g., a custom object), it returns the string
183+
representation of the data using ``str(data)``.
181184
182185
Args:
183186
data (Any): The data to serialize.
184187
185188
Returns:
186-
str: The serialized JSON-compatible string or the string representation of the input.
189+
str | PipelineVariable: The serialized JSON-compatible string, the string
190+
representation of the input, or the original ``PipelineVariable`` object.
187191
"""
188192
if isinstance(data, str):
189193
return data
@@ -197,11 +201,9 @@ def safe_serialize(data):
197201
except TypeError:
198202
# PipelineVariable.__str__ raises TypeError by design.
199203
# If the isinstance check above didn't catch it (e.g. import
200-
# path mismatch), fall back to returning the object directly
201-
# when it looks like a PipelineVariable (has an ``expr`` property).
202-
if hasattr(data, "expr"):
203-
return data
204-
raise
204+
# path mismatch or reload issues), return the object directly
205+
# so pipeline serialization can handle it downstream.
206+
return data
205207

206208

207209
def _run_clone_command_silent(repo_url, dest_dir):

sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,13 @@
2626

2727
from sagemaker.core.helper.session_helper import Session
2828
from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar
29-
from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger, ParameterFloat
29+
from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger
3030
from sagemaker.train.model_trainer import ModelTrainer, Mode
3131
from sagemaker.train.configs import (
3232
Compute,
3333
StoppingCondition,
3434
OutputDataConfig,
3535
)
36-
from sagemaker.core.workflow.pipeline_context import PipelineSession
3736
from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE
3837

3938

@@ -180,10 +179,10 @@ def test_training_image_rejects_invalid_type(self):
180179

181180

182181
class TestModelTrainerHyperparametersPipelineVariable:
183-
"""Test that PipelineVariable objects in hyperparameters survive safe_serialize."""
182+
"""Test that ModelTrainer correctly preserves PipelineVariable objects in hyperparameters."""
184183

185-
def test_hyperparameters_with_pipeline_variable_integer(self):
186-
"""ParameterInteger in hyperparameters should be passed through as-is."""
184+
def test_hyperparameters_preserves_pipeline_variable_integer(self):
185+
"""ParameterInteger in hyperparameters should be preserved in ModelTrainer."""
187186
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
188187
trainer = ModelTrainer(
189188
training_image=DEFAULT_IMAGE,
@@ -193,13 +192,10 @@ def test_hyperparameters_with_pipeline_variable_integer(self):
193192
output_data_config=DEFAULT_OUTPUT,
194193
hyperparameters={"max_depth": max_depth},
195194
)
196-
# safe_serialize should return the PipelineVariable object directly
197-
from sagemaker.train.utils import safe_serialize
198-
result = safe_serialize(max_depth)
199-
assert result is max_depth
195+
assert trainer.hyperparameters["max_depth"] is max_depth
200196

201-
def test_hyperparameters_with_pipeline_variable_string(self):
202-
"""ParameterString in hyperparameters should be passed through as-is."""
197+
def test_hyperparameters_preserves_pipeline_variable_string(self):
198+
"""ParameterString in hyperparameters should be preserved in ModelTrainer."""
203199
optimizer = ParameterString(name="Optimizer", default_value="sgd")
204200
trainer = ModelTrainer(
205201
training_image=DEFAULT_IMAGE,
@@ -209,12 +205,10 @@ def test_hyperparameters_with_pipeline_variable_string(self):
209205
output_data_config=DEFAULT_OUTPUT,
210206
hyperparameters={"optimizer": optimizer},
211207
)
212-
from sagemaker.train.utils import safe_serialize
213-
result = safe_serialize(optimizer)
214-
assert result is optimizer
208+
assert trainer.hyperparameters["optimizer"] is optimizer
215209

216-
def test_hyperparameters_with_mixed_pipeline_and_regular_values(self):
217-
"""Mixed PipelineVariable and regular values should both serialize correctly."""
210+
def test_hyperparameters_preserves_mixed_pipeline_and_regular_values(self):
211+
"""Mixed PipelineVariable and regular values should all be preserved."""
218212
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
219213
trainer = ModelTrainer(
220214
training_image=DEFAULT_IMAGE,
@@ -228,10 +222,8 @@ def test_hyperparameters_with_mixed_pipeline_and_regular_values(self):
228222
"objective": "binary:logistic",
229223
},
230224
)
231-
from sagemaker.train.utils import safe_serialize
232-
# PipelineVariable should be returned as-is
233-
assert safe_serialize(max_depth) is max_depth
234-
# Float should be JSON-serialized
235-
assert safe_serialize(0.1) == "0.1"
236-
# String should be returned as-is
237-
assert safe_serialize("binary:logistic") == "binary:logistic"
225+
# PipelineVariable should be preserved as-is
226+
assert trainer.hyperparameters["max_depth"] is max_depth
227+
# Regular values should also be preserved
228+
assert trainer.hyperparameters["eta"] == 0.1
229+
assert trainer.hyperparameters["objective"] == "binary:logistic"

sagemaker-train/tests/unit/train/test_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ def test_safe_serialize_with_string_returns_string_as_is():
5454
assert safe_serialize("12345") == "12345"
5555

5656

57+
def test_safe_serialize_with_json_like_string_returns_as_is():
58+
"""A string that looks like JSON should be returned as-is, not double-serialized."""
59+
json_str = '{"key": "value"}'
60+
assert safe_serialize(json_str) == json_str
61+
62+
5763
def test_safe_serialize_with_int_returns_json_string():
5864
assert safe_serialize(5) == "5"
5965
assert safe_serialize(0) == "0"

0 commit comments

Comments
 (0)