@@ -676,9 +676,96 @@ static __global__ void flash_attn_mask_to_KV_max(
676676
677677template <int D, int ncols1, int ncols2> // D == head size
678678__launch_bounds__ (D, 1 )
679- static __global__ void flash_attn_stream_k_fixup(
680- float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03,
681- const int ne11, const int ne12, const int nbatch_fa) {
679+ static __global__ void flash_attn_stream_k_fixup_uniform(
680+ float * __restrict__ dst,
681+ const float2 * __restrict__ dst_fixup,
682+ const int ne01, const int ne02,
683+ const int ne12, const int nblocks_stream_k,
684+ const int gqa_ratio,
685+ const int blocks_per_tile,
686+ const uint3 fd_iter_j_z_ne12,
687+ const uint3 fd_iter_j_z,
688+ const uint3 fd_iter_j) {
689+ constexpr int ncols = ncols1*ncols2;
690+
691+ const int tile_idx = blockIdx .x ; // One block per output tile.
692+ const int j = blockIdx .y ;
693+ const int c = blockIdx .z ;
694+ const int jc = j*ncols2 + c;
695+ const int tid = threadIdx .x ;
696+
697+ // nblocks_stream_k is a multiple of ntiles_dst (== gridDim.x), so each tile gets the same number of blocks.
698+ const int b_first = tile_idx * blocks_per_tile;
699+ const int b_last = b_first + blocks_per_tile - 1 ;
700+
701+ const float * dst_fixup_data = ((const float *) dst_fixup) + nblocks_stream_k*(2 *2 *ncols);
702+
703+ // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
704+ const uint2 dm0 = fast_div_modulo (tile_idx, fd_iter_j_z_ne12);
705+ const uint2 dm1 = fast_div_modulo (dm0.y , fd_iter_j_z);
706+ const uint2 dm2 = fast_div_modulo (dm1.y , fd_iter_j);
707+
708+ const int sequence = dm0.x ;
709+ const int z_KV = dm1.x ;
710+ const int zt_gqa = dm2.x ;
711+ const int jt = dm2.y ;
712+
713+ const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
714+
715+ if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {
716+ return ;
717+ }
718+
719+ dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;
720+
721+ // Load the partial result that needs a fixup
722+ float dst_val = *dst;
723+ float max_val;
724+ float rowsum;
725+ {
726+ const float2 tmp = dst_fixup[b_last*ncols + jc];
727+ max_val = tmp.x ;
728+ rowsum = tmp.y ;
729+ }
730+
731+ // Combine with all previous blocks in this tile.
732+ for (int bidx = b_last - 1 ; bidx >= b_first; --bidx) {
733+ const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
734+
735+ const float2 tmp = dst_fixup[(nblocks_stream_k + bidx)*ncols + jc];
736+
737+ const float max_val_new = fmaxf (max_val, tmp.x );
738+
739+ const float diff_val = max_val - max_val_new;
740+ const float diff_add = tmp.x - max_val_new;
741+
742+ const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf (diff_val) : 0 .0f ;
743+ const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf (diff_add) : 0 .0f ;
744+
745+ dst_val = scale_val*dst_val + scale_add*dst_add;
746+ rowsum = scale_val*rowsum + scale_add*tmp.y ;
747+
748+ max_val = max_val_new;
749+ }
750+
751+ // Write back final result:
752+ *dst = dst_val / rowsum;
753+ }
754+
755+ // General fixup kernel for the case where the number of blocks per tile is not uniform across tiles
756+ // (blocks_num.x not a multiple of ntiles_dst)
757+ template <int D, int ncols1, int ncols2> // D == head size
758+ __launch_bounds__ (D, 1 )
759+ static __global__ void flash_attn_stream_k_fixup_general(
760+ float * __restrict__ dst,
761+ const float2 * __restrict__ dst_fixup,
762+ const int ne01, const int ne02,
763+ const int gqa_ratio,
764+ const int total_work,
765+ const uint3 fd_iter_k_j_z_ne12,
766+ const uint3 fd_iter_k_j_z,
767+ const uint3 fd_iter_k_j,
768+ const uint3 fd_iter_k) {
682769 constexpr int ncols = ncols1*ncols2;
683770
684771 const int bidx0 = blockIdx .x ;
@@ -689,27 +776,26 @@ static __global__ void flash_attn_stream_k_fixup(
689776
690777 const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim .x *(2 *2 *ncols);
691778
692- const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
693-
694- const int iter_k = (ne11 + (nbatch_fa - 1 )) / nbatch_fa;
695- const int iter_j = (ne01 + (ncols1 - 1 )) / ncols1;
696- const int iter_z_gqa = (gqa_ratio + (ncols2 - 1 )) / ncols2;
697-
698- const int kbc0 = int64_t (bidx0 + 0 )*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim .x ;
699- const int kbc0_stop = int64_t (bidx0 + 1 )*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim .x ;
779+ const int kbc0 = int64_t (bidx0 + 0 )*total_work / gridDim .x ;
780+ const int kbc0_stop = int64_t (bidx0 + 1 )*total_work / gridDim .x ;
700781
701782 const bool did_not_have_any_data = kbc0 == kbc0_stop;
702- const bool wrote_beginning_of_tile = kbc0 % iter_k == 0 ;
703- const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0 ;
783+ const bool wrote_beginning_of_tile = fastmodulo ( kbc0, fd_iter_k) == 0 ;
784+ const bool did_not_write_last = fastdiv ( kbc0, fd_iter_k) == fastdiv ( kbc0_stop, fd_iter_k) && fastmodulo ( kbc0_stop, fd_iter_k) != 0 ;
704785 if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
705786 return ;
706787 }
707788
708789 // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
709- const int sequence = kbc0 /(iter_k*iter_j*iter_z_gqa*ne12);
710- const int z_KV = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
711- const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
712- const int jt = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
790+ const uint2 dm0 = fast_div_modulo (kbc0, fd_iter_k_j_z_ne12);
791+ const uint2 dm1 = fast_div_modulo (dm0.y , fd_iter_k_j_z);
792+ const uint2 dm2 = fast_div_modulo (dm1.y , fd_iter_k_j);
793+ const uint2 dm3 = fast_div_modulo (dm2.y , fd_iter_k);
794+
795+ const int sequence = dm0.x ;
796+ const int z_KV = dm1.x ;
797+ const int zt_gqa = dm2.x ;
798+ const int jt = dm3.x ;
713799
714800 const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
715801
@@ -733,10 +819,11 @@ static __global__ void flash_attn_stream_k_fixup(
733819
734820 // Iterate over previous blocks and compute the combined results.
735821 // All CUDA blocks that get here must have a previous block that needs a fixup.
822+ const int tile_kbc0 = fastdiv (kbc0, fd_iter_k);
736823 int bidx = bidx0 - 1 ;
737824 int kbc_stop = kbc0;
738825 while (true ) {
739- const int kbc = int64_t (bidx)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim .x ;
826+ const int kbc = int64_t (bidx)*total_work / gridDim .x ;
740827 if (kbc == kbc_stop) { // Did not have any data.
741828 bidx--;
742829 kbc_stop = kbc;
@@ -762,7 +849,7 @@ static __global__ void flash_attn_stream_k_fixup(
762849 max_val = max_val_new;
763850
764851 // If this block started in a previous tile we are done and don't need to combine additional partial results.
765- if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k ) {
852+ if (fastmodulo ( kbc, fd_iter_k) == 0 || fastdiv ( kbc, fd_iter_k) < tile_kbc0 ) {
766853 break ;
767854 }
768855 bidx--;
@@ -976,14 +1063,28 @@ void launch_fattn(
9761063 const int tiles_nwaves = (ntiles_dst + max_blocks - 1 ) / max_blocks;
9771064 const int tiles_efficiency_percent = 100 * ntiles_dst / (max_blocks*tiles_nwaves);
9781065
979- const int nblocks_stream_k = std::min (max_blocks, ntiles_KV*ntiles_dst);
980-
9811066 const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available (cc) || tiles_efficiency_percent < 75 ;
9821067
983- blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_dst;
1068+ blocks_num.x = ntiles_dst;
9841069 blocks_num.y = 1 ;
9851070 blocks_num.z = 1 ;
9861071
1072+ if (use_stream_k) {
1073+ const int nblocks_stream_k_raw = std::min (max_blocks, ntiles_KV*ntiles_dst);
1074+ // Round down to a multiple of ntiles_dst so that each output tile gets the same number of blocks (avoids fixup).
1075+ // Only do this if the occupancy loss from rounding is acceptable.
1076+ const int nblocks_stream_k_rounded = (nblocks_stream_k_raw / ntiles_dst) * ntiles_dst;
1077+ const int max_efficiency_loss_percent = 5 ;
1078+ const int efficiency_loss_percent = nblocks_stream_k_rounded > 0
1079+ ? 100 * (nblocks_stream_k_raw - nblocks_stream_k_rounded) / nblocks_stream_k_raw
1080+ : 100 ;
1081+ const int nblocks_stream_k = efficiency_loss_percent <= max_efficiency_loss_percent
1082+ ? nblocks_stream_k_rounded
1083+ : nblocks_stream_k_raw;
1084+
1085+ blocks_num.x = nblocks_stream_k;
1086+ }
1087+
9871088 if (ntiles_dst % blocks_num.x != 0 ) { // Fixup is only needed if the SMs work on fractional tiles.
9881089 dst_tmp_meta.alloc ((size_t (blocks_num.x ) * ncols * (2 + DV/2 )));
9891090 }
@@ -1063,13 +1164,40 @@ void launch_fattn(
10631164 CUDA_CHECK (cudaGetLastError ());
10641165
10651166 if (stream_k) {
1066- if (ntiles_dst % blocks_num.x != 0 ) { // Fixup is only needed if the SMs work on fractional tiles.
1167+ if ((int )blocks_num.x % ntiles_dst == 0 && (int )blocks_num.x > ntiles_dst) {
1168+ // Optimized fixup: nblocks_stream_k is a multiple of ntiles_dst, launch one block per tile.
1169+ const int nblocks_sk = (int )blocks_num.x ;
1170+ const int bpt = nblocks_sk / ntiles_dst;
1171+
1172+ const uint3 fd0 = init_fastdiv_values (ntiles_x * ntiles_z_gqa * K->ne [2 ]);
1173+ const uint3 fd1 = init_fastdiv_values (ntiles_x * ntiles_z_gqa);
1174+ const uint3 fd2 = init_fastdiv_values (ntiles_x);
1175+
1176+ const dim3 block_dim_combine (DV, 1 , 1 );
1177+ const dim3 blocks_num_combine = {(unsigned )ntiles_dst, ncols1, ncols2};
1178+
1179+ flash_attn_stream_k_fixup_uniform<DV, ncols1, ncols2>
1180+ <<<blocks_num_combine, block_dim_combine, 0 , main_stream>>>
1181+ ((float *) KQV->data , dst_tmp_meta.ptr ,
1182+ Q->ne [1 ], Q->ne [2 ], K->ne [2 ], nblocks_sk,
1183+ gqa_ratio, bpt, fd0, fd1, fd2);
1184+ } else if (ntiles_dst % blocks_num.x != 0 ) {
1185+ // General fixup for the cases where nblocks_stream_k < ntiles_dst.
1186+ const int total_work = ntiles_KV * ntiles_dst;
1187+
1188+ const uint3 fd_k_j_z_ne12 = init_fastdiv_values (ntiles_KV * ntiles_x * ntiles_z_gqa * K->ne [2 ]);
1189+ const uint3 fd_k_j_z = init_fastdiv_values (ntiles_KV * ntiles_x * ntiles_z_gqa);
1190+ const uint3 fd_k_j = init_fastdiv_values (ntiles_KV * ntiles_x);
1191+ const uint3 fd_k = init_fastdiv_values (ntiles_KV);
1192+
10671193 const dim3 block_dim_combine (DV, 1 , 1 );
10681194 const dim3 blocks_num_combine = {blocks_num.x , ncols1, ncols2};
10691195
1070- flash_attn_stream_k_fixup <DV, ncols1, ncols2>
1196+ flash_attn_stream_k_fixup_general <DV, ncols1, ncols2>
10711197 <<<blocks_num_combine, block_dim_combine, 0 , main_stream>>>
1072- ((float *) KQV->data , dst_tmp_meta.ptr , Q->ne [1 ], Q->ne [2 ], Q->ne [3 ], K->ne [1 ], K->ne [2 ], nbatch_fa);
1198+ ((float *) KQV->data , dst_tmp_meta.ptr ,
1199+ Q->ne [1 ], Q->ne [2 ], gqa_ratio, total_work,
1200+ fd_k_j_z_ne12, fd_k_j_z, fd_k_j, fd_k);
10731201 }
10741202 } else if (parallel_blocks > 1 ) {
10751203 const dim3 block_dim_combine (DV, 1 , 1 );
0 commit comments