@@ -25,11 +25,8 @@ def _block_torch():
2525 """Block torch imports by setting sys.modules['torch'] to None.
2626
2727 Returns a dict of saved torch submodule entries so they can be restored.
28-
29- Note: This only saves and removes torch submodules that exist at the time
30- of the call. Submodules imported *during* the test (after blocking) are not
31- tracked and will not be cleaned up automatically.
3228 """
29+ saved = {}
3330 torch_keys = [key for key in sys .modules if key .startswith ("torch." )]
3431 saved = {key : sys .modules .pop (key ) for key in torch_keys }
3532 saved ["torch" ] = sys .modules .get ("torch" )
@@ -50,11 +47,13 @@ def _restore_torch(saved):
5047
5148def test_serializer_module_imports_without_torch ():
5249 """Verify that importing non-torch serializers succeeds without torch installed."""
53- import sagemaker .core .serializers .base as ser_module
54-
55- saved = _block_torch ()
50+ saved = {}
5651 try :
52+ saved = _block_torch ()
53+
5754 # Reload the module so it re-evaluates imports with torch blocked
55+ import sagemaker .core .serializers .base as ser_module
56+
5857 importlib .reload (ser_module )
5958
6059 # Verify non-torch serializers can be instantiated
@@ -64,15 +63,16 @@ def test_serializer_module_imports_without_torch():
6463 assert ser_module .IdentitySerializer () is not None
6564 finally :
6665 _restore_torch (saved )
67- importlib .reload (ser_module )
6866
6967
7068def test_deserializer_module_imports_without_torch ():
7169 """Verify that importing non-torch deserializers succeeds without torch installed."""
72- import sagemaker .core .deserializers .base as deser_module
73-
74- saved = _block_torch ()
70+ saved = {}
7571 try :
72+ saved = _block_torch ()
73+
74+ import sagemaker .core .deserializers .base as deser_module
75+
7676 importlib .reload (deser_module )
7777
7878 # Verify non-torch deserializers can be instantiated
@@ -83,45 +83,42 @@ def test_deserializer_module_imports_without_torch():
8383 assert deser_module .JSONDeserializer () is not None
8484 finally :
8585 _restore_torch (saved )
86- importlib .reload (deser_module )
8786
8887
8988def test_torch_tensor_serializer_raises_import_error_without_torch ():
9089 """Verify TorchTensorSerializer raises ImportError when torch is not installed."""
9190 import sagemaker .core .serializers .base as ser_module
9291
93- saved = _block_torch ()
92+ saved = {}
9493 try :
95- # Reload after blocking torch for consistency — ensures the module
96- # does not cache torch at import time.
97- importlib .reload (ser_module )
94+ saved = _block_torch ()
9895
9996 with pytest .raises (ImportError , match = "Unable to import torch" ):
10097 ser_module .TorchTensorSerializer ()
10198 finally :
10299 _restore_torch (saved )
103- importlib .reload (ser_module )
104100
105101
106102def test_torch_tensor_deserializer_raises_import_error_without_torch ():
107103 """Verify TorchTensorDeserializer raises ImportError when torch is not installed."""
108104 import sagemaker .core .deserializers .base as deser_module
109105
110- saved = _block_torch ()
106+ saved = {}
111107 try :
112- # Reload after blocking torch for consistency
113- importlib .reload (deser_module )
108+ saved = _block_torch ()
114109
115110 with pytest .raises (ImportError , match = "Unable to import torch" ):
116111 deser_module .TorchTensorDeserializer ()
117112 finally :
118113 _restore_torch (saved )
119- importlib .reload (deser_module )
120114
121115
122116def test_torch_tensor_serializer_works_with_torch ():
123117 """Verify TorchTensorSerializer works when torch is available."""
124- torch = pytest .importorskip ("torch" )
118+ try :
119+ import torch
120+ except ImportError :
121+ pytest .skip ("torch is not installed" )
125122
126123 from sagemaker .core .serializers .base import TorchTensorSerializer
127124
@@ -136,7 +133,10 @@ def test_torch_tensor_serializer_works_with_torch():
136133
137134def test_torch_tensor_deserializer_works_with_torch ():
138135 """Verify TorchTensorDeserializer works when torch is available."""
139- torch = pytest .importorskip ("torch" )
136+ try :
137+ import torch
138+ except ImportError :
139+ pytest .skip ("torch is not installed" )
140140
141141 from sagemaker .core .deserializers .base import TorchTensorDeserializer
142142
0 commit comments