|
| 1 | +# Plan: Speed Up CUDA Build via Per-Library Architecture Splitting and nvcc_threads Optimization |
| 2 | + |
| 3 | +## TL;DR |
| 4 | + |
| 5 | +Split the monolithic CUDA provider compilation into architecture-specific OBJECT libraries for |
| 6 | +flash_attention (SM80+, `--threads 1`) and llm/ (SM75+), allowing the main target and llm to use |
| 7 | +higher `nvcc_threads` for faster parallel compilation. Merge fpA_intB SM90 launchers into the |
| 8 | +existing SM90 TMA OBJECT library. |
| 9 | + |
| 10 | +## Architecture Requirements (Verified) |
| 11 | + |
| 12 | +| Directory | Min SM | Notes | |
| 13 | +|-----------|--------|-------| |
| 14 | +| `bert/flash_attention/` (48 .cu) | **SM80** | `__CUDA_ARCH__ >= 800` in kernel_traits.h | |
| 15 | +| `llm/fpA_intB_gemv/` (11 .cu) | **SM75** | `ORT_ENFORCE(arch >= 75)` | |
| 16 | +| `llm/fpA_intB_gemm/` (8 base .cu) | **SM75** | SM75+ base support | |
| 17 | +| `llm/fpA_intB_gemm/launchers/` (2 .cu) | **SM90** | `#ifndef EXCLUDE_SM_90` | |
| 18 | +| `llm/moe_gemm/` (14 root .cu) | **SM75** | CUTLASS stages=2 fallback for SM75 | |
| 19 | +| `llm/moe_gemm/launchers/fused_moe_sm80` (2 .cu) | **SM80** | `#ifndef EXCLUDE_SM_80` (has arch guard in code, safe to compile at SM75+) | |
| 20 | +| `llm/moe_gemm/launchers/` SM90 TMA (324 .cu) | SM90 | **Already extracted** | |
| 21 | +| `llm/moe_gemm/launchers/` SM120 TMA (11 .cu) | SM120 | **Already extracted** | |
| 22 | +| `llm/kernels/` (1 .cu) | SM50 | BF16 guarded by `__CUDA_ARCH__ >= 800` | |
| 23 | + |
| 24 | +## Steps |
| 25 | + |
| 26 | +### Phase 1: Flash Attention OBJECT Library |
| 27 | + |
| 28 | +1. Add macro `onnxruntime_extract_flash_attention_sources()` in |
| 29 | + `cmake/onnxruntime_cuda_source_filters.cmake` — extracts `*/bert/flash_attention/*.cu` |
| 30 | + from the main CU source list. |
| 31 | + |
| 32 | +2. In both provider cmake files, call this macro after existing filtering. Create OBJECT library: |
| 33 | + - `CUDA_ARCHITECTURES` = entries from `CMAKE_CUDA_ARCHITECTURES` where arch >= 80 |
| 34 | + - `--threads ${onnxruntime_FLASH_NVCC_THREADS}` |
| 35 | + - Same includes/compile defs as parent (`config_cuda_provider_shared_module()`) |
| 36 | + - Link into parent |
| 37 | + |
| 38 | +3. Add CMake cache option: `onnxruntime_FLASH_NVCC_THREADS` (default `"1"`, type STRING) |
| 39 | + |
| 40 | +### Phase 2: LLM OBJECT Library (SM75+ — Backward Compatible) |
| 41 | + |
| 42 | +4. Add macro `onnxruntime_extract_llm_sources()` — extracts `*/contrib_ops/cuda/llm/*.cu`, |
| 43 | + then further extracts SM90 launcher files (`fpA_intB_gemm_launcher_*.generated.cu`) into a |
| 44 | + separate output variable. |
| 45 | + |
| 46 | +5. Create `onnxruntime_providers_cuda_llm` OBJECT library: |
| 47 | + - `CUDA_ARCHITECTURES` = entries from `CMAKE_CUDA_ARCHITECTURES` where arch >= 75 |
| 48 | + - `--threads ${onnxruntime_NVCC_THREADS}` (user can now safely set to 2-4) |
| 49 | + - Contains all llm/ .cu files EXCEPT SM90 TMA (already extracted) and fpA_intB SM90 launchers |
| 50 | + |
| 51 | +6. **Merge fpA_intB SM90 launchers** (`fpA_intB_gemm_launcher_1.generated.cu`, |
| 52 | + `fpA_intB_gemm_launcher_2.generated.cu`) into existing |
| 53 | + `onnxruntime_providers_cuda_sm90_tma` OBJECT library — both need |
| 54 | + `CUDA_ARCHITECTURES "90a-real"` and `COMPILE_HOPPER_TMA_GEMMS`. |
| 55 | + |
| 56 | +### Phase 3: nvcc_threads Configuration |
| 57 | + |
| 58 | +7. Define `onnxruntime_FLASH_NVCC_THREADS` (default `"1"`). Flash attention target uses this. |
| 59 | + Main target and LLM target use existing `onnxruntime_NVCC_THREADS` (can be raised to 2-4 |
| 60 | + since flash attention is isolated). |
| 61 | + |
| 62 | +### Phase 4: Mirror in Plugin Build |
| 63 | + |
| 64 | +8. Identical pattern in `onnxruntime_providers_cuda_plugin.cmake`: |
| 65 | + - `onnxruntime_providers_cuda_plugin_flash_attention` (SM80+, threads from `onnxruntime_FLASH_NVCC_THREADS`) |
| 66 | + - `onnxruntime_providers_cuda_plugin_llm` (SM75+) |
| 67 | + - fpA_intB SM90 launchers merged into `onnxruntime_providers_cuda_plugin_sm90_tma` |
| 68 | + |
| 69 | +### Phase 5: Build Script for Testing |
| 70 | + |
| 71 | +9. Create `.env/cuda_build_time_test.sh` (based on `.env/cuda13_all.sh`): |
| 72 | + - `CMAKE_CUDA_ARCHITECTURES="75-real;80-real;86-real;89-real;90-real;100-real;120-real;120-virtual"` |
| 73 | + - `onnxruntime_NVCC_THREADS=4` |
| 74 | + - `onnxruntime_FLASH_NVCC_THREADS=1` |
| 75 | + - Build with timing, report total duration |
| 76 | + |
| 77 | +## Relevant Files |
| 78 | + |
| 79 | +- `cmake/onnxruntime_cuda_source_filters.cmake` — new macros |
| 80 | +- `cmake/onnxruntime_providers_cuda.cmake` — create flash_attention and llm OBJECT libraries |
| 81 | +- `cmake/onnxruntime_providers_cuda_plugin.cmake` — mirror for plugin build |
| 82 | +- `.env/cuda_build_time_test.sh` — new build script for benchmarking |
| 83 | + |
| 84 | +## Verification |
| 85 | + |
| 86 | +1. Check `build.ninja`: flash_attention files have only SM80+ `--generate-code`; llm files have SM75+ |
| 87 | +2. Build with `onnxruntime_NVCC_THREADS=4`, `onnxruntime_FLASH_NVCC_THREADS=1` — no OOM |
| 88 | +3. Compare total build time before/after using multi-arch build script |
| 89 | +4. Run: `./onnxruntime_test_all --gtest_filter=*FlashAttention*:*MoE*:*FpAIntB*` |
| 90 | +5. Run: `python test_gqa.py`, `python test_moe_cuda.py` |
| 91 | +6. No link errors in both in-tree and plugin builds |
| 92 | + |
| 93 | +## Decisions |
| 94 | + |
| 95 | +- **LLM = SM75+** (not SM80+) — preserves backward compatibility for `fpA_intB_gemv/gemm` |
| 96 | +- **Flash attention = SM80+** — all kernel files are `_sm80` suffixed with arch guards |
| 97 | +- **fpA_intB SM90 launchers merged into SM90 TMA lib** — both need "90a-real" + `COMPILE_HOPPER_TMA_GEMMS` |
| 98 | +- **`onnxruntime_FLASH_NVCC_THREADS`** = new option (default 1); `onnxruntime_NVCC_THREADS` remains |
| 99 | +- **`onnxruntime_QUICK_BUILD`** filtering applies before OBJECT library creation (no behavior change) |
| 100 | +- SM90/SM120 TMA MoE extraction unchanged (existing mechanism) |
| 101 | + |
| 102 | +## Reference |
| 103 | + |
| 104 | +- TRT-LLM pattern: `~/tensorrt-llm/cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt` |
| 105 | + — uses `set_cuda_architectures()` per sub-library with separate OBJECT targets |
| 106 | +- Flash Attention official: `~/flash-attention/setup.py` |
| 107 | + — defaults to `NVCC_THREADS=4`, architectures `80;90;100;110;120`, ~5GB per nvcc thread |
0 commit comments