Skip to content

Commit 4eac5b4

Browse files
CUDA: refactor mma data loading for AMD (#22051)
* CUDA: refactor mma data loading for AMD * fix CDNA MMQ occupancy * fix CDNA3 mma * fix RDNA3 compile
1 parent d5b780a commit 4eac5b4

4 files changed

Lines changed: 112 additions & 395 deletions

File tree

ggml/src/ggml-cuda/common.cuh

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -269,10 +269,6 @@ static const char * cu_get_error_str(CUresult err) {
269269
#define FLASH_ATTN_AVAILABLE
270270
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
271271

272-
#if defined(TURING_MMA_AVAILABLE)
273-
#define LDMATRIX_TRANS_AVAILABLE
274-
#endif // defined(TURING_MMA_AVAILABLE)
275-
276272
static bool fp16_available(const int cc) {
277273
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL ||
278274
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 17 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -305,12 +305,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
305305
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) {
306306
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
307307
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
308-
// The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
308+
// The minimum granularity is 16 bytes.
309+
constexpr int h2_per_chunk = 16/sizeof(half2);
310+
const int chunks_per_row = D2 / h2_per_chunk;
309311
if constexpr (use_cp_async) {
312+
static_assert(warp_size == 32, "bad warp_size");
310313
static_assert(!oob_check, "OOB check not compatible with cp_async");
311314
constexpr int preload = 64;
312-
constexpr int h2_per_chunk = 16/sizeof(half2);
313-
const int chunks_per_row = D2 / h2_per_chunk;
314315

315316
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
316317

@@ -348,11 +349,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
348349
// 6: max 1*16= 16 bytes, 8 half
349350
ggml_cuda_unroll<6>{}(load);
350351
} else {
351-
// TODO use ggml_cuda_memcpy_1
352+
const half2 zero[4] = {{0.0f, 0.0f}, {0.0f, 0.0f}, {0.0f, 0.0f}, {0.0f, 0.0f}};
352353
auto load = [&] __device__ (const int n) {
353-
const int stride_k = warp_size >> n;
354-
const int k0_start = stride_k == warp_size ? 0 : D2 - D2 % (2*stride_k);
355-
const int k0_stop = D2 - D2 % (1*stride_k);
354+
const int stride_k = 32 >> n;
355+
const int k0_start = stride_k == 32 ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
356+
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
356357
const int stride_i = warp_size / stride_k;
357358

358359
if (k0_start == k0_stop) {
@@ -371,15 +372,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
371372
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
372373
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
373374

374-
tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f);
375+
ggml_cuda_memcpy_1<16>(tile_KV + i*stride_tile + k*4,
376+
!oob_check || i < i_sup ? KV + i*stride_KV + k*h2_per_chunk : zero);
375377
}
376378
}
377379
};
378-
// 1: max 32* 4=128 bytes, 64 half
379-
// 2: max 16* 4= 64 bytes, 32 half
380-
// 3: max 8* 4= 32 bytes, 16 half
381-
// 4: max 4* 4= 16 bytes, 8 half
382-
ggml_cuda_unroll<4>{}(load);
380+
// 1: max 32*16=512 bytes, 256 half
381+
// 2: max 16*16=256 bytes, 128 half
382+
// 3: max 8*16=128 bytes, 64 half
383+
// 4: max 4*16= 64 bytes, 32 half
384+
// 5: max 2*16= 32 bytes, 16 half
385+
// 6: max 1*16= 16 bytes, 8 half
386+
ggml_cuda_unroll<6>{}(load);
383387
}
384388
}
385389

@@ -862,11 +866,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
862866
}
863867

864868

865-
#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
866-
T_A_VKQ A_identity;
867-
make_identity_mat(A_identity);
868-
#endif // defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
869-
870869
// Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:
871870
#pragma unroll
872871
for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) {
@@ -897,29 +896,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
897896
const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J;
898897

899898
T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
900-
#if defined(LDMATRIX_TRANS_AVAILABLE)
901899
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
902-
#elif defined(AMD_MFMA_AVAILABLE)
903-
// MFMA A register layout: A_mat[i=lane%16][k=4*(lane/16)+reg].
904-
// Normal load gives A_mat[seq][dv] but we need A_mat[dv][seq] = V^T.
905-
// Load with transposed addressing: 4 strided half loads.
906-
{
907-
const half2 * xs0 = tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2;
908-
const half * xs0_h = (const half *) xs0;
909-
const int stride_h = stride_tile_V * 2; // stride in half units
910-
half * A_h = (half *) A.x;
911-
#pragma unroll
912-
for (int l = 0; l < 4; ++l) {
913-
A_h[l] = xs0_h[(4*(threadIdx.x / 16) + l) * stride_h + threadIdx.x % 16];
914-
}
915-
}
916-
#else
917-
// TODO: Try to transpose tile_V when loading gmem to smem.
918-
// Use mma to transpose T_A_VKQ for RDNA.
919-
T_A_VKQ A_trans;
920-
load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
921-
mma(A, A_trans, A_identity);
922-
#endif // defined(LDMATRIX_TRANS_AVAILABLE)
923900
if constexpr (T_B_KQ::I == 8) {
924901
mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
925902
} else {

0 commit comments

Comments
 (0)