Skip to content

Commit b6cdb04

Browse files
JohannesGaesslermeh
authored andcommitted
CUDA: fix tile FA kernel on Pascal (ggml-org#22541)
1 parent aaa0a91 commit b6cdb04

1 file changed

Lines changed: 16 additions & 5 deletions

File tree

ggml/src/ggml-cuda/fattn-tile.cuh

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
6868
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
6969
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
7070

71-
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 64, 64)
71+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 16, 256, 2, 64, 64)
7272

7373
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
7474
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
@@ -134,7 +134,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
134134
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
135135
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
136136

137-
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 32, 64)
137+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 16, 256, 2, 32, 64)
138138

139139
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64)
140140
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64)
@@ -1142,7 +1142,7 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm
11421142
constexpr size_t nbytes_shared = 0;
11431143

11441144
#ifdef GGML_USE_HIP
1145-
if constexpr (DV <= 128) {
1145+
if constexpr (DKQ <= 128) {
11461146
if (Q->ne[1] > 32/ncols2) {
11471147
constexpr int cols_per_block = 64;
11481148
const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
@@ -1156,7 +1156,7 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm
11561156
#endif // GGML_USE_HIP
11571157

11581158
#ifndef GGML_USE_HIP
1159-
if constexpr (DV <= 256)
1159+
if constexpr (DKQ <= 256)
11601160
#endif // GGML_USE_HIP
11611161
{
11621162
if (Q->ne[1] > 16/ncols2) {
@@ -1238,11 +1238,22 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
12381238
const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;
12391239
const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
12401240

1241-
if constexpr (DKQ == 320) { // Mistral Small 4
1241+
if constexpr (DKQ == 320) {
1242+
// This branch is only used for Mistral Small 4 which has a GQA ratio of 32.
1243+
// On AMD, simply use that GQA ratio with 32 columns / block since we always have enough SRAM.
1244+
// On NVIDIA however, the tile kernel is only used for GPUs that can't use the mma kernel (Pascal and older).
1245+
// Therefore, use a GQA ratio of 16 with 16 columns / block to stay below 48 kiB of SRAM / block.
1246+
#ifdef GGML_USE_HIP
12421247
if (use_gqa_opt && gqa_ratio % 32 == 0) {
12431248
launch_fattn_tile_switch_ncols1<DKQ, DV, 32, use_logit_softcap>(ctx, dst);
12441249
return;
12451250
}
1251+
#else
1252+
if (use_gqa_opt && gqa_ratio % 16 == 0) {
1253+
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
1254+
return;
1255+
}
1256+
#endif // GGML_USE_HIP
12461257
GGML_ABORT("flash-attn tile (320/256): expected GQA ratio multiple of 32");
12471258
}
12481259

0 commit comments

Comments
 (0)