Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions sagemaker-core/src/sagemaker/core/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
from sagemaker.core.shapes import Unassigned
from sagemaker.core.modules import logger

try:
from sagemaker.core.helper.pipeline_variable import PipelineVariable
except ImportError:
Comment thread
aviruthen marked this conversation as resolved.
Outdated
PipelineVariable = None


def _is_valid_s3_uri(path: str, path_type: Literal["File", "Directory", "Any"] = "Any") -> bool:
"""Check if the path is a valid S3 URI.
Expand Down Expand Up @@ -129,9 +134,11 @@ def safe_serialize(data):

This function handles the following cases:
1. If `data` is a string, it returns the string as-is without wrapping in quotes.
2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns
2. If `data` is of type `PipelineVariable`, it returns the PipelineVariable object
as-is for pipeline serialization.
3. If `data` is serializable (e.g., a dictionary, list, int, float), it returns
the JSON-encoded string using `json.dumps()`.
3. If `data` cannot be serialized (e.g., a custom object), it returns the string
4. If `data` cannot be serialized (e.g., a custom object), it returns the string
representation of the data using `str(data)`.

Args:
Expand All @@ -142,6 +149,8 @@ def safe_serialize(data):
"""
if isinstance(data, str):
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
return data
elif PipelineVariable is not None and isinstance(data, PipelineVariable):
return data
try:
return json.dumps(data)
except TypeError:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Tests for safe_serialize in sagemaker.core.modules.utils with PipelineVariable support.

Verifies that safe_serialize correctly handles PipelineVariable objects
(e.g., ParameterInteger, ParameterString) by returning them as-is rather
than attempting str() conversion which would raise TypeError.

See: https://github.com/aws/sagemaker-python-sdk/issues/5504
"""
from __future__ import annotations

import pytest

Comment thread
aviruthen marked this conversation as resolved.
from sagemaker.core.modules.utils import safe_serialize
from sagemaker.core.helper.pipeline_variable import PipelineVariable
from sagemaker.core.workflow.parameters import ParameterInteger, ParameterString
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.


class TestSafeSerializeWithPipelineVariables:
"""Test safe_serialize handles PipelineVariable objects correctly."""

def test_safe_serialize_with_parameter_integer(self):
"""ParameterInteger should be returned as-is (identity preserved)."""
Comment thread
aviruthen marked this conversation as resolved.
Outdated
param = ParameterInteger(name="MaxDepth", default_value=5)
result = safe_serialize(param)
assert result is param
assert isinstance(result, PipelineVariable)

def test_safe_serialize_with_parameter_string(self):
"""ParameterString should be returned as-is (identity preserved)."""
param = ParameterString(name="Algorithm", default_value="xgboost")
result = safe_serialize(param)
assert result is param
assert isinstance(result, PipelineVariable)

def test_safe_serialize_does_not_call_str_on_pipeline_variable(self):
"""Verify that PipelineVariable.__str__ is never invoked (would raise TypeError)."""
param = ParameterInteger(name="TestParam", default_value=10)
# This should NOT raise TypeError
result = safe_serialize(param)
assert result is param

def test_pipeline_variable_str_raises_type_error(self):
"""Confirm PipelineVariable.__str__ raises TypeError (the root cause of the bug)."""
param = ParameterInteger(name="TestParam", default_value=10)
with pytest.raises(TypeError):
str(param)


class TestSafeSerializeBasicTypes:
"""Regression tests: verify basic types still work after PipelineVariable support."""

def test_safe_serialize_with_string(self):
"""Strings should be returned as-is without JSON wrapping."""
assert safe_serialize("hello") == "hello"

def test_safe_serialize_with_int(self):
"""Integers should be JSON-serialized to string."""
assert safe_serialize(42) == "42"

def test_safe_serialize_with_dict(self):
"""Dicts should be JSON-serialized."""
result = safe_serialize({"key": "val"})
assert result == '{"key": "val"}'

def test_safe_serialize_with_bool(self):
"""Booleans should be JSON-serialized."""
assert safe_serialize(True) == "true"
assert safe_serialize(False) == "false"

def test_safe_serialize_with_none(self):
"""None should be JSON-serialized to 'null'."""
assert safe_serialize(None) == "null"

def test_safe_serialize_with_custom_object(self):
"""Custom objects should fall back to str()."""

class CustomObj:
def __str__(self):
return "custom"

assert safe_serialize(CustomObj()) == "custom"
Comment thread
aviruthen marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from sagemaker.core.helper.session_helper import Session
from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar
from sagemaker.core.workflow.parameters import ParameterString
from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger
from sagemaker.train.model_trainer import ModelTrainer, Mode
from sagemaker.train.configs import (
Compute,
Expand Down Expand Up @@ -176,3 +176,87 @@ def test_training_image_rejects_invalid_type(self):
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
)


class TestModelTrainerPipelineVariableHyperparameters:
"""Test that PipelineVariable objects work correctly in ModelTrainer hyperparameters."""

def test_hyperparameters_with_parameter_integer(self):
"""ParameterInteger in hyperparameters should be preserved through _create_training_job_args."""
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
Comment thread
aviruthen marked this conversation as resolved.
output_data_config=DEFAULT_OUTPUT,
hyperparameters={"max_depth": max_depth},
)
args = trainer._create_training_job_args()
# PipelineVariable should be preserved as-is, not stringified
Comment thread
aviruthen marked this conversation as resolved.
assert args["hyper_parameters"]["max_depth"] is max_depth

def test_hyperparameters_with_parameter_string(self):
Comment thread
aviruthen marked this conversation as resolved.
"""ParameterString in hyperparameters should be preserved through _create_training_job_args."""
Comment thread
aviruthen marked this conversation as resolved.
Outdated
algo = ParameterString(name="Algorithm", default_value="xgboost")
trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
hyperparameters={"algorithm": algo},
)
args = trainer._create_training_job_args()
assert args["hyper_parameters"]["algorithm"] is algo

def test_hyperparameters_with_parameter_integer_does_not_raise(self):
"""Verify ParameterInteger in hyperparameters does NOT raise TypeError.

