@@ -68,7 +68,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
6868 GGML_CUDA_FATTN_TILE_CONFIG_CASE (256 , 256 , 16 , 256 , 2 , 64 , 64 )
6969 GGML_CUDA_FATTN_TILE_CONFIG_CASE (256 , 256 , 32 , 256 , 2 , 64 , 64 )
7070
71- GGML_CUDA_FATTN_TILE_CONFIG_CASE (320 , 256 , 32 , 256 , 2 , 64 , 64 )
71+ GGML_CUDA_FATTN_TILE_CONFIG_CASE (320 , 256 , 16 , 256 , 2 , 64 , 64 )
7272
7373 GGML_CUDA_FATTN_TILE_CONFIG_CASE (512 , 512 , 4 , 128 , 2 , 64 , 64 )
7474 GGML_CUDA_FATTN_TILE_CONFIG_CASE (512 , 512 , 8 , 256 , 2 , 64 , 64 )
@@ -130,7 +130,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
130130 GGML_CUDA_FATTN_TILE_CONFIG_CASE (256 , 256 , 16 , 256 , 2 , 32 , 128 )
131131 GGML_CUDA_FATTN_TILE_CONFIG_CASE (256 , 256 , 32 , 256 , 2 , 32 , 64 )
132132
133- GGML_CUDA_FATTN_TILE_CONFIG_CASE (320 , 256 , 32 , 256 , 2 , 32 , 64 )
133+ GGML_CUDA_FATTN_TILE_CONFIG_CASE (320 , 256 , 16 , 256 , 2 , 32 , 64 )
134134
135135 GGML_CUDA_FATTN_TILE_CONFIG_CASE (512 , 512 , 4 , 128 , 2 , 32 , 64 )
136136 GGML_CUDA_FATTN_TILE_CONFIG_CASE (512 , 512 , 8 , 256 , 2 , 32 , 64 )
@@ -1124,7 +1124,7 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm
11241124 constexpr size_t nbytes_shared = 0 ;
11251125
11261126#ifdef GGML_USE_HIP
1127- if constexpr (DV <= 128 ) {
1127+ if constexpr (DKQ <= 128 ) {
11281128 if (Q->ne [1 ] > 32 /ncols2) {
11291129 constexpr int cols_per_block = 64 ;
11301130 const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
@@ -1138,7 +1138,7 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm
11381138#endif // GGML_USE_HIP
11391139
11401140#ifndef GGML_USE_HIP
1141- if constexpr (DV <= 256 )
1141+ if constexpr (DKQ <= 256 )
11421142#endif // GGML_USE_HIP
11431143 {
11441144 if (Q->ne [1 ] > 16 /ncols2) {
@@ -1220,11 +1220,22 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
12201220 const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;
12211221 const bool use_gqa_opt = mask && max_bias == 0 .0f && Q->ne [1 ] <= gqa_limit && K->ne [1 ] % FATTN_KQ_STRIDE == 0 ;
12221222
1223- if constexpr (DKQ == 320 ) { // Mistral Small 4
1223+ if constexpr (DKQ == 320 ) {
1224+ // This branch is only used for Mistral Small 4 which has a GQA ratio of 32.
1225+ // On AMD, simply use that GQA ratio with 32 columns / block since we always have enough SRAM.
1226+ // On NVIDIA however, the tile kernel is only used for GPUs that can't use the mma kernel (Pascal and older).
1227+ // Therefore, use a GQA ratio of 16 with 16 columns / block to stay below 48 kiB of SRAM / block.
1228+ #ifdef GGML_USE_HIP
12241229 if (use_gqa_opt && gqa_ratio % 32 == 0 ) {
12251230 launch_fattn_tile_switch_ncols1<DKQ, DV, 32 , use_logit_softcap>(ctx, dst);
12261231 return ;
12271232 }
1233+ #else
1234+ if (use_gqa_opt && gqa_ratio % 16 == 0 ) {
1235+ launch_fattn_tile_switch_ncols1<DKQ, DV, 16 , use_logit_softcap>(ctx, dst);
1236+ return ;
1237+ }
1238+ #endif // GGML_USE_HIP
12281239 GGML_ABORT (" flash-attn tile (320/256): expected GQA ratio multiple of 32" );
12291240 }
12301241
0 commit comments