Skip to content

Commit 0c21677

Browse files
CUDA: faster FA for GQA > 1 but not power of 2 (ggml-org#19092)
1 parent 0440bfd commit 0c21677

6 files changed

Lines changed: 99 additions & 36 deletions

File tree

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -643,9 +643,10 @@ static __global__ void flash_attn_stream_k_fixup(
643643

644644
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
645645
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
646+
const int iter_z = (ne02 + (ncols2 - 1)) / ncols2;
646647

647-
const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
648-
const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
648+
const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
649+
const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
649650

650651
const bool did_not_have_any_data = kbc0 == kbc0_stop;
651652
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@@ -654,15 +655,15 @@ static __global__ void flash_attn_stream_k_fixup(
654655
return;
655656
}
656657

657-
const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
658-
const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
659-
const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
658+
const int sequence = kbc0 / (iter_k*iter_j*iter_z);
659+
const int zt = (kbc0 - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j);
660+
const int jt = (kbc0 - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
660661

661-
if (jt*ncols1 + j >= ne01) {
662+
if (jt*ncols1 + j >= ne01 || zt*ncols2 + c >= ne02) {
662663
return;
663664
}
664665

665-
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
666+
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt*(ncols2*D) + (j*ne02 + c)*D + tid;
666667

667668
// Load the partial result that needs a fixup:
668669
float dst_val = 0.0f;
@@ -681,7 +682,7 @@ static __global__ void flash_attn_stream_k_fixup(
681682
int bidx = bidx0 - 1;
682683
int kbc_stop = kbc0;
683684
while(true) {
684-
const int kbc = int64_t(bidx)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
685+
const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
685686
if (kbc == kbc_stop) { // Did not have any data.
686687
bidx--;
687688
kbc_stop = kbc;
@@ -883,7 +884,8 @@ void launch_fattn(
883884
}
884885

885886
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
886-
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
887+
const int ntiles_z = ((Q->ne[2] + ncols2 - 1) / ncols2);
888+
const int ntiles_total = ntiles_x * ntiles_z * Q->ne[3];
887889

888890
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
889891
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
@@ -958,7 +960,7 @@ void launch_fattn(
958960

959961
blocks_num.x = ntiles_x;
960962
blocks_num.y = parallel_blocks;
961-
blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3];
963+
blocks_num.z = ntiles_z*Q->ne[3];
962964

963965
if (parallel_blocks > 1) {
964966
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -940,6 +940,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
940940
const int stride_V,
941941
const int stride_mask,
942942
const int jt,
943+
const int zt,
943944
const int kb0_start,
944945
const int kb0_stop) {
945946
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
@@ -1022,7 +1023,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
10221023
const int j = jc / ncols2;
10231024
const int c = jc % ncols2;
10241025

1025-
if (jt*ncols1 + j < int(ne01.z)) {
1026+
if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt*ncols2 + c < ne02)) {
10261027
#pragma unroll
10271028
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
10281029
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
@@ -1408,7 +1409,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
14081409
const int j_dst = jc_dst / ncols2;
14091410
const int c_dst = jc_dst % ncols2;
14101411

1411-
if (!is_fixup && jt*ncols1 + j_dst >= int(ne01.z)) {
1412+
if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt*ncols2 + c_dst >= ne02))) {
14121413
continue;
14131414
}
14141415

@@ -1522,10 +1523,11 @@ static __global__ void flash_attn_ext_f16(
15221523

15231524
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
15241525
const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
1526+
const int iter_z = (ne02 + (ncols2 - 1)) / ncols2;
15251527

15261528
// kbc == k block continuous, current index in continuous ijk space.
1527-
int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
1528-
const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
1529+
int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
1530+
const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
15291531

15301532
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
15311533
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@@ -1536,9 +1538,9 @@ static __global__ void flash_attn_ext_f16(
15361538
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
15371539

15381540
while (kbc < kbc_stop && kb0_stop == iter_k) {
1539-
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1540-
const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
1541-
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
1541+
const int sequence = kbc / (iter_k*iter_j*iter_z);
1542+
const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); // head in units of ncols2
1543+
const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
15421544

15431545
const int head0 = zt * ncols2;
15441546

@@ -1561,12 +1563,12 @@ static __global__ void flash_attn_ext_f16(
15611563
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
15621564
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
15631565
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1564-
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
1566+
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop);
15651567
} else {
15661568
constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.
15671569
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
15681570
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1569-
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
1571+
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop);
15701572
}
15711573

15721574
kbc += iter_k;
@@ -1580,9 +1582,9 @@ static __global__ void flash_attn_ext_f16(
15801582
return;
15811583
}
15821584

1583-
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1584-
const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
1585-
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
1585+
const int sequence = kbc / (iter_k*iter_j*iter_z);
1586+
const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); // head in units of ncols2
1587+
const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
15861588

15871589
const int head0 = zt * ncols2;
15881590

@@ -1605,7 +1607,7 @@ static __global__ void flash_attn_ext_f16(
16051607
constexpr bool needs_fixup = false;
16061608
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
16071609
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1608-
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
1610+
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop);
16091611
#else
16101612
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
16111613
max_bias, m0, m1, n_head_log2, logit_softcap,
@@ -1739,3 +1741,5 @@ extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
17391741
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
17401742
extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
17411743
extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
1744+
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32);
1745+
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32);

ggml/src/ggml-cuda/fattn.cu

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
1818
}
1919
}
2020

21-
if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) {
22-
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
23-
return;
21+
if constexpr (ncols2 <= 16) {
22+
if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) {
23+
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
24+
return;
25+
}
2426
}
2527

