@@ -75,13 +75,17 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
7575 return ;
7676 }
7777
78- if (use_gqa_opt && gqa_ratio % 2 == 0 ) {
79- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2 >(ctx, dst);
78+ if constexpr (DKQ <= 256 ) {
79+ if (use_gqa_opt && gqa_ratio % 2 == 0 ) {
80+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2 >(ctx, dst);
81+ return ;
82+ }
83+
84+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1 >(ctx, dst);
8085 return ;
86+ } else {
87+ GGML_ABORT (" fatal error" );
8188 }
82-
83- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1 >(ctx, dst);
84- return ;
8589 }
8690
8791 if (use_gqa_opt && gqa_ratio > 4 ) {
@@ -94,12 +98,16 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
9498 return ;
9599 }
96100
97- if (use_gqa_opt && gqa_ratio > 1 ) {
98- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2 >(ctx, dst);
99- return ;
100- }
101+ if constexpr (DKQ <= 256 ) {
102+ if (use_gqa_opt && gqa_ratio > 1 ) {
103+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2 >(ctx, dst);
104+ return ;
105+ }
101106
102- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1 >(ctx, dst);
107+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1 >(ctx, dst);
108+ } else {
109+ GGML_ABORT (" fatal error" );
110+ }
103111}
104112
105113static void ggml_cuda_flash_attn_ext_mma_f16 (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
0 commit comments