Skip to content

Commit d1ca9e0

Browse files
committed
fix: address review comments (iteration #3)
1 parent 808472d commit d1ca9e0

File tree

3 files changed

+337
-174
lines changed

3 files changed

+337
-174
lines changed

sagemaker-core/src/sagemaker/core/serializers/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,6 @@ def __init__(self, content_type="tensor/pt"):
452452
"Unable to import torch. Please install torch to use TorchTensorSerializer: "
453453
"pip install 'sagemaker-core[torch]'"
454454
) from e
455-
456455
self.numpy_serializer = NumpySerializer()
457456

458457
def serialize(self, data):

sagemaker-core/tests/unit/test_optional_torch_dependency.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,8 @@ def _block_torch():
2525
"""Block torch imports by setting sys.modules['torch'] to None.
2626
2727
Returns a dict of saved torch submodule entries so they can be restored.
28-
29-
Note: This only saves and removes torch submodules that exist at the time
30-
of the call. Submodules imported *during* the test (after blocking) are not
31-
tracked and will not be cleaned up automatically.
3228
"""
29+
saved = {}
3330
torch_keys = [key for key in sys.modules if key.startswith("torch.")]
3431
saved = {key: sys.modules.pop(key) for key in torch_keys}
3532
saved["torch"] = sys.modules.get("torch")
@@ -50,11 +47,13 @@ def _restore_torch(saved):
5047

5148
def test_serializer_module_imports_without_torch():
5249
"""Verify that importing non-torch serializers succeeds without torch installed."""
53-
import sagemaker.core.serializers.base as ser_module
54-
55-
saved = _block_torch()
50+
saved = {}
5651
try:
52+
saved = _block_torch()
53+
5754
# Reload the module so it re-evaluates imports with torch blocked
55+
import sagemaker.core.serializers.base as ser_module
56+
5857
importlib.reload(ser_module)
5958

6059
# Verify non-torch serializers can be instantiated
@@ -64,15 +63,16 @@ def test_serializer_module_imports_without_torch():
6463
assert ser_module.IdentitySerializer() is not None
6564
finally:
6665
_restore_torch(saved)
67-
importlib.reload(ser_module)
6866

6967

7068
def test_deserializer_module_imports_without_torch():
7169
"""Verify that importing non-torch deserializers succeeds without torch installed."""
72-
import sagemaker.core.deserializers.base as deser_module
73-
74-
saved = _block_torch()
70+
saved = {}
7571
try:
72+
saved = _block_torch()
73+
74+
import sagemaker.core.deserializers.base as deser_module
75+
7676
importlib.reload(deser_module)
7777

7878
# Verify non-torch deserializers can be instantiated
@@ -83,45 +83,42 @@ def test_deserializer_module_imports_without_torch():
8383
assert deser_module.JSONDeserializer() is not None
8484
finally:
8585
_restore_torch(saved)
86-
importlib.reload(deser_module)
8786

8887

8988
def test_torch_tensor_serializer_raises_import_error_without_torch():
9089
"""Verify TorchTensorSerializer raises ImportError when torch is not installed."""
9190
import sagemaker.core.serializers.base as ser_module
9291

93-
saved = _block_torch()
92+
saved = {}
9493
try:
95-
# Reload after blocking torch for consistency — ensures the module
96-
# does not cache torch at import time.
97-
importlib.reload(ser_module)
94+
saved = _block_torch()
9895

9996
with pytest.raises(ImportError, match="Unable to import torch"):
10097
ser_module.TorchTensorSerializer()
10198
finally:
10299
_restore_torch(saved)
103-
importlib.reload(ser_module)
104100

105101

106102
def test_torch_tensor_deserializer_raises_import_error_without_torch():
107103
"""Verify TorchTensorDeserializer raises ImportError when torch is not installed."""
108104
import sagemaker.core.deserializers.base as deser_module
109105

110-
saved = _block_torch()
106+
saved = {}
111107
try:
112-
# Reload after blocking torch for consistency
113-
importlib.reload(deser_module)
108+
saved = _block_torch()
114109

115110
with pytest.raises(ImportError, match="Unable to import torch"):
116111
deser_module.TorchTensorDeserializer()
117112
finally:
118113
_restore_torch(saved)
119-
importlib.reload(deser_module)
120114

121115

122116
def test_torch_tensor_serializer_works_with_torch():
123117
"""Verify TorchTensorSerializer works when torch is available."""
124-
torch = pytest.importorskip("torch")
118+
try:
119+
import torch
120+
except ImportError:
121+
pytest.skip("torch is not installed")
125122

126123
from sagemaker.core.serializers.base import TorchTensorSerializer
127124

@@ -136,7 +133,10 @@ def test_torch_tensor_serializer_works_with_torch():
136133

137134
def test_torch_tensor_deserializer_works_with_torch():
138135
"""Verify TorchTensorDeserializer works when torch is available."""
139-
torch = pytest.importorskip("torch")
136+
try:
137+
import torch
138+
except ImportError:
139+
pytest.skip("torch is not installed")
140140

141141
from sagemaker.core.deserializers.base import TorchTensorDeserializer
142142

0 commit comments

Comments
 (0)