@@ -61,10 +61,10 @@ static constexpr float quant_map_static[16] = {
6161};
6262#endif
6363
64- using TileShape = Shape<_32 , _128, _128>;
64+ using TileShape = Shape<_64 , _128, _128>;
6565using TiledMma =
6666 typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
67- Layout<Shape<_1, _8 , _1>, Stride<_8 , _1, _0>>>::TiledMMA;
67+ Layout<Shape<_2, _4 , _1>, Stride<_4 , _1, _0>>>::TiledMMA;
6868using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
6969using GmemTiledCopyB = XE_2D_U4x32x16_LD_T;
7070constexpr int PipelineStages = 2 ;
@@ -98,9 +98,9 @@ static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{});
9898using DispatchPolicy = cutlass::gemm::MainloopIntelPVCMixedPrecision<PipelineStages>;
9999static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;
100100
101- static constexpr auto FragsM = get<0 >(SubgroupTileShape{}) / get<0 >(MmaAtomShape());
102- static constexpr auto FragsN = get<1 >(SubgroupTileShape{}) / get<1 >(MmaAtomShape());
103- static constexpr auto FragmentSize = (get<0 >(MmaAtomShape()) * get<1 >(MmaAtomShape())) / SubgroupSize;
101+ // static constexpr auto FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape());
102+ // static constexpr auto FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape());
103+ // static constexpr auto FragmentSize = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize;
104104
105105// Design Scheduler
106106using TileScheduler_ = PersistentScheduler;
@@ -395,35 +395,51 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
395395 constexpr int src_loop_num = K / src_vec_size / src_compress_size;
396396 constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
397397
398- // if(cute::thread0()) printf("params.group_size = %d, k_reload_factor = %d, k_tile_count = %d, N = %d, K = %d, src_compress_size = %d, src_vec_size = %d, dst_compress_size = %d, dst_vec_size = %d \n",params.group_size, k_reload_factor, k_tile_count, N, K, src_compress_size, src_vec_size, dst_compress_size , dst_vec_size);
398+ if (cute::thread0 ()) printf (" N = %d, K = %d, src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_vec_size = %d, src_loop_num = %d, dst_loop_num = %d\n " , N, K, src_compress_size, dst_compress_size, src_vec_size , dst_vec_size, src_loop_num, dst_loop_num);
399399
400400 src_compress_type src[src_vec_size];
401401 ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
402402
403403 #pragma unroll
404404 for (int n = 0 ; n < N; n++) {
405- // float scale_value = fragment_scale(n);
406405 #pragma unroll
407406 for (int l = 0 ; l < src_loop_num; l++) {
408- // src_compress_type src[src_vec_size];
409- // ElementMMA dst[K/dst_loop_num];
410407 reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(src)[0 ] = reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast (dequant_frag.data ()))[n*src_loop_num + l];
408+
409+ if (thread_idx==0 && m_coord==0 && n_coord==0 && l_coord==0 ) {
410+ printf (" n = %d, src_l = %d\n " , n, l);
411+ print (" ======================= src vectorization: \n " );
412+ print (" src_g_ptr : " ); print (&(reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast (dequant_frag.data ()))[n * src_loop_num + l])); print (" \n " );
413+ print (" src_ptr : " ); print (&(reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(src)[0 ])); print (" \n " );
414+ print (" =======================\n " );
415+ }
416+
411417 #pragma unroll
412418 for (int v = 0 ; v < src_vec_size; v++) {
413419 src_compress_type src_value = src[v];
414420 int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
415421 #pragma unroll
416422 for (int c = 0 ; c < src_compress_size; c++) {
417423 uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
418- float scale_value = fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + c) / GROUP_SIZE );
424+ float scale_value = 1 . 0f ; // fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + c) / GROUP_SIZE);
419425 dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map[bit_value] * scale_value);
426+ // if(thread_idx==0 && m_coord==0 && n_coord==0 && l_coord==0) printf("n = %d, src_l = %d, dst_base_idx+c = %d, n * (BLK_K / GROUP_SIZE) + (dst_base_idx+c)/GROUP_SIZE) = %d, scale_value = %f\n", n, l, dst_base_idx+c, n * (BLK_K / GROUP_SIZE) + (dst_base_idx+c)/GROUP_SIZE, scale_value);
420427 }
421428 }
422429 }
423430
424431 #pragma unroll
425432 for (int l = 0 ; l < dst_loop_num; l++) {
426- reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast (mma_B.data ()))[n*dst_loop_num + l] = reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
433+ reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast (mma_B.data ()))[n * dst_loop_num + l] = reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
434+
435+ if (thread_idx==0 && m_coord==0 && n_coord==0 && l_coord==0 ) {
436+ printf (" n = %d, dst_l = %d\n " , n, l);
437+ print (" ======================= dst vectorization: \n " );
438+ print (" dst_g_ptr : " ); print (&(reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast (mma_B.data ()))[n*dst_loop_num + l])); print (" \n " );
439+ print (" dst_ptr : " ); print (&(reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l])); print (" \n " );
440+ print (" =======================\n " );
441+ }
442+
427443 }
428444 }
429445 };
@@ -474,21 +490,39 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
474490 }
475491
476492// replace epilige for store
477- Tensor mD_mnl = cute::get_pvc_tensor (make_shape (params.m , params.n , params.l ));
478- Tensor g_wg_D = local_tile (mD_mnl , take<0 ,2 >(WorkgroupTileShape{}), make_coord (m_coord,n_coord,l_coord));
479- Tensor gD = local_tile (g_wg_D, take<0 ,2 >(SubgroupTileShape{}), make_coord (
480- get_sub_group_id () / ATOM_N ,
481- get_sub_group_id () % ATOM_N
482- ));
483-
484- auto thread_xe_store_d = params.tiled_store_d .get_thread_slice (thread_idx);
485- Tensor tCgD = thread_xe_store_d.partition_D (gD );
493+ // Tensor mD_mnl = cute::get_pvc_tensor(make_shape(params.m, params.n, params.l));
494+ // Tensor g_wg_D = local_tile(mD_mnl, take<0,2>(WorkgroupTileShape{}), make_coord(m_coord,n_coord,l_coord));
495+ // Tensor gD = local_tile(g_wg_D, take<0,2>(SubgroupTileShape{}), make_coord(
496+ // get_sub_group_id() / ATOM_N,
497+ // get_sub_group_id() % ATOM_N
498+ // ));
499+ //
500+ // auto thread_xe_store_d = params.tiled_store_d.get_thread_slice(thread_idx);
501+ // Tensor tCgD = thread_xe_store_d.partition_D(gD);
486502
487- #pragma unroll
488- for (int epi = 0 ; epi < FragsM * FragsN; ++epi) {
489- int epi_m = epi / FragsN;
490- int epi_n = epi % FragsN;
491- copy (params.tiled_store_d , accumulators (_, epi_m, epi_n), tCgD (_, epi_m, epi_n));
503+ static constexpr int FragsM = get<0 >(SubgroupTileShape{}) / get<0 >(MmaAtomShape ()); // atom numbers per thread; A frags per sub_group
504+ static constexpr int FragsN = get<1 >(SubgroupTileShape{}) / get<1 >(MmaAtomShape ()); // atom numbers per thread; B frags per sub_group
505+
506+ auto m_sg = get_sub_group_id () / ATOM_N ;
507+ auto n_sg = get_sub_group_id () % ATOM_N ;
508+
509+ Tensor mD_mnl = cute::get_pvc_tensor (make_shape (params.m , params.n , params.l )); // Logical full output tensor
510+
511+ // Tile the output tensor per WG and select the tile for current WG
512+ Tensor g_wg_D = local_tile (mD_mnl , take<0 ,2 >(TileShape{}), make_coord (m_coord,n_coord,l_coord));
513+
514+ // Tile the output tensor per SG and select tile for the current SG
515+ Tensor gD = local_tile (g_wg_D, take<0 ,2 >(SubgroupTileShape{}), make_coord (m_sg,n_sg));
516+
517+ auto thread_xe_store_d = params.tiled_store_d .get_thread_slice (thread_idx); // partial copy_atom for current thread
518+ Tensor tCgD = thread_xe_store_d.partition_D (gD ); // values for current thread
519+
520+ CUTLASS_PRAGMA_UNROLL
521+ for (int epi_n = 0 ; epi_n < FragsN; ++epi_n) {
522+ CUTLASS_PRAGMA_UNROLL
523+ for (int epi_m = 0 ; epi_m < FragsM; ++epi_m) {
524+ copy (params.tiled_store_d , accumulators (_, epi_m, epi_n), tCgD (_, epi_m, epi_n));
525+ }
492526 }
493527 }
494528};
0 commit comments