This test documents the exact bug scenario from GH#5504: safe_serialize
would fall back to str(data) for PipelineVariable objects, but
PipelineVariable.__str__ intentionally raises TypeError.
"""
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
hyperparameters={"max_depth": max_depth},
)
Comment thread
aviruthen marked this conversation as resolved.
Outdated
# This call would have raised TypeError before the fix
try:
args = trainer._create_training_job_args()
except TypeError:
pytest.fail(
"safe_serialize raised TypeError on PipelineVariable - "
"this is the bug described in GH#5504"
)
assert args["hyper_parameters"]["max_depth"] is max_depth

def test_hyperparameters_with_mixed_pipeline_and_static_values(self):
"""Mixed PipelineVariable and static values should both be handled correctly."""
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
hyperparameters={
"max_depth": max_depth,
"eta": 0.1,
"objective": "binary:logistic",
"num_round": 100,
},
)
args = trainer._create_training_job_args()
hp = args["hyper_parameters"]
# PipelineVariable preserved as-is
assert hp["max_depth"] is max_depth
# Static values serialized to strings
assert hp["eta"] == "0.1"
assert hp["objective"] == "binary:logistic"
assert hp["num_round"] == "100"
103 changes: 103 additions & 0 deletions sagemaker-train/tests/unit/train/test_safe_serialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Tests for safe_serialize with PipelineVariable support.

Verifies that safe_serialize in sagemaker.train.utils correctly handles
PipelineVariable objects (e.g., ParameterInteger, ParameterString) by
returning them as-is rather than attempting str() conversion which would
raise TypeError.

See: https://github.com/aws/sagemaker-python-sdk/issues/5504
"""
from __future__ import annotations

import pytest

Comment thread
aviruthen marked this conversation as resolved.
from sagemaker.train.utils import safe_serialize
from sagemaker.core.helper.pipeline_variable import PipelineVariable
from sagemaker.core.workflow.parameters import ParameterInteger, ParameterString


class TestSafeSerializeWithPipelineVariables:
"""Test safe_serialize handles PipelineVariable objects correctly."""

def test_safe_serialize_with_parameter_integer(self):
Comment thread
aviruthen marked this conversation as resolved.
Outdated
"""ParameterInteger should be returned as-is (identity preserved)."""
param = ParameterInteger(name="MaxDepth", default_value=5)
result = safe_serialize(param)
assert result is param
assert isinstance(result, PipelineVariable)

def test_safe_serialize_with_parameter_string(self):
"""ParameterString should be returned as-is (identity preserved)."""
param = ParameterString(name="Algorithm", default_value="xgboost")
result = safe_serialize(param)
assert result is param
assert isinstance(result, PipelineVariable)

def test_safe_serialize_does_not_call_str_on_pipeline_variable(self):
"""Verify that PipelineVariable.__str__ is never invoked (would raise TypeError)."""
param = ParameterInteger(name="TestParam", default_value=10)
# This should NOT raise TypeError
result = safe_serialize(param)
assert result is param

def test_pipeline_variable_str_raises_type_error(self):
"""Confirm PipelineVariable.__str__ raises TypeError (the root cause of the bug)."""
param = ParameterInteger(name="TestParam", default_value=10)
with pytest.raises(TypeError):
str(param)


class TestSafeSerializeBasicTypes:
"""Regression tests: verify basic types still work after PipelineVariable support."""

def test_safe_serialize_with_string(self):
"""Strings should be returned as-is without JSON wrapping."""
assert safe_serialize("hello") == "hello"
assert safe_serialize("12345") == "12345"

def test_safe_serialize_with_int(self):
"""Integers should be JSON-serialized to string."""
assert safe_serialize(42) == "42"

def test_safe_serialize_with_float(self):
"""Floats should be JSON-serialized to string."""
assert safe_serialize(3.14) == "3.14"

def test_safe_serialize_with_dict(self):
"""Dicts should be JSON-serialized."""
result = safe_serialize({"key": "val"})
assert result == '{"key": "val"}'

def test_safe_serialize_with_bool(self):
"""Booleans should be JSON-serialized."""
assert safe_serialize(True) == "true"
assert safe_serialize(False) == "false"

def test_safe_serialize_with_none(self):
"""None should be JSON-serialized to 'null'."""
assert safe_serialize(None) == "null"

def test_safe_serialize_with_list(self):
"""Lists should be JSON-serialized."""
assert safe_serialize([1, 2, 3]) == "[1, 2, 3]"

def test_safe_serialize_with_custom_object(self):
"""Custom objects should fall back to str()."""

class CustomObj:
def __str__(self):
return "custom"

assert safe_serialize(CustomObj()) == "custom"
Loading