forked from aws/sagemaker-python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_utils.py
More file actions
96 lines (67 loc) · 3.15 KB
/
test_utils.py
File metadata and controls
96 lines (67 loc) · 3.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Unit tests for sagemaker.train.utils – specifically safe_serialize."""
from __future__ import absolute_import
import json
from sagemaker.train.utils import safe_serialize
from sagemaker.core.workflow.parameters import (
ParameterInteger,
ParameterString,
ParameterFloat,
)
# ---------------------------------------------------------------------------
# PipelineVariable inputs – should be returned as-is (identity)
# ---------------------------------------------------------------------------
def test_safe_serialize_with_pipeline_variable_integer_returns_object_directly():
param = ParameterInteger(name="MaxDepth", default_value=5)
result = safe_serialize(param)
assert result is param
def test_safe_serialize_with_pipeline_variable_string_returns_object_directly():
param = ParameterString(name="Optimizer", default_value="sgd")
result = safe_serialize(param)
assert result is param
def test_safe_serialize_with_pipeline_variable_float_returns_object_directly():
param = ParameterFloat(name="LearningRate", default_value=0.01)
result = safe_serialize(param)
assert result is param
# ---------------------------------------------------------------------------
# Regular / primitive inputs
# ---------------------------------------------------------------------------
def test_safe_serialize_with_string_returns_string_as_is():
assert safe_serialize("hello") == "hello"
assert safe_serialize("12345") == "12345"
def test_safe_serialize_with_json_like_string_returns_as_is():
"""A string that looks like JSON should be returned as-is, not double-serialized."""
json_str = '{"key": "value"}'
assert safe_serialize(json_str) == json_str
def test_safe_serialize_with_int_returns_json_string():
assert safe_serialize(5) == "5"
assert safe_serialize(0) == "0"
def test_safe_serialize_with_dict_returns_json_string():
data = {"key": "value", "num": 1}
assert safe_serialize(data) == json.dumps(data)
def test_safe_serialize_with_bool_returns_json_string():
assert safe_serialize(True) == "true"
assert safe_serialize(False) == "false"
def test_safe_serialize_with_custom_object_returns_str():
class CustomObject:
def __str__(self):
return "CustomObject"
obj = CustomObject()
assert safe_serialize(obj) == "CustomObject"
def test_safe_serialize_with_none_returns_json_null():
assert safe_serialize(None) == "null"
def test_safe_serialize_with_list_returns_json_string():
assert safe_serialize([1, 2, 3]) == "[1, 2, 3]"
def test_safe_serialize_with_empty_string():
assert safe_serialize("") == ""