@@ -68,7 +68,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
6868 GGML_CUDA_FATTN_TILE_CONFIG_CASE (256 , 256 , 16 , 256 , 2 , 64 , 64 )
6969 GGML_CUDA_FATTN_TILE_CONFIG_CASE (256 , 256 , 32 , 256 , 2 , 64 , 64 )
7070
71- GGML_CUDA_FATTN_TILE_CONFIG_CASE (320 , 256 , 32 , 256 , 2 , 64 , 64 )
71+ GGML_CUDA_FATTN_TILE_CONFIG_CASE (320 , 256 , 16 , 256 , 2 , 64 , 64 )
7272
7373 GGML_CUDA_FATTN_TILE_CONFIG_CASE (512 , 512 , 4 , 128 , 2 , 64 , 64 )
7474 GGML_CUDA_FATTN_TILE_CONFIG_CASE (512 , 512 , 8 , 256 , 2 , 64 , 64 )
@@ -134,7 +134,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
134134 GGML_CUDA_FATTN_TILE_CONFIG_CASE (256 , 256 , 16 , 256 , 2 , 32 , 128 )
135135 GGML_CUDA_FATTN_TILE_CONFIG_CASE (256 , 256 , 32 , 256 , 2 , 32 , 64 )
136136
137- GGML_CUDA_FATTN_TILE_CONFIG_CASE (320 , 256 , 32 , 256 , 2 , 32 , 64 )
137+ GGML_CUDA_FATTN_TILE_CONFIG_CASE (320 , 256 , 16 , 256 , 2 , 32 , 64 )
138138
139139 GGML_CUDA_FATTN_TILE_CONFIG_CASE (512 , 512 , 4 , 128 , 2 , 32 , 64 )
140140 GGML_CUDA_FATTN_TILE_CONFIG_CASE (512 , 512 , 8 , 256 , 2 , 32 , 64 )
@@ -1142,7 +1142,7 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm
11421142 constexpr size_t nbytes_shared = 0 ;
11431143
11441144#ifdef GGML_USE_HIP
1145- if constexpr (DV <= 128 ) {
1145+ if constexpr (DKQ <= 128 ) {
11461146 if (Q->ne [1 ] > 32 /ncols2) {
11471147 constexpr int cols_per_block = 64 ;
11481148 const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
@@ -1156,7 +1156,7 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm
11561156#endif // GGML_USE_HIP
11571157
11581158#ifndef GGML_USE_HIP
1159- if constexpr (DV <= 256 )
1159+ if constexpr (DKQ <= 256 )
11601160#endif // GGML_USE_HIP
11611161 {
11621162 if (Q->ne [1 ] > 16 /ncols2) {
@@ -1238,11 +1238,22 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
12381238 const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;
12391239 const bool use_gqa_opt = mask && max_bias == 0 .0f && Q->ne [1 ] <= gqa_limit && K->ne [1 ] % FATTN_KQ_STRIDE == 0 ;
12401240
1241- if constexpr (DKQ == 320 ) { // Mistral Small 4
1241+ if constexpr (DKQ == 320 ) {
1242+ // This branch is only used for Mistral Small 4 which has a GQA ratio of 32.
1243+ // On AMD, simply use that GQA ratio with 32 columns / block since we always have enough SRAM.
1244+ // On NVIDIA however, the tile kernel is only used for GPUs that can't use the mma kernel (Pascal and older).
1245+ // Therefore, use a GQA ratio of 16 with 16 columns / block to stay below 48 kiB of SRAM / block.
1246+ #ifdef GGML_USE_HIP
12421247 if (use_gqa_opt && gqa_ratio % 32 == 0 ) {
12431248 launch_fattn_tile_switch_ncols1<DKQ, DV, 32 , use_logit_softcap>(ctx, dst);
12441249 return ;
12451250 }
1251+ #else
1252+ if (use_gqa_opt && gqa_ratio % 16 == 0 ) {
1253+ launch_fattn_tile_switch_ncols1<DKQ, DV, 16 , use_logit_softcap>(ctx, dst);
1254+ return ;
1255+ }
1256+ #endif // GGML_USE_HIP
12461257 GGML_ABORT (" flash-attn tile (320/256): expected GQA ratio multiple of 32" );
12471258 }
12481259
0 commit comments