Skip to content

Commit e81b491

Browse files
committed
fix: address review feedback on flash attention quick-build and pre-SM80 fallback
- Update quick-build comment to reflect that both FP16 and BF16 hdim128 kernels are intentionally retained (not just FP16). - Add fallback for pre-SM80 builds: when no SM80+ architectures are configured, flash attention sources are added back to the parent target so the linker can find host-side symbols referenced by flash_api.cc.
1 parent 6636c99 commit e81b491

3 files changed

Lines changed: 12 additions & 2 deletions

File tree

cmake/onnxruntime_cuda_source_filters.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ function(onnxruntime_filter_cuda_cu_sources CU_SRC_LIST)
1414
set(_list "${${CU_SRC_LIST}}")
1515

1616
# Quick build mode: Filter flash attention kernels for faster development iteration.
17-
# - We keep only hdim128 fp16 flash attention kernels in quick build mode.
17+
# - We keep only hdim128 fp16 and bf16 flash attention kernels in quick build mode.
1818
# - All other listed head dimensions are excluded (e.g., 32, 64, 96, 192, 256).
1919
# If new head dimensions are added or removed, update this list to match the supported set.
2020
if(onnxruntime_QUICK_BUILD)
21-
message(STATUS "Quick build mode enabled: Only building hdim128 fp16 flash attention kernels")
21+
message(STATUS "Quick build mode enabled: Only building hdim128 fp16/bf16 flash attention kernels")
2222
list(FILTER _list EXCLUDE REGEX "flash_fwd.*hdim(32|64|96|192|256)")
2323
endif()
2424

cmake/onnxruntime_providers_cuda.cmake

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,11 @@
480480
CUDA_ARCHITECTURES "${_ort_flash_cuda_architectures}"
481481
NVCC_THREADS "${onnxruntime_FLASH_NVCC_THREADS}"
482482
SOURCES ${onnxruntime_cuda_flash_attention_srcs})
483+
else()
484+
# No SM80+ architectures available: compile flash sources in parent target so the
485+
# linker can find the host-side symbols referenced by flash_api.cc. The kernels
486+
# themselves will be empty stubs due to __CUDA_ARCH__ >= 800 guards.
487+
target_sources(onnxruntime_providers_cuda PRIVATE ${onnxruntime_cuda_flash_attention_srcs})
483488
endif()
484489
endif()
485490

cmake/onnxruntime_providers_cuda_plugin.cmake

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,11 @@ if(_cuda_plugin_flash_attention_srcs)
274274
NVCC_THREADS "${onnxruntime_FLASH_NVCC_THREADS}"
275275
COMPILE_OPTIONS ${_cuda_plugin_shared_compile_options}
276276
SOURCES ${_cuda_plugin_flash_attention_srcs})
277+
else()
278+
# No SM80+ architectures available: compile flash sources in parent target so the
279+
# linker can find the host-side symbols referenced by flash_api.cc. The kernels
280+
# themselves will be empty stubs due to __CUDA_ARCH__ >= 800 guards.
281+
target_sources(onnxruntime_providers_cuda_plugin PRIVATE ${_cuda_plugin_flash_attention_srcs})
277282
endif()
278283
endif()
279284

0 commit comments

Comments
 (0)