Skip to content

Commit 89be62e

Browse files
committed
feature: Torch dependency in sagameker-core to be made optional (5457)
1 parent 6497a94 commit 89be62e

File tree

4 files changed

+171
-3
lines changed

4 files changed

+171
-3
lines changed

sagemaker-core/pyproject.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ dependencies = [
3232
"smdebug_rulesconfig>=1.0.1",
3333
"schema>=0.7.5",
3434
"omegaconf>=2.1.0",
35-
"torch>=1.9.0",
3635
"scipy>=1.5.0",
3736
# Remote function dependencies
3837
"cloudpickle>=2.0.0",
@@ -51,6 +50,12 @@ classifiers = [
5150
]
5251

5352
[project.optional-dependencies]
53+
torch = [
54+
"torch>=1.9.0",
55+
]
56+
all = [
57+
"torch>=1.9.0",
58+
]
5459
codegen = [
5560
"black>=24.3.0, <25.0.0",
5661
"pandas>=2.0.0, <3.0.0",

sagemaker-core/src/sagemaker/core/deserializers/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,10 @@ def __init__(self, accept="tensor/pt"):
366366

367367
self.convert_npy_to_tensor = from_numpy
368368
except ImportError:
369-
raise Exception("Unable to import pytorch.")
369+
raise ImportError(
370+
"Unable to import torch. Please install torch to use TorchTensorDeserializer: "
371+
"pip install 'sagemaker-core[torch]'"
372+
)
370373

371374
def deserialize(self, stream, content_type="tensor/pt"):
372375
"""Deserialize streamed data to TorchTensor

sagemaker-core/src/sagemaker/core/serializers/base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,13 @@ class TorchTensorSerializer(SimpleBaseSerializer):
443443

444444
def __init__(self, content_type="tensor/pt"):
445445
super(TorchTensorSerializer, self).__init__(content_type=content_type)
446-
from torch import Tensor
446+
try:
447+
from torch import Tensor
448+
except ImportError:
449+
raise ImportError(
450+
"Unable to import torch. Please install torch to use TorchTensorSerializer: "
451+
"pip install 'sagemaker-core[torch]'"
452+
)
447453

448454
self.torch_tensor = Tensor
449455
self.numpy_serializer = NumpySerializer()
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Tests to verify torch dependency is optional in sagemaker-core."""
14+
from __future__ import absolute_import
15+
16+
import io
17+
import sys
18+
from unittest import mock
19+
20+
import numpy as np
21+
import pytest
22+
23+
24+
def test_serializer_module_imports_without_torch():
25+
"""Verify that importing serializers module succeeds without torch installed."""
26+
# The serializers module should be importable even without torch
27+
# because TorchTensorSerializer uses lazy import in __init__
28+
from sagemaker.core.serializers.base import (
29+
CSVSerializer,
30+
NumpySerializer,
31+
JSONSerializer,
32+
IdentitySerializer,
33+
SparseMatrixSerializer,
34+
JSONLinesSerializer,
35+
LibSVMSerializer,
36+
DataSerializer,
37+
StringSerializer,
38+
)
39+
# Verify non-torch serializers can be instantiated
40+
assert CSVSerializer() is not None
41+
assert NumpySerializer() is not None
42+
assert JSONSerializer() is not None
43+
assert IdentitySerializer() is not None
44+
45+
46+
def test_deserializer_module_imports_without_torch():
47+
"""Verify that importing deserializers module succeeds without torch installed."""
48+
from sagemaker.core.deserializers.base import (
49+
StringDeserializer,
50+
BytesDeserializer,
51+
CSVDeserializer,
52+
StreamDeserializer,
53+
NumpyDeserializer,
54+
JSONDeserializer,
55+
PandasDeserializer,
56+
JSONLinesDeserializer,
57+
)
58+
# Verify non-torch deserializers can be instantiated
59+
assert StringDeserializer() is not None
60+
assert BytesDeserializer() is not None
61+
assert CSVDeserializer() is not None
62+
assert NumpyDeserializer() is not None
63+
assert JSONDeserializer() is not None
64+
65+
66+
def test_torch_tensor_serializer_raises_import_error_without_torch():
67+
"""Verify TorchTensorSerializer raises ImportError when torch is not installed."""
68+
import importlib
69+
import sagemaker.core.serializers.base as ser_module
70+
71+
# Save original torch module if present
72+
original_torch = sys.modules.get('torch')
73+
74+
try:
75+
# Simulate torch not being installed
76+
sys.modules['torch'] = None
77+
# Need to also handle the case where torch submodules are cached
78+
torch_keys = [key for key in sys.modules if key.startswith('torch.')]
79+
saved = {key: sys.modules.pop(key) for key in torch_keys}
80+
81+
with pytest.raises(ImportError, match="Unable to import torch"):
82+
ser_module.TorchTensorSerializer()
83+
finally:
84+
# Restore original state
85+
if original_torch is not None:
86+
sys.modules['torch'] = original_torch
87+
elif 'torch' in sys.modules:
88+
del sys.modules['torch']
89+
for key, val in saved.items():
90+
sys.modules[key] = val
91+
92+
93+
def test_torch_tensor_deserializer_raises_import_error_without_torch():
94+
"""Verify TorchTensorDeserializer raises ImportError when torch is not installed."""
95+
import sagemaker.core.deserializers.base as deser_module
96+
97+
# Save original torch module if present
98+
original_torch = sys.modules.get('torch')
99+
100+
try:
101+
# Simulate torch not being installed
102+
sys.modules['torch'] = None
103+
torch_keys = [key for key in sys.modules if key.startswith('torch.')]
104+
saved = {key: sys.modules.pop(key) for key in torch_keys}
105+
106+
with pytest.raises(ImportError, match="Unable to import torch"):
107+
deser_module.TorchTensorDeserializer()
108+
finally:
109+
# Restore original state
110+
if original_torch is not None:
111+
sys.modules['torch'] = original_torch
112+
elif 'torch' in sys.modules:
113+
del sys.modules['torch']
114+
for key, val in saved.items():
115+
sys.modules[key] = val
116+
117+
118+
def test_torch_tensor_serializer_works_with_torch():
119+
"""Verify TorchTensorSerializer works when torch is available."""
120+
try:
121+
import torch
122+
except ImportError:
123+
pytest.skip("torch is not installed")
124+
125+
from sagemaker.core.serializers.base import TorchTensorSerializer
126+
127+
serializer = TorchTensorSerializer()
128+
tensor = torch.tensor([1.0, 2.0, 3.0])
129+
result = serializer.serialize(tensor)
130+
assert result is not None
131+
# Verify the result can be loaded back as numpy
132+
array = np.load(io.BytesIO(result))
133+
assert np.array_equal(array, np.array([1.0, 2.0, 3.0]))
134+
135+
136+
def test_torch_tensor_deserializer_works_with_torch():
137+
"""Verify TorchTensorDeserializer works when torch is available."""
138+
try:
139+
import torch
140+
except ImportError:
141+
pytest.skip("torch is not installed")
142+
143+
from sagemaker.core.deserializers.base import TorchTensorDeserializer
144+
145+
deserializer = TorchTensorDeserializer()
146+
# Create a numpy array, save it, and deserialize to tensor
147+
array = np.array([1.0, 2.0, 3.0])
148+
buffer = io.BytesIO()
149+
np.save(buffer, array)
150+
buffer.seek(0)
151+
152+
result = deserializer.deserialize(buffer, "tensor/pt")
153+
assert isinstance(result, torch.Tensor)
154+
assert torch.equal(result, torch.tensor([1.0, 2.0, 3.0]))

0 commit comments

Comments
 (0)