Skip to content

Commit 4703728

Browse files
committed
Add FLASH_ATTN_HDIMS option to limit kernel compilation
In many applications model head dimensions are known in advance and it's possible to opt-out of compiling ones that will never be used, even regardless of model choice. Signed-off-by: Tin Švagelj <tin.svagelj@live.com>
1 parent 226c95d commit 4703728

1 file changed

Lines changed: 22 additions & 67 deletions

File tree

CMakeLists.txt

Lines changed: 22 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ option(BUILD_TESTS "Compile the tests" OFF)
2323
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
2424
option(WITH_TENSOR_PARALLEL "Compile with NCCL and MPI backend" OFF)
2525
option(WITH_FLASH_ATTN "Compile with Flash Attention 2" OFF)
26+
set(FLASH_ATTN_HDIMS "" CACHE STRING "Head dimensions to compile for flash attention (e.g. '32;64'). Empty means all.")
2627
option(ENABLE_ADDRESS_SANITIZER "ASAN" OFF)
2728

2829
MESSAGE(STATUS "Compiler Id: ${CMAKE_CXX_COMPILER_ID}")
@@ -606,74 +607,28 @@ if (WITH_CUDA)
606607
endif()
607608
if (WITH_FLASH_ATTN)
608609
add_definitions(-DCT2_WITH_FLASH_ATTN)
609-
list(APPEND SOURCES
610-
src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu
611-
src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu
612-
src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu
613-
src/ops/flash-attention/flash_fwd_hdim64_fp16_sm80.cu
614-
src/ops/flash-attention/flash_fwd_hdim96_bf16_sm80.cu
615-
src/ops/flash-attention/flash_fwd_hdim96_fp16_sm80.cu
616-
src/ops/flash-attention/flash_fwd_hdim128_bf16_sm80.cu
617-
src/ops/flash-attention/flash_fwd_hdim128_fp16_sm80.cu
618-
src/ops/flash-attention/flash_fwd_hdim160_bf16_sm80.cu
619-
src/ops/flash-attention/flash_fwd_hdim160_fp16_sm80.cu
620-
src/ops/flash-attention/flash_fwd_hdim192_bf16_sm80.cu
621-
src/ops/flash-attention/flash_fwd_hdim192_fp16_sm80.cu
622-
src/ops/flash-attention/flash_fwd_hdim224_bf16_sm80.cu
623-
src/ops/flash-attention/flash_fwd_hdim224_fp16_sm80.cu
624-
src/ops/flash-attention/flash_fwd_hdim256_bf16_sm80.cu
625-
src/ops/flash-attention/flash_fwd_hdim256_fp16_sm80.cu
626-
src/ops/flash-attention/flash_fwd_split_hdim32_bf16_sm80.cu
627-
src/ops/flash-attention/flash_fwd_split_hdim32_fp16_sm80.cu
628-
src/ops/flash-attention/flash_fwd_split_hdim64_bf16_sm80.cu
629-
src/ops/flash-attention/flash_fwd_split_hdim64_fp16_sm80.cu
630-
src/ops/flash-attention/flash_fwd_split_hdim96_bf16_sm80.cu
631-
src/ops/flash-attention/flash_fwd_split_hdim96_fp16_sm80.cu
632-
src/ops/flash-attention/flash_fwd_split_hdim128_bf16_sm80.cu
633-
src/ops/flash-attention/flash_fwd_split_hdim128_fp16_sm80.cu
634-
src/ops/flash-attention/flash_fwd_split_hdim160_bf16_sm80.cu
635-
src/ops/flash-attention/flash_fwd_split_hdim160_fp16_sm80.cu
636-
src/ops/flash-attention/flash_fwd_split_hdim192_bf16_sm80.cu
637-
src/ops/flash-attention/flash_fwd_split_hdim192_fp16_sm80.cu
638-
src/ops/flash-attention/flash_fwd_split_hdim224_bf16_sm80.cu
639-
src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu
640-
src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu
641-
src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu
642-
)
643610

