@@ -940,6 +940,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
940940 const int stride_V,
941941 const int stride_mask,
942942 const int jt,
943+ const int zt,
943944 const int kb0_start,
944945 const int kb0_stop) {
945946#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
@@ -1022,7 +1023,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
10221023 const int j = jc / ncols2;
10231024 const int c = jc % ncols2;
10241025
1025- if (jt*ncols1 + j < int (ne01.z )) {
1026+ if ((ncols1 == 1 || jt*ncols1 + j < int (ne01.z )) && (ncols2 == 1 || zt*ncols2 + c < ne02 )) {
10261027#pragma unroll
10271028 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
10281029 const int k = k0 + (stride_k == WARP_SIZE ? threadIdx .x : threadIdx .x % stride_k);
@@ -1408,7 +1409,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
14081409 const int j_dst = jc_dst / ncols2;
14091410 const int c_dst = jc_dst % ncols2;
14101411
1411- if (!is_fixup && jt*ncols1 + j_dst >= int (ne01.z )) {
1412+ if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int (ne01.z )) || (ncols2 > 1 && zt*ncols2 + c_dst >= ne02) )) {
14121413 continue ;
14131414 }
14141415
@@ -1522,10 +1523,11 @@ static __global__ void flash_attn_ext_f16(
15221523
15231524 const int iter_k = (ne11 + (nbatch_fa - 1 )) / nbatch_fa;
15241525 const int iter_j = (ne01.z + (ncols1 - 1 )) / ncols1;
1526+ const int iter_z = (ne02 + (ncols2 - 1 )) / ncols2;
15251527
15261528 // kbc == k block continuous, current index in continuous ijk space.
1527- int kbc = int64_t (blockIdx .x + 0 )*(iter_k*iter_j*(ne02/ncols2) *ne03) / gridDim .x ;
1528- const int kbc_stop = int64_t (blockIdx .x + 1 )*(iter_k*iter_j*(ne02/ncols2) *ne03) / gridDim .x ;
1529+ int kbc = int64_t (blockIdx .x + 0 )*(iter_k*iter_j*iter_z *ne03) / gridDim .x ;
1530+ const int kbc_stop = int64_t (blockIdx .x + 1 )*(iter_k*iter_j*iter_z *ne03) / gridDim .x ;
15291531
15301532 // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
15311533 // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@@ -1536,9 +1538,9 @@ static __global__ void flash_attn_ext_f16(
15361538 int kb0_stop = min (iter_k, kb0_start + kbc_stop - kbc);
15371539
15381540 while (kbc < kbc_stop && kb0_stop == iter_k) {
1539- const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2) );
1540- const int zt = (kbc - iter_k*iter_j*(ne02/ncols2) *sequence) / (iter_k*iter_j); // head in units of ncols2
1541- const int jt = (kbc - iter_k*iter_j*(ne02/ncols2) *sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
1541+ const int sequence = kbc / (iter_k*iter_j*iter_z );
1542+ const int zt = (kbc - iter_k*iter_j*iter_z *sequence) / (iter_k*iter_j); // head in units of ncols2
1543+ const int jt = (kbc - iter_k*iter_j*iter_z *sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
15421544
15431545 const int head0 = zt * ncols2;
15441546
@@ -1561,12 +1563,12 @@ static __global__ void flash_attn_ext_f16(
15611563 constexpr bool needs_fixup = false ; // CUDA block is working on an entire tile.
15621564 flash_attn_ext_f16_process_tile<DKQ , DV , ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
15631565 (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1564- ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
1566+ ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop);
15651567 } else {
15661568 constexpr bool needs_fixup = true ; // CUDA block is missing the beginning of a tile.
15671569 flash_attn_ext_f16_process_tile<DKQ , DV , ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
15681570 (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1569- ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
1571+ ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop);
15701572 }
15711573
15721574 kbc += iter_k;
@@ -1580,9 +1582,9 @@ static __global__ void flash_attn_ext_f16(
15801582 return ;
15811583 }
15821584
1583- const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2) );
1584- const int zt = (kbc - iter_k*iter_j*(ne02/ncols2) *sequence) / (iter_k*iter_j); // head in units of ncols2
1585- const int jt = (kbc - iter_k*iter_j*(ne02/ncols2) *sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
1585+ const int sequence = kbc / (iter_k*iter_j*iter_z );
1586+ const int zt = (kbc - iter_k*iter_j*iter_z *sequence) / (iter_k*iter_j); // head in units of ncols2
1587+ const int jt = (kbc - iter_k*iter_j*iter_z *sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
15861588
15871589 const int head0 = zt * ncols2;
15881590
@@ -1605,7 +1607,7 @@ static __global__ void flash_attn_ext_f16(
16051607 constexpr bool needs_fixup = false ;
16061608 flash_attn_ext_f16_process_tile<DKQ , DV , ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
16071609 (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1608- ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
1610+ ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop);
16091611#else
16101612 GGML_UNUSED_VARS (Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
16111613 max_bias, m0, m1, n_head_log2, logit_softcap,
@@ -1739,3 +1741,5 @@ extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
17391741extern DECL_FATTN_MMA_F16_CASE (576 , 512 , 4 , 4 );
17401742extern DECL_FATTN_MMA_F16_CASE (576 , 512 , 8 , 4 );
17411743extern DECL_FATTN_MMA_F16_CASE (576 , 512 , 16 , 4 );
1744+ extern DECL_FATTN_MMA_F16_CASE (576 , 512 , 1 , 32 );
1745+ extern DECL_FATTN_MMA_F16_CASE (576 , 512 , 2 , 32 );
0 commit comments