Skip to content

Commit 6b4a35f

Browse files
authored
feature: Torch dependency in sagameker-core to be made optional (5457) (aws#5713)
* feature: Torch dependency in sagameker-core to be made optional (5457) * fix: address review comments (iteration #1) * fix: address review comments (iteration #1) * fix: use subprocess instead of importlib.reload to avoid breaking six.with_metaclass super()
1 parent 215713f commit 6b4a35f

File tree

5 files changed

+231
-5
lines changed

5 files changed

+231
-5
lines changed

sagemaker-core/pyproject.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ dependencies = [
3232
"smdebug_rulesconfig>=1.0.1",
3333
"schema>=0.7.5",
3434
"omegaconf>=2.1.0",
35-
"torch>=1.9.0",
3635
"scipy>=1.5.0",
3736
# Remote function dependencies
3837
"cloudpickle>=2.0.0",
@@ -52,6 +51,12 @@ classifiers = [
5251
]
5352

5453
[project.optional-dependencies]
54+
torch = [
55+
"torch>=1.9.0",
56+
]
57+
all = [
58+
"sagemaker-core[torch]",
59+
]
5560
codegen = [
5661
"black>=24.3.0, <25.0.0",
5762
"pandas>=2.0.0, <3.0.0",
@@ -62,6 +67,7 @@ test = [
6267
"pytest>=8.0.0, <9.0.0",
6368
"pytest-cov>=4.0.0",
6469
"pytest-xdist>=3.0.0",
70+
"sagemaker-core[torch]",
6571
]
6672

6773
[project.urls]

sagemaker-core/src/sagemaker/core/deserializers/base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,8 +365,11 @@ def __init__(self, accept="tensor/pt"):
365365
from torch import from_numpy
366366

367367
self.convert_npy_to_tensor = from_numpy
368-
except ImportError:
369-
raise Exception("Unable to import pytorch.")
368+
except ImportError as e:
369+
raise ImportError(
370+
"Unable to import torch. Please install torch to use TorchTensorDeserializer: "
371+
"pip install 'sagemaker-core[torch]'"
372+
) from e
370373

371374
def deserialize(self, stream, content_type="tensor/pt"):
372375
"""Deserialize streamed data to TorchTensor

sagemaker-core/src/sagemaker/core/serializers/base.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -443,9 +443,16 @@ class TorchTensorSerializer(SimpleBaseSerializer):
443443

444444
def __init__(self, content_type="tensor/pt"):
445445
super(TorchTensorSerializer, self).__init__(content_type=content_type)
446-
from torch import Tensor
446+
try:
447+
from torch import Tensor
448+
449+
self.torch_tensor = Tensor
450+
except ImportError as e:
451+
raise ImportError(
452+
"Unable to import torch. Please install torch to use TorchTensorSerializer: "
453+
"pip install 'sagemaker-core[torch]'"
454+
) from e
447455

448-
self.torch_tensor = Tensor
449456
self.numpy_serializer = NumpySerializer()
450457

451458
def serialize(self, data):
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Tests to verify torch dependency is optional in sagemaker-core.
14+
15+
The "module imports without torch" tests use subprocess instead of
16+
importlib.reload to avoid poisoning the class hierarchy in the current
17+
process. six.with_metaclass + old-style super() breaks when a module
18+
is reloaded because the class identity changes, causing
19+
``TypeError: super(type, obj): obj must be an instance or subtype of type``
20+
in subsequent tests that instantiate serializers/deserializers.
21+
"""
22+
from __future__ import absolute_import
23+
24+
import io
25+
import subprocess
26+
import sys
27+
import textwrap
28+
29+
import numpy as np
30+
import pytest
31+
32+
33+
def _block_torch():
34+
"""Block torch imports by setting sys.modules['torch'] to None.
35+
36+
Returns a dict of saved torch submodule entries so they can be restored.
37+
"""
38+
torch_keys = [key for key in sys.modules if key.startswith("torch.")]
39+
saved = {key: sys.modules.pop(key) for key in torch_keys}
40+
saved["torch"] = sys.modules.get("torch")
41+
sys.modules["torch"] = None
42+
return saved
43+
44+
45+
def _restore_torch(saved):
46+
"""Restore torch modules from saved dict."""
47+
original_torch = saved.pop("torch", None)
48+
if original_torch is not None:
49+
sys.modules["torch"] = original_torch
50+
elif "torch" in sys.modules:
51+
del sys.modules["torch"]
52+
for key, val in saved.items():
53+
sys.modules[key] = val
54+
55+
56+
def test_serializer_module_imports_without_torch():
57+
"""Verify that non-torch serializers can be imported and instantiated without torch.
58+
59+
Runs in a subprocess to avoid polluting the current process's class
60+
hierarchy via importlib.reload (which breaks six.with_metaclass).
61+
"""
62+
code = textwrap.dedent("""\
63+
import sys
64+
# Block torch before any sagemaker imports
65+
sys.modules["torch"] = None
66+
67+
from sagemaker.core.serializers.base import (
68+
CSVSerializer,
69+
NumpySerializer,
70+
JSONSerializer,
71+
IdentitySerializer,
72+
)
73+
74+
assert CSVSerializer() is not None
75+
assert NumpySerializer() is not None
76+
assert JSONSerializer() is not None
77+
assert IdentitySerializer() is not None
78+
print("OK")
79+
""")
80+
result = subprocess.run(
81+
[sys.executable, "-c", code],
82+
capture_output=True,
83+
text=True,
84+
)
85+
assert result.returncode == 0, (
86+
f"Subprocess failed:\nstdout: {result.stdout}\nstderr: {result.stderr}"
87+
)
88+
89+
90+
def test_deserializer_module_imports_without_torch():
91+
"""Verify that non-torch deserializers can be imported and instantiated without torch.
92+
93+
Runs in a subprocess for the same reason as the serializer test above.
94+
"""
95+
code = textwrap.dedent("""\
96+
import sys
97+
sys.modules["torch"] = None
98+
99+
from sagemaker.core.deserializers.base import (
100+
StringDeserializer,
101+
BytesDeserializer,
102+
CSVDeserializer,
103+
NumpyDeserializer,
104+
JSONDeserializer,
105+
)
106+
107+
assert StringDeserializer() is not None
108+
assert BytesDeserializer() is not None
109+
assert CSVDeserializer() is not None
110+
assert NumpyDeserializer() is not None
111+
assert JSONDeserializer() is not None
112+
print("OK")
113+
""")
114+
result = subprocess.run(
115+
[sys.executable, "-c", code],
116+
capture_output=True,
117+
text=True,
118+
)
119+
assert result.returncode == 0, (
120+
f"Subprocess failed:\nstdout: {result.stdout}\nstderr: {result.stderr}"
121+
)
122+
123+
124+
def test_torch_tensor_serializer_raises_import_error_without_torch():
125+
"""Verify TorchTensorSerializer raises ImportError when torch is not installed."""
126+
import sagemaker.core.serializers.base as ser_module
127+
128+
saved = {}
129+
try:
130+
saved = _block_torch()
131+
132+
with pytest.raises(ImportError, match="Unable to import torch"):
133+
ser_module.TorchTensorSerializer()
134+
finally:
135+
_restore_torch(saved)
136+
137+
138+
def test_torch_tensor_deserializer_raises_import_error_without_torch():
139+
"""Verify TorchTensorDeserializer raises ImportError when torch is not installed."""
140+
import sagemaker.core.deserializers.base as deser_module
141+
142+
saved = {}
143+
try:
144+
saved = _block_torch()
145+
146+
with pytest.raises(ImportError, match="Unable to import torch"):
147+
deser_module.TorchTensorDeserializer()
148+
finally:
149+
_restore_torch(saved)
150+
151+
152+
def test_torch_tensor_serializer_works_with_torch():
153+
"""Verify TorchTensorSerializer works when torch is available."""
154+
try:
155+
import torch
156+
except ImportError:
157+
pytest.skip("torch is not installed")
158+
159+
from sagemaker.core.serializers.base import TorchTensorSerializer
160+
161+
serializer = TorchTensorSerializer()
162+
tensor = torch.tensor([1.0, 2.0, 3.0])
163+
result = serializer.serialize(tensor)
164+
assert result is not None
165+
# Verify the result can be loaded back as numpy
166+
array = np.load(io.BytesIO(result))
167+
assert np.array_equal(array, np.array([1.0, 2.0, 3.0]))
168+
169+
170+
def test_torch_tensor_deserializer_works_with_torch():
171+
"""Verify TorchTensorDeserializer works when torch is available."""
172+
try:
173+
import torch
174+
except ImportError:
175+
pytest.skip("torch is not installed")
176+
177+
from sagemaker.core.deserializers.base import TorchTensorDeserializer
178+
179+
deserializer = TorchTensorDeserializer()
180+
# Create a numpy array, save it, and deserialize to tensor
181+
array = np.array([1.0, 2.0, 3.0])
182+
buffer = io.BytesIO()
183+
np.save(buffer, array)
184+
buffer.seek(0)
185+
186+
result = deserializer.deserialize(buffer, "tensor/pt")
187+
assert isinstance(result, torch.Tensor)
188+
assert torch.equal(result, torch.tensor([1.0, 2.0, 3.0]))

sagemaker-core/tests/unit/test_serializer_implementations.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,25 @@ def test_numpy_serializer_import(self):
162162
def test_record_serializer_deprecated(self):
163163
"""Test that numpy_to_record_serializer is available as deprecated."""
164164
assert hasattr(implementations, "numpy_to_record_serializer")
165+
166+
167+
class TestTorchSerializerWithOptionalDependency:
168+
"""Test torch serializer/deserializer with optional torch dependency."""
169+
170+
def test_torch_tensor_serializer_instantiation(self):
171+
"""Test that TorchTensorSerializer can be instantiated when torch is available."""
172+
torch = pytest.importorskip("torch")
173+
from sagemaker.core.serializers.base import TorchTensorSerializer
174+
175+
serializer = TorchTensorSerializer()
176+
assert serializer is not None
177+
assert serializer.content_type == "tensor/pt"
178+
179+
def test_torch_tensor_deserializer_instantiation(self):
180+
"""Test that TorchTensorDeserializer can be instantiated when torch is available."""
181+
torch = pytest.importorskip("torch")
182+
from sagemaker.core.deserializers.base import TorchTensorDeserializer
183+
184+
deserializer = TorchTensorDeserializer()
185+
assert deserializer is not None
186+
assert deserializer.accept == "tensor/pt"

0 commit comments

Comments
 (0)