|
11 | 11 | # ANY KIND, either express or implied. See the License for the specific |
12 | 12 | # language governing permissions and limitations under the License. |
13 | 13 | """Tests to verify torch dependency is optional in sagemaker-core.""" |
14 | | -from __future__ import absolute_import |
| 14 | +from __future__ import annotations |
15 | 15 |
|
| 16 | +import importlib |
16 | 17 | import io |
17 | 18 | import sys |
18 | | -from unittest import mock |
19 | 19 |
|
20 | 20 | import numpy as np |
21 | 21 | import pytest |
22 | 22 |
|
23 | 23 |
|
| 24 | +def _block_torch(): |
| 25 | + """Block torch imports by setting sys.modules['torch'] to None. |
| 26 | +
|
| 27 | + Returns a dict of saved torch submodule entries so they can be restored. |
| 28 | + """ |
| 29 | + saved = {} |
| 30 | + torch_keys = [key for key in sys.modules if key.startswith("torch.")] |
| 31 | + saved = {key: sys.modules.pop(key) for key in torch_keys} |
| 32 | + saved["torch"] = sys.modules.get("torch") |
| 33 | + sys.modules["torch"] = None |
| 34 | + return saved |
| 35 | + |
| 36 | + |
| 37 | +def _restore_torch(saved): |
| 38 | + """Restore torch modules from saved dict.""" |
| 39 | + original_torch = saved.pop("torch", None) |
| 40 | + if original_torch is not None: |
| 41 | + sys.modules["torch"] = original_torch |
| 42 | + elif "torch" in sys.modules: |
| 43 | + del sys.modules["torch"] |
| 44 | + for key, val in saved.items(): |
| 45 | + sys.modules[key] = val |
| 46 | + |
| 47 | + |
24 | 48 | def test_serializer_module_imports_without_torch(): |
25 | | - """Verify that importing serializers module succeeds without torch installed.""" |
26 | | - # The serializers module should be importable even without torch |
27 | | - # because TorchTensorSerializer uses lazy import in __init__ |
28 | | - from sagemaker.core.serializers.base import ( |
29 | | - CSVSerializer, |
30 | | - NumpySerializer, |
31 | | - JSONSerializer, |
32 | | - IdentitySerializer, |
33 | | - SparseMatrixSerializer, |
34 | | - JSONLinesSerializer, |
35 | | - LibSVMSerializer, |
36 | | - DataSerializer, |
37 | | - StringSerializer, |
38 | | - ) |
39 | | - # Verify non-torch serializers can be instantiated |
40 | | - assert CSVSerializer() is not None |
41 | | - assert NumpySerializer() is not None |
42 | | - assert JSONSerializer() is not None |
43 | | - assert IdentitySerializer() is not None |
| 49 | + """Verify that importing non-torch serializers succeeds without torch installed.""" |
| 50 | + saved = {} |
| 51 | + try: |
| 52 | + saved = _block_torch() |
| 53 | + |
| 54 | + # Reload the module so it re-evaluates imports with torch blocked |
| 55 | + import sagemaker.core.serializers.base as ser_module |
| 56 | + |
| 57 | + importlib.reload(ser_module) |
| 58 | + |
| 59 | + # Verify non-torch serializers can be instantiated |
| 60 | + assert ser_module.CSVSerializer() is not None |
| 61 | + assert ser_module.NumpySerializer() is not None |
| 62 | + assert ser_module.JSONSerializer() is not None |
| 63 | + assert ser_module.IdentitySerializer() is not None |
| 64 | + finally: |
| 65 | + _restore_torch(saved) |
44 | 66 |
|
45 | 67 |
|
46 | 68 | def test_deserializer_module_imports_without_torch(): |
47 | | - """Verify that importing deserializers module succeeds without torch installed.""" |
48 | | - from sagemaker.core.deserializers.base import ( |
49 | | - StringDeserializer, |
50 | | - BytesDeserializer, |
51 | | - CSVDeserializer, |
52 | | - StreamDeserializer, |
53 | | - NumpyDeserializer, |
54 | | - JSONDeserializer, |
55 | | - PandasDeserializer, |
56 | | - JSONLinesDeserializer, |
57 | | - ) |
58 | | - # Verify non-torch deserializers can be instantiated |
59 | | - assert StringDeserializer() is not None |
60 | | - assert BytesDeserializer() is not None |
61 | | - assert CSVDeserializer() is not None |
62 | | - assert NumpyDeserializer() is not None |
63 | | - assert JSONDeserializer() is not None |
| 69 | + """Verify that importing non-torch deserializers succeeds without torch installed.""" |
| 70 | + saved = {} |
| 71 | + try: |
| 72 | + saved = _block_torch() |
| 73 | + |
| 74 | + import sagemaker.core.deserializers.base as deser_module |
| 75 | + |
| 76 | + importlib.reload(deser_module) |
| 77 | + |
| 78 | + # Verify non-torch deserializers can be instantiated |
| 79 | + assert deser_module.StringDeserializer() is not None |
| 80 | + assert deser_module.BytesDeserializer() is not None |
| 81 | + assert deser_module.CSVDeserializer() is not None |
| 82 | + assert deser_module.NumpyDeserializer() is not None |
| 83 | + assert deser_module.JSONDeserializer() is not None |
| 84 | + finally: |
| 85 | + _restore_torch(saved) |
64 | 86 |
|
65 | 87 |
|
66 | 88 | def test_torch_tensor_serializer_raises_import_error_without_torch(): |
67 | 89 | """Verify TorchTensorSerializer raises ImportError when torch is not installed.""" |
68 | | - import importlib |
69 | 90 | import sagemaker.core.serializers.base as ser_module |
70 | 91 |
|
71 | | - # Save original torch module if present |
72 | | - original_torch = sys.modules.get('torch') |
73 | | - |
| 92 | + saved = {} |
74 | 93 | try: |
75 | | - # Simulate torch not being installed |
76 | | - sys.modules['torch'] = None |
77 | | - # Need to also handle the case where torch submodules are cached |
78 | | - torch_keys = [key for key in sys.modules if key.startswith('torch.')] |
79 | | - saved = {key: sys.modules.pop(key) for key in torch_keys} |
80 | | - |
| 94 | + saved = _block_torch() |
| 95 | + |
81 | 96 | with pytest.raises(ImportError, match="Unable to import torch"): |
82 | 97 | ser_module.TorchTensorSerializer() |
83 | 98 | finally: |
84 | | - # Restore original state |
85 | | - if original_torch is not None: |
86 | | - sys.modules['torch'] = original_torch |
87 | | - elif 'torch' in sys.modules: |
88 | | - del sys.modules['torch'] |
89 | | - for key, val in saved.items(): |
90 | | - sys.modules[key] = val |
| 99 | + _restore_torch(saved) |
91 | 100 |
|
92 | 101 |
|
93 | 102 | def test_torch_tensor_deserializer_raises_import_error_without_torch(): |
94 | 103 | """Verify TorchTensorDeserializer raises ImportError when torch is not installed.""" |
95 | 104 | import sagemaker.core.deserializers.base as deser_module |
96 | 105 |
|
97 | | - # Save original torch module if present |
98 | | - original_torch = sys.modules.get('torch') |
99 | | - |
| 106 | + saved = {} |
100 | 107 | try: |
101 | | - # Simulate torch not being installed |
102 | | - sys.modules['torch'] = None |
103 | | - torch_keys = [key for key in sys.modules if key.startswith('torch.')] |
104 | | - saved = {key: sys.modules.pop(key) for key in torch_keys} |
105 | | - |
| 108 | + saved = _block_torch() |
| 109 | + |
106 | 110 | with pytest.raises(ImportError, match="Unable to import torch"): |
107 | 111 | deser_module.TorchTensorDeserializer() |
108 | 112 | finally: |
109 | | - # Restore original state |
110 | | - if original_torch is not None: |
111 | | - sys.modules['torch'] = original_torch |
112 | | - elif 'torch' in sys.modules: |
113 | | - del sys.modules['torch'] |
114 | | - for key, val in saved.items(): |
115 | | - sys.modules[key] = val |
| 113 | + _restore_torch(saved) |
116 | 114 |
|
117 | 115 |
|
118 | 116 | def test_torch_tensor_serializer_works_with_torch(): |
|
0 commit comments