@@ -305,12 +305,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
305305 const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) {
306306 constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
307307 // K/V data is loaded with decreasing granularity for D for better memory bandwidth.
308- // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
308+ // The minimum granularity is 16 bytes.
309+ constexpr int h2_per_chunk = 16 /sizeof (half2);
310+ const int chunks_per_row = D2 / h2_per_chunk;
309311 if constexpr (use_cp_async) {
312+ static_assert (warp_size == 32 , " bad warp_size" );
310313 static_assert (!oob_check, " OOB check not compatible with cp_async" );
311314 constexpr int preload = 64 ;
312- constexpr int h2_per_chunk = 16 /sizeof (half2);
313- const int chunks_per_row = D2 / h2_per_chunk;
314315
315316 const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared (tile_KV);
316317
@@ -348,11 +349,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
348349 // 6: max 1*16= 16 bytes, 8 half
349350 ggml_cuda_unroll<6 >{}(load);
350351 } else {
351- // TODO use ggml_cuda_memcpy_1
352+ const half2 zero[ 4 ] = {{ 0 . 0f , 0 . 0f }, { 0 . 0f , 0 . 0f }, { 0 . 0f , 0 . 0f }, { 0 . 0f , 0 . 0f }};
352353 auto load = [&] __device__ (const int n) {
353- const int stride_k = warp_size >> n;
354- const int k0_start = stride_k == warp_size ? 0 : D2 - D2 % (2 *stride_k);
355- const int k0_stop = D2 - D2 % (1 *stride_k);
354+ const int stride_k = 32 >> n;
355+ const int k0_start = stride_k == 32 ? 0 : chunks_per_row - chunks_per_row % (2 *stride_k);
356+ const int k0_stop = chunks_per_row - chunks_per_row % (1 *stride_k);
356357 const int stride_i = warp_size / stride_k;
357358
358359 if (k0_start == k0_stop) {
@@ -371,15 +372,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
371372 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
372373 const int k = k0 + (stride_k == warp_size ? threadIdx .x : threadIdx .x % stride_k);
373374
374- tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2 (0 .0f , 0 .0f );
375+ ggml_cuda_memcpy_1<16 >(tile_KV + i*stride_tile + k*4 ,
376+ !oob_check || i < i_sup ? KV + i*stride_KV + k*h2_per_chunk : zero);
375377 }
376378 }
377379 };
378- // 1: max 32* 4=128 bytes, 64 half
379- // 2: max 16* 4= 64 bytes, 32 half
380- // 3: max 8* 4= 32 bytes, 16 half
381- // 4: max 4* 4= 16 bytes, 8 half
382- ggml_cuda_unroll<4 >{}(load);
380+ // 1: max 32*16=512 bytes, 256 half
381+ // 2: max 16*16=256 bytes, 128 half
382+ // 3: max 8*16=128 bytes, 64 half
383+ // 4: max 4*16= 64 bytes, 32 half
384+ // 5: max 2*16= 32 bytes, 16 half
385+ // 6: max 1*16= 16 bytes, 8 half
386+ ggml_cuda_unroll<6 >{}(load);
383387 }
384388}
385389
@@ -862,11 +866,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
862866 }
863867
864868
865- #if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
866- T_A_VKQ A_identity;
867- make_identity_mat (A_identity);
868- #endif // defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
869-
870869 // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:
871870#pragma unroll
872871 for (int i0_start = 0 ; i0_start < DV; i0_start += 2 *nbatch_V2) {
@@ -897,29 +896,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
897896 const int k0 = k00 + (threadIdx .y % np)*T_A_VKQ::J;
898897
899898 T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
900- #if defined(LDMATRIX_TRANS_AVAILABLE)
901899 load_ldmatrix_trans (A, tile_V_i + 2 *k0*stride_tile_V + (i_VKQ_0 - i0_start)/2 , stride_tile_V);
902- #elif defined(AMD_MFMA_AVAILABLE)
903- // MFMA A register layout: A_mat[i=lane%16][k=4*(lane/16)+reg].
904- // Normal load gives A_mat[seq][dv] but we need A_mat[dv][seq] = V^T.
905- // Load with transposed addressing: 4 strided half loads.
906- {
907- const half2 * xs0 = tile_V_i + 2 *k0*stride_tile_V + (i_VKQ_0 - i0_start)/2 ;
908- const half * xs0_h = (const half *) xs0;
909- const int stride_h = stride_tile_V * 2 ; // stride in half units
910- half * A_h = (half *) A.x ;
911- #pragma unroll
912- for (int l = 0 ; l < 4 ; ++l) {
913- A_h[l] = xs0_h[(4 *(threadIdx .x / 16 ) + l) * stride_h + threadIdx .x % 16 ];
914- }
915- }
916- #else
917- // TODO: Try to transpose tile_V when loading gmem to smem.
918- // Use mma to transpose T_A_VKQ for RDNA.
919- T_A_VKQ A_trans;
920- load_ldmatrix (A_trans, tile_V_i + 2 *k0*stride_tile_V + (i_VKQ_0 - i0_start)/2 , stride_tile_V);
921- mma (A, A_trans, A_identity);
922- #endif // defined(LDMATRIX_TRANS_AVAILABLE)
923900 if constexpr (T_B_KQ::I == 8 ) {
924901 mma (VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
925902 } else {
0 commit comments