Skip to content

Commit 43a7066

Browse files
JohannesGaesslersamuraieng
authored andcommitted
CUDA: fix tile FA kernel on Pascal (ggml-org#22541)
1 parent a2b63f6 commit 43a7066

1 file changed

Lines changed: 16 additions & 5 deletions

File tree

ggml/src/ggml-cuda/fattn-tile.cuh

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
6868
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
6969
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
7070

71-
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 64, 64)
71+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 16, 256, 2, 64, 64)
7272

7373
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
7474
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
@@ -130,7 +130,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
130130
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
131131
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
132132

133-
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 32, 64)
133+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 16, 256, 2, 32, 64)
134134

135135
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64)
136136
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64)
@@ -1124,7 +1124,7 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm
11241124
constexpr size_t nbytes_shared = 0;
11251125

11261126
#ifdef GGML_USE_HIP
1127-
if constexpr (DV <= 128) {
1127+
if constexpr (DKQ <= 128) {
11281128
if (Q->ne[1] > 32/ncols2) {
11291129
constexpr int cols_per_block = 64;
11301130
const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
@@ -1138,7 +1138,7 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm
11381138
#endif // GGML_USE_HIP
11391139

11401140
#ifndef GGML_USE_HIP
1141-
if constexpr (DV <= 256)
1141+
if constexpr (DKQ <= 256)
11421142
#endif // GGML_USE_HIP
11431143
{
11441144
if (Q->ne[1] > 16/ncols2) {
@@ -1220,11 +1220,22 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
12201220
const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;
12211221
const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
12221222

1223-
if constexpr (DKQ == 320) { // Mistral Small 4
1223+
if constexpr (DKQ == 320) {
1224+
// This branch is only used for Mistral Small 4 which has a GQA ratio of 32.
1225+
// On AMD, simply use that GQA ratio with 32 columns / block since we always have enough SRAM.
1226+
// On NVIDIA however, the tile kernel is only used for GPUs that can't use the mma kernel (Pascal and older).
1227+
// Therefore, use a GQA ratio of 16 with 16 columns / block to stay below 48 kiB of SRAM / block.
1228+
#ifdef GGML_USE_HIP
12241229
if (use_gqa_opt && gqa_ratio % 32 == 0) {
12251230
launch_fattn_tile_switch_ncols1<DKQ, DV, 32, use_logit_softcap>(ctx, dst);
12261231
return;
12271232
}
1233+
#else
1234+
if (use_gqa_opt && gqa_ratio % 16 == 0) {
1235+
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
1236+
return;
1237+
}
1238+
#endif // GGML_USE_HIP
12281239
GGML_ABORT("flash-attn tile (320/256): expected GQA ratio multiple of 32");
12291240
}
12301241

0 commit comments

Comments
 (0)