Skip to content

Commit ff5ef82

Browse files
CUDA: skip compilation of superfluous FA kernels (ggml-org#21768)
1 parent 073bb2c commit ff5ef82

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

ggml/src/ggml-cuda/fattn.cu

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,17 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
7575
return;
7676
}
7777

78-
if (use_gqa_opt && gqa_ratio % 2 == 0) {
79-
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
78+
if constexpr (DKQ <= 256) {
79+
if (use_gqa_opt && gqa_ratio % 2 == 0) {
80+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
81+
return;
82+
}
83+
84+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
8085
return;
86+
} else {
87+
GGML_ABORT("fatal error");
8188
}
82-
83-
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
84-
return;
8589
}
8690

8791
if (use_gqa_opt && gqa_ratio > 4) {
@@ -94,12 +98,16 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
9498
return;
9599
}
96100

97-
if (use_gqa_opt && gqa_ratio > 1) {
98-
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
99-
return;
100-
}
101+
if constexpr (DKQ <= 256) {
102+
if (use_gqa_opt && gqa_ratio > 1) {
103+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
104+
return;
105+
}
101106

102-
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
107+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
108+
} else {
109+
GGML_ABORT("fatal error");
110+
}
103111
}
104112

105113
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

0 commit comments

Comments
 (0)