Skip to content

Commit 0c7374c

Browse files
committed
fix: address review comments (iteration #1)
1 parent 8d6fc7f commit 0c7374c

File tree

4 files changed

+159
-7
lines changed

4 files changed

+159
-7
lines changed

sagemaker-core/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ torch = [
5454
"torch>=1.9.0",
5555
]
5656
all = [
57-
"torch>=1.9.0",
57+
"sagemaker-core[torch]",
5858
]
5959
codegen = [
6060
"black>=24.3.0, <25.0.0",

sagemaker-core/src/sagemaker/core/deserializers/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,11 +365,11 @@ def __init__(self, accept="tensor/pt"):
365365
from torch import from_numpy
366366

367367
self.convert_npy_to_tensor = from_numpy
368-
except ImportError:
368+
except ImportError as e:
369369
raise ImportError(
370370
"Unable to import torch. Please install torch to use TorchTensorDeserializer: "
371-
"'pip install torch' or 'pip install sagemaker-core[torch]'"
372-
)
371+
"pip install 'sagemaker-core[torch]'"
372+
) from e
373373

374374
def deserialize(self, stream, content_type="tensor/pt"):
375375
"""Deserialize streamed data to TorchTensor

sagemaker-core/src/sagemaker/core/serializers/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -445,11 +445,11 @@ def __init__(self, content_type="tensor/pt"):
445445
super(TorchTensorSerializer, self).__init__(content_type=content_type)
446446
try:
447447
from torch import Tensor
448-
except ImportError:
448+
except ImportError as e:
449449
raise ImportError(
450450
"Unable to import torch. Please install torch to use TorchTensorSerializer: "
451-
"'pip install torch' or 'pip install sagemaker-core[torch]'"
452-
)
451+
"pip install 'sagemaker-core[torch]'"
452+
) from e
453453

454454
self.torch_tensor = Tensor
455455
self.numpy_serializer = NumpySerializer()
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Tests to verify torch dependency is optional in sagemaker-core."""
14+
from __future__ import annotations
15+
16+
import importlib
17+
import io
18+
import sys
19+
20+
import numpy as np
21+
import pytest
22+
23+
24+
def _block_torch():
25+
"""Block torch imports by setting sys.modules['torch'] to None.
26+
27+
Returns a dict of saved torch submodule entries so they can be restored.
28+
"""
29+
saved = {}
30+
torch_keys = [key for key in sys.modules if key.startswith("torch.")]
31+
saved = {key: sys.modules.pop(key) for key in torch_keys}
32+
saved["torch"] = sys.modules.get("torch")
33+
sys.modules["torch"] = None
34+
return saved
35+
36+
37+
def _restore_torch(saved):
38+
"""Restore torch modules from saved dict."""
39+
original_torch = saved.pop("torch", None)
40+
if original_torch is not None:
41+
sys.modules["torch"] = original_torch
42+
elif "torch" in sys.modules:
43+
del sys.modules["torch"]
44+
for key, val in saved.items():
45+
sys.modules[key] = val
46+
47+
48+
def test_serializer_module_imports_without_torch():
49+
"""Verify that importing non-torch serializers succeeds without torch installed."""
50+
saved = {}
51+
try:
52+
saved = _block_torch()
53+
54+
# Reload the module so it re-evaluates imports with torch blocked
55+
import sagemaker.core.serializers.base as ser_module
56+
57+
importlib.reload(ser_module)
58+
59+
# Verify non-torch serializers can be instantiated
60+
assert ser_module.CSVSerializer() is not None
61+
assert ser_module.NumpySerializer() is not None
62+
assert ser_module.JSONSerializer() is not None
63+
assert ser_module.IdentitySerializer() is not None
64+
finally:
65+
_restore_torch(saved)
66+
67+
68+
def test_deserializer_module_imports_without_torch():
69+
"""Verify that importing non-torch deserializers succeeds without torch installed."""
70+
saved = {}
71+
try:
72+
saved = _block_torch()
73+
74+
import sagemaker.core.deserializers.base as deser_module
75+
76+
importlib.reload(deser_module)
77+
78+
# Verify non-torch deserializers can be instantiated
79+
assert deser_module.StringDeserializer() is not None
80+
assert deser_module.BytesDeserializer() is not None
81+
assert deser_module.CSVDeserializer() is not None
82+
assert deser_module.NumpyDeserializer() is not None
83+
assert deser_module.JSONDeserializer() is not None
84+
finally:
85+
_restore_torch(saved)
86+
87+
88+
def test_torch_tensor_serializer_raises_import_error_without_torch():
89+
"""Verify TorchTensorSerializer raises ImportError when torch is not installed."""
90+
import sagemaker.core.serializers.base as ser_module
91+
92+
saved = {}
93+
try:
94+
saved = _block_torch()
95+
96+
with pytest.raises(ImportError, match="Unable to import torch"):
97+
ser_module.TorchTensorSerializer()
98+
finally:
99+
_restore_torch(saved)
100+
101+
102+
def test_torch_tensor_deserializer_raises_import_error_without_torch():
103+
"""Verify TorchTensorDeserializer raises ImportError when torch is not installed."""
104+
import sagemaker.core.deserializers.base as deser_module
105+
106+
saved = {}
107+
try:
108+
saved = _block_torch()
109+
110+
with pytest.raises(ImportError, match="Unable to import torch"):
111+
deser_module.TorchTensorDeserializer()
112+
finally:
113+
_restore_torch(saved)
114+
115+
116+
def test_torch_tensor_serializer_works_with_torch():
117+
"""Verify TorchTensorSerializer works when torch is available."""
118+
try:
119+
import torch
120+
except ImportError:
121+
pytest.skip("torch is not installed")
122+
123+
from sagemaker.core.serializers.base import TorchTensorSerializer
124+
125+
serializer = TorchTensorSerializer()
126+
tensor = torch.tensor([1.0, 2.0, 3.0])
127+
result = serializer.serialize(tensor)
128+
assert result is not None
129+
# Verify the result can be loaded back as numpy
130+
array = np.load(io.BytesIO(result))
131+
assert np.array_equal(array, np.array([1.0, 2.0, 3.0]))
132+
133+
134+
def test_torch_tensor_deserializer_works_with_torch():
135+
"""Verify TorchTensorDeserializer works when torch is available."""
136+
try:
137+
import torch
138+
except ImportError:
139+
pytest.skip("torch is not installed")
140+
141+
from sagemaker.core.deserializers.base import TorchTensorDeserializer
142+
143+
deserializer = TorchTensorDeserializer()
144+
# Create a numpy array, save it, and deserialize to tensor
145+
array = np.array([1.0, 2.0, 3.0])
146+
buffer = io.BytesIO()
147+
np.save(buffer, array)
148+
buffer.seek(0)
149+
150+
result = deserializer.deserialize(buffer, "tensor/pt")
151+
assert isinstance(result, torch.Tensor)
152+
assert torch.equal(result, torch.tensor([1.0, 2.0, 3.0]))

0 commit comments

Comments
 (0)