644-
set_source_files_properties(
645-
src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu
646-
src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu
647-
src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu
648-
src/ops/flash-attention/flash_fwd_hdim64_fp16_sm80.cu
649-
src/ops/flash-attention/flash_fwd_hdim96_bf16_sm80.cu
650-
src/ops/flash-attention/flash_fwd_hdim96_fp16_sm80.cu
651-
src/ops/flash-attention/flash_fwd_hdim128_bf16_sm80.cu
652-
src/ops/flash-attention/flash_fwd_hdim128_fp16_sm80.cu
653-
src/ops/flash-attention/flash_fwd_hdim160_bf16_sm80.cu
654-
src/ops/flash-attention/flash_fwd_hdim160_fp16_sm80.cu
655-
src/ops/flash-attention/flash_fwd_hdim192_bf16_sm80.cu
656-
src/ops/flash-attention/flash_fwd_hdim192_fp16_sm80.cu
657-
src/ops/flash-attention/flash_fwd_hdim224_bf16_sm80.cu
658-
src/ops/flash-attention/flash_fwd_hdim224_fp16_sm80.cu
659-
src/ops/flash-attention/flash_fwd_hdim256_bf16_sm80.cu
660-
src/ops/flash-attention/flash_fwd_hdim256_fp16_sm80.cu
661-
src/ops/flash-attention/flash_fwd_split_hdim32_bf16_sm80.cu
662-
src/ops/flash-attention/flash_fwd_split_hdim32_fp16_sm80.cu
663-
src/ops/flash-attention/flash_fwd_split_hdim64_bf16_sm80.cu
664-
src/ops/flash-attention/flash_fwd_split_hdim64_fp16_sm80.cu
665-
src/ops/flash-attention/flash_fwd_split_hdim96_bf16_sm80.cu
666-
src/ops/flash-attention/flash_fwd_split_hdim96_fp16_sm80.cu
667-
src/ops/flash-attention/flash_fwd_split_hdim128_bf16_sm80.cu
668-
src/ops/flash-attention/flash_fwd_split_hdim128_fp16_sm80.cu
669-
src/ops/flash-attention/flash_fwd_split_hdim160_bf16_sm80.cu
670-
src/ops/flash-attention/flash_fwd_split_hdim160_fp16_sm80.cu
671-
src/ops/flash-attention/flash_fwd_split_hdim192_bf16_sm80.cu
672-
src/ops/flash-attention/flash_fwd_split_hdim192_fp16_sm80.cu
673-
src/ops/flash-attention/flash_fwd_split_hdim224_bf16_sm80.cu
674-
src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu
675-
src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu
676-
src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu
611+
set(_ALL_FLASH_HDIMS 32 64 96 128 160 192 224 256)
612+
if(FLASH_ATTN_HDIMS)
613+
set(_FLASH_HDIMS ${FLASH_ATTN_HDIMS})
614+
else()
615+
set(_FLASH_HDIMS ${_ALL_FLASH_HDIMS})
616+
endif()
617+
618+
message(STATUS "Flash attention head dimensions: ${_FLASH_HDIMS}")
619+
620+
set(_FLASH_ATTN_SOURCES "")
621+
foreach(_hdim ${_FLASH_HDIMS})
622+
list(APPEND _FLASH_ATTN_SOURCES
623+
src/ops/flash-attention/flash_fwd_hdim${_hdim}_bf16_sm80.cu
624+
src/ops/flash-attention/flash_fwd_hdim${_hdim}_fp16_sm80.cu
625+
src/ops/flash-attention/flash_fwd_split_hdim${_hdim}_bf16_sm80.cu
626+
src/ops/flash-attention/flash_fwd_split_hdim${_hdim}_fp16_sm80.cu
627+
)
628+
endforeach()
629+
630+
list(APPEND SOURCES ${_FLASH_ATTN_SOURCES})
631+
set_source_files_properties(${_FLASH_ATTN_SOURCES}
677632
PROPERTIES COMPILE_FLAGS "--use_fast_math")
678633
endif()
679634
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)

0 commit comments

Comments
 (0)