@@ -70,6 +70,7 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp16(const int DKQ, co
7070 GGML_SYCL_FATTN_TILE_CONFIG_CASE (576 , 512 , 4 , 128 , 2 , 64 , 64 )
7171 GGML_SYCL_FATTN_TILE_CONFIG_CASE (576 , 512 , 8 , 256 , 2 , 64 , 64 )
7272 GGML_SYCL_FATTN_TILE_CONFIG_CASE (576 , 512 , 16 , 256 , 2 , 64 , 64 )
73+ GGML_SYCL_FATTN_TILE_CONFIG_CASE (576 , 512 , 32 , 256 , 2 , 64 , 64 )
7374
7475 return 0 ;
7576}
@@ -310,11 +311,11 @@ static __dpct_inline__ void flash_attn_tile_load_tile(const sycl::half2 * const
310311 sycl::half2 * const __restrict__ tile_KV,
311312 const int stride_KV,
312313 const int i_sup) {
314+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3 >();
313315 constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes ();
314316 constexpr int cpy_ne = cpy_nb / 4 ;
315317
316318 auto load = [&] (const int n) {
317- auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3 >();
318319 const int stride_j = warp_size >> n;
319320
320321 if (stride_j == 0 ) {
@@ -455,7 +456,7 @@ static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp,
455456
456457 flash_attn_tile_load_tile<warp_size, nwarps, nbatch_fa, nbatch_K, cpy_ne, oob_check>
457458 (K_h2 + int64_t (k_VKQ_0)*stride_K2 + k_KQ_0/2 , KV_tmp, stride_K2, k_VKQ_sup);
458- item_ct1.barrier ();
459+ item_ct1.barrier (sycl::access::fence_space::local_space );
459460
460461#ifdef SYCL_FAST_FP16
461462 static_assert ((nbatch_K/2 ) % cpy_ne == 0 , " bad nbatch_K" );
@@ -505,7 +506,7 @@ static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp,
505506 }
506507
507508 if (k_KQ_0 + nbatch_K < DKQ) {
508- item_ct1.barrier (); // Sync not needed on last iteration.
509+ item_ct1.barrier (sycl::access::fence_space::local_space ); // Sync not needed on last iteration.
509510 }
510511}
511512
@@ -545,7 +546,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
545546 const int k_VKQ_max,
546547 const int col_Q_0,
547548 float * KQ_max_new_shared) {
548- auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3 >();
549+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3 >();
549550 constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes ();
550551 constexpr int cpy_ne = cpy_nb / 4 ;
551552
@@ -620,14 +621,14 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
620621 }
621622
622623 if constexpr (np == 1 ) {
623- item_ct1.barrier ();
624+ item_ct1.barrier (sycl::access::fence_space::local_space );
624625 } else {
625626 static_assert (cpw == 1 , " bad cpw" );
626627
627628 if (item_ct1.get_local_id (2 ) == 0 ) {
628629 KQ_max_new_shared[item_ct1.get_local_id (1 )] = KQ_max_new[0 ];
629630 }
630- item_ct1.barrier ();
631+ item_ct1.barrier (sycl::access::fence_space::local_space );
631632 KQ_max_new[0 ] = KQ_max_new_shared[(item_ct1.get_local_id (1 ) & ~(np - 1 )) + item_ct1.get_local_id (2 ) % np];
632633 KQ_max_new[0 ] = warp_reduce_max<np>(KQ_max_new[0 ]);
633634 }
@@ -697,7 +698,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
697698 for (int k0 = 0 ; k0 < nbatch_fa; k0 += nbatch_V) {
698699 flash_attn_tile_load_tile<warp_size, nwarps, nbatch_V, DV, 0 , oob_check>
699700 (V_h2 + int64_t (k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0);
700- item_ct1.barrier ();
701+ item_ct1.barrier (sycl::access::fence_space::local_space );
701702
702703#ifdef SYCL_FAST_FP16
703704#pragma unroll
@@ -765,7 +766,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
765766 }
766767 }
767768#endif // SYCL_FAST_FP16
768- item_ct1.barrier ();
769+ item_ct1.barrier (sycl::access::fence_space::local_space );
769770 }
770771}
771772
@@ -972,7 +973,7 @@ static void flash_attn_tile(const char * Q,
972973 }
973974 }
974975
975- item_ct1.barrier ();
976+ item_ct1.barrier (sycl::access::fence_space::local_space );
976977
977978 // Main loop over KV cache:
978979 const int k_VKQ_max = KV_max ? KV_max[sequence * item_ct1.get_group_range (2 ) + item_ct1.get_group (2 )] : ne11;
@@ -1051,7 +1052,7 @@ static void flash_attn_tile(const char * Q,
10511052 return ;
10521053 }
10531054
1054- item_ct1.barrier ();
1055+ item_ct1.barrier (sycl::access::fence_space::local_space );
10551056
10561057#pragma unroll
10571058 for (int ip = 1 ; ip < np; ++ip) {
@@ -1193,37 +1194,39 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_sycl_context & ctx, ggm
11931194
11941195 constexpr size_t nbytes_shared = 0 ;
11951196
1196- if constexpr (DV <= 256 ) {
1197- if (Q->ne [1 ] > 16 /ncols2) {
1198- constexpr int cols_per_block = 32 ;
1199- const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1200- const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa (DKQ, DV, cols_per_block, cc);
1201- launch_fattn<DV, cols_per_block/ncols2, ncols2,
1202- flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
1203- (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true , true , false );
1204- return ;
1197+ if (DV < 512 && Q->ne [1 ] < 32 ) {
1198+ if constexpr (ncols2 <= 32 ) {
1199+ if (Q->ne [1 ] > 16 /ncols2) {
1200+ constexpr int cols_per_block = 32 ;
1201+ const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1202+ const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa (DKQ, DV, cols_per_block, cc);
1203+ launch_fattn<DV, cols_per_block/ncols2, ncols2,
1204+ flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
1205+ (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true , true , false );
1206+ return ;
1207+ }
12051208 }
1206- }
1207-
1208- if (Q-> ne [ 1 ] > 8 /ncols2) {
1209- constexpr int cols_per_block = 16 ;
1210- const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size ;
1211- const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa (DKQ, DV, cols_per_block, cc);
1212- launch_fattn< DV, cols_per_block/ ncols2, ncols2,
1213- flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
1214- (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true , true , false ) ;
1215- return ;
1216- }
1217-
1218- if constexpr (ncols2 <= 8 ) {
1219- if (Q-> ne [ 1 ] > 4 /ncols2) {
1220- constexpr int cols_per_block = 8 ;
1221- const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size ;
1222- const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa (DKQ, DV, cols_per_block, cc);
1223- launch_fattn< DV, cols_per_block/ ncols2, ncols2,
1224- flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
1225- (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true , true , false ) ;
1226- return ;
1209+ if constexpr (ncols2 <= 16 ) {
1210+ if (Q-> ne [ 1 ] > 8 /ncols2) {
1211+ constexpr int cols_per_block = 16 ;
1212+ const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size ;
1213+ const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa (DKQ, DV, cols_per_block, cc);
1214+ launch_fattn< DV, cols_per_block/ncols2, ncols2,
1215+ flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
1216+ (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true , true , false );
1217+ return ;
1218+ }
1219+ }
1220+ if constexpr (ncols2 <= 8 ) {
1221+ if (Q-> ne [ 1 ] > 4 /ncols2 ) {
1222+ constexpr int cols_per_block = 8 ;
1223+ const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size ;
1224+ const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa (DKQ, DV, cols_per_block, cc);
1225+ launch_fattn< DV, cols_per_block/ncols2, ncols2,
1226+ flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
1227+ (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true , true , false );
1228+ return ;
1229+ }
12271230 }
12281231 }
12291232
0 commit comments