forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_cpp_extensions_mtia_backend.py
More file actions
170 lines (146 loc) · 6.74 KB
/
test_cpp_extensions_mtia_backend.py
File metadata and controls
170 lines (146 loc) · 6.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# Owner(s): ["module: mtia"]
import os
import tempfile
import unittest
import torch
import torch.testing._internal.common_utils as common
import torch.utils.cpp_extension
from torch.testing._internal.common_utils import (
IS_ARM64,
IS_LINUX,
skipIfTorchDynamo,
TEST_CUDA,
TEST_PRIVATEUSE1,
TEST_XPU,
)
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
# define TEST_ROCM before changing TEST_CUDA
TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None
TEST_CUDA = TEST_CUDA and CUDA_HOME is not None
@unittest.skipIf(
IS_ARM64 or not IS_LINUX or TEST_CUDA or TEST_PRIVATEUSE1 or TEST_ROCM or TEST_XPU,
"Only on linux platform and mutual exclusive to other backends",
)
@torch.testing._internal.common_utils.markDynamoStrictTest
class TestCppExtensionMTIABackend(common.TestCase):
"""Tests MTIA backend with C++ extensions."""
module = None
def setUp(self):
super().setUp()
# cpp extensions use relative paths. Those paths are relative to
# this file, so we'll change the working directory temporarily
self.old_working_dir = os.getcwd()
os.chdir(os.path.dirname(os.path.abspath(__file__)))
def tearDown(self):
super().tearDown()
# return the working directory (see setUp)
os.chdir(self.old_working_dir)
@classmethod
def tearDownClass(cls):
torch.testing._internal.common_utils.remove_cpp_extensions_build_root()
@classmethod
def setUpClass(cls):
torch.testing._internal.common_utils.remove_cpp_extensions_build_root()
build_dir = tempfile.mkdtemp()
# Load the fake device guard impl.
cls.module = torch.utils.cpp_extension.load(
name="mtia_extension",
sources=["cpp_extensions/mtia_extension.cpp"],
build_directory=build_dir,
extra_include_paths=[
"cpp_extensions",
"path / with spaces in it",
"path with quote'",
],
is_python_module=False,
verbose=True,
)
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
def test_get_device_module(self):
device = torch.device("mtia:0")
default_stream = torch.get_device_module(device).current_stream()
self.assertEqual(
default_stream.device_type, int(torch._C._autograd.DeviceType.MTIA)
)
print(torch._C.Stream.__mro__)
print(torch.cuda.Stream.__mro__)
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
def test_stream_basic(self):
default_stream = torch.mtia.current_stream()
user_stream = torch.mtia.Stream()
self.assertEqual(torch.mtia.current_stream(), default_stream)
self.assertNotEqual(default_stream, user_stream)
# Check mtia_extension.cpp, default stream id starts from 0.
self.assertEqual(default_stream.stream_id, 0)
self.assertNotEqual(user_stream.stream_id, 0)
with torch.mtia.stream(user_stream):
self.assertEqual(torch.mtia.current_stream(), user_stream)
self.assertTrue(user_stream.query())
default_stream.synchronize()
self.assertTrue(default_stream.query())
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
def test_stream_context(self):
mtia_stream_0 = torch.mtia.Stream(device="mtia:0")
mtia_stream_1 = torch.mtia.Stream(device="mtia:0")
print(mtia_stream_0)
print(mtia_stream_1)
with torch.mtia.stream(mtia_stream_0):
current_stream = torch.mtia.current_stream()
msg = f"current_stream {current_stream} should be {mtia_stream_0}"
self.assertTrue(current_stream == mtia_stream_0, msg=msg)
with torch.mtia.stream(mtia_stream_1):
current_stream = torch.mtia.current_stream()
msg = f"current_stream {current_stream} should be {mtia_stream_1}"
self.assertTrue(current_stream == mtia_stream_1, msg=msg)
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
def test_stream_context_different_device(self):
device_0 = torch.device("mtia:0")
device_1 = torch.device("mtia:1")
mtia_stream_0 = torch.mtia.Stream(device=device_0)
mtia_stream_1 = torch.mtia.Stream(device=device_1)
print(mtia_stream_0)
print(mtia_stream_1)
orig_current_device = torch.mtia.current_device()
with torch.mtia.stream(mtia_stream_0):
current_stream = torch.mtia.current_stream()
self.assertTrue(torch.mtia.current_device() == device_0.index)
msg = f"current_stream {current_stream} should be {mtia_stream_0}"
self.assertTrue(current_stream == mtia_stream_0, msg=msg)
self.assertTrue(torch.mtia.current_device() == orig_current_device)
with torch.mtia.stream(mtia_stream_1):
current_stream = torch.mtia.current_stream()
self.assertTrue(torch.mtia.current_device() == device_1.index)
msg = f"current_stream {current_stream} should be {mtia_stream_1}"
self.assertTrue(current_stream == mtia_stream_1, msg=msg)
self.assertTrue(torch.mtia.current_device() == orig_current_device)
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
def test_device_context(self):
device_0 = torch.device("mtia:0")
device_1 = torch.device("mtia:1")
with torch.mtia.device(device_0):
self.assertTrue(torch.mtia.current_device() == device_0.index)
with torch.mtia.device(device_1):
self.assertTrue(torch.mtia.current_device() == device_1.index)
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
def test_default_generators(self):
# Trigger lazy initialization first by calling current_stream()
torch.mtia.current_stream()
device_count = torch.mtia.device_count()
# Verify the interface exists and is properly initialized
self.assertTrue(hasattr(torch.mtia, "default_generators"))
self.assertIsInstance(torch.mtia.default_generators, tuple)
self.assertEqual(len(torch.mtia.default_generators), device_count)
# Verify we can access generators by device index
gen_0 = torch.mtia.default_generators[0]
gen_1 = torch.mtia.default_generators[1]
self.assertIsInstance(gen_0, torch.Generator)
self.assertIsInstance(gen_1, torch.Generator)
# Different devices should have different generator objects
self.assertIsNot(gen_0, gen_1)
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
def test_new_generator(self):
# Verify we can create a generator via the hooks interface
gen = torch.Generator(device="mtia:0")
self.assertIsInstance(gen, torch.Generator)
if __name__ == "__main__":
common.run_tests()