@@ -129,16 +129,13 @@ using CopyThreadShapeRev = decltype(cute::reverse(CopyThreadShape{}));
129129
130130using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; // XE_2D_U16x16x32_LD_N;
131131using StrideA = cutlass::gemm::TagToStrideA_t<cutlass::layout::RowMajor>;
132- // using Copy_A = typename Copy_Traits<GmemTiledCopyA, StrideA>::template DefaultTiledCopy<ElementA>;
133132using traits_load_A = Copy_Traits<GmemTiledCopyA, StrideA>;
134133using atom_load_A = Copy_Atom<traits_load_A, ElementA>;
135134using val_layout_load_A = decltype (make_layout(shape_div(typename traits_load_A::BlockShape{}, CopyThreadShape{})));
136135using Copy_A = decltype (make_tiled_copy(atom_load_A{}, Layout<CopyThreadShape>{}, val_layout_load_A{}));
137136
138137using GmemTiledCopyB = XE_2D_U4x32x16_LD_T;
139138using StrideB = cutlass::gemm::TagToStrideB_t<cutlass::layout::ColumnMajor>;
140- // using StrideB = Stride<int64_t, int64_t, int64_t>;
141- // using Copy_B = typename Copy_Traits<GmemTiledCopyB, StrideB>::template DefaultTiledCopy<ElementB>;
142139using traits_load_B = Copy_Traits<GmemTiledCopyB, StrideB>;
143140using atom_load_B = Copy_Atom<traits_load_B, ElementB>;
144141using val_layout_load_B = decltype (make_layout(shape_div(typename traits_load_B::BlockShape{}, CopyThreadShape{})));
@@ -148,12 +145,6 @@ using Copy_B = decltype(make_tiled_copy(atom_load_B{}, Layout<CopyThreadShape>{}
148145using GmemTiledCopyScale = XE_2D_U16x1x16_LD_N;
149146using StrideScale = cute::Stride<_1, int64_t , int64_t >; // dynamic stride
150147using traits_load_scale = Copy_Traits<GmemTiledCopyScale, StrideScale>;
151- // using AtomLayout = Layout<
152- // Shape<_16, _2>, // 匹配 XE_2D_U16x1x32_LD_N 的 BlockShape
153- // Stride<_1, _16> // 连续存储,步长 16
154- // >;
155- // using atom_load_scale = Copy_Atom<traits_load_scale, ElementScale, AtomLayout>;
156- // using Copy_Scale = decltype(make_tiled_copy(atom_load_scale{}, Layout<CopyThreadShapeRev>{}, AtomLayout{})); //group-wise scale
157148using atom_load_scale = Copy_Atom<traits_load_scale, ElementScale>;
158149using val_layout_load_scale = decltype (make_layout(shape_div(typename traits_load_scale::BlockShape{}, CopyThreadShapeRev{})));
159150using Copy_Scale = decltype (make_tiled_copy(atom_load_scale{}, Layout<CopyThreadShapeRev>{}, val_layout_load_scale{})); // group-wise scale
@@ -245,17 +236,20 @@ class kgemm_4bit_inference_cutlass_dequant {
245236 auto s_tensor = make_tensor ((format_type*)(raw_pointer_cast (in.data ())), Shape<Int<loop_cnt / scalar>, Int<N>>{});
246237 auto d_tensor = make_tensor (out.data (), Shape<Int<vec_size>, Int<splits>, Int<N>>{});
247238
239+ CUTLASS_PRAGMA_UNROLL
248240 for (int n = 0 ; n < N; n++) {
249241 const auto ts = tCrS_input (n);
250242
251243 auto & src = *(cute::array<format_type, loop_cnt / scalar>*)(s_tensor (_, n).data ());
252244
245+ CUTLASS_PRAGMA_UNROLL
253246 for (int s = 0 ; s < splits; s++) {
254247 auto idx = vec_size * s / scalar;
255248 auto format_data = src[idx];
256249
257250 auto & dst = *(cute::array<DstType, vec_size>*)(d_tensor (_, s, n).data ());
258251
252+ CUTLASS_PRAGMA_UNROLL
259253 for (int i = 0 ; i < vec_size; i++) {
260254 uint8_t value = (format_data >> (src_bits * i)) & 0xf ;
261255 if (i % 2 != 0 ) { // 1,3, high_4bit
@@ -271,27 +265,18 @@ class kgemm_4bit_inference_cutlass_dequant {
271265 CUTLASS_DEVICE
272266 void operator ()(Params const & params, char * smem_buf) {
273267 // if(cute::thread0()) printf("this is fusion kernel...........\n");
274-
275268 int M = params.m ;
276269 int N = params.n ;
277270 int K = params.k ;
278271 int L = 1 ;
279-
280- const int BLK_M = 256 ;
281- const int BLK_N = 256 ;
282- const int BLK_K = 32 ;
283272
284- const int ATOM_M = 8 ;
285- const int ATOM_N = 4 ;
286- const int ATOM_K = 1 ;
287-
288- const int SG_M = ceil_div (BLK_M , ATOM_M );
289- const int SG_N = ceil_div (BLK_N , ATOM_N );
290- const int SG_K = ceil_div (BLK_K , ATOM_K );
291-
292- const int Num_SGs = ATOM_N * ATOM_M * ATOM_K ;
273+ // Total Threads number
274+ static constexpr auto Num_SGs = ATOM_N * ATOM_M * ATOM_K ; // 32 //2
275+
293276 static constexpr auto SG_QNT_WIDTH = Int<SG_N >{};
294277
278+ if (cute::thread0 ()) printf (" BLK_M = %d, BLK_N = %d, BLK_K = %d, ATOM_M = %d, ATOM_N = %d, ATOM_K = %d, SG_M = %d, SG_N = %d, SG_K = %d, Num_SGs = %d, SG_QNT_WIDTH = %d\n " , static_cast <int >(BLK_M ), static_cast <int >(BLK_N ), static_cast <int >(BLK_K ), static_cast <int >(ATOM_M ), static_cast <int >(ATOM_N ), static_cast <int >(ATOM_K ), static_cast <int >(SG_M ), static_cast <int >(SG_N ), static_cast <int >(SG_K ), static_cast <int >(Num_SGs), static_cast <int >(SG_QNT_WIDTH ));
279+
295280 T* A = params.A ;
296281 uint8_t * B = params.B ;
297282 float * out = params.out ;
@@ -401,14 +386,14 @@ class kgemm_4bit_inference_cutlass_dequant {
401386 auto pBgB = thr_prefetch_B.partition_S (gB );
402387
403388// Run mainloop
404- auto copy_iter_s = [&](){
389+ auto tSgS = [&](){
405390 return make_tensor (make_inttuple_iter (make_coord (n_coord, 0 , l_coord)),
406391 make_layout (make_shape (Int<scale_traits_size>{}, Int<scale_traits_num>{}, _1{}, k_tile_count),
407392 make_stride (E<0 >{} * _16{}, E<0 >{} * decltype (size<1 >(typename GmemTiledCopyScale::BlockShape{}))::value, _0{}, E<1 >{} * _1{})));
408393
409394 }();
410395
411- #if 0
396+ #if 1
412397 #define PRINT (x ) print(#x " : " ); print(x); print(" \n " );
413398 if (cutlass::thread (LOG_THREAD , LOG_GROUP )) {
414399 print (" \n\n ======================= A: \n " );
@@ -426,11 +411,17 @@ class kgemm_4bit_inference_cutlass_dequant {
426411 print (" frag_copy_B : " ); print (frag_copy_B); print (" \n " );
427412 print (" dequant_frag : " ); print (dequant_frag); print (" \n " );
428413
429- print("===================== D :\n");
430- print(" tiled_copy_scale : "); print(tiled_copy_scale); print("\n");
414+ print (" ===================== Scale :\n " );
415+ // print(" traits_load_scale::BlockShape{} : "); print(traits_load_scale::BlockShape{}); print("\n");
416+ // print(" CopyThreadShapeRev{} : "); print(CopyThreadShapeRev{}); print("\n");
417+ // print(" val_layout_load_scale{} : "); print(val_layout_load_scale{}); print("\n");
418+ // print(" atom_load_scale{} : "); print(atom_load_scale{}); print("\n");
419+ // print(" Layout<CopyThreadShapeRev>{} : "); print(Layout<CopyThreadShapeRev>{}); print("\n");
420+ // print(" Copy_Scale{} : "); print(Copy_Scale{}); print("\n");
421+ // print(" tiled_copy_scale : "); print(tiled_copy_scale); print("\n");
431422 print (" fragment_scale : " ); print (fragment_scale); print (" \n " );
432423 print (" frag_copy_Scale : " ); print (frag_copy_Scale); print (" \n " );
433- print(" copy_iter_s : "); print(copy_iter_s ); print("\n");
424+ print (" tSgS : " ); print (tSgS ); print (" \n " );
434425
435426 print (" ===================== D :\n " );
436427 print (" accumulators : " ); print (accumulators); print (" \n " );
@@ -439,9 +430,25 @@ class kgemm_4bit_inference_cutlass_dequant {
439430 print (" threads per workgroup : " ); print (MaxThreadsPerBlock); print (" \n " );
440431 print (" SubgroupTileShape : " ); print (SubgroupTileShape{}); print (" \n " );
441432
433+ print (" ===================== Config: \n " );
434+ print (" tiled_mma : " ); print (tiled_mma); print (" \n " );
435+
436+ print (" ===================== Config: \n " );
437+ print (" SubgroupTileShape : " ); print (SubgroupTileShape{}); print (" \n " );
438+
439+ print (" ===================== Config: \n " );
440+ print (" thr_mma : " ); print (thr_mma); print (" \n " );
441+
442+ print (" ===================== Config: \n " );
442443 print (" tiled_prefetch_a : " ); print (tiled_prefetch_a); print (" \n " );
444+
445+ print (" ===================== Config: \n " );
443446 print (" tiled_prefetch_b : " ); print (tiled_prefetch_b); print (" \n " );
447+
448+ print (" ===================== Config: \n " );
444449 print (" pAgA : " ); print (pAgA); print (" \n " );
450+
451+ print (" ===================== Config: \n " );
445452 print (" pBgB : " ); print (pBgB); print (" \n\n\n " );
446453 }
447454 #undef PRINT
@@ -450,7 +457,7 @@ class kgemm_4bit_inference_cutlass_dequant {
450457 int prefetch_k = k_start_idx;
451458
452459 const int k_reload_factor = ceil_div (params.group_size , BLK_K );
453- // if(cute::thread0()) printf("params.group_size = %d, BLK_K = %d, k_reload_factor = %f \n",params.group_size, BLK_K, k_reload_factor);
460+ if (cute::thread0 ()) printf (" params.group_size = %d, BLK_K = %d, k_reload_factor = %d \n " ,params.group_size , static_cast < int >( BLK_K ) , k_reload_factor);
454461
455462 CUTLASS_PRAGMA_UNROLL
456463 for (int i = 0 ; i < DispatchPolicy::Stages; i++, prefetch_k++) {
@@ -465,9 +472,9 @@ class kgemm_4bit_inference_cutlass_dequant {
465472 copy (tiled_copy_a, tAgA (_,_,_,k_tile), frag_copy_A);
466473 copy (tiled_copy_b, tBgB (_,_,_,k_tile), frag_copy_B);
467474
468- const int s_step = k_start_idx + ( k_s / k_reload_factor) ;
469- // if(cute::thread0()) printf("k_start_idx = %d, k_s = %d, k_reload_factor = %f, s_step = %d\n",k_start_idx, k_s, k_reload_factor, s_step );
470- copy (tiled_copy_scale, copy_iter_s (_, _, _, s_step ), frag_copy_Scale);
475+ const int s_idx = ( k_start_idx + k_s) / k_reload_factor;
476+ if (cute::thread0 ()) printf (" k_start_idx = %d, k_s = %d, k_reload_factor = %d, s_idx = %d\n " ,k_start_idx, k_s, k_reload_factor, s_idx );
477+ copy (tiled_copy_scale, tSgS (_, _, _, s_idx ), frag_copy_Scale);
471478
472479 if (prefetch_k < k_tile_count) {
473480 prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
@@ -591,7 +598,7 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
591598 StrideC stride_C = cutlass::make_cute_packed_stride (StrideC{}, cute::make_shape (m, n, l));
592599 StrideD stride_D = cutlass::make_cute_packed_stride (StrideD{}, cute::make_shape (m, n, l));
593600
594- #if 0
601+ #if 1
595602 #define PRINT (x ) print(#x " : " ); print(x); print(" \n " );
596603 if (cutlass::thread (LOG_THREAD , LOG_GROUP )) {
597604 print (" ===================== stride :\n " );
0 commit comments