Skip to content

Commit 808472d

Browse files
committed
fix: address review comments (iteration #2)
1 parent 0c7374c commit 808472d

File tree

2 files changed

+25
-24
lines changed

2 files changed

+25
-24
lines changed

sagemaker-core/src/sagemaker/core/serializers/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,13 +445,14 @@ def __init__(self, content_type="tensor/pt"):
445445
super(TorchTensorSerializer, self).__init__(content_type=content_type)
446446
try:
447447
from torch import Tensor
448+
449+
self.torch_tensor = Tensor
448450
except ImportError as e:
449451
raise ImportError(
450452
"Unable to import torch. Please install torch to use TorchTensorSerializer: "
451453
"pip install 'sagemaker-core[torch]'"
452454
) from e
453455

454-
self.torch_tensor = Tensor
455456
self.numpy_serializer = NumpySerializer()
456457

457458
def serialize(self, data):

sagemaker-core/tests/unit/test_optional_torch_dependency.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,11 @@ def _block_torch():
2525
"""Block torch imports by setting sys.modules['torch'] to None.
2626
2727
Returns a dict of saved torch submodule entries so they can be restored.
28+
29+
Note: This only saves and removes torch submodules that exist at the time
30+
of the call. Submodules imported *during* the test (after blocking) are not
31+
tracked and will not be cleaned up automatically.
2832
"""
29-
saved = {}
3033
torch_keys = [key for key in sys.modules if key.startswith("torch.")]
3134
saved = {key: sys.modules.pop(key) for key in torch_keys}
3235
saved["torch"] = sys.modules.get("torch")
@@ -47,13 +50,11 @@ def _restore_torch(saved):
4750

4851
def test_serializer_module_imports_without_torch():
4952
"""Verify that importing non-torch serializers succeeds without torch installed."""
50-
saved = {}
51-
try:
52-
saved = _block_torch()
53+
import sagemaker.core.serializers.base as ser_module
5354

55+
saved = _block_torch()
56+
try:
5457
# Reload the module so it re-evaluates imports with torch blocked
55-
import sagemaker.core.serializers.base as ser_module
56-
5758
importlib.reload(ser_module)
5859

5960
# Verify non-torch serializers can be instantiated
@@ -63,16 +64,15 @@ def test_serializer_module_imports_without_torch():
6364
assert ser_module.IdentitySerializer() is not None
6465
finally:
6566
_restore_torch(saved)
67+
importlib.reload(ser_module)
6668

6769

6870
def test_deserializer_module_imports_without_torch():
6971
"""Verify that importing non-torch deserializers succeeds without torch installed."""
70-
saved = {}
71-
try:
72-
saved = _block_torch()
73-
74-
import sagemaker.core.deserializers.base as deser_module
72+
import sagemaker.core.deserializers.base as deser_module
7573

74+
saved = _block_torch()
75+
try:
7676
importlib.reload(deser_module)
7777

7878
# Verify non-torch deserializers can be instantiated
@@ -83,42 +83,45 @@ def test_deserializer_module_imports_without_torch():
8383
assert deser_module.JSONDeserializer() is not None
8484
finally:
8585
_restore_torch(saved)
86+
importlib.reload(deser_module)
8687

8788

8889
def test_torch_tensor_serializer_raises_import_error_without_torch():
8990
"""Verify TorchTensorSerializer raises ImportError when torch is not installed."""
9091
import sagemaker.core.serializers.base as ser_module
9192

92-
saved = {}
93+
saved = _block_torch()
9394
try:
94-
saved = _block_torch()
95+
# Reload after blocking torch for consistency — ensures the module
96+
# does not cache torch at import time.
97+
importlib.reload(ser_module)
9598

9699
with pytest.raises(ImportError, match="Unable to import torch"):
97100
ser_module.TorchTensorSerializer()
98101
finally:
99102
_restore_torch(saved)
103+
importlib.reload(ser_module)
100104

101105

102106
def test_torch_tensor_deserializer_raises_import_error_without_torch():
103107
"""Verify TorchTensorDeserializer raises ImportError when torch is not installed."""
104108
import sagemaker.core.deserializers.base as deser_module
105109

106-
saved = {}
110+
saved = _block_torch()
107111
try:
108-
saved = _block_torch()
112+
# Reload after blocking torch for consistency
113+
importlib.reload(deser_module)
109114

110115
with pytest.raises(ImportError, match="Unable to import torch"):
111116
deser_module.TorchTensorDeserializer()
112117
finally:
113118
_restore_torch(saved)
119+
importlib.reload(deser_module)
114120

115121

116122
def test_torch_tensor_serializer_works_with_torch():
117123
"""Verify TorchTensorSerializer works when torch is available."""
118-
try:
119-
import torch
120-
except ImportError:
121-
pytest.skip("torch is not installed")
124+
torch = pytest.importorskip("torch")
122125

123126
from sagemaker.core.serializers.base import TorchTensorSerializer
124127

@@ -133,10 +136,7 @@ def test_torch_tensor_serializer_works_with_torch():
133136

134137
def test_torch_tensor_deserializer_works_with_torch():
135138
"""Verify TorchTensorDeserializer works when torch is available."""
136-
try:
137-
import torch
138-
except ImportError:
139-
pytest.skip("torch is not installed")
139+
torch = pytest.importorskip("torch")
140140

141141
from sagemaker.core.deserializers.base import TorchTensorDeserializer
142142

0 commit comments

Comments
 (0)