@@ -126,13 +126,6 @@ using val_layout_load_A = decltype(make_layout(shape_div(typename traits_load_A:
126126// val_layout_load_A:寄存器片段布局
127127using Copy_A = decltype (make_tiled_copy(atom_load_A{}, Layout<CopyThreadShape>{}, val_layout_load_A{}));
128128
129- using GmemTiledCopyB_4bit = XE_2D_U8x32x32_LD_V; // U8 (1-byte) block copy for 8bit-B (narrower type)
130- using StrideB_4bit = cutlass::gemm::TagToStrideB_t<cutlass::layout::RowMajor>;
131- using traits_load_B_4bit = Copy_Traits<GmemTiledCopyB_4bit, StrideB_4bit>;
132- using atom_load_B_4bit = Copy_Atom<traits_load_B_4bit, ElementB>;
133- using val_layout_load_B_4bit = decltype (make_layout(shape_div(typename traits_load_B_4bit::BlockShape{}, CopyThreadShape{})));
134- using Copy_B_4bit = decltype (make_tiled_copy(atom_load_B_4bit{}, Layout<CopyThreadShape>{}, val_layout_load_B_4bit{}));
135-
136129using GmemTiledCopyB = XE_2D_U8x32x32_LD_V; // U8 (1-byte) block copy for 8bit-B (narrower type)
137130using StrideB = cutlass::gemm::TagToStrideB_t<cutlass::layout::RowMajor>;
138131using traits_load_B = Copy_Traits<GmemTiledCopyB, StrideB>;
@@ -200,7 +193,7 @@ class kgemm_4bit_inference_cutlass_dequant {
200193
201194 Copy_A tiled_copy_a;
202195 Copy_B tiled_copy_b;
203- Copy_B_4bit tiled_copy_b_4bit;
196+ Copy_B tiled_copy_b_4bit;
204197 Copy_Scale tiled_copy_scale;
205198 int group_size;
206199
@@ -293,13 +286,13 @@ class kgemm_4bit_inference_cutlass_dequant {
293286 // 2 x 16 of these are same K
294287 // 4 different scale/zero values per thread, no exchange needed
295288 // CUTLASS_PRAGMA_UNROLL
296- for (int i = 0 ; i < 4 ; ++i) {
297- // CUTLASS_PRAGMA_UNROLL
298- for (int j = 0 ; j < 32 ; ++j) {
299- tCrB_dst (_, i, _)[j] *= tCrS (i);
300- // printf("thread_idx = %d, i = %d, j = %d, scale_value = %f\n", thread_idx, i, j, tCrS(i));
301- }
302- }
289+ // for (int i = 0; i < 4; ++i) {
290+ // //CUTLASS_PRAGMA_UNROLL
291+ // for (int j = 0; j < 32; ++j) {
292+ // tCrB_dst(_, i, _)[j] *= tCrS(i);
293+ // //printf("thread_idx = %d, i = %d, j = %d, scale_value = %f\n", thread_idx, i, j, tCrS(i));
294+ // }
295+ // }
303296
304297#if 0
305298 for(int i=0; i<num_elements_dst; i++){
@@ -386,7 +379,7 @@ class kgemm_4bit_inference_cutlass_dequant {
386379
387380 Tensor gA = local_tile (mA_mkl , select<0 ,2 >(blk_shape), make_coord (m_coord,_,l_coord));
388381 Tensor gB = local_tile (mB_nkl , select<1 ,2 >(blk_shape), make_coord (n_coord,_,l_coord));
389- Tensor gB_4bit = local_tile (mB_nkl_4bit , select<1 ,2 >(blk_shape /* blk_shape_4bit*/ ), make_coord (n_coord,_,l_coord));
382+ Tensor gB_4bit = local_tile (mB_nkl_4bit , select<1 ,2 >(blk_shape_4bit), make_coord (n_coord,_,l_coord));
390383
391384// // Allocate the tiled_mma and the accumulators for the (M,N) subgroup_tile_shape
392385 TiledMma tiled_mma;
@@ -496,8 +489,14 @@ class kgemm_4bit_inference_cutlass_dequant {
496489
497490 CUTLASS_PRAGMA_UNROLL
498491 for (int i = 0 ; i < DispatchPolicy::Stages; i++, prefetch_k++) {
499- prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
500- prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k));
492+ if (prefetch_k < k_tile_count) {
493+ prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
494+ }
495+ if (prefetch_k < k_tile_count/2 ) {
496+ prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k));
497+ }
498+ // prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
499+ // prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
501500 }
502501
503502 const int k_reload_factor = params.group_size / BLK_K ;
@@ -560,7 +559,7 @@ for(int i=0; i<num_Acc; i++) {
560559
561560 SharedStorage& shared_storage = *reinterpret_cast <SharedStorage*>((char *)nullptr );
562561 CollectiveEpilogue epilogue{params.epilogue , shared_storage.epilogue };
563- auto problem_shape_MNKL = problem_size; // append<4>(problem_size, 1);
562+ auto problem_shape_MNKL = append<4 >(problem_size, 1 );
564563 epilogue (
565564 problem_shape_MNKL,
566565 subgroup_tile_shape,
@@ -613,7 +612,7 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
613612
614613 StrideB stride_B_4bit = cutlass::make_cute_packed_stride (StrideB{}, cute::make_shape (n, k/2 , l));
615614 auto mB_nkl_4bit = make_tensor (make_gmem_ptr (B), make_layout (make_shape (n, k/2 , l), stride_B_4bit));
616- Copy_B_4bit tiled_copy_b_4bit{Copy_B_4bit {}.with (mB_nkl_4bit )};
615+ Copy_B tiled_copy_b_4bit{Copy_B {}.with (mB_nkl_4bit )};
617616
618617 params.tiled_copy_a = tiled_copy_a;
619618 params.tiled_copy_b = tiled_copy_b;
@@ -631,7 +630,7 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
631630
632631 cutlass::KernelHardwareInfo hw_info;
633632 hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count (hw_info.device_id );
634- auto problem_shape_MNKL = problem_size; // append<4>(problem_size, 1);
633+ auto problem_shape_MNKL = append<4 >(problem_size, 1 );
635634 float alpha=1.0 ;
636635 float beta=0 .f ;
637636 StrideC stride_C = cutlass::make_cute_packed_stride (StrideC{}, cute::make_shape (m, n, l));
0 commit comments