Commit fa905d2
authored
[Compile] accelerate compilation speed using NVRTC (#18519)
This PR supports NVRTC as an alternative to NVCC for faster, device-side
JIT compilation of CUDA kernels, in favor of the PR
[https://github.com/apache/tvm-ffi/pull/283](https://github.com/apache/tvm-ffi/pull/283).
It enhances the CUDA compilation backend by:
- Adding Python NVRTC support using cuda-python bindings
- Removing legacy C++ NVRTC fallback in favor of a Python-first approach
- Keeping nvcc as the default compiler with fatbin output (no behavior
change for existing users)
Users can choose the compilation backend using an environment variable
`TVM_CUDA_COMPILE_MODE`, choosing from "nvcc" and "nvrtc". For example,
`TVM_CUDA_COMPILE_MODE=nvrtc python3 your_program.py`
Here is a short benchmark of the compilation speed of kernels in
`test_target_codegen_cuda.py`.
### NVCC vs NVRTC Compilation Time Comparison (Python-side Call)
| Test Case | Code Size | NVCC Time (ms) | NVRTC Time (ms) | Speedup |
| :--- | :--- | :--- | :--- | :--- |
| `test_crossthread_reduction1` | 1945 B | 241.27 | 51.23 | **4.7x** |
| `test_cuda_bf16_vectorize_add` | 3760 B | 342.72 | 44.50 | **7.7x** |
| `test_cuda_const_float_to_half` | 12394 B | 272.85 | 31.99 | **8.5x**
|
| `test_cuda_device_func_call` | 975 B | 215.58 | 21.47 | **10.0x** |
| `test_cuda_float_const_hex_format` | 685 B | 217.39 | 20.52 |
**10.6x** |
| `test_cuda_floordiv_with_vectorization` | 1050 B | 213.88 | 23.32 |
**9.2x** |
| `test_cuda_inf_nan` | 673 B | 214.33 | 24.94 | **8.6x** |
| `test_cuda_tensormap` | 755 B | 213.91 | 20.74 | **10.3x** |
| `test_cuda_thread_sync_inside_condition` | 1007 B | 213.43 | 28.29 |
**7.5x** |
| `test_cuda_vectorize_add` | 908 B | 226.81 | 40.39 | **5.6x** |
| `test_cuda_vectorize_load` | 734 B | 217.25 | 24.02 | **9.0x** |
| `test_device_host_call_same_func` | 924 B | 216.03 | 21.21 | **10.2x**
|
| `test_vectorized_intrin1` | 847 B | 226.15 | 26.34 | **8.6x** |
### NVSHMEM Support
Currently, NVSHMEM is **not** supported via NVRTC.
- Fallback Behavior: When NVSHMEM is required, the compilation pipeline
will automatically fall back to NVCC, even if `TVM_CUDA_COMPILE_MODE` is
set to nvrtc.
- Future Roadmap: Support for NVRTC with NVSHMEM is planned for
follow-up PRs.1 parent b3b6024 commit fa905d2
File tree
13 files changed
+465
-150
lines changed- cmake
- modules
- utils
- docker
- install
- python/tvm
- contrib
- script/ir_builder/tir
- src
- runtime/contrib/nvshmem
- target
- opt
- source
- literal
- tests/python
- codegen
- disco
- tir-transform
13 files changed
+465
-150
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
54 | 54 | | |
55 | 55 | | |
56 | 56 | | |
57 | | - | |
58 | 57 | | |
59 | 58 | | |
60 | | - | |
61 | 59 | | |
62 | 60 | | |
63 | 61 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
33 | 33 | | |
34 | 34 | | |
35 | 35 | | |
36 | | - | |
37 | 36 | | |
38 | 37 | | |
39 | 38 | | |
| |||
64 | 63 | | |
65 | 64 | | |
66 | 65 | | |
67 | | - | |
68 | | - | |
69 | | - | |
70 | 66 | | |
71 | 67 | | |
72 | 68 | | |
| |||
81 | 77 | | |
82 | 78 | | |
83 | 79 | | |
84 | | - | |
85 | | - | |
86 | | - | |
87 | | - | |
88 | 80 | | |
89 | 81 | | |
90 | 82 | | |
| |||
140 | 132 | | |
141 | 133 | | |
142 | 134 | | |
143 | | - | |
144 | 135 | | |
145 | 136 | | |
146 | 137 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
60 | 60 | | |
61 | 61 | | |
62 | 62 | | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
63 | 66 | | |
64 | 67 | | |
65 | 68 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
0 commit comments