Skip to content

Commit 62278ce

Browse files
authored
sycl : enhance fattn perf (#21185)
1 parent 90aa83c commit 62278ce

1 file changed

Lines changed: 43 additions & 40 deletions

File tree

ggml/src/ggml-sycl/fattn-tile.hpp

Lines changed: 43 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp16(const int DKQ, co
7070
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
7171
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
7272
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
73+
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 64, 64)
7374

7475
return 0;
7576
}
@@ -310,11 +311,11 @@ static __dpct_inline__ void flash_attn_tile_load_tile(const sycl::half2 * const
310311
sycl::half2 * const __restrict__ tile_KV,
311312
const int stride_KV,
312313
const int i_sup) {
314+
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
313315
constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
314316
constexpr int cpy_ne = cpy_nb / 4;
315317

316318
auto load = [&] (const int n) {
317-
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
318319
const int stride_j = warp_size >> n;
319320

320321
if (stride_j == 0) {
@@ -455,7 +456,7 @@ static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp,
455456

456457
flash_attn_tile_load_tile<warp_size, nwarps, nbatch_fa, nbatch_K, cpy_ne, oob_check>
457458
(K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup);
458-
item_ct1.barrier();
459+
item_ct1.barrier(sycl::access::fence_space::local_space);
459460

460461
#ifdef SYCL_FAST_FP16
461462
static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K");
@@ -505,7 +506,7 @@ static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp,
505506
}
506507

507508
if (k_KQ_0 + nbatch_K < DKQ) {
508-
item_ct1.barrier(); // Sync not needed on last iteration.
509+
item_ct1.barrier(sycl::access::fence_space::local_space); // Sync not needed on last iteration.
509510
}
510511
}
511512

@@ -545,7 +546,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
545546
const int k_VKQ_max,
546547
const int col_Q_0,
547548
float * KQ_max_new_shared) {
548-
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
549+
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
549550
constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
550551
constexpr int cpy_ne = cpy_nb / 4;
551552

@@ -620,14 +621,14 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
620621
}
621622

622623
if constexpr (np == 1) {
623-
item_ct1.barrier();
624+
item_ct1.barrier(sycl::access::fence_space::local_space);
624625
} else {
625626
static_assert(cpw == 1, "bad cpw");
626627

627628
if (item_ct1.get_local_id(2) == 0) {
628629
KQ_max_new_shared[item_ct1.get_local_id(1)] = KQ_max_new[0];
629630
}
630-
item_ct1.barrier();
631+
item_ct1.barrier(sycl::access::fence_space::local_space);
631632
KQ_max_new[0] = KQ_max_new_shared[(item_ct1.get_local_id(1) & ~(np - 1)) + item_ct1.get_local_id(2) % np];
632633
KQ_max_new[0] = warp_reduce_max<np>(KQ_max_new[0]);
633634
}
@@ -697,7 +698,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
697698
for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) {
698699
flash_attn_tile_load_tile<warp_size, nwarps, nbatch_V, DV, 0, oob_check>
699700
(V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0);
700-
item_ct1.barrier();
701+
item_ct1.barrier(sycl::access::fence_space::local_space);
701702

702703
#ifdef SYCL_FAST_FP16
703704
#pragma unroll
@@ -765,7 +766,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
765766
}
766767
}
767768
#endif // SYCL_FAST_FP16
768-
item_ct1.barrier();
769+
item_ct1.barrier(sycl::access::fence_space::local_space);
769770
}
770771
}
771772

@@ -972,7 +973,7 @@ static void flash_attn_tile(const char * Q,
972973
}
973974
}
974975

975-
item_ct1.barrier();
976+
item_ct1.barrier(sycl::access::fence_space::local_space);
976977

977978
// Main loop over KV cache:
978979
const int k_VKQ_max = KV_max ? KV_max[sequence * item_ct1.get_group_range(2) + item_ct1.get_group(2)] : ne11;
@@ -1051,7 +1052,7 @@ static void flash_attn_tile(const char * Q,
10511052
return;
10521053
}
10531054

1054-
item_ct1.barrier();
1055+
item_ct1.barrier(sycl::access::fence_space::local_space);
10551056

10561057
#pragma unroll
10571058
for (int ip = 1; ip < np; ++ip) {
@@ -1193,37 +1194,39 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_sycl_context & ctx, ggm
11931194

11941195
constexpr size_t nbytes_shared = 0;
11951196

1196-
if constexpr (DV <= 256) {
1197-
if (Q->ne[1] > 16/ncols2) {
1198-
constexpr int cols_per_block = 32;
1199-
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1200-
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
1201-
launch_fattn<DV, cols_per_block/ncols2, ncols2,
1202-
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
1203-
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
1204-
return;
1197+
if (DV < 512 && Q->ne[1] < 32) {
1198+
if constexpr (ncols2 <= 32) {
1199+
if (Q->ne[1] > 16/ncols2) {
1200+
constexpr int cols_per_block = 32;
1201+
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1202+
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
1203+
launch_fattn<DV, cols_per_block/ncols2, ncols2,
1204+
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
1205+
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
1206+
return;
1207+
}
12051208
}
1206-
}
1207-
1208-
if (Q->ne[1] > 8/ncols2) {
1209-
constexpr int cols_per_block = 16;
1210-
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1211-
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
1212-
launch_fattn<DV, cols_per_block/ncols2, ncols2,
1213-
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
1214-
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
1215-
return;
1216-
}
1217-
1218-
if constexpr (ncols2 <= 8) {
1219-
if (Q->ne[1] > 4/ncols2) {
1220-
constexpr int cols_per_block = 8;
1221-
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1222-
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
1223-
launch_fattn<DV, cols_per_block/ncols2, ncols2,
1224-
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
1225-
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
1226-
return;
1209+
if constexpr (ncols2 <= 16) {
1210+
if (Q->ne[1] > 8/ncols2) {
1211+
constexpr int cols_per_block = 16;
1212+
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1213+
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
1214+
launch_fattn<DV, cols_per_block/ncols2, ncols2,
1215+
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
1216+
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
1217+
return;
1218+
}
1219+
}
1220+
if constexpr (ncols2 <= 8) {
1221+
if (Q->ne[1] > 4/ncols2) {
1222+
constexpr int cols_per_block = 8;
1223+
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1224+
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
1225+
launch_fattn<DV, cols_per_block/ncols2, ncols2,
1226+
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
1227+
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
1228+
return;
1229+
}
12271230
}
12281231
}
12291232

0 commit comments

Comments
 (0)