Skip to content

Commit 8d6fc7f

Browse files
committed
feature: Torch dependency in sagameker-core to be made optional (5457)
1 parent 6497a94 commit 8d6fc7f

File tree

4 files changed

+108
-3
lines changed

4 files changed

+108
-3
lines changed

sagemaker-core/pyproject.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ dependencies = [
3232
"smdebug_rulesconfig>=1.0.1",
3333
"schema>=0.7.5",
3434
"omegaconf>=2.1.0",
35-
"torch>=1.9.0",
3635
"scipy>=1.5.0",
3736
# Remote function dependencies
3837
"cloudpickle>=2.0.0",
@@ -51,6 +50,12 @@ classifiers = [
5150
]
5251

5352
[project.optional-dependencies]
53+
torch = [
54+
"torch>=1.9.0",
55+
]
56+
all = [
57+
"torch>=1.9.0",
58+
]
5459
codegen = [
5560
"black>=24.3.0, <25.0.0",
5661
"pandas>=2.0.0, <3.0.0",

sagemaker-core/src/sagemaker/core/deserializers/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,10 @@ def __init__(self, accept="tensor/pt"):
366366

367367
self.convert_npy_to_tensor = from_numpy
368368
except ImportError:
369-
raise Exception("Unable to import pytorch.")
369+
raise ImportError(
370+
"Unable to import torch. Please install torch to use TorchTensorDeserializer: "
371+
"'pip install torch' or 'pip install sagemaker-core[torch]'"
372+
)
370373

371374
def deserialize(self, stream, content_type="tensor/pt"):
372375
"""Deserialize streamed data to TorchTensor

sagemaker-core/src/sagemaker/core/serializers/base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,13 @@ class TorchTensorSerializer(SimpleBaseSerializer):
443443

444444
def __init__(self, content_type="tensor/pt"):
445445
super(TorchTensorSerializer, self).__init__(content_type=content_type)
446-
from torch import Tensor
446+
try:
447+
from torch import Tensor
448+
except ImportError:
449+
raise ImportError(
450+
"Unable to import torch. Please install torch to use TorchTensorSerializer: "
451+
"'pip install torch' or 'pip install sagemaker-core[torch]'"
452+
)
447453

448454
self.torch_tensor = Tensor
449455
self.numpy_serializer = NumpySerializer()
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import sys
16+
from unittest.mock import patch, MagicMock
17+
18+
import pytest
19+
import numpy as np
20+
21+
22+
def test_torch_tensor_serializer_raises_import_error_when_torch_missing():
23+
"""Verify TorchTensorSerializer() raises ImportError with helpful install message
24+
when torch is not installed."""
25+
import sagemaker.core.serializers.base as base_module
26+
27+
with patch.dict(sys.modules, {"torch": None}):
28+
with pytest.raises(ImportError, match="pip install.*torch"):
29+
base_module.TorchTensorSerializer()
30+
31+
32+
def test_torch_tensor_deserializer_raises_import_error_when_torch_missing():
33+
"""Verify TorchTensorDeserializer() raises ImportError with helpful install message
34+
when torch is not installed."""
35+
import sagemaker.core.deserializers.base as base_module
36+
37+
with patch.dict(sys.modules, {"torch": None}):
38+
with pytest.raises(ImportError, match="pip install.*torch"):
39+
base_module.TorchTensorDeserializer()
40+
41+
42+
def test_non_torch_serializers_work_without_torch():
43+
"""Verify CSVSerializer, JSONSerializer, NumpySerializer etc. all work fine
44+
even if torch is not available."""
45+
from sagemaker.core.serializers.base import (
46+
CSVSerializer,
47+
JSONSerializer,
48+
NumpySerializer,
49+
IdentitySerializer,
50+
)
51+
52+
csv_ser = CSVSerializer()
53+
assert csv_ser.serialize([1, 2, 3]) == "1,2,3"
54+
55+
json_ser = JSONSerializer()
56+
assert json_ser.serialize({"a": 1}) == '{"a": 1}'
57+
58+
numpy_ser = NumpySerializer()
59+
result = numpy_ser.serialize(np.array([1, 2, 3]))
60+
assert result is not None
61+
62+
identity_ser = IdentitySerializer()
63+
assert identity_ser.serialize(b"hello") == b"hello"
64+
65+
66+
def test_torch_tensor_serializer_works_when_torch_available():
67+
"""Verify TorchTensorSerializer works normally when torch is installed."""
68+
try:
69+
import torch
70+
except ImportError:
71+
pytest.skip("torch not installed")
72+
73+
from sagemaker.core.serializers.base import TorchTensorSerializer
74+
75+
serializer = TorchTensorSerializer()
76+
tensor = torch.tensor([1.0, 2.0, 3.0])
77+
result = serializer.serialize(tensor)
78+
assert result is not None
79+
80+
81+
def test_torch_tensor_deserializer_works_when_torch_available():
82+
"""Verify TorchTensorDeserializer works normally when torch is installed."""
83+
try:
84+
import torch
85+
except ImportError:
86+
pytest.skip("torch not installed")
87+
88+
from sagemaker.core.deserializers.base import TorchTensorDeserializer
89+
90+
deserializer = TorchTensorDeserializer()
91+
assert deserializer is not None

0 commit comments

Comments
 (0)