Skip to content

Commit 5280803

Browse files
committed
fix: use subprocess instead of importlib.reload to avoid breaking six.with_metaclass super()
1 parent 0ed06b2 commit 5280803

2 files changed

Lines changed: 73 additions & 37 deletions

File tree

sagemaker-core/tests/unit/test_optional_torch_dependency.py

Lines changed: 72 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,21 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Tests to verify torch dependency is optional in sagemaker-core."""
14-
from __future__ import annotations
13+
"""Tests to verify torch dependency is optional in sagemaker-core.
14+
15+
The "module imports without torch" tests use subprocess instead of
16+
importlib.reload to avoid poisoning the class hierarchy in the current
17+
process. six.with_metaclass + old-style super() breaks when a module
18+
is reloaded because the class identity changes, causing
19+
``TypeError: super(type, obj): obj must be an instance or subtype of type``
20+
in subsequent tests that instantiate serializers/deserializers.
21+
"""
22+
from __future__ import absolute_import
1523

16-
import importlib
1724
import io
25+
import subprocess
1826
import sys
27+
import textwrap
1928

2029
import numpy as np
2130
import pytest
@@ -26,7 +35,6 @@ def _block_torch():
2635
2736
Returns a dict of saved torch submodule entries so they can be restored.
2837
"""
29-
saved = {}
3038
torch_keys = [key for key in sys.modules if key.startswith("torch.")]
3139
saved = {key: sys.modules.pop(key) for key in torch_keys}
3240
saved["torch"] = sys.modules.get("torch")
@@ -46,43 +54,71 @@ def _restore_torch(saved):
4654

4755

4856
def test_serializer_module_imports_without_torch():
49-
"""Verify that importing non-torch serializers succeeds without torch installed."""
50-
saved = {}
51-
try:
52-
saved = _block_torch()
57+
"""Verify that non-torch serializers can be imported and instantiated without torch.
5358
54-
# Reload the module so it re-evaluates imports with torch blocked
55-
import sagemaker.core.serializers.base as ser_module
56-
57-
importlib.reload(ser_module)
58-
59-
# Verify non-torch serializers can be instantiated
60-
assert ser_module.CSVSerializer() is not None
61-
assert ser_module.NumpySerializer() is not None
62-
assert ser_module.JSONSerializer() is not None
63-
assert ser_module.IdentitySerializer() is not None
64-
finally:
65-
_restore_torch(saved)
59+
Runs in a subprocess to avoid polluting the current process's class
60+
hierarchy via importlib.reload (which breaks six.with_metaclass).
61+
"""
62+
code = textwrap.dedent("""\
63+
import sys
64+
# Block torch before any sagemaker imports
65+
sys.modules["torch"] = None
66+
67+
from sagemaker.core.serializers.base import (
68+
CSVSerializer,
69+
NumpySerializer,
70+
JSONSerializer,
71+
IdentitySerializer,
72+
)
73+
74+
assert CSVSerializer() is not None
75+
assert NumpySerializer() is not None
76+
assert JSONSerializer() is not None
77+
assert IdentitySerializer() is not None
78+
print("OK")
79+
""")
80+
result = subprocess.run(
81+
[sys.executable, "-c", code],
82+
capture_output=True,
83+
text=True,
84+
)
85+
assert result.returncode == 0, (
86+
f"Subprocess failed:\nstdout: {result.stdout}\nstderr: {result.stderr}"
87+
)
6688

6789

6890
def test_deserializer_module_imports_without_torch():
69-
"""Verify that importing non-torch deserializers succeeds without torch installed."""
70-
saved = {}
71-
try:
72-
saved = _block_torch()
73-
74-
import sagemaker.core.deserializers.base as deser_module
91+
"""Verify that non-torch deserializers can be imported and instantiated without torch.
7592
76-
importlib.reload(deser_module)
77-
78-
# Verify non-torch deserializers can be instantiated
79-
assert deser_module.StringDeserializer() is not None
80-
assert deser_module.BytesDeserializer() is not None
81-
assert deser_module.CSVDeserializer() is not None
82-
assert deser_module.NumpyDeserializer() is not None
83-
assert deser_module.JSONDeserializer() is not None
84-
finally:
85-
_restore_torch(saved)
93+
Runs in a subprocess for the same reason as the serializer test above.
94+
"""
95+
code = textwrap.dedent("""\
96+
import sys
97+
sys.modules["torch"] = None
98+
99+
from sagemaker.core.deserializers.base import (
100+
StringDeserializer,
101+
BytesDeserializer,
102+
CSVDeserializer,
103+
NumpyDeserializer,
104+
JSONDeserializer,
105+
)
106+
107+
assert StringDeserializer() is not None
108+
assert BytesDeserializer() is not None
109+
assert CSVDeserializer() is not None
110+
assert NumpyDeserializer() is not None
111+
assert JSONDeserializer() is not None
112+
print("OK")
113+
""")
114+
result = subprocess.run(
115+
[sys.executable, "-c", code],
116+
capture_output=True,
117+
text=True,
118+
)
119+
assert result.returncode == 0, (
120+
f"Subprocess failed:\nstdout: {result.stdout}\nstderr: {result.stderr}"
121+
)
86122

87123

88124
def test_torch_tensor_serializer_raises_import_error_without_torch():

sagemaker-core/tests/unit/test_serializer_implementations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
"""Unit tests for sagemaker.core.serializers.implementations module."""
14-
from __future__ import annotations
14+
from __future__ import absolute_import
1515

1616
import pytest
1717
from unittest.mock import Mock, patch

0 commit comments

Comments
 (0)