1010# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111# ANY KIND, either express or implied. See the License for the specific
1212# language governing permissions and limitations under the License.
13- """Tests to verify torch dependency is optional in sagemaker-core."""
14- from __future__ import annotations
13+ """Tests to verify torch dependency is optional in sagemaker-core.
14+
15+ The "module imports without torch" tests use subprocess instead of
16+ importlib.reload to avoid poisoning the class hierarchy in the current
17+ process. six.with_metaclass + old-style super() breaks when a module
18+ is reloaded because the class identity changes, causing
19+ ``TypeError: super(type, obj): obj must be an instance or subtype of type``
20+ in subsequent tests that instantiate serializers/deserializers.
21+ """
22+ from __future__ import absolute_import
1523
16- import importlib
1724import io
25+ import subprocess
1826import sys
27+ import textwrap
1928
2029import numpy as np
2130import pytest
@@ -26,7 +35,6 @@ def _block_torch():
2635
2736 Returns a dict of saved torch submodule entries so they can be restored.
2837 """
29- saved = {}
3038 torch_keys = [key for key in sys .modules if key .startswith ("torch." )]
3139 saved = {key : sys .modules .pop (key ) for key in torch_keys }
3240 saved ["torch" ] = sys .modules .get ("torch" )
@@ -46,43 +54,71 @@ def _restore_torch(saved):
4654
4755
4856def test_serializer_module_imports_without_torch ():
49- """Verify that importing non-torch serializers succeeds without torch installed."""
50- saved = {}
51- try :
52- saved = _block_torch ()
57+ """Verify that non-torch serializers can be imported and instantiated without torch.
5358
54- # Reload the module so it re-evaluates imports with torch blocked
55- import sagemaker .core .serializers .base as ser_module
56-
57- importlib .reload (ser_module )
58-
59- # Verify non-torch serializers can be instantiated
60- assert ser_module .CSVSerializer () is not None
61- assert ser_module .NumpySerializer () is not None
62- assert ser_module .JSONSerializer () is not None
63- assert ser_module .IdentitySerializer () is not None
64- finally :
65- _restore_torch (saved )
59+ Runs in a subprocess to avoid polluting the current process's class
60+ hierarchy via importlib.reload (which breaks six.with_metaclass).
61+ """
62+ code = textwrap .dedent ("""\
63+ import sys
64+ # Block torch before any sagemaker imports
65+ sys.modules["torch"] = None
66+
67+ from sagemaker.core.serializers.base import (
68+ CSVSerializer,
69+ NumpySerializer,
70+ JSONSerializer,
71+ IdentitySerializer,
72+ )
73+
74+ assert CSVSerializer() is not None
75+ assert NumpySerializer() is not None
76+ assert JSONSerializer() is not None
77+ assert IdentitySerializer() is not None
78+ print("OK")
79+ """ )
80+ result = subprocess .run (
81+ [sys .executable , "-c" , code ],
82+ capture_output = True ,
83+ text = True ,
84+ )
85+ assert result .returncode == 0 , (
86+ f"Subprocess failed:\n stdout: { result .stdout } \n stderr: { result .stderr } "
87+ )
6688
6789
6890def test_deserializer_module_imports_without_torch ():
69- """Verify that importing non-torch deserializers succeeds without torch installed."""
70- saved = {}
71- try :
72- saved = _block_torch ()
73-
74- import sagemaker .core .deserializers .base as deser_module
91+ """Verify that non-torch deserializers can be imported and instantiated without torch.
7592
76- importlib .reload (deser_module )
77-
78- # Verify non-torch deserializers can be instantiated
79- assert deser_module .StringDeserializer () is not None
80- assert deser_module .BytesDeserializer () is not None
81- assert deser_module .CSVDeserializer () is not None
82- assert deser_module .NumpyDeserializer () is not None
83- assert deser_module .JSONDeserializer () is not None
84- finally :
85- _restore_torch (saved )
93+ Runs in a subprocess for the same reason as the serializer test above.
94+ """
95+ code = textwrap .dedent ("""\
96+ import sys
97+ sys.modules["torch"] = None
98+
99+ from sagemaker.core.deserializers.base import (
100+ StringDeserializer,
101+ BytesDeserializer,
102+ CSVDeserializer,
103+ NumpyDeserializer,
104+ JSONDeserializer,
105+ )
106+
107+ assert StringDeserializer() is not None
108+ assert BytesDeserializer() is not None
109+ assert CSVDeserializer() is not None
110+ assert NumpyDeserializer() is not None
111+ assert JSONDeserializer() is not None
112+ print("OK")
113+ """ )
114+ result = subprocess .run (
115+ [sys .executable , "-c" , code ],
116+ capture_output = True ,
117+ text = True ,
118+ )
119+ assert result .returncode == 0 , (
120+ f"Subprocess failed:\n stdout: { result .stdout } \n stderr: { result .stderr } "
121+ )
86122
87123
88124def test_torch_tensor_serializer_raises_import_error_without_torch ():
0 commit comments