-
Notifications
You must be signed in to change notification settings - Fork 1k
Expand file tree
/
Copy pathtest_cuda_export.py
More file actions
396 lines (322 loc) · 14.2 KB
/
test_cuda_export.py
File metadata and controls
396 lines (322 loc) · 14.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from typing import Tuple
import torch
from executorch.backends.cuda.cuda_backend import CudaBackend
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner
from executorch.examples.models.toy_model import SdpaModule
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
from executorch.exir.backend.compile_spec_schema import CompileSpec
from torch.export import export
class TestCudaExport(unittest.TestCase):
"""Test CUDA export functionality for various operations using to_edge_transform_and_lower."""
def setUp(self):
"""Set up test environment."""
# Skip tests if CUDA is not available
if not torch.cuda.is_available():
self.skipTest("CUDA is not available")
def _export_to_cuda_with_lower(
self,
module: torch.nn.Module,
inputs: Tuple[torch.Tensor, ...],
compile_specs: list[CompileSpec] | None = None,
) -> None:
"""Helper method to export a module to CUDA backend using to_edge_transform_and_lower.
Args:
module: The torch.nn.Module to export
inputs: The example inputs for the module
compile_specs: Optional list of compile specs. If not provided, defaults to
only the method name compile spec for "forward"
"""
# Export the model
exported_program = export(module, inputs, strict=True)
# Create partitioner with compile specs
if compile_specs is None:
compile_specs = [CudaBackend.generate_method_name_compile_spec("forward")]
partitioner = CudaPartitioner(compile_specs)
# Use to_edge_transform_and_lower for complete pipeline
edge_program_manager = to_edge_transform_and_lower(
exported_program,
partitioner=[partitioner],
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
),
)
# Verify that the pipeline succeeded
self.assertIsNotNone(edge_program_manager)
self.assertTrue(hasattr(edge_program_manager, "exported_program"))
# Verify that the final exported program contains delegated calls
exported_program = edge_program_manager.exported_program()
has_delegate_call = False
for node in exported_program.graph.nodes:
if node.op == "call_function" and "executorch_call_delegate" in str(
node.target
):
has_delegate_call = True
break
self.assertTrue(
has_delegate_call, "No delegate calls found in final exported program"
)
return edge_program_manager
def test_simple_add(self):
"""Test CUDA export for simple element-wise addition."""
class AddModule(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
module = AddModule()
module.eval()
inputs = (torch.randn(3, 4), torch.randn(3, 4))
# Test export
edge_program_manager = self._export_to_cuda_with_lower(module, inputs)
self.assertIsNotNone(edge_program_manager, "Simple add operation export failed")
def test_conv2d(self):
"""Test CUDA export for 2D convolution."""
class Conv2dModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
module = Conv2dModule()
module.eval()
inputs = (torch.randn(1, 3, 32, 32),)
# Test export
edge_program_manager = self._export_to_cuda_with_lower(module, inputs)
self.assertIsNotNone(edge_program_manager, "Conv2d operation export failed")
def test_linear(self):
"""Test CUDA export for linear layer."""
class LinearModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(128, 64)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)
module = LinearModule()
module.eval()
inputs = (torch.randn(8, 128),)
# Test export
edge_program_manager = self._export_to_cuda_with_lower(module, inputs)
self.assertIsNotNone(edge_program_manager, "Linear operation export failed")
def test_resnet_block(self):
"""Test CUDA export for a ResNet-style block."""
class ResNetBlock(torch.nn.Module):
def __init__(self, in_channels: int, out_channels: int, stride: int = 1):
super().__init__()
self.conv1 = torch.nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
)
# Use eval mode to avoid batch norm mutations during export
self.bn1 = torch.nn.BatchNorm2d(out_channels)
self.relu = torch.nn.ReLU(inplace=True)
self.conv2 = torch.nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False,
)
self.bn2 = torch.nn.BatchNorm2d(out_channels)
# Shortcut connection
self.shortcut = torch.nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = torch.nn.Sequential(
torch.nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=stride,
bias=False,
),
torch.nn.BatchNorm2d(out_channels),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = self.shortcut(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity
out = self.relu(out)
return out
module = ResNetBlock(64, 64)
# Set module to eval mode to avoid batch norm running statistics mutations
module.eval()
inputs = (torch.randn(1, 64, 32, 32),)
# Test export
edge_program_manager = self._export_to_cuda_with_lower(module, inputs)
self.assertIsNotNone(edge_program_manager, "ResNet block export failed")
def test_multi_operation_module(self):
"""Test CUDA export for a module with multiple operations."""
class MultiOpModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.relu = torch.nn.ReLU()
self.pool = torch.nn.AdaptiveAvgPool2d((1, 1))
self.linear = torch.nn.Linear(32, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
x = self.relu(x)
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.linear(x)
return x
module = MultiOpModule()
module.eval()
inputs = (torch.randn(2, 3, 16, 16),)
# Test export
edge_program_manager = self._export_to_cuda_with_lower(module, inputs)
self.assertIsNotNone(
edge_program_manager, "Multi-operation module export failed"
)
def test_activation_functions(self):
"""Test CUDA export for various activation functions."""
class ActivationModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Test multiple activation functions
x1 = torch.relu(x)
x2 = torch.sigmoid(x)
x3 = torch.tanh(x)
return x1 + x2 + x3
module = ActivationModule()
module.eval()
inputs = (torch.randn(4, 8),)
# Test export
edge_program_manager = self._export_to_cuda_with_lower(module, inputs)
self.assertIsNotNone(edge_program_manager, "Activation functions export failed")
def test_mathematical_operations(self):
"""Test CUDA export for mathematical operations."""
class MathOpsModule(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# Test various mathematical operations
add_result = x + y
mul_result = x * y
sub_result = x - y
div_result = x / (y + 1e-8) # Add epsilon to avoid division by zero
return add_result + mul_result + sub_result + div_result
module = MathOpsModule()
module.eval()
inputs = (torch.randn(4, 4), torch.randn(4, 4))
# Test export
edge_program_manager = self._export_to_cuda_with_lower(module, inputs)
self.assertIsNotNone(
edge_program_manager, "Mathematical operations export failed"
)
def test_conv1d(self):
"""Test CUDA export for 1D convolution."""
class Conv1dModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv1d(3, 16, kernel_size=3, padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
module = Conv1dModule()
module.eval()
inputs = (torch.randn(1, 3, 10),)
# Test export
edge_program_manager = self._export_to_cuda_with_lower(module, inputs)
self.assertIsNotNone(edge_program_manager, "Conv1d operation export failed")
def test_sdpa_single_kernel(self):
"""
Test CUDA export for model containing single SDPA kernel.
SDPA: Scaled Dot Product Attention
"""
sdpa = SdpaModule()
# Test export
edge_program_manager = self._export_to_cuda_with_lower(
sdpa.get_eager_model(), sdpa.get_example_inputs()
)
self.assertIsNotNone(
edge_program_manager,
"SDPA single kernel operation export failed",
)
def test_triton_kernel_mode_off(self):
"""
Test CUDA export with triton_kernel_mode set to OFF for SDPA kernel.
This validates that the backend correctly processes the triton_kernel_mode
compile spec and can export SDPA operations without Triton kernel replacements.
When triton_kernel_mode is OFF, SDPA should be decomposed using the MATH backend.
"""
sdpa = SdpaModule()
# Create compile specs with triton_kernel_mode set to OFF
compile_specs = [
CudaBackend.generate_method_name_compile_spec("forward"),
CompileSpec(key="triton_kernel_mode", value=b"OFF"),
]
# Test export with triton_kernel_mode=OFF
edge_program_manager = self._export_to_cuda_with_lower(
sdpa.get_eager_model(), sdpa.get_example_inputs(), compile_specs
)
self.assertIsNotNone(
edge_program_manager,
"SDPA kernel export with triton_kernel_mode=OFF failed",
)
def test_device_info_propagated_to_cuda_delegate_outputs(self):
"""
Test that device info is correctly propagated from export to serialization
for CUDA delegate outputs.
This verifies the device propagation flow:
1. CudaPartitioner adds target_device="cuda:0" CompileSpec
2. PropagateDevicePass sets TensorSpec.device = CUDA for delegate outputs
3. Emitter serializes device info into ExtraTensorInfo.device_type
4. Serialized tensors have device_type = DeviceType.CUDA
Note: At this stage, the tensor memory is still on CPU. The CUDA backend
will copy data to GPU device at runtime. Device info tagging is the first
step toward full device-aware memory allocation.
"""
from executorch.exir import schema
class AddModule(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
module = AddModule()
module.eval()
inputs = (torch.randn(2, 3), torch.randn(2, 3))
# Export to CUDA with full pipeline
edge_program_manager = self._export_to_cuda_with_lower(module, inputs)
self.assertIsNotNone(edge_program_manager, "CUDA export failed")
# Convert to ExecutorTorch and access the serialized program
et_prog = edge_program_manager.to_executorch()
program = et_prog._emitter_output.program
# Get the execution plan and verify delegate exists
plan = program.execution_plan[0]
self.assertGreater(
len(plan.delegates),
0,
"Expected at least one delegate in the execution plan",
)
# Count tensors by device type
cpu_tensors = []
cuda_tensors = []
for value in plan.values:
if isinstance(value.val, schema.Tensor):
tensor = value.val
if (
tensor.extra_tensor_info is not None
and tensor.extra_tensor_info.device_type == schema.DeviceType.CUDA
):
cuda_tensors.append(tensor)
else:
# Either no extra_tensor_info or device_type is CPU (default)
cpu_tensors.append(tensor)
# Both input and output tensors should be on CUDA device for now.
self.assertEqual(
len(cpu_tensors),
0,
"All tensors are on CUDA device..",
)
self.assertGreater(
len(cuda_tensors),
3,
"Expected CUDA tensors for delegate outputs",
)