diff --git a/sagemaker-core/src/sagemaker/core/modules/utils.py b/sagemaker-core/src/sagemaker/core/modules/utils.py index 94dc2dff22..3cb884d475 100644 --- a/sagemaker-core/src/sagemaker/core/modules/utils.py +++ b/sagemaker-core/src/sagemaker/core/modules/utils.py @@ -24,6 +24,7 @@ from sagemaker.core.shapes import Unassigned from sagemaker.core.modules import logger +from sagemaker.core.helper.pipeline_variable import PipelineVariable def _is_valid_s3_uri(path: str, path_type: Literal["File", "Directory", "Any"] = "Any") -> bool: @@ -129,19 +130,24 @@ def safe_serialize(data): This function handles the following cases: 1. If `data` is a string, it returns the string as-is without wrapping in quotes. - 2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns + 2. If `data` is of type `PipelineVariable`, it returns the PipelineVariable object + as-is for pipeline serialization. + 3. If `data` is serializable (e.g., a dictionary, list, int, float), it returns the JSON-encoded string using `json.dumps()`. - 3. If `data` cannot be serialized (e.g., a custom object), it returns the string + 4. If `data` cannot be serialized (e.g., a custom object), it returns the string representation of the data using `str(data)`. Args: data (Any): The data to serialize. Returns: - str: The serialized JSON-compatible string or the string representation of the input. + str | PipelineVariable: The serialized JSON-compatible string, the string + representation of the input, or the PipelineVariable object as-is. """ if isinstance(data, str): return data + elif isinstance(data, PipelineVariable): + return data try: return json.dumps(data) except TypeError: diff --git a/sagemaker-core/tests/unit/core/modules/test_utils_safe_serialize.py b/sagemaker-core/tests/unit/core/modules/test_utils_safe_serialize.py new file mode 100644 index 0000000000..ad54f3c28f --- /dev/null +++ b/sagemaker-core/tests/unit/core/modules/test_utils_safe_serialize.py @@ -0,0 +1,82 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Tests for safe_serialize in sagemaker.core.modules.utils with PipelineVariable support. + +Verifies that safe_serialize correctly handles PipelineVariable objects +(e.g., ParameterInteger, ParameterString) by returning them as-is rather +than attempting str() conversion which would raise TypeError. + +See: https://github.com/aws/sagemaker-python-sdk/issues/5504 +""" +from __future__ import annotations + +import pytest + +from sagemaker.core.modules.utils import safe_serialize +from sagemaker.core.helper.pipeline_variable import PipelineVariable +from sagemaker.core.workflow.parameters import ParameterInteger, ParameterString + + +class TestSafeSerializeWithPipelineVariables: + """Test safe_serialize handles PipelineVariable objects correctly.""" + + @pytest.mark.parametrize("param", [ + ParameterInteger(name="MaxDepth", default_value=5), + ParameterString(name="Algorithm", default_value="xgboost"), + ]) + def test_safe_serialize_returns_pipeline_variable_as_is(self, param): + """PipelineVariable objects should be returned as-is (identity preserved).""" + result = safe_serialize(param) + assert result is param + assert isinstance(result, PipelineVariable) + + def test_pipeline_variable_str_raises_type_error(self): + """Confirm PipelineVariable.__str__ raises TypeError (the root cause of the bug).""" + param = ParameterInteger(name="TestParam", default_value=10) + with pytest.raises(TypeError): + str(param) + + +class TestSafeSerializeBasicTypes: + """Regression tests: verify basic types still work after PipelineVariable support.""" + + def test_safe_serialize_with_string(self): + """Strings should be returned as-is without JSON wrapping.""" + assert safe_serialize("hello") == "hello" + + def test_safe_serialize_with_int(self): + """Integers should be JSON-serialized to string.""" + assert safe_serialize(42) == "42" + + def test_safe_serialize_with_dict(self): + """Dicts should be JSON-serialized.""" + result = safe_serialize({"key": "val"}) + assert result == '{"key": "val"}' + + def test_safe_serialize_with_bool(self): + """Booleans should be JSON-serialized.""" + assert safe_serialize(True) == "true" + assert safe_serialize(False) == "false" + + def test_safe_serialize_with_none(self): + """None should be JSON-serialized to 'null'.""" + assert safe_serialize(None) == "null" + + def test_safe_serialize_with_custom_object(self): + """Custom objects should fall back to str().""" + + class CustomObj: + def __str__(self): + return "custom" + + assert safe_serialize(CustomObj()) == "custom" diff --git a/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py b/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py index 3fd34fa47b..e6aacf13f4 100644 --- a/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py +++ b/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py @@ -26,7 +26,7 @@ from sagemaker.core.helper.session_helper import Session from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar -from sagemaker.core.workflow.parameters import ParameterString +from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger from sagemaker.train.model_trainer import ModelTrainer, Mode from sagemaker.train.configs import ( Compute, @@ -176,3 +176,68 @@ def test_training_image_rejects_invalid_type(self): stopping_condition=DEFAULT_STOPPING, output_data_config=DEFAULT_OUTPUT, ) + + +class TestModelTrainerPipelineVariableHyperparameters: + """Test that PipelineVariable objects work correctly in ModelTrainer hyperparameters.""" + + def test_hyperparameters_with_parameter_integer(self): + """ParameterInteger in hyperparameters should be preserved through _create_training_job_args. + + This test documents the exact bug scenario from GH#5504: safe_serialize + would fall back to str(data) for PipelineVariable objects, but + PipelineVariable.__str__ intentionally raises TypeError. + Before the fix, this call would have raised TypeError. + """ + max_depth = ParameterInteger(name="MaxDepth", default_value=5) + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE, + stopping_condition=DEFAULT_STOPPING, + output_data_config=DEFAULT_OUTPUT, + hyperparameters={"max_depth": max_depth}, + ) + # This call would have raised TypeError before the fix (GH#5504) + args = trainer._create_training_job_args() + # PipelineVariable should be preserved as-is, not stringified + assert args["hyper_parameters"]["max_depth"] is max_depth + + def test_hyperparameters_with_parameter_string(self): + """ParameterString in hyperparameters should be preserved.""" + algo = ParameterString(name="Algorithm", default_value="xgboost") + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE, + stopping_condition=DEFAULT_STOPPING, + output_data_config=DEFAULT_OUTPUT, + hyperparameters={"algorithm": algo}, + ) + args = trainer._create_training_job_args() + assert args["hyper_parameters"]["algorithm"] is algo + + def test_hyperparameters_with_mixed_pipeline_and_static_values(self): + """Mixed PipelineVariable and static values should both be handled correctly.""" + max_depth = ParameterInteger(name="MaxDepth", default_value=5) + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE, + stopping_condition=DEFAULT_STOPPING, + output_data_config=DEFAULT_OUTPUT, + hyperparameters={ + "max_depth": max_depth, + "eta": 0.1, + "objective": "binary:logistic", + "num_round": 100, + }, + ) + args = trainer._create_training_job_args() + hp = args["hyper_parameters"] + # PipelineVariable preserved as-is + assert hp["max_depth"] is max_depth + # Static values serialized to strings + assert hp["eta"] == "0.1" + assert hp["objective"] == "binary:logistic" + assert hp["num_round"] == "100" diff --git a/sagemaker-train/tests/unit/train/test_safe_serialize.py b/sagemaker-train/tests/unit/train/test_safe_serialize.py new file mode 100644 index 0000000000..1ad227a278 --- /dev/null +++ b/sagemaker-train/tests/unit/train/test_safe_serialize.py @@ -0,0 +1,92 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Tests for safe_serialize with PipelineVariable support. + +Verifies that safe_serialize in sagemaker.train.utils correctly handles +PipelineVariable objects (e.g., ParameterInteger, ParameterString) by +returning them as-is rather than attempting str() conversion which would +raise TypeError. + +See: https://github.com/aws/sagemaker-python-sdk/issues/5504 +""" +from __future__ import annotations + +import pytest + +from sagemaker.train.utils import safe_serialize +from sagemaker.core.helper.pipeline_variable import PipelineVariable +from sagemaker.core.workflow.parameters import ParameterInteger, ParameterString + + +class TestSafeSerializeWithPipelineVariables: + """Test safe_serialize handles PipelineVariable objects correctly.""" + + @pytest.mark.parametrize("param", [ + ParameterInteger(name="MaxDepth", default_value=5), + ParameterString(name="Algorithm", default_value="xgboost"), + ]) + def test_safe_serialize_returns_pipeline_variable_as_is(self, param): + """PipelineVariable objects should be returned as-is (identity preserved).""" + result = safe_serialize(param) + assert result is param + assert isinstance(result, PipelineVariable) + + def test_pipeline_variable_str_raises_type_error(self): + """Confirm PipelineVariable.__str__ raises TypeError (the root cause of the bug).""" + param = ParameterInteger(name="TestParam", default_value=10) + with pytest.raises(TypeError): + str(param) + + +class TestSafeSerializeBasicTypes: + """Regression tests: verify basic types still work after PipelineVariable support.""" + + def test_safe_serialize_with_string(self): + """Strings should be returned as-is without JSON wrapping.""" + assert safe_serialize("hello") == "hello" + assert safe_serialize("12345") == "12345" + + def test_safe_serialize_with_int(self): + """Integers should be JSON-serialized to string.""" + assert safe_serialize(42) == "42" + + def test_safe_serialize_with_float(self): + """Floats should be JSON-serialized to string.""" + assert safe_serialize(3.14) == "3.14" + + def test_safe_serialize_with_dict(self): + """Dicts should be JSON-serialized.""" + result = safe_serialize({"key": "val"}) + assert result == '{"key": "val"}' + + def test_safe_serialize_with_bool(self): + """Booleans should be JSON-serialized.""" + assert safe_serialize(True) == "true" + assert safe_serialize(False) == "false" + + def test_safe_serialize_with_none(self): + """None should be JSON-serialized to 'null'.""" + assert safe_serialize(None) == "null" + + def test_safe_serialize_with_list(self): + """Lists should be JSON-serialized.""" + assert safe_serialize([1, 2, 3]) == "[1, 2, 3]" + + def test_safe_serialize_with_custom_object(self): + """Custom objects should fall back to str().""" + + class CustomObj: + def __str__(self): + return "custom" + + assert safe_serialize(CustomObj()) == "custom"