@@ -209,16 +209,16 @@ class kgemm_4bit_inference_cutlass_dequant {
209209 // / Utilities to transform A.
210210 template <class EngineIn ,
211211 class EngineOut ,
212- // class EngineScales,
212+ class EngineScales ,
213213 class LayoutIn ,
214214 class LayoutOut ,
215- // class LayoutScales,
215+ class LayoutScales ,
216216 class ... Ts>
217217 CUTLASS_DEVICE
218218 void dequant (
219219 Tensor<EngineIn, LayoutIn> const & in,
220220 Tensor<EngineOut, LayoutOut>& out,
221- // Tensor<EngineScales, LayoutScales>& tCrS_input,
221+ Tensor<EngineScales, LayoutScales>& tCrS_input,
222222 float * quant_map
223223 ) {
224224 static_assert (is_rmem<EngineIn>::value, " Input tensor for A conversion must come from registers" );
@@ -227,7 +227,7 @@ class kgemm_4bit_inference_cutlass_dequant {
227227
228228 using SrcType = typename EngineIn::value_type;
229229 using DstType = typename EngineOut::value_type;
230- // using ScaleType = typename EngineScales::value_type;
230+ using ScaleType = typename EngineScales::value_type;
231231#if 0
232232 int numbers = decltype(size(in))::value;
233233 for(int i=0; i<numbers; i++){
@@ -263,7 +263,7 @@ class kgemm_4bit_inference_cutlass_dequant {
263263// printf("thread_idx = %d, decltype(size(in))::value = %d, K = %d, N = %d, L = %d, src_bits = %d, sizeof_bits_v<format_type> = %d, scalar = %d, decltype(size(out))::value = %d, loop_cnt = %d, splits = %d\n",int(ThreadIdxX()), decltype(size(in))::value, decltype(size<0>(in))::value, N, decltype(size<2>(in))::value, src_bits, sizeof_bits_v<format_type>, scalar, decltype(size(out))::value, loop_cnt, splits);
264264
265265 for (int n = 0 ; n < N; n++) {
266- // const auto ts = tCrS_input(n);
266+ const auto ts = tCrS_input (n);
267267
268268 auto & src = *(cute::array<format_type, loop_cnt / scalar>*)(s_tensor (_, n).data ());
269269
@@ -299,7 +299,7 @@ class kgemm_4bit_inference_cutlass_dequant {
299299
300300 auto tiled_copy_a = params.tiled_copy_a ;
301301 auto tiled_copy_b = params.tiled_copy_b ;
302- // auto tiled_copy_scale = params.tiled_copy_scale;
302+ auto tiled_copy_scale = params.tiled_copy_scale ;
303303
304304 auto problem_size = ProblemShape{M, N, K, L};
305305
@@ -320,7 +320,7 @@ class kgemm_4bit_inference_cutlass_dequant {
320320 barrier_wait (1 );
321321
322322// // Get the block level coordinate(indexing) for current block
323- auto blk_shape = TileShape{}; // 256,256,32
323+ auto blk_shape = TileShape{}; // 16,64,64
324324 int m_coord, n_coord, l_coord; // block index
325325 if (params.scheduler .raster_order_ == TileScheduler::RasterOrder::AlongN) {
326326 if (cute::thread0 ()) printf (" AlongN !!\n " );
@@ -363,7 +363,7 @@ class kgemm_4bit_inference_cutlass_dequant {
363363// //// MainLoop //////
364364 auto thr_copy_A = tiled_copy_a.get_slice (thread_idx);
365365 auto thr_copy_B = tiled_copy_b.get_slice (thread_idx);
366- // auto thr_copy_scale = tiled_copy_scale.get_slice(thread_idx);
366+ auto thr_copy_scale = tiled_copy_scale.get_slice (thread_idx);
367367
368368 auto sg = syclcompat::get_nd_item<1 >().get_sub_group ();
369369 auto first_thread_in_sg_idx = sg.get_group_linear_id () * DispatchPolicy::SubgroupSize;
@@ -379,10 +379,10 @@ class kgemm_4bit_inference_cutlass_dequant {
379379
380380 Tensor dequant_frag = make_tensor<ElementB>(mma_B.layout ());
381381
382- // static constexpr auto scale_traits_size = decltype(size(typename GmemTiledCopyScale::BlockShape{}))::value / SubgroupSize;
383- // static constexpr auto scale_traits_num = SG_QNT_WIDTH / size<1>(typename GmemTiledCopyScale::BlockShape{});
384- // using FragScaleLayout = Layout<Shape<Int<scale_traits_size>, Int<scale_traits_num>, _1>>;
385- // Tensor fragment_scale = make_tensor<ElementScale>(FragScaleLayout{});
382+ static constexpr auto scale_traits_size = decltype (size (typename GmemTiledCopyScale::BlockShape{}))::value / SubgroupSize;
383+ static constexpr auto scale_traits_num = SG_QNT_WIDTH / size<1 >(typename GmemTiledCopyScale::BlockShape{});
384+ using FragScaleLayout = Layout<Shape<Int<scale_traits_size>, Int<scale_traits_num>, _1>>;
385+ Tensor fragment_scale = make_tensor<ElementScale>(FragScaleLayout{});
386386
387387 static_assert (std::is_same_v<typename decltype (dequant_frag)::value_type, ElementQuant>);
388388 static_assert (std::is_same_v<typename decltype (mma_A)::value_type, ElementMMA>);
@@ -391,7 +391,7 @@ class kgemm_4bit_inference_cutlass_dequant {
391391// // Retile for copy
392392 Tensor frag_copy_A = thr_copy_A.retile_D (mma_A);
393393 Tensor frag_copy_B = thr_copy_B.retile_D (dequant_frag);
394- // Tensor frag_copy_Scale = thr_copy_scale.retile_D(fragment_scale);
394+ Tensor frag_copy_Scale = thr_copy_scale.retile_D (fragment_scale);
395395
396396// // Retile global counting tensors for copies:
397397 Tensor tAgA = thr_copy_A.retile_S (tCgA);
@@ -407,17 +407,17 @@ class kgemm_4bit_inference_cutlass_dequant {
407407 auto pAgA = thr_prefetch_A.partition_S (gA );
408408 auto pBgB = thr_prefetch_B.partition_S (gB );
409409
410- // // Run mainloop
411- // auto [m_idx, n_idx, k_idx, l_idx] = blk_coord_mnkl;
412- // const int n_coord_s = n_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N;
413- // const int l_coord_s = l_idx;
414- //
415- // auto copy_iter_s = [&](){
416- // return make_tensor(make_inttuple_iter(make_coord(n_coord_s, 0, l_coord_s)),
417- // make_layout(make_shape(Int<scale_traits_size>{}, Int<scale_traits_num>{}, _1{}, k_tile_count),
418- // make_stride(E<0>{} * _16{}, E<0>{} * size<1>(typename GmemTiledCopyScale::BlockShape{}), _0{}, E<1>{} * _1{})));
419- //
420- // }();
410+ // Run mainloop
411+ auto [m_idx, n_idx, k_idx, l_idx] = blk_coord_mnkl;
412+ const int n_coord_s = n_idx * BLK_N + (get_sub_group_id () % ATOM_N ) * SG_N ;
413+ const int l_coord_s = l_idx;
414+
415+ auto copy_iter_s = [&](){
416+ return make_tensor (make_inttuple_iter (make_coord (n_coord_s, 0 , l_coord_s)),
417+ make_layout (make_shape (Int<scale_traits_size>{}, Int<scale_traits_num>{}, _1{}, k_tile_count),
418+ make_stride (E<0 >{} * _16{}, E<0 >{} * size<1 >(typename GmemTiledCopyScale::BlockShape{}), _0{}, E<1 >{} * _1{})));
419+
420+ }();
421421#if 1
422422 #define PRINT (x ) print(#x " : " ); print(x); print(" \n " );
423423 if (cutlass::thread (LOG_THREAD , LOG_GROUP )) {
@@ -458,7 +458,7 @@ class kgemm_4bit_inference_cutlass_dequant {
458458 prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
459459 prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k));
460460 }
461- // k_tile_count=1;
461+
462462 for (int k_tile = k_start_idx; k_tile < k_tile_count + k_start_idx; k_tile++, prefetch_k++) {
463463 barrier_arrive (2 );
464464
@@ -468,7 +468,7 @@ class kgemm_4bit_inference_cutlass_dequant {
468468
469469 const int k_reload_factor = params.group_size / BLK_K ;
470470
471- // copy(tiled_copy_scale, copy_iter_s(_, _, _, k_tile / k_reload_factor), frag_copy_Scale);
471+ copy (tiled_copy_scale, copy_iter_s (_, _, _, k_tile / k_reload_factor), frag_copy_Scale);
472472
473473 if (prefetch_k < k_tile_count) {
474474 prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
@@ -477,9 +477,8 @@ class kgemm_4bit_inference_cutlass_dequant {
477477 prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k));
478478 }
479479
480- dequant (dequant_frag, mma_B, /* fragment_scale,*/ quant_map);
480+ dequant (dequant_frag, mma_B, fragment_scale, quant_map);
481481
482- // barrier_wait(1);
483482
484483 cute::gemm (tiled_mma, mma_A, mma_B, accumulators);
485484 barrier_wait (2 );
@@ -575,22 +574,8 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
575574 auto mA_mkl = make_tensor (make_gmem_ptr (A), make_layout (make_shape (m, k, l), stride_A));
576575 Copy_A tiled_copy_a{Copy_A{}.with (mA_mkl )};
577576
578- // StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n/2, k, l));
579- // auto stride_B_custom = cute::make_stride(
580- // cute::Int<1>{}, // 连续维度步幅(字节)
581- // (n * 4 + 7) / 8, // pitch = ceil(n * 4bit / 8bit)
582- // cute::Int<0>{} // 无批量步幅(根据需求调整)
583- // );
584- // constexpr int stride_k = (n * 4 ) / 8;
585- // constexpr int stride_l = (n * k * 4 ) / 8;
586- // auto stride_B = cute::make_stride(
587- // cute::Int<1>{},
588- // (n * 4 ) / 8,
589- // (n * k * 4 ) / 8
590- // );
591- // int k_half = k/2;
592- // StrideB stride_B = make_stride(int64_t{1}, int64_t{n}, int64_t{n * k});
593- StrideB stride_B = make_stride (int64_t {n}, cute::Int<1 >{}, int64_t {0 });
577+ // StrideB stride_B = make_stride(int64_t{n}, cute::Int<1>{}, int64_t{0});
578+ StrideB stride_B = cutlass::make_cute_packed_stride (StrideB{}, cute::make_shape (n, k, l));
594579 auto mB_nkl = make_tensor (cute::subbyte_iterator<ElementB>(B), make_layout (make_shape (n, k, l), stride_B));
595580 Copy_B tiled_copy_b{Copy_B{}.with (mB_nkl )};
596581
0 commit comments