Skip to content

Commit 0ed06b2

Browse files
committed
fix: address review comments (iteration #1)
1 parent 57b5c02 commit 0ed06b2

File tree

4 files changed

+26
-3
lines changed

4 files changed

+26
-3
lines changed

sagemaker-core/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ test = [
6666
"pytest>=8.0.0, <9.0.0",
6767
"pytest-cov>=4.0.0",
6868
"pytest-xdist>=3.0.0",
69+
"sagemaker-core[torch]",
6970
]
7071

7172
[project.urls]

sagemaker-core/src/sagemaker/core/serializers/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,13 +445,14 @@ def __init__(self, content_type="tensor/pt"):
445445
super(TorchTensorSerializer, self).__init__(content_type=content_type)
446446
try:
447447
from torch import Tensor
448+
449+
self.torch_tensor = Tensor
448450
except ImportError as e:
449451
raise ImportError(
450452
"Unable to import torch. Please install torch to use TorchTensorSerializer: "
451453
"pip install 'sagemaker-core[torch]'"
452454
) from e
453455

454-
self.torch_tensor = Tensor
455456
self.numpy_serializer = NumpySerializer()
456457

457458
def serialize(self, data):

sagemaker-core/tests/unit/test_serializer_implementations.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
"""Unit tests for sagemaker.core.serializers.implementations module."""
14-
from __future__ import absolute_import
14+
from __future__ import annotations
1515

1616
import pytest
1717
from unittest.mock import Mock, patch
@@ -162,3 +162,25 @@ def test_numpy_serializer_import(self):
162162
def test_record_serializer_deprecated(self):
163163
"""Test that numpy_to_record_serializer is available as deprecated."""
164164
assert hasattr(implementations, "numpy_to_record_serializer")
165+
166+
167+
class TestTorchSerializerWithOptionalDependency:
168+
"""Test torch serializer/deserializer with optional torch dependency."""
169+
170+
def test_torch_tensor_serializer_instantiation(self):
171+
"""Test that TorchTensorSerializer can be instantiated when torch is available."""
172+
torch = pytest.importorskip("torch")
173+
from sagemaker.core.serializers.base import TorchTensorSerializer
174+
175+
serializer = TorchTensorSerializer()
176+
assert serializer is not None
177+
assert serializer.content_type == "tensor/pt"
178+
179+
def test_torch_tensor_deserializer_instantiation(self):
180+
"""Test that TorchTensorDeserializer can be instantiated when torch is available."""
181+
torch = pytest.importorskip("torch")
182+
from sagemaker.core.deserializers.base import TorchTensorDeserializer
183+
184+
deserializer = TorchTensorDeserializer()
185+
assert deserializer is not None
186+
assert deserializer.accept == "tensor/pt"

sagemaker-core/tox.ini

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
[tox]
77
isolated_build = true
88
envlist = black-format,flake8,pylint,docstyle,sphinx,doc8,twine,py39,py310,py311,py312
9-
109
skip_missing_interpreters = False
1110

1211
[flake8]

0 commit comments

Comments
 (0)