Skip to content

Commit 2cc8af9

Browse files
committed
fix: address review comments (iteration #1)
1 parent 769d06e commit 2cc8af9

File tree

4 files changed

+9
-7
lines changed

4 files changed

+9
-7
lines changed

sagemaker-core/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ torch = [
6060
"torch>=1.9.0",
6161
]
6262
all = [
63-
"torch>=1.9.0",
63+
"sagemaker-core[torch]",
6464
]
6565
test = [
6666
"pytest>=8.0.0, <9.0.0",

sagemaker-core/src/sagemaker/core/deserializers/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,11 +365,11 @@ def __init__(self, accept="tensor/pt"):
365365
from torch import from_numpy
366366

367367
self.convert_npy_to_tensor = from_numpy
368-
except ImportError:
368+
except ImportError as e:
369369
raise ImportError(
370370
"Unable to import torch. Please install torch to use TorchTensorDeserializer: "
371371
"pip install 'sagemaker-core[torch]'"
372-
)
372+
) from e
373373

374374
def deserialize(self, stream, content_type="tensor/pt"):
375375
"""Deserialize streamed data to TorchTensor

sagemaker-core/src/sagemaker/core/serializers/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,11 +445,11 @@ def __init__(self, content_type="tensor/pt"):
445445
super(TorchTensorSerializer, self).__init__(content_type=content_type)
446446
try:
447447
from torch import Tensor
448-
except ImportError:
448+
except ImportError as e:
449449
raise ImportError(
450450
"Unable to import torch. Please install torch to use TorchTensorSerializer: "
451451
"pip install 'sagemaker-core[torch]'"
452-
)
452+
) from e
453453

454454
self.torch_tensor = Tensor
455455
self.numpy_serializer = NumpySerializer()

sagemaker-core/tests/unit/test_torch_optional_dependency.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import sys
17-
from unittest.mock import patch, MagicMock
17+
from unittest.mock import patch
1818

1919
import numpy as np
2020
import pytest
@@ -36,7 +36,7 @@ def test_torch_tensor_serializer_raises_import_error_when_torch_missing():
3636

3737

3838
def test_torch_tensor_deserializer_raises_import_error_when_torch_missing():
39-
"""Verify TorchTensorDeserializer raises ImportError with helpful message when torch is missing."""
39+
"""Verify TorchTensorDeserializer raises ImportError when torch is missing."""
4040
import importlib
4141
import sagemaker.core.deserializers.base as base_module
4242

@@ -51,6 +51,7 @@ def test_torch_tensor_deserializer_raises_import_error_when_torch_missing():
5151

5252
def test_torch_tensor_serializer_works_when_torch_installed():
5353
"""Verify TorchTensorSerializer can be instantiated when torch is available."""
54+
pytest.importorskip("torch")
5455
from sagemaker.core.serializers.base import TorchTensorSerializer
5556

5657
serializer = TorchTensorSerializer()
@@ -60,6 +61,7 @@ def test_torch_tensor_serializer_works_when_torch_installed():
6061

6162
def test_torch_tensor_deserializer_works_when_torch_installed():
6263
"""Verify TorchTensorDeserializer can be instantiated when torch is available."""
64+
pytest.importorskip("torch")
6365
from sagemaker.core.deserializers.base import TorchTensorDeserializer
6466

6567
deserializer = TorchTensorDeserializer()

0 commit comments

Comments
 (0)