Skip to content

Commit 3cd3b92

Browse files
authored
add cuda doc
Differential Revision: D90635397 Pull Request resolved: #16570
1 parent 5dbbec3 commit 3cd3b92

3 files changed

Lines changed: 136 additions & 21 deletions

File tree

docs/source/backends-overview.md

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@ Backends are the bridge between your exported model and the hardware it runs on.
1818

1919
## Choosing a Backend
2020

21-
| Backend | Platform(s) | Hardware Type | Typical Use Case |
22-
|--------------------------------------------------------------|-------------|---------------|---------------------------------|
23-
| [XNNPACK](backends/xnnpack/xnnpack-overview.md) | All | CPU | General-purpose, fallback |
24-
| [Core ML](/backends/coreml/coreml-overview.md) | iOS, macOS | NPU/GPU/CPU | Apple devices, high performance |
25-
| [Metal Performance Shaders](/backends/mps/mps-overview.md) | iOS, macOS | GPU | Apple GPU acceleration |
26-
| [Vulkan ](/backends/vulkan/vulkan-overview.md) | Android | GPU | Android GPU acceleration |
21+
| Backend | Platform(s) | Hardware Type | Typical Use Case |
22+
|--------------------------------------------------------------|---------------|---------------|---------------------------------|
23+
| [XNNPACK](backends/xnnpack/xnnpack-overview.md) | All | CPU | General-purpose, fallback |
24+
| [CUDA](/backends/cuda/cuda-overview.md) | Linux/Windows | GPU | NVIDIA GPU acceleration |
25+
| [Core ML](/backends/coreml/coreml-overview.md) | iOS, macOS | NPU/GPU/CPU | Apple devices, high performance |
26+
| [Metal Performance Shaders](/backends/mps/mps-overview.md) | iOS, macOS | GPU | Apple GPU acceleration |
27+
| [Vulkan ](/backends/vulkan/vulkan-overview.md) | Android | GPU | Android GPU acceleration |
2728
| [Qualcomm](backends-qualcomm) | Android | NPU | Qualcomm SoCs |
2829
| [MediaTek](backends-mediatek) | Android | NPU | MediaTek SoCs |
2930
| [Arm Ethos-U](/backends/arm-ethos-u/arm-ethos-u-overview.md) | Embedded | NPU | Arm MCUs |
@@ -51,6 +52,7 @@ Backends are the bridge between your exported model and the hardware it runs on.
5152
:caption: Backend Overview
5253
5354
backends-xnnpack
55+
backends/cuda/cuda-overview
5456
backends/coreml/coreml-overview
5557
backends-mps
5658
backends-vulkan
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# CUDA Backend
2+
3+
The CUDA backend is the ExecuTorch solution for running models on NVIDIA GPUs. It leverages the [AOTInductor](https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html) compiler to generate optimized CUDA kernels with libtorch-free execution, and uses [Triton](https://triton-lang.org/) for high-performance GPU kernel generation.
4+
5+
## Features
6+
7+
- **Optimized GPU Execution**: Uses AOTInductor to generate highly optimized CUDA kernels for model operators
8+
- **Triton Kernel Support**: Leverages Triton for GEMM (General Matrix Multiply), convolution, and SDPA (Scaled Dot-Product Attention) kernels.
9+
- **Quantization Support**: INT4 weight quantization with tile-packed format for improved performance and reduced memory footprint
10+
- **Cross-Platform**: Supports both Linux and Windows platforms
11+
- **Multiple Model Support**: Works with various models including LLMs, vision-language models, and audio models
12+
13+
## Target Requirements
14+
15+
Below are the requirements for running a CUDA-delegated ExecuTorch model:
16+
17+
- **Hardware**: NVIDIA GPU with CUDA compute capability
18+
- **CUDA Toolkit**: CUDA 11.x or later (CUDA 12.x recommended)
19+
- **Operating System**: Linux or Windows
20+
- **Drivers**: PyTorch-Compatible NVIDIA GPU drivers installed
21+
22+
## Development Requirements
23+
24+
To develop and export models using the CUDA backend:
25+
26+
- **Python**: Python 3.8+
27+
- **PyTorch**: PyTorch with CUDA support
28+
- **ExecuTorch**: Install ExecuTorch with CUDA backend support
29+
30+
## Using the CUDA Backend
31+
32+
### Exporting Models with Python API
33+
34+
The CUDA backend uses the `CudaBackend` and `CudaPartitioner` classes to export models. Here is a complete example:
35+
36+
```python
37+
import torch
38+
from executorch.backends.cuda.cuda_backend import CudaBackend
39+
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner
40+
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
41+
from executorch.extension.export_util.utils import save_pte_program
42+
43+
# Configure edge compilation
44+
edge_compile_config = EdgeCompileConfig(
45+
_check_ir_validity=False,
46+
_skip_dim_order=True,
47+
)
48+
49+
# Define your model
50+
model = YourModel().eval()
51+
example_inputs = (torch.randn(1, 3, 224, 224),)
52+
53+
# Export the model using torch.export
54+
exported_program = torch.export.export(model, example_inputs)
55+
56+
# Create the CUDA partitioner
57+
partitioner = CudaPartitioner(
58+
[CudaBackend.generate_method_name_compile_spec(model_name)]
59+
)
60+
61+
# Add decompositions for Triton to generate kernels
62+
exported_program = exported_program.run_decompositions({
63+
torch.ops.aten.conv1d.default: conv1d_to_conv2d,
64+
})
65+
66+
# Lower to ExecuTorch with CUDA backend
67+
et_program = to_edge_transform_and_lower(
68+
exported_program,
69+
partitioner=[partitioner],
70+
compile_config=edge_compile_config,
71+
)
72+
73+
# Convert to executable program and save
74+
exec_program = et_program.to_executorch()
75+
save_pte_program(exec_program, model_name, "./output_dir")
76+
```
77+
This generates `.pte` and `.ptd` files that can be executed on CUDA devices.
78+
79+
For a complete working example, see the [CUDA export script](https://github.com/pytorch/executorch/blob/main/examples/cuda/scripts/export.py).
80+
81+
82+
----
83+
84+
## Runtime Integration
85+
86+
To run the model on device, use the standard ExecuTorch runtime APIs. See [Running on Device](getting-started.md#running-on-device) for more information.
87+
88+
When building from source, pass `-DEXECUTORCH_BUILD_CUDA=ON` when configuring the CMake build to compile the CUDA backend.
89+
90+
```
91+
# CMakeLists.txt
92+
add_subdirectory("executorch")
93+
...
94+
target_link_libraries(
95+
my_target
96+
PRIVATE executorch
97+
extension_module_static
98+
extension_tensor
99+
aoti_cuda_backend)
100+
```
101+
102+
No additional steps are necessary to use the backend beyond linking the target. CUDA-delegated `.pte` and `.ptd` files will automatically run on the registered backend.
103+
104+
----
105+
106+
## Examples
107+
108+
For complete end-to-end examples of exporting and running models with the CUDA backend, see:
109+
110+
- [Whisper](https://github.com/pytorch/executorch/blob/main/examples/models/whisper/README.md) — Audio transcription model with CUDA support
111+
- [Voxtral](https://github.com/pytorch/executorch/blob/main/examples/models/voxtral/README.md) — Audio multimodal model with CUDA support
112+
- [Gemma3](https://github.com/pytorch/executorch/blob/main/examples/models/gemma3/README.md) — Vision-language model with CUDA support
113+
114+
These examples demonstrate the full workflow including model export, quantization options, building runners, and runtime execution.
115+
116+
ExecuTorch provides Makefile targets for building these example runners:
117+
118+
```bash
119+
make whisper-cuda # Build Whisper runner with CUDA
120+
make voxtral-cuda # Build Voxtral runner with CUDA
121+
make gemma3-cuda # Build Gemma3 runner with CUDA
122+
```

examples/cuda/scripts/export.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
2222

2323
from executorch.extension.export_util.utils import save_pte_program
24-
from torch._inductor.decomposition import conv1d_to_conv2d
25-
from torch.nn.attention import SDPBackend
2624

2725
# Script to export a model with CUDA delegation.
2826

@@ -88,24 +86,17 @@ def main():
8886
kwargs=example_kwargs,
8987
dynamic_shapes=dynamic_shapes,
9088
)
91-
print(exported_programs)
9289

9390
partitioner = CudaPartitioner(
9491
[CudaBackend.generate_method_name_compile_spec(args.model_name)]
9592
)
96-
# Add decompositions for triton to generate kernels.
97-
exported_programs = exported_programs.run_decompositions(
98-
{
99-
torch.ops.aten.conv1d.default: conv1d_to_conv2d,
100-
}
93+
94+
et_prog = to_edge_transform_and_lower(
95+
exported_programs,
96+
partitioner=[partitioner],
97+
compile_config=_EDGE_COMPILE_CONFIG,
98+
generate_etrecord=args.generate_etrecord,
10199
)
102-
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]):
103-
et_prog = to_edge_transform_and_lower(
104-
exported_programs,
105-
partitioner=[partitioner],
106-
compile_config=_EDGE_COMPILE_CONFIG,
107-
generate_etrecord=args.generate_etrecord,
108-
)
109100
exec_program = et_prog.to_executorch()
110101
save_pte_program(exec_program, args.model_name, args.output_dir)
111102
if args.generate_etrecord:

0 commit comments

Comments
 (0)