Skip to content

Commit 7b8443a

Browse files
ggml-cuda: add flash-attn support for DKQ=320/DV=256 with ncols2=32 (… (#22286)
* ggml-cuda: add flash-attn support for DKQ=320/DV=256 with ncols2=32 (GQA=32) Adds MMA-f16 and tile kernel configs, dispatch logic, template instances, and tile .cu file for Mistral Small 4 (head sizes 320/256), restricting to ncols2=32 to support GQA ratio 32 only. * Adding check to return BEST_FATTN_KERNEL_NONE in case GQA!=32 * Apply suggestions from code review Address review comments Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Address review comments and making kernel config default to DQK=512, DV=512 instead of DQK=256,DV=256 * Fixed bug with sinks=1, with ncols=32, there are two warp-groups created but sinks index is same(0,...,15) for both the groups hence with sinks=1, output is not matching with CPU output. Added sink_base which will be base index for each warp_group (threadIdx.y / np) * Apply suggestions from code review Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Update ggml/src/ggml-cuda/template-instances/generate_cu_files.py Co-authored-by: Johannes Gäßler <johannesg@5d6.de> --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
1 parent 5d56eff commit 7b8443a

8 files changed

Lines changed: 86 additions & 16 deletions

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
6666
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
6767
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
6868

69+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false);
70+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false);
71+
6972
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 128, 1, false);
7073
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 128, 1, false);
7174
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
@@ -85,6 +88,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
8588
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
8689
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
8790

91+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false);
92+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false);
93+
8894
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
8995
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
9096
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
@@ -118,6 +124,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
118124
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
119125
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
120126

127+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 64, 160, 128, 64, 2, true);
128+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 128, 2, 64, 160, 128, 64, 2, false);
129+
121130
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 128, 128, 128, 1, false);
122131
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
123132
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false);
@@ -1217,7 +1226,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
12171226
float KQ_max_scale[cols_per_thread];
12181227
#pragma unroll
12191228
for (int col = 0; col < cols_per_thread; ++col) {
1220-
const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col);
1229+
const int jc = (threadIdx.y/np)*cols_per_warp + (cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col));
12211230
const float sink = sinks_f[jc % ncols2];
12221231

12231232
const float KQ_max_new = fmaxf(KQ_max[col], sink);
@@ -1825,6 +1834,10 @@ extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
18251834
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
18261835
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
18271836

1837+
// Mistral Small 4 (DKQ=320, DV=256), GQA=32-only build:
1838+
extern DECL_FATTN_MMA_F16_CASE(320, 256, 1, 32);
1839+
extern DECL_FATTN_MMA_F16_CASE(320, 256, 2, 32);
1840+
18281841
// For GLM 4.7 Flash
18291842
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
18301843
extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);

ggml/src/ggml-cuda/fattn-tile.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
3838
GGML_ASSERT(V->ne[0] == K->ne[0]);
3939
ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst);
4040
} break;
41+
case 320: {
42+
GGML_ASSERT(V->ne[0] == 256);
43+
ggml_cuda_flash_attn_ext_tile_case<320, 256>(ctx, dst);
44+
} break;
4145
case 512: {
4246
GGML_ASSERT(V->ne[0] == K->ne[0]);
4347
ggml_cuda_flash_attn_ext_tile_case<512, 512>(ctx, dst);

ggml/src/ggml-cuda/fattn-tile.cuh

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
6868
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
6969
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
7070

71+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 64, 64)
72+
7173
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
7274
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
7375
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
@@ -128,6 +130,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
128130
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
129131
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
130132

133+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 32, 64)
134+
131135
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64)
132136
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64)
133137
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 32, 64)
@@ -195,6 +199,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
195199
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
196200
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
197201

202+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 512, 1, 128, 64)
203+
198204
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
199205
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
200206
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
@@ -264,6 +270,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
264270
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
265271
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
266272

273+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 128, 64)
274+
267275
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
268276
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
269277
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 4, 64, 64)
@@ -1144,14 +1152,16 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm
11441152
}
11451153
}
11461154

