@@ -23,6 +23,7 @@ option(BUILD_TESTS "Compile the tests" OFF)
2323option (BUILD_SHARED_LIBS "Build shared libraries" ON )
2424option (WITH_TENSOR_PARALLEL "Compile with NCCL and MPI backend" OFF )
2525option (WITH_FLASH_ATTN "Compile with Flash Attention 2" OFF )
26+ set (FLASH_ATTN_HDIMS "" CACHE STRING "Head dimensions to compile for flash attention (e.g. '32;64'). Empty means all." )
2627option (ENABLE_ADDRESS_SANITIZER "ASAN" OFF )
2728
2829MESSAGE (STATUS "Compiler Id: ${CMAKE_CXX_COMPILER_ID} " )
@@ -606,74 +607,28 @@ if (WITH_CUDA)
606607 endif ()
607608 if (WITH_FLASH_ATTN)
608609 add_definitions (-DCT2_WITH_FLASH_ATTN )
609- list (APPEND SOURCES
610- src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu
611- src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu
612- src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu
613- src/ops/flash-attention/flash_fwd_hdim64_fp16_sm80.cu
614- src/ops/flash-attention/flash_fwd_hdim96_bf16_sm80.cu
615- src/ops/flash-attention/flash_fwd_hdim96_fp16_sm80.cu
616- src/ops/flash-attention/flash_fwd_hdim128_bf16_sm80.cu
617- src/ops/flash-attention/flash_fwd_hdim128_fp16_sm80.cu
618- src/ops/flash-attention/flash_fwd_hdim160_bf16_sm80.cu
619- src/ops/flash-attention/flash_fwd_hdim160_fp16_sm80.cu
620- src/ops/flash-attention/flash_fwd_hdim192_bf16_sm80.cu
621- src/ops/flash-attention/flash_fwd_hdim192_fp16_sm80.cu
622- src/ops/flash-attention/flash_fwd_hdim224_bf16_sm80.cu
623- src/ops/flash-attention/flash_fwd_hdim224_fp16_sm80.cu
624- src/ops/flash-attention/flash_fwd_hdim256_bf16_sm80.cu
625- src/ops/flash-attention/flash_fwd_hdim256_fp16_sm80.cu
626- src/ops/flash-attention/flash_fwd_split_hdim32_bf16_sm80.cu
627- src/ops/flash-attention/flash_fwd_split_hdim32_fp16_sm80.cu
628- src/ops/flash-attention/flash_fwd_split_hdim64_bf16_sm80.cu
629- src/ops/flash-attention/flash_fwd_split_hdim64_fp16_sm80.cu
630- src/ops/flash-attention/flash_fwd_split_hdim96_bf16_sm80.cu
631- src/ops/flash-attention/flash_fwd_split_hdim96_fp16_sm80.cu
632- src/ops/flash-attention/flash_fwd_split_hdim128_bf16_sm80.cu
633- src/ops/flash-attention/flash_fwd_split_hdim128_fp16_sm80.cu
634- src/ops/flash-attention/flash_fwd_split_hdim160_bf16_sm80.cu
635- src/ops/flash-attention/flash_fwd_split_hdim160_fp16_sm80.cu
636- src/ops/flash-attention/flash_fwd_split_hdim192_bf16_sm80.cu
637- src/ops/flash-attention/flash_fwd_split_hdim192_fp16_sm80.cu
638- src/ops/flash-attention/flash_fwd_split_hdim224_bf16_sm80.cu
639- src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu
640- src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu
641- src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu
642- )
643610
644- set_source_files_properties (
645- src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu
646- src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu
647- src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu
648- src/ops/flash-attention/flash_fwd_hdim64_fp16_sm80.cu
649- src/ops/flash-attention/flash_fwd_hdim96_bf16_sm80.cu
650- src/ops/flash-attention/flash_fwd_hdim96_fp16_sm80.cu
651- src/ops/flash-attention/flash_fwd_hdim128_bf16_sm80.cu
652- src/ops/flash-attention/flash_fwd_hdim128_fp16_sm80.cu
653- src/ops/flash-attention/flash_fwd_hdim160_bf16_sm80.cu
654- src/ops/flash-attention/flash_fwd_hdim160_fp16_sm80.cu
655- src/ops/flash-attention/flash_fwd_hdim192_bf16_sm80.cu
656- src/ops/flash-attention/flash_fwd_hdim192_fp16_sm80.cu
657- src/ops/flash-attention/flash_fwd_hdim224_bf16_sm80.cu
658- src/ops/flash-attention/flash_fwd_hdim224_fp16_sm80.cu
659- src/ops/flash-attention/flash_fwd_hdim256_bf16_sm80.cu
660- src/ops/flash-attention/flash_fwd_hdim256_fp16_sm80.cu
661- src/ops/flash-attention/flash_fwd_split_hdim32_bf16_sm80.cu
662- src/ops/flash-attention/flash_fwd_split_hdim32_fp16_sm80.cu
663- src/ops/flash-attention/flash_fwd_split_hdim64_bf16_sm80.cu
664- src/ops/flash-attention/flash_fwd_split_hdim64_fp16_sm80.cu
665- src/ops/flash-attention/flash_fwd_split_hdim96_bf16_sm80.cu
666- src/ops/flash-attention/flash_fwd_split_hdim96_fp16_sm80.cu
667- src/ops/flash-attention/flash_fwd_split_hdim128_bf16_sm80.cu
668- src/ops/flash-attention/flash_fwd_split_hdim128_fp16_sm80.cu
669- src/ops/flash-attention/flash_fwd_split_hdim160_bf16_sm80.cu
670- src/ops/flash-attention/flash_fwd_split_hdim160_fp16_sm80.cu
671- src/ops/flash-attention/flash_fwd_split_hdim192_bf16_sm80.cu
672- src/ops/flash-attention/flash_fwd_split_hdim192_fp16_sm80.cu
673- src/ops/flash-attention/flash_fwd_split_hdim224_bf16_sm80.cu
674- src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu
675- src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu
676- src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu
611+ set (_ALL_FLASH_HDIMS 32 64 96 128 160 192 224 256)
612+ if (FLASH_ATTN_HDIMS)
613+ set (_FLASH_HDIMS ${FLASH_ATTN_HDIMS} )
614+ else ()
615+ set (_FLASH_HDIMS ${_ALL_FLASH_HDIMS} )
616+ endif ()
617+
618+ message (STATUS "Flash attention head dimensions: ${_FLASH_HDIMS} " )
619+
620+ set (_FLASH_ATTN_SOURCES "" )
621+ foreach (_hdim ${_FLASH_HDIMS} )
622+ list (APPEND _FLASH_ATTN_SOURCES
623+ src/ops/flash-attention/flash_fwd_hdim${_hdim} _bf16_sm80.cu
624+ src/ops/flash-attention/flash_fwd_hdim${_hdim} _fp16_sm80.cu
625+ src/ops/flash-attention/flash_fwd_split_hdim${_hdim} _bf16_sm80.cu
626+ src/ops/flash-attention/flash_fwd_split_hdim${_hdim} _fp16_sm80.cu
627+ )
628+ endforeach ()
629+
630+ list (APPEND SOURCES ${_FLASH_ATTN_SOURCES} )
631+ set_source_files_properties (${_FLASH_ATTN_SOURCES}
677632 PROPERTIES COMPILE_FLAGS "--use_fast_math" )
678633 endif ()
679634 set (CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
0 commit comments