Skip to content

Commit 15f786e

Browse files
authored
[CUDA ] Write an optimized flash_attn_stream_k_fixup kernel (ggml-org#21159)
* Write an optimized flash_attn_stream_k_fixup kernel Write a specialized and more optimized kernel for cases where nblocks_stream_k is multiple of ntiles_dst. Make nblocks_stream_k to multiple of ntiles_dst if nblocks_stream_k > 2 * ntiles_dst * Use the new kernel only for nblocks_stream_k_raw > 4 * ntiles_dst to make sure we have enough concurrency on GPUs * Address review comments * Address review comments * Revert variable names to original
1 parent 94ca829 commit 15f786e

1 file changed

Lines changed: 153 additions & 25 deletions

File tree

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 153 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -676,9 +676,96 @@ static __global__ void flash_attn_mask_to_KV_max(
676676

677677
template<int D, int ncols1, int ncols2> // D == head size
678678
__launch_bounds__(D, 1)
679-
static __global__ void flash_attn_stream_k_fixup(
680-
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03,
681-
const int ne11, const int ne12, const int nbatch_fa) {
679+
static __global__ void flash_attn_stream_k_fixup_uniform(
680+
float * __restrict__ dst,
681+
const float2 * __restrict__ dst_fixup,
682+
const int ne01, const int ne02,
683+
const int ne12, const int nblocks_stream_k,
684+
const int gqa_ratio,
685+
const int blocks_per_tile,
686+
const uint3 fd_iter_j_z_ne12,
687+
const uint3 fd_iter_j_z,
688+
const uint3 fd_iter_j) {
689+
constexpr int ncols = ncols1*ncols2;
690+
691+
const int tile_idx = blockIdx.x; // One block per output tile.
692+
const int j = blockIdx.y;
693+
const int c = blockIdx.z;
694+
const int jc = j*ncols2 + c;
695+
const int tid = threadIdx.x;
696+
697+
// nblocks_stream_k is a multiple of ntiles_dst (== gridDim.x), so each tile gets the same number of blocks.
698+
const int b_first = tile_idx * blocks_per_tile;
699+
const int b_last = b_first + blocks_per_tile - 1;
700+
701+
const float * dst_fixup_data = ((const float *) dst_fixup) + nblocks_stream_k*(2*2*ncols);
702+
703+
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
704+
const uint2 dm0 = fast_div_modulo(tile_idx, fd_iter_j_z_ne12);
705+
const uint2 dm1 = fast_div_modulo(dm0.y, fd_iter_j_z);
706+
const uint2 dm2 = fast_div_modulo(dm1.y, fd_iter_j);
707+
708+
const int sequence = dm0.x;
709+
const int z_KV = dm1.x;
710+
const int zt_gqa = dm2.x;
711+
const int jt = dm2.y;
712+
713+
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
714+
715+
if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {
716+
return;
717+
}
718+
719+
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;
720+
721+
// Load the partial result that needs a fixup
722+
float dst_val = *dst;
723+
float max_val;
724+
float rowsum;
725+
{
726+
const float2 tmp = dst_fixup[b_last*ncols + jc];
727+
max_val = tmp.x;
728+
rowsum = tmp.y;
729+
}
730+
731+
// Combine with all previous blocks in this tile.
732+
for (int bidx = b_last - 1; bidx >= b_first; --bidx) {
733+
const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
734+
735+
const float2 tmp = dst_fixup[(nblocks_stream_k + bidx)*ncols + jc];
736+
737+
const float max_val_new = fmaxf(max_val, tmp.x);
738+
739+
const float diff_val = max_val - max_val_new;
740+
const float diff_add = tmp.x - max_val_new;
741+
742+
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
743+
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
744+
745+
dst_val = scale_val*dst_val + scale_add*dst_add;
746+
rowsum = scale_val*rowsum + scale_add*tmp.y;
747+
748+
max_val = max_val_new;
749+
}
750+
751+
// Write back final result:
752+
*dst = dst_val / rowsum;
753+
}
754+
755+
// General fixup kernel for the case where the number of blocks per tile is not uniform across tiles
756+
// (blocks_num.x not a multiple of ntiles_dst)
757+
template <int D, int ncols1, int ncols2> // D == head size
758+
__launch_bounds__(D, 1)
759+
static __global__ void flash_attn_stream_k_fixup_general(
760+
float * __restrict__ dst,
761+
const float2 * __restrict__ dst_fixup,
762+
const int ne01, const int ne02,
763+
const int gqa_ratio,
764+
const int total_work,
765+
const uint3 fd_iter_k_j_z_ne12,
766+
const uint3 fd_iter_k_j_z,
767+
const uint3 fd_iter_k_j,
768+
const uint3 fd_iter_k) {
682769
constexpr int ncols = ncols1*ncols2;
683770

684771
const int bidx0 = blockIdx.x;
@@ -689,27 +776,26 @@ static __global__ void flash_attn_stream_k_fixup(
689776

690777
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
691778

692-
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
693-
694-
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
695-
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
696-
const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
697-
698-
const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
699-
const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
779+
const int kbc0 = int64_t(bidx0 + 0)*total_work / gridDim.x;
780+
const int kbc0_stop = int64_t(bidx0 + 1)*total_work / gridDim.x;
700781

701782
const bool did_not_have_any_data = kbc0 == kbc0_stop;
702-
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
703-
const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;
783+
const bool wrote_beginning_of_tile = fastmodulo(kbc0, fd_iter_k) == 0;
784+
const bool did_not_write_last = fastdiv(kbc0, fd_iter_k) == fastdiv(kbc0_stop, fd_iter_k) && fastmodulo(kbc0_stop, fd_iter_k) != 0;
704785
if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
705786
return;
706787
}
707788

708789
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
709-
const int sequence = kbc0 /(iter_k*iter_j*iter_z_gqa*ne12);
710-
const int z_KV = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
711-
const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
712-
const int jt = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
790+
const uint2 dm0 = fast_div_modulo(kbc0, fd_iter_k_j_z_ne12);
791+
const uint2 dm1 = fast_div_modulo(dm0.y, fd_iter_k_j_z);
792+
const uint2 dm2 = fast_div_modulo(dm1.y, fd_iter_k_j);
793+
const uint2 dm3 = fast_div_modulo(dm2.y, fd_iter_k);
794+
795+
const int sequence = dm0.x;
796+
const int z_KV = dm1.x;
797+
const int zt_gqa = dm2.x;
798+
const int jt = dm3.x;
713799

714800
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
715801

@@ -733,10 +819,11 @@ static __global__ void flash_attn_stream_k_fixup(
733819

734820
// Iterate over previous blocks and compute the combined results.
735821
// All CUDA blocks that get here must have a previous block that needs a fixup.
822+
const int tile_kbc0 = fastdiv(kbc0, fd_iter_k);
736823
int bidx = bidx0 - 1;
737824
int kbc_stop = kbc0;
738825
while(true) {
739-
const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
826+
const int kbc = int64_t(bidx)*total_work / gridDim.x;
740827
if (kbc == kbc_stop) { // Did not have any data.
741828
bidx--;
742829
kbc_stop = kbc;
@@ -762,7 +849,7 @@ static __global__ void flash_attn_stream_k_fixup(
762849
max_val = max_val_new;
763850

764851
// If this block started in a previous tile we are done and don't need to combine additional partial results.
765-
if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
852+
if (fastmodulo(kbc, fd_iter_k) == 0 || fastdiv(kbc, fd_iter_k) < tile_kbc0) {
766853
break;
767854
}
768855
bidx--;
@@ -976,14 +1063,28 @@ void launch_fattn(
9761063
const int tiles_nwaves = (ntiles_dst + max_blocks - 1) / max_blocks;
9771064
const int tiles_efficiency_percent = 100 * ntiles_dst / (max_blocks*tiles_nwaves);
9781065

979-
const int nblocks_stream_k = std::min(max_blocks, ntiles_KV*ntiles_dst);
980-
9811066
const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75;
9821067

983-
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_dst;
1068+
blocks_num.x = ntiles_dst;
9841069
blocks_num.y = 1;
9851070
blocks_num.z = 1;
9861071

1072+
if(use_stream_k) {
1073+
const int nblocks_stream_k_raw = std::min(max_blocks, ntiles_KV*ntiles_dst);
1074+
// Round down to a multiple of ntiles_dst so that each output tile gets the same number of blocks (avoids fixup).
1075+
// Only do this if the occupancy loss from rounding is acceptable.
1076+
const int nblocks_stream_k_rounded = (nblocks_stream_k_raw / ntiles_dst) * ntiles_dst;
1077+
const int max_efficiency_loss_percent = 5;
1078+
const int efficiency_loss_percent = nblocks_stream_k_rounded > 0
1079+
? 100 * (nblocks_stream_k_raw - nblocks_stream_k_rounded) / nblocks_stream_k_raw
1080+
: 100;
1081+
const int nblocks_stream_k = efficiency_loss_percent <= max_efficiency_loss_percent
1082+
? nblocks_stream_k_rounded
1083+
: nblocks_stream_k_raw;
1084+
1085+
blocks_num.x = nblocks_stream_k;
1086+
}
1087+
9871088
if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
9881089
dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));
9891090
}
@@ -1063,13 +1164,40 @@ void launch_fattn(
10631164
CUDA_CHECK(cudaGetLastError());
10641165

10651166
if (stream_k) {
1066-
if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
1167+
if ((int)blocks_num.x % ntiles_dst == 0 && (int)blocks_num.x > ntiles_dst) {
1168+
// Optimized fixup: nblocks_stream_k is a multiple of ntiles_dst, launch one block per tile.
1169+
const int nblocks_sk = (int)blocks_num.x;
1170+
const int bpt = nblocks_sk / ntiles_dst;
1171+
1172+
const uint3 fd0 = init_fastdiv_values(ntiles_x * ntiles_z_gqa * K->ne[2]);
1173+
const uint3 fd1 = init_fastdiv_values(ntiles_x * ntiles_z_gqa);
1174+
const uint3 fd2 = init_fastdiv_values(ntiles_x);
1175+
1176+
const dim3 block_dim_combine(DV, 1, 1);
1177+
const dim3 blocks_num_combine = {(unsigned)ntiles_dst, ncols1, ncols2};
1178+
1179+
flash_attn_stream_k_fixup_uniform<DV, ncols1, ncols2>
1180+
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
1181+
((float *) KQV->data, dst_tmp_meta.ptr,
1182+
Q->ne[1], Q->ne[2], K->ne[2], nblocks_sk,
1183+
gqa_ratio, bpt, fd0, fd1, fd2);
1184+
} else if (ntiles_dst % blocks_num.x != 0) {
1185+
// General fixup for the cases where nblocks_stream_k < ntiles_dst.
1186+
const int total_work = ntiles_KV * ntiles_dst;
1187+
1188+
const uint3 fd_k_j_z_ne12 = init_fastdiv_values(ntiles_KV * ntiles_x * ntiles_z_gqa * K->ne[2]);
1189+
const uint3 fd_k_j_z = init_fastdiv_values(ntiles_KV * ntiles_x * ntiles_z_gqa);
1190+
const uint3 fd_k_j = init_fastdiv_values(ntiles_KV * ntiles_x);
1191+
const uint3 fd_k = init_fastdiv_values(ntiles_KV);
1192+
10671193
const dim3 block_dim_combine(DV, 1, 1);
10681194
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
10691195

1070-
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
1196+
flash_attn_stream_k_fixup_general<DV, ncols1, ncols2>
10711197
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
1072-
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa);
1198+
((float *) KQV->data, dst_tmp_meta.ptr,
1199+
Q->ne[1], Q->ne[2], gqa_ratio, total_work,
1200+
fd_k_j_z_ne12, fd_k_j_z, fd_k_j, fd_k);
10731201
}
10741202
} else if (parallel_blocks > 1) {
10751203
const dim3 block_dim_combine(DV, 1, 1);

0 commit comments

Comments
 (0)