1147-
if (Q->ne[1] > 8/ncols2) {
1148-
constexpr int cols_per_block = 16;
1149-
const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1150-
const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
1151-
fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
1152-
launch_fattn<DV, cols_per_block/ncols2, ncols2>
1153-
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
1154-
return;
1155+
if constexpr (ncols2 <= 16) {
1156+
if (Q->ne[1] > 8/ncols2) {
1157+
constexpr int cols_per_block = 16;
1158+
const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1159+
const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
1160+
fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
1161+
launch_fattn<DV, cols_per_block/ncols2, ncols2>
1162+
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
1163+
return;
1164+
}
11551165
}
11561166

11571167
if constexpr (ncols2 <= 8) {
@@ -1210,6 +1220,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
12101220
const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;
12111221
const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
12121222

1223+
if constexpr (DKQ == 320) { // Mistral Small 4
1224+
if (use_gqa_opt && gqa_ratio % 32 == 0) {
1225+
launch_fattn_tile_switch_ncols1<DKQ, DV, 32, use_logit_softcap>(ctx, dst);
1226+
return;
1227+
}
1228+
GGML_ABORT("flash-attn tile (320/256): expected GQA ratio multiple of 32");
1229+
}
1230+
12131231
if constexpr (DKQ == 576) {
12141232
if (use_gqa_opt && gqa_ratio % 16 == 0) {
12151233
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
@@ -1221,7 +1239,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
12211239
}
12221240
}
12231241

1224-
if constexpr (DKQ <= 512) {
1242+
if constexpr (DKQ <= 512 && DKQ != 320) {
12251243
if (use_gqa_opt && gqa_ratio % 8 == 0) {
12261244
launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
12271245
return;
@@ -1275,5 +1293,6 @@ extern DECL_FATTN_TILE_CASE( 96, 96);
12751293
extern DECL_FATTN_TILE_CASE(112, 112);
12761294
extern DECL_FATTN_TILE_CASE(128, 128);
12771295
extern DECL_FATTN_TILE_CASE(256, 256);
1296+
extern DECL_FATTN_TILE_CASE(320, 256);
12781297
extern DECL_FATTN_TILE_CASE(512, 512);
12791298
extern DECL_FATTN_TILE_CASE(576, 512);

ggml/src/ggml-cuda/fattn.cu

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,22 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
143143
GGML_ASSERT(V->ne[0] == 256);
144144
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
145145
break;
146+
case 320:
147+
// For Mistral Small 4, go straight to the ncols1 switch (ncols2=32-only build).
148+
GGML_ASSERT(V->ne[0] == 256);
149+
{
150+
float max_bias = 0.0f;
151+
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
152+
153+
const bool use_gqa_opt = mask && max_bias == 0.0f;
154+
GGML_ASSERT(use_gqa_opt);
155+
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
156+
const int gqa_ratio = Q->ne[2] / K->ne[2];
157+
GGML_ASSERT(gqa_ratio % 32 == 0);
158+
159+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<320, 256, 32>(ctx, dst);
160+
}
161+
break;
146162
case 512:
147163
GGML_ASSERT(V->ne[0] == 512);
148164
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<512, 512>(ctx, dst);
@@ -352,6 +368,14 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
352368
return BEST_FATTN_KERNEL_NONE;
353369
}
354370
break;
371+
case 320:
372+
if (V->ne[0] != 256 || !gqa_opt_applies) {
373+
return BEST_FATTN_KERNEL_NONE;
374+
}
375+
if (gqa_ratio % 32 != 0) {
376+
return BEST_FATTN_KERNEL_NONE;
377+
}
378+
break;
355379
case 512:
356380
if (V->ne[0] != K->ne[0]) {
357381
return BEST_FATTN_KERNEL_NONE;

ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22

33
#include "../fattn-mma-f16.cuh"
44

5+
DECL_FATTN_MMA_F16_CASE(320, 256, 1, 32);
56
DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32);

ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22

33
#include "../fattn-mma-f16.cuh"
44

5+
DECL_FATTN_MMA_F16_CASE(320, 256, 2, 32);
56
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-tile.cuh"
4+
5+
DECL_FATTN_TILE_CASE(320, 256);

ggml/src/ggml-cuda/template-instances/generate_cu_files.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from glob import glob
44
import os
55

6-
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 512, 576]
6+
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 320, 512, 576]
77

88
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"]
99

@@ -62,7 +62,7 @@ def get_short_name(long_quant_name):
6262
os.remove(filename)
6363

6464
for head_size_kq in HEAD_SIZES_KQ:
65-
head_size_v = head_size_kq if head_size_kq != 576 else 512
65+
head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512)
6666
with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f:
6767
f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v))
6868

@@ -84,13 +84,16 @@ def get_short_name(long_quant_name):
8484
continue
8585
if head_size_kq == 72:
8686
continue
87-
if head_size_kq == 512 and ncols2 not in (4, 8):
87+
# Skip compilation of unused ncols2 values for niche head sizes:
88+
if head_size_kq == 320 and ncols2 != 32: # Mistral Small 4
8889
continue
89-
if head_size_kq != 576 and ncols2 in (16, 32):
90+
if head_size_kq == 512 and ncols2 not in (4, 8): # Gemma 4
9091
continue
91-
if head_size_kq == 576 and ncols2 not in (4, 16, 32):
92+
if head_size_kq == 576 and ncols2 not in (4, 16, 32): # Deepseek, GLM 4.7 Flash
9293
continue
93-
head_size_v = head_size_kq if head_size_kq != 576 else 512
94+
if head_size_kq not in (320, 576) and ncols2 in (16, 32):
95+
continue
96+
head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512)
9497
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))
9598

9699
for type in TYPES_MMQ:

0 commit comments

Comments
 (0)