Skip to content

Commit 57b5c02

Browse files
committed
fix: address review comments (iteration #1)
1 parent 89be62e commit 57b5c02

File tree

4 files changed

+73
-75
lines changed

4 files changed

+73
-75
lines changed

sagemaker-core/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ torch = [
5454
"torch>=1.9.0",
5555
]
5656
all = [
57-
"torch>=1.9.0",
57+
"sagemaker-core[torch]",
5858
]
5959
codegen = [
6060
"black>=24.3.0, <25.0.0",

sagemaker-core/src/sagemaker/core/deserializers/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,11 +365,11 @@ def __init__(self, accept="tensor/pt"):
365365
from torch import from_numpy
366366

367367
self.convert_npy_to_tensor = from_numpy
368-
except ImportError:
368+
except ImportError as e:
369369
raise ImportError(
370370
"Unable to import torch. Please install torch to use TorchTensorDeserializer: "
371371
"pip install 'sagemaker-core[torch]'"
372-
)
372+
) from e
373373

374374
def deserialize(self, stream, content_type="tensor/pt"):
375375
"""Deserialize streamed data to TorchTensor

sagemaker-core/src/sagemaker/core/serializers/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,11 +445,11 @@ def __init__(self, content_type="tensor/pt"):
445445
super(TorchTensorSerializer, self).__init__(content_type=content_type)
446446
try:
447447
from torch import Tensor
448-
except ImportError:
448+
except ImportError as e:
449449
raise ImportError(
450450
"Unable to import torch. Please install torch to use TorchTensorSerializer: "
451451
"pip install 'sagemaker-core[torch]'"
452-
)
452+
) from e
453453

454454
self.torch_tensor = Tensor
455455
self.numpy_serializer = NumpySerializer()

sagemaker-core/tests/unit/test_optional_torch_dependency.py

Lines changed: 68 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -11,108 +11,106 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
"""Tests to verify torch dependency is optional in sagemaker-core."""
14-
from __future__ import absolute_import
14+
from __future__ import annotations
1515

16+
import importlib
1617
import io
1718
import sys
18-
from unittest import mock
1919

2020
import numpy as np
2121
import pytest
2222

2323

24+
def _block_torch():
25+
"""Block torch imports by setting sys.modules['torch'] to None.
26+
27+
Returns a dict of saved torch submodule entries so they can be restored.
28+
"""
29+
saved = {}
30+
torch_keys = [key for key in sys.modules if key.startswith("torch.")]
31+
saved = {key: sys.modules.pop(key) for key in torch_keys}
32+
saved["torch"] = sys.modules.get("torch")
33+
sys.modules["torch"] = None
34+
return saved
35+
36+
37+
def _restore_torch(saved):
38+
"""Restore torch modules from saved dict."""
39+
original_torch = saved.pop("torch", None)
40+
if original_torch is not None:
41+
sys.modules["torch"] = original_torch
42+
elif "torch" in sys.modules:
43+
del sys.modules["torch"]
44+
for key, val in saved.items():
45+
sys.modules[key] = val
46+
47+
2448
def test_serializer_module_imports_without_torch():
25-
"""Verify that importing serializers module succeeds without torch installed."""
26-
# The serializers module should be importable even without torch
27-
# because TorchTensorSerializer uses lazy import in __init__
28-
from sagemaker.core.serializers.base import (
29-
CSVSerializer,
30-
NumpySerializer,
31-
JSONSerializer,
32-
IdentitySerializer,
33-
SparseMatrixSerializer,
34-
JSONLinesSerializer,
35-
LibSVMSerializer,
36-
DataSerializer,
37-
StringSerializer,
38-
)
39-
# Verify non-torch serializers can be instantiated
40-
assert CSVSerializer() is not None
41-
assert NumpySerializer() is not None
42-
assert JSONSerializer() is not None
43-
assert IdentitySerializer() is not None
49+
"""Verify that importing non-torch serializers succeeds without torch installed."""
50+
saved = {}
51+
try:
52+
saved = _block_torch()
53+
54+
# Reload the module so it re-evaluates imports with torch blocked
55+
import sagemaker.core.serializers.base as ser_module
56+
57+
importlib.reload(ser_module)
58+
59+
# Verify non-torch serializers can be instantiated
60+
assert ser_module.CSVSerializer() is not None
61+
assert ser_module.NumpySerializer() is not None
62+
assert ser_module.JSONSerializer() is not None
63+
assert ser_module.IdentitySerializer() is not None
64+
finally:
65+
_restore_torch(saved)
4466

4567

4668
def test_deserializer_module_imports_without_torch():
47-
"""Verify that importing deserializers module succeeds without torch installed."""
48-
from sagemaker.core.deserializers.base import (
49-
StringDeserializer,
50-
BytesDeserializer,
51-
CSVDeserializer,
52-
StreamDeserializer,
53-
NumpyDeserializer,
54-
JSONDeserializer,
55-
PandasDeserializer,
56-
JSONLinesDeserializer,
57-
)
58-
# Verify non-torch deserializers can be instantiated
59-
assert StringDeserializer() is not None
60-
assert BytesDeserializer() is not None
61-
assert CSVDeserializer() is not None
62-
assert NumpyDeserializer() is not None
63-
assert JSONDeserializer() is not None
69+
"""Verify that importing non-torch deserializers succeeds without torch installed."""
70+
saved = {}
71+
try:
72+
saved = _block_torch()
73+
74+
import sagemaker.core.deserializers.base as deser_module
75+
76+
importlib.reload(deser_module)
77+
78+
# Verify non-torch deserializers can be instantiated
79+
assert deser_module.StringDeserializer() is not None
80+
assert deser_module.BytesDeserializer() is not None
81+
assert deser_module.CSVDeserializer() is not None
82+
assert deser_module.NumpyDeserializer() is not None
83+
assert deser_module.JSONDeserializer() is not None
84+
finally:
85+
_restore_torch(saved)
6486

6587

6688
def test_torch_tensor_serializer_raises_import_error_without_torch():
6789
"""Verify TorchTensorSerializer raises ImportError when torch is not installed."""
68-
import importlib
6990
import sagemaker.core.serializers.base as ser_module
7091

71-
# Save original torch module if present
72-
original_torch = sys.modules.get('torch')
73-
92+
saved = {}
7493
try:
75-
# Simulate torch not being installed
76-
sys.modules['torch'] = None
77-
# Need to also handle the case where torch submodules are cached
78-
torch_keys = [key for key in sys.modules if key.startswith('torch.')]
79-
saved = {key: sys.modules.pop(key) for key in torch_keys}
80-
94+
saved = _block_torch()
95+
8196
with pytest.raises(ImportError, match="Unable to import torch"):
8297
ser_module.TorchTensorSerializer()
8398
finally:
84-
# Restore original state
85-
if original_torch is not None:
86-
sys.modules['torch'] = original_torch
87-
elif 'torch' in sys.modules:
88-
del sys.modules['torch']
89-
for key, val in saved.items():
90-
sys.modules[key] = val
99+
_restore_torch(saved)
91100

92101

93102
def test_torch_tensor_deserializer_raises_import_error_without_torch():
94103
"""Verify TorchTensorDeserializer raises ImportError when torch is not installed."""
95104
import sagemaker.core.deserializers.base as deser_module
96105

97-
# Save original torch module if present
98-
original_torch = sys.modules.get('torch')
99-
106+
saved = {}
100107
try:
101-
# Simulate torch not being installed
102-
sys.modules['torch'] = None
103-
torch_keys = [key for key in sys.modules if key.startswith('torch.')]
104-
saved = {key: sys.modules.pop(key) for key in torch_keys}
105-
108+
saved = _block_torch()
109+
106110
with pytest.raises(ImportError, match="Unable to import torch"):
107111
deser_module.TorchTensorDeserializer()
108112
finally:
109-
# Restore original state
110-
if original_torch is not None:
111-
sys.modules['torch'] = original_torch
112-
elif 'torch' in sys.modules:
113-
del sys.modules['torch']
114-
for key, val in saved.items():
115-
sys.modules[key] = val
113+
_restore_torch(saved)
116114

117115

118116
def test_torch_tensor_serializer_works_with_torch():

0 commit comments

Comments
 (0)