2628
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 32/ncols2) {
@@ -33,6 +35,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
3335

3436
template <int DKQ, int DV>
3537
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
38+
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
3639
const ggml_tensor * KQV = dst;
3740
const ggml_tensor * Q = dst->src[0];
3841
const ggml_tensor * K = dst->src[1];
@@ -60,17 +63,38 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
6063
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
6164
const int gqa_ratio = Q->ne[2] / K->ne[2];
6265

63-
if (use_gqa_opt && gqa_ratio % 8 == 0) {
66+
// On Volta the GQA optimizations aren't as impactful vs. minimizing wasted compute:
67+
if (cc == GGML_CUDA_CC_VOLTA) {
68+
if (use_gqa_opt && gqa_ratio % 8 == 0) {
69+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
70+
return;
71+
}
72+
73+
if (use_gqa_opt && gqa_ratio % 4 == 0) {
74+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
75+
return;
76+
}
77+
78+
if (use_gqa_opt && gqa_ratio % 2 == 0) {
79+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
80+
return;
81+
}
82+
83+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
84+
return;
85+
}
86+
87+
if (use_gqa_opt && gqa_ratio > 4) {
6488
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
6589
return;
6690
}
6791

68-
if (use_gqa_opt && gqa_ratio % 4 == 0) {
92+
if (use_gqa_opt && gqa_ratio > 2) {
6993
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
7094
return;
7195
}
7296

73-
if (use_gqa_opt && gqa_ratio % 2 == 0) {
97+
if (use_gqa_opt && gqa_ratio > 1) {
7498
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
7599
return;
76100
}
@@ -79,6 +103,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
79103
}
80104

81105
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
106+
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
82107
const ggml_tensor * KQV = dst;
83108
const ggml_tensor * Q = dst->src[0];
84109
const ggml_tensor * K = dst->src[1];
@@ -121,8 +146,30 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
121146

122147
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
123148
const int gqa_ratio = Q->ne[2] / K->ne[2];
124-
GGML_ASSERT(gqa_ratio % 4 == 0);
125-
if (gqa_ratio % 16 == 0) {
149+
if (gqa_ratio == 20) { // GLM 4.7 Flash
150+
if (cc >= GGML_CUDA_CC_BLACKWELL) {
151+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
152+
break;
153+
}
154+
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
155+
if (Q->ne[1] <= 4) {
156+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
157+
break;
158+
}
159+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
160+
break;
161+
}
162+
if (cc >= GGML_CUDA_CC_TURING) {
163+
if (Q->ne[1] <= 4) {
164+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 32>(ctx, dst);
165+
break;
166+
}
167+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
168+
break;
169+
}
170+
// Volta:
171+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
172+
} else if (gqa_ratio % 16 == 0) {
126173
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
127174
} else {
128175
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
@@ -234,7 +281,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
234281

235282
// The effective batch size for the kernel can be increased by gqa_ratio.
236283
// The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded,
237-
bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
284+
bool gqa_opt_applies = gqa_ratio >= 2 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
238285
for (const ggml_tensor * t : {Q, K, V, mask}) {
239286
if (t == nullptr || ggml_is_quantized(t->type)) {
240287
continue;
@@ -268,7 +315,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
268315
if (V->ne[0] != 512) {
269316
return BEST_FATTN_KERNEL_NONE;
270317
}
271-
if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
318+
if (!gqa_opt_applies) {
272319
return BEST_FATTN_KERNEL_NONE;
273320
}
274321
if (!V_is_K_view) {
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-mma-f16.cuh"
4+
5+
DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-mma-f16.cuh"
4+
5+
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32);

ggml/src/ggml-cuda/template-instances/generate_cu_files.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def get_short_name(long_quant_name):
7171
f.write(SOURCE_FATTN_VEC.format(type_k=type_k, type_v=type_v))
7272

7373
for ncols in [8, 16, 32, 64]:
74-
for ncols2 in [1, 2, 4, 8, 16]:
74+
for ncols2 in [1, 2, 4, 8, 16, 32]:
7575
if ncols2 > ncols:
7676
continue
7777
ncols1 = ncols // ncols2
@@ -83,9 +83,9 @@ def get_short_name(long_quant_name):
8383
continue
8484
if head_size_kq == 72:
8585
continue
86-
if head_size_kq != 576 and ncols2 == 16:
86+
if head_size_kq != 576 and ncols2 in (16, 32):
8787
continue
88-
if head_size_kq == 576 and ncols2 not in (4, 16):
88+
if head_size_kq == 576 and ncols2 not in (4, 16, 32):
8989
continue
9090
head_size_v = head_size_kq if head_size_kq != 576 else 512
9191
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))

0 commit comments

Comments
 (0)