Skip to content

Commit e49badb

Browse files
committed
save code
1 parent 7206605 commit e49badb

1 file changed

Lines changed: 59 additions & 25 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 59 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ static constexpr float quant_map_static[16] = {
6161
};
6262
#endif
6363

64-
using TileShape = Shape<_32, _128, _128>;
64+
using TileShape = Shape<_64, _128, _128>;
6565
using TiledMma =
6666
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
67-
Layout<Shape<_1, _8, _1>, Stride<_8, _1, _0>>>::TiledMMA;
67+
Layout<Shape<_2, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
6868
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
6969
using GmemTiledCopyB = XE_2D_U4x32x16_LD_T;
7070
constexpr int PipelineStages = 2;
@@ -98,9 +98,9 @@ static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{});
9898
using DispatchPolicy = cutlass::gemm::MainloopIntelPVCMixedPrecision<PipelineStages>;
9999
static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;
100100

101-
static constexpr auto FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape());
102-
static constexpr auto FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape());
103-
static constexpr auto FragmentSize = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize;
101+
//static constexpr auto FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape());
102+
//static constexpr auto FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape());
103+
//static constexpr auto FragmentSize = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize;
104104

105105
// Design Scheduler
106106
using TileScheduler_ = PersistentScheduler;
@@ -395,35 +395,51 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
395395
constexpr int src_loop_num = K / src_vec_size / src_compress_size;
396396
constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
397397

398-
//if(cute::thread0()) printf("params.group_size = %d, k_reload_factor = %d, k_tile_count = %d, N = %d, K = %d, src_compress_size = %d, src_vec_size = %d, dst_compress_size = %d, dst_vec_size = %d\n",params.group_size, k_reload_factor, k_tile_count, N, K, src_compress_size, src_vec_size, dst_compress_size, dst_vec_size);
398+
if(cute::thread0()) printf("N = %d, K = %d, src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_vec_size = %d, src_loop_num = %d, dst_loop_num = %d\n", N, K, src_compress_size, dst_compress_size, src_vec_size, dst_vec_size, src_loop_num, dst_loop_num);
399399

400400
src_compress_type src[src_vec_size];
401401
ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
402402

403403
#pragma unroll
404404
for (int n = 0; n < N; n++) {
405-
//float scale_value = fragment_scale(n);
406405
#pragma unroll
407406
for (int l = 0; l < src_loop_num; l++) {
408-
//src_compress_type src[src_vec_size];
409-
//ElementMMA dst[K/dst_loop_num];
410407
reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[0] = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[n*src_loop_num + l];
408+
409+
if(thread_idx==0 && m_coord==0 && n_coord==0 && l_coord==0) {
410+
printf("n = %d, src_l = %d\n", n, l);
411+
print("======================= src vectorization: \n");
412+
print(" src_g_ptr : "); print(&(reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[n * src_loop_num + l])); print("\n");
413+
print(" src_ptr : "); print(&(reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[0])); print("\n");
414+
print("=======================\n");
415+
}
416+
411417
#pragma unroll
412418
for (int v = 0; v < src_vec_size; v++) {
413419
src_compress_type src_value = src[v];
414420
int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
415421
#pragma unroll
416422
for (int c = 0; c < src_compress_size; c++) {
417423
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
418-
float scale_value = fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + c) / GROUP_SIZE);
424+
float scale_value = 1.0f; //fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + c) / GROUP_SIZE);
419425
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
426+
//if(thread_idx==0 && m_coord==0 && n_coord==0 && l_coord==0) printf("n = %d, src_l = %d, dst_base_idx+c = %d, n * (BLK_K / GROUP_SIZE) + (dst_base_idx+c)/GROUP_SIZE) = %d, scale_value = %f\n", n, l, dst_base_idx+c, n * (BLK_K / GROUP_SIZE) + (dst_base_idx+c)/GROUP_SIZE, scale_value);
420427
}
421428
}
422429
}
423430

424431
#pragma unroll
425432
for (int l = 0; l < dst_loop_num; l++) {
426-
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n*dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
433+
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
434+
435+
if(thread_idx==0 && m_coord==0 && n_coord==0 && l_coord==0) {
436+
printf("n = %d, dst_l = %d\n", n, l);
437+
print("======================= dst vectorization: \n");
438+
print(" dst_g_ptr : "); print(&(reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n*dst_loop_num + l])); print("\n");
439+
print(" dst_ptr : "); print(&(reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l])); print("\n");
440+
print("=======================\n");
441+
}
442+
427443
}
428444
}
429445
};
@@ -474,21 +490,39 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
474490
}
475491

476492
//replace epilige for store
477-
Tensor mD_mnl = cute::get_pvc_tensor(make_shape(params.m, params.n, params.l));
478-
Tensor g_wg_D = local_tile(mD_mnl, take<0,2>(WorkgroupTileShape{}), make_coord(m_coord,n_coord,l_coord));
479-
Tensor gD = local_tile(g_wg_D, take<0,2>(SubgroupTileShape{}), make_coord(
480-
get_sub_group_id() / ATOM_N,
481-
get_sub_group_id() % ATOM_N
482-
));
483-
484-
auto thread_xe_store_d = params.tiled_store_d.get_thread_slice(thread_idx);
485-
Tensor tCgD = thread_xe_store_d.partition_D(gD);
493+
// Tensor mD_mnl = cute::get_pvc_tensor(make_shape(params.m, params.n, params.l));
494+
// Tensor g_wg_D = local_tile(mD_mnl, take<0,2>(WorkgroupTileShape{}), make_coord(m_coord,n_coord,l_coord));
495+
// Tensor gD = local_tile(g_wg_D, take<0,2>(SubgroupTileShape{}), make_coord(
496+
// get_sub_group_id() / ATOM_N,
497+
// get_sub_group_id() % ATOM_N
498+
// ));
499+
//
500+
// auto thread_xe_store_d = params.tiled_store_d.get_thread_slice(thread_idx);
501+
// Tensor tCgD = thread_xe_store_d.partition_D(gD);
486502

487-
#pragma unroll
488-
for (int epi = 0; epi < FragsM * FragsN; ++epi) {
489-
int epi_m = epi / FragsN;
490-
int epi_n = epi % FragsN;
491-
copy(params.tiled_store_d, accumulators(_, epi_m, epi_n), tCgD(_, epi_m, epi_n));
503+
static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // atom numbers per thread; A frags per sub_group
504+
static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // atom numbers per thread; B frags per sub_group
505+
506+
auto m_sg = get_sub_group_id() / ATOM_N;
507+
auto n_sg = get_sub_group_id() % ATOM_N;
508+
509+
Tensor mD_mnl = cute::get_pvc_tensor(make_shape(params.m, params.n, params.l)); // Logical full output tensor
510+
511+
// Tile the output tensor per WG and select the tile for current WG
512+
Tensor g_wg_D = local_tile(mD_mnl, take<0,2>(TileShape{}), make_coord(m_coord,n_coord,l_coord));
513+
514+
// Tile the output tensor per SG and select tile for the current SG
515+
Tensor gD = local_tile(g_wg_D, take<0,2>(SubgroupTileShape{}), make_coord(m_sg,n_sg));
516+
517+
auto thread_xe_store_d = params.tiled_store_d.get_thread_slice(thread_idx); //partial copy_atom for current thread
518+
Tensor tCgD = thread_xe_store_d.partition_D(gD); //values for current thread
519+
520+
CUTLASS_PRAGMA_UNROLL
521+
for (int epi_n = 0; epi_n < FragsN; ++epi_n) {
522+
CUTLASS_PRAGMA_UNROLL
523+
for (int epi_m = 0; epi_m < FragsM; ++epi_m) {
524+
copy(params.tiled_store_d, accumulators(_, epi_m, epi_n), tCgD(_, epi_m, epi_n));
525+
}
492526
}
493527
}
494528
};

0 commit comments

Comments
 (0)