Skip to content

Commit 046e284

Browse files
authored
Add flash attention MMA / Tiles to support MiMo-V2.5 (#22812)
* mimo-v2.5: add flash attention mma/tiles for for d_kq=192 d_v=128 * mimo-v2.5: follow (256, 256) fattn templates * mimo-v2.5: cleanup comments * mimo-v2.5: further comment cleanup * mimo-v2.5: address PR feedback fix GQA handling check for other dangling 320/576 carveouts and mirror them for 192 Add to backend ops test so new paths are covered
1 parent 6600172 commit 046e284

14 files changed

Lines changed: 105 additions & 10 deletions

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
6161
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 2, true);
6262
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 2, true);
6363

64+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 8, 64, 4, 64, 96, 64, 64, 2, true);
65+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 16, 64, 4, 32, 96, 64, 64, 2, true);
66+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 32, 128, 2, 32, 96, 64, 64, 2, true);
67+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 64, 128, 2, 32, 96, 64, 64, 2, true);
68+
6469
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 2, true);
6570
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 2, true);
6671
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
@@ -1561,6 +1566,10 @@ static __global__ void flash_attn_ext_f16(
15611566
NO_DEVICE_CODE;
15621567
return;
15631568
}
1569+
if (DKQ == 192 && ncols2 != 8 && ncols2 != 16) {
1570+
NO_DEVICE_CODE;
1571+
return;
1572+
}
15641573
#ifdef VOLTA_MMA_AVAILABLE
15651574
if (ncols1*ncols2 < 32) {
15661575
NO_DEVICE_CODE;

ggml/src/ggml-cuda/fattn-tile.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
3434
GGML_ASSERT(V->ne[0] == K->ne[0]);
3535
ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst);
3636
} break;
37+
case 192: {
38+
GGML_ASSERT(V->ne[0] == 128);
39+
ggml_cuda_flash_attn_ext_tile_case<192, 128>(ctx, dst);
40+
} break;
3741
case 256: {
3842
GGML_ASSERT(V->ne[0] == K->ne[0]);
3943
ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst);

ggml/src/ggml-cuda/fattn-tile.cuh

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
6262
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64)
6363
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
6464

65+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 64, 2, 64, 64)
66+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 128, 2, 64, 64)
67+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 256, 2, 64, 64)
68+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 2, 64, 64)
69+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 2, 64, 64)
70+
6571
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 2, 64, 64)
6672
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 2, 64, 64)
6773
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 64)
@@ -124,6 +130,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
124130
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 128, 3, 32, 128)
125131
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
126132

133+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 128, 3, 64, 64)
134+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 128, 3, 32, 64)
135+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 256, 2, 32, 64)
136+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 2, 32, 64)
137+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 2, 32, 64)
138+
127139
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 128, 3, 64, 64)
128140
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 3, 32, 64)
129141
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 32, 256)
@@ -193,6 +205,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
193205
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
194206
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2, 64, 32)
195207

208+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 256, 2, 128, 64)
209+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 256, 2, 64, 64)
210+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 256, 2, 64, 64)
211+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 2, 32, 64)
212+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 2, 32, 64)
213+
196214
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 256, 2, 128, 64)
197215
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 256, 2, 64, 128)
198216
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 128)
@@ -264,6 +282,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
264282
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128, 64)
265283
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3, 64, 64)
266284

285+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 64, 8, 32, 64)
286+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 128, 6, 32, 64)
287+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 128, 6, 32, 64)
288+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 5, 32, 64)
289+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 3, 64, 64)
290+
267291
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 8, 32, 64)
268292
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 6, 32, 256)
269293
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 128, 6, 32, 256)
@@ -1250,7 +1274,20 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
12501274
}
12511275
}
12521276

1253-
if constexpr (DKQ <= 512 && DKQ != 320) {
1277+
if constexpr (DKQ == 192) {
1278+
// MiMo-V2.5 / V2.5-Pro / V2-Flash: gqa_ratio is 8 (SWA) or 16 (full attn)
1279+
if (use_gqa_opt && gqa_ratio % 16 == 0) {
1280+
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
1281+
return;
1282+
}
1283+
if (use_gqa_opt && gqa_ratio % 8 == 0) {
1284+
launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
1285+
return;
1286+
}
1287+
GGML_ABORT("flash-attn tile (192/128): expected GQA ratio multiple of 8");
1288+
}
1289+
1290+
if constexpr (DKQ <= 512 && DKQ != 320 && DKQ != 192) {
12541291
if (use_gqa_opt && gqa_ratio % 8 == 0) {
12551292
launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
12561293
return;
@@ -1303,6 +1340,7 @@ extern DECL_FATTN_TILE_CASE( 80, 80);
13031340
extern DECL_FATTN_TILE_CASE( 96, 96);
13041341
extern DECL_FATTN_TILE_CASE(112, 112);
13051342
extern DECL_FATTN_TILE_CASE(128, 128);
1343+
extern DECL_FATTN_TILE_CASE(192, 128);
13061344
extern DECL_FATTN_TILE_CASE(256, 256);
13071345
extern DECL_FATTN_TILE_CASE(320, 256);
13081346
extern DECL_FATTN_TILE_CASE(512, 512);

ggml/src/ggml-cuda/fattn.cu

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,22 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
139139
GGML_ASSERT(V->ne[0] == 128);
140140
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst);
141141
break;
142+
case 192: {
143+
// MiMo-V2.5 / V2.5-Pro / V2-Flash: gqa_ratio is 8 (SWA) or 16 (full attn)
144+
GGML_ASSERT(V->ne[0] == 128);
145+
float max_bias = 0.0f;
146+
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
147+
const bool use_gqa_opt = mask && max_bias == 0.0f;
148+
GGML_ASSERT(use_gqa_opt);
149+
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
150+
const int gqa_ratio = Q->ne[2] / K->ne[2];
151+
if (gqa_ratio % 16 == 0) {
152+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 16>(ctx, dst);
153+
} else {
154+
GGML_ASSERT(gqa_ratio % 8 == 0);
155+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 8>(ctx, dst);
156+
}
157+
} break;
142158
case 256:
143159
GGML_ASSERT(V->ne[0] == 256);
144160
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
@@ -368,6 +384,14 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
368384
return BEST_FATTN_KERNEL_NONE;
369385
}
370386
break;
387+
case 192:
388+
if (V->ne[0] != 128 || !gqa_opt_applies) {
389+
return BEST_FATTN_KERNEL_NONE;
390+
}
391+
if (gqa_ratio % 8 != 0) {
392+
return BEST_FATTN_KERNEL_NONE;
393+
}
394+
break;
371395
case 320:
372396
if (V->ne[0] != 256 || !gqa_opt_applies) {
373397
return BEST_FATTN_KERNEL_NONE;
@@ -425,7 +449,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
425449
}
426450

427451
// For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
428-
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
452+
// 192 satisfies % 64 == 0 but has no vec instance (DKQ != DV); force it onto the MMA path.
453+
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && Q->ne[0] != 192 && K->ne[1] % FATTN_KQ_STRIDE == 0;
429454

430455
// If Turing tensor cores are available, use them:
431456
if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
@@ -454,7 +479,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
454479

455480
if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
456481
int gqa_ratio_eff = 1;
457-
const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;
482+
const int ncols2_max = (Q->ne[0] == 576 || Q->ne[0] == 192) ? 16 : 8;
458483
while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
459484
gqa_ratio_eff *= 2;
460485
}
@@ -468,7 +493,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
468493
}
469494

470495
// Use the WMMA kernel if possible:
471-
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 512 && Q->ne[0] != 576) {
496+
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 192 && Q->ne[0] != 512 && Q->ne[0] != 576) {
472497
if (can_use_vector_kernel && Q->ne[1] <= 2) {
473498
return BEST_FATTN_KERNEL_VEC;
474499
}
@@ -501,7 +526,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
501526
}
502527

503528
// Use MFMA flash attention for CDNA (MI100+):
504-
if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 512 && Q->ne[0] != 576) {
529+
if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 192 && Q->ne[0] != 256 && Q->ne[0] != 512 && Q->ne[0] != 576) {
505530
const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1);
506531
// MMA vs tile crossover benchmarked on MI300X @ d32768:
507532
// hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%)

ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22

33
#include "../fattn-mma-f16.cuh"
44

5+
DECL_FATTN_MMA_F16_CASE(192, 128, 1, 16);
56
DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);

ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 1, 8);
77
DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8);
88
DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8);
99
DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8);
10+
DECL_FATTN_MMA_F16_CASE(192, 128, 1, 8);
1011
DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8);
1112
DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8);

ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22

33
#include "../fattn-mma-f16.cuh"
44

5+
DECL_FATTN_MMA_F16_CASE(192, 128, 2, 16);
56
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);

ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 2, 8);
77
DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8);
88
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8);
99
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8);
10+
DECL_FATTN_MMA_F16_CASE(192, 128, 2, 8);
1011
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8);
1112
DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8);

ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22

33
#include "../fattn-mma-f16.cuh"
44

5+
DECL_FATTN_MMA_F16_CASE(192, 128, 4, 16);
56
DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);

ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 4, 8);
77
DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8);
88
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8);
99
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8);
10+
DECL_FATTN_MMA_F16_CASE(192, 128, 4, 8);
1011
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8);
1112
DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8);

0 commit comments

Comments
 (0)