@@ -69,7 +69,7 @@ using TiledMma =
6969 Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
7070
7171// Define Mainloop dispatch policy
72- constexpr int PipelineStages = 3 ;
72+ constexpr int PipelineStages = 1 ;
7373using DispatchPolicy = cutlass::gemm::MainloopIntelPVCMixedPrecision<PipelineStages>;
7474static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; // sub_group size
7575
@@ -231,15 +231,18 @@ class kgemm_4bit_inference_cutlass_dequant {
231231 // / Utilities to transform A.
232232 template <class EngineIn ,
233233 class EngineOut ,
234+ class EngineRef ,
234235 class EngineScales ,
235236 class LayoutIn ,
236237 class LayoutOut ,
238+ class LayoutRef ,
237239 class LayoutScales ,
238240 class ... Ts>
239241 CUTLASS_DEVICE
240242 void dequant (
241243 Tensor<EngineIn, LayoutIn> const & tCrA_load,
242244 Tensor<EngineOut, LayoutOut>& tCrA_mma,
245+ Tensor<EngineRef, LayoutRef>& A_ref, // mma_A for debug
243246 Tensor<EngineScales, LayoutScales>& tCrS_input,
244247 float * quant_map
245248 ) {
@@ -258,21 +261,27 @@ class kgemm_4bit_inference_cutlass_dequant {
258261 auto const & dst = tCrA_mma (_, _, _);
259262 auto pSrc = const_cast <SrcType*>(raw_pointer_cast (src.data ()));
260263 auto pDst = const_cast <DstType*>(raw_pointer_cast (dst.data ()));
264+ auto pA = const_cast <DstType*>(raw_pointer_cast (A_ref.data ()));
261265 constexpr int num_elements = decltype (size (src))::value / 2 ;
266+ for (int i=0 ; i<num_elements * 2 ; i++){
267+ if (cute::thread0 ())
268+ printf (" ThreadIdxX() = %d, i = %d, *(pSrc + i) = %d, *(pA + i*2) = %f, *(pA + i*2+1) = %f\n " , ThreadIdxX (), i, static_cast <int >(*(pSrc + i)), static_cast <int >(*(pA + i*2 )), static_cast <int >(*(pA + i*2 +1 )));
269+ }
262270
263271 // TODO(Codeplay): (perf) consider replacing `pack` with `num_elements` here - See xe_flash_attn_mma.hpp
264272 constexpr int pack = 1 ; // decltype(select_packing<SrcType, DstType, num_elements>::value())::value;
265273 int src_size = sizeof_bits_v<SrcType>;
266274 int dst_size = sizeof_bits_v<DstType>;
267- if (cute::thread0 ()) printf (" Cosize = %d, src_size = %d, dst_size = %d\n " , num_elements, src_size, dst_size);
275+ // if(cute::thread0()) printf("Cosize = %d, src_size = %d, dst_size = %d\n", num_elements, src_size, dst_size);
268276 // using Converter = cutlass::NumericArrayConverter<DstType, SrcType, pack, cutlass::FloatRoundStyle::round_to_nearest>;
269277#if 1
270278 for (int i=0 ; i<num_elements; i++){
271279 auto src_value = *(pSrc + i);
272- if (cute::thread0 ()) printf (" *(pSrc + i) = %d, src_value = %d\n " ,static_cast <int >(*(pSrc + i)), static_cast <int >(src_value));
280+ // if(cute::thread0()) printf("*(pSrc + i) = %d, src_value = %d\n",static_cast<int>(*(pSrc + i)), static_cast<int>(src_value));
273281 *(pDst + (2 * i)) = static_cast <DstType>(quant_map[src_value >> 4 ]);
274282 *(pDst + (2 * i + 1 )) = static_cast <DstType>(quant_map[src_value & 0x0f ]);
275- if (cute::thread0 ()) printf (" num_elements = %d, i = %d, *(pSrc + i) = %d, *(pSrc + i) >> 4= %d, *(pSrc + i) & 0x0f, quant_map[*(pSrc + i) >> 4] = %f, quant_map[src_value & 0x0f] = %f \n " , num_elements, i, static_cast <int >(*(pSrc + i)), static_cast <int >(*(pSrc + i) >> 4 ), static_cast <int >(*(pSrc + i) & 0x0f ), static_cast <int >(quant_map[*(pSrc + i) >> 4 ]), static_cast <int >(quant_map[src_value & 0x0f ]), static_cast <int >(*(pDst + (2 * i))), static_cast <int >(*(pDst + (2 * i + 1 ))));
283+ if (cute::thread0 ())
284+ printf (" num_elements = %d, i = %d, *(pSrc + i) = %d, *(pSrc + i) >> 4= %d, *(pSrc + i) & 0x0f, quant_map[*(pSrc + i) >> 4] = %f, quant_map[src_value & 0x0f] = %f \n " , num_elements, i, static_cast <int >(*(pSrc + i)), static_cast <int >(*(pSrc + i) >> 4 ), static_cast <int >(*(pSrc + i) & 0x0f ), static_cast <int >(quant_map[*(pSrc + i) >> 4 ]), static_cast <int >(quant_map[src_value & 0x0f ]), static_cast <int >(*(pDst + (2 * i))), static_cast <int >(*(pDst + (2 * i + 1 ))));
276285 }
277286#else
278287 using SrcArray = cutlass::Array<SrcType, pack>;
@@ -306,7 +315,9 @@ class kgemm_4bit_inference_cutlass_dequant {
306315 for (int i = 0 ; i < 4 ; ++i) {
307316 CUTLASS_PRAGMA_UNROLL
308317 for (int j = 0 ; j < 32 ; ++j) {
318+ if (cute::thread0 ()) printf (" tCrA_mma(_, i, _)[j] = %f, i = %d, j = %d, tCrS_input(i) = %f\n " ,tCrA_mma (_, i, _)[j], i, j, tCrS_input (i));
309319 tCrA_mma (_, i, _)[j] *= tCrS_input (i);
320+ // if(cute::thread0()) printf("after scaling tCrA_mma(_, i, _)[j] = %f\n", tCrA_mma(_, i, _)[j]);
310321 }
311322 }
312323#else
@@ -366,7 +377,7 @@ class kgemm_4bit_inference_cutlass_dequant {
366377
367378// // Get the block level coordinate(indexing) for current block
368379 auto blk_shape = TileShape{}; // 256,256,32
369- auto blk_shape_4bit = Shape<_256, _256, _16>{}; // TileShape{}; //256,256,32
380+ // auto blk_shape_4bit = Shape<_256, _256, _16>{}; //TileShape{}; //256,256,32
370381 int m_coord, n_coord, l_coord; // block index
371382 if (params.scheduler .raster_order_ == TileScheduler::RasterOrder::AlongN) {
372383 if (cute::thread0 ()) printf (" AlongN !!\n " );
@@ -393,7 +404,7 @@ class kgemm_4bit_inference_cutlass_dequant {
393404 // gA: 逻辑视图(无实际内存分配)
394405 Tensor gA = local_tile (mA_mkl , select<0 ,2 >(blk_shape), make_coord (m_coord,_,l_coord));
395406 Tensor gB = local_tile (mB_nkl , select<1 ,2 >(blk_shape), make_coord (n_coord,_,l_coord));
396- Tensor gB_4bit = local_tile (mB_nkl_4bit , select<1 ,2 >(blk_shape_4bit), make_coord (n_coord,_,l_coord / 2 ));
407+ // Tensor gB_4bit = local_tile(mB_nkl_4bit, select<1,2>(blk_shape_4bit), make_coord(n_coord,_,l_coord / 2));
397408
398409// // Allocate the tiled_mma and the accumulators for the (M,N) subgroup_tile_shape
399410 TiledMma tiled_mma;
@@ -412,8 +423,8 @@ class kgemm_4bit_inference_cutlass_dequant {
412423 // 对于单维度,坐标直接等于索引值。
413424 // 使用方式:int k = get<0>(coord); // k = 0
414425 // cute::make_coord_iterator(A, B): 生成起始坐标A,步长B的迭代器
415- auto k_tile_iter = cute::make_coord_iterator (idx2crd (0 , make_shape (K)), make_shape (K));
416- int k_tile_count = ceil_div (K, get<2 >(workgroup_shape));
426+ auto k_tile_iter = cute::make_coord_iterator (idx2crd (0 , make_shape (K / 2 )), make_shape (K / 2 ));
427+ int k_tile_count = ceil_div (K / 2 , get<2 >(workgroup_shape));
417428 if (cute::thread0 ()) printf (" k_tile_count = %d\n " , k_tile_count);
418429
419430// ////Run MainLoop//////
@@ -440,7 +451,7 @@ class kgemm_4bit_inference_cutlass_dequant {
440451 // 虽然每个线程参与多个 Atom 的计算,但 tCgB 的 shape 是针对单个Atom 的线程分片
441452 Tensor tCgA = thr_mma.partition_A (gA );
442453 Tensor tCgB = thr_mma.partition_B (gB );
443- Tensor tCgB_4bit = thr_mma.partition_B (gB_4bit );
454+ // Tensor tCgB_4bit = thr_mma.partition_B(gB_4bit);
444455
445456// // Create fragments:将全局或共享内存中的数据分块转换为适合硬件加速器(如 Tensor Core)计算的寄存器格式
446457 // make_fragment_layout: 为寄存器片段(Fragment)创建内存布局(Layout),确保数据在寄存器中的排布符合硬件指令(如 Tensor Core)的要求
@@ -452,7 +463,7 @@ class kgemm_4bit_inference_cutlass_dequant {
452463 Tensor fragment_scale_input = make_tensor<NonVoidElementScale>(FragScaleLayout{}); // 创建scale 寄存器张量
453464
454465 // narrow input fragment
455- Tensor mma_B_4bit = make_tensor<ElementMMA>(make_fragment_layout (tiled_copy_b_4bit, tCgB_4bit (_,_,_,0 ).shape ()));
466+ Tensor mma_B_4bit = make_tensor<ElementMMA>(make_fragment_layout (tiled_copy_b_4bit, tCgB (_,_,_,0 ).shape ()));
456467 Tensor quant_frag = make_tensor<ElementQuant>(decltype (mma_B_4bit.layout ()){});
457468
458469 static_assert (std::is_same_v<typename decltype (quant_frag)::value_type, ElementQuant>);
@@ -515,7 +526,7 @@ class kgemm_4bit_inference_cutlass_dequant {
515526 // partition_S: 生成逻辑视图(源布局),不实际移动数据
516527 // partition_D: 实际复制数据到目标布局(如共享内存→寄存器)
517528 auto pAgA = thr_prefetch_A.partition_S (gA );
518- auto pBgB = thr_prefetch_B.partition_S (gB_4bit );
529+ auto pBgB = thr_prefetch_B.partition_S (gB );
519530
520531// //
521532// // Mainloop
@@ -528,7 +539,7 @@ class kgemm_4bit_inference_cutlass_dequant {
528539
529540 Tensor copy_iter_s = [&](){
530541 return make_tensor (make_inttuple_iter (make_coord (n_coord, 0 , l_coord)), // 初始坐标:(n_coord, 0, l_coord),表示从 N 维的 n_coord 开始,K 维从 0 开始
531- make_layout (make_shape (_2{}, _2{}, _1{}, k_tile_count/ 2 ), // 迭代器的逻辑形状:[2, 2, 1, k_tile_count],表示每次迭代生成 2x2x1 的坐标块,共 k_tile_count 次
542+ make_layout (make_shape (_2{}, _2{}, _1{}, k_tile_count), // 迭代器的逻辑形状:[2, 2, 1, k_tile_count],表示每次迭代生成 2x2x1 的坐标块,共 k_tile_count 次
532543 make_stride (E<0 >{} * _16{}, E<0 >{} * _32{}, _0{}, E<1 >{} * _1{}))); // 步长 [16, 32, 0, 1]:
533544 // E<0>{} * _16{}: 第一维度(N)的步长为 16;
534545 // E<0>{} * _32{}:第二维度(K)的步长为 32;
@@ -538,7 +549,7 @@ class kgemm_4bit_inference_cutlass_dequant {
538549 // E<N>:一个模板类,表示第 N 维的步长或索引,通常用于动态形状或步长的占位符
539550 // E<0>{}:表示第 0 维(最内层维度)的动态步长或索引值,具体值在运行时确定
540551 }();
541-
552+ # if 0
542553 #define CUTLASS_ENABLE_DEBUG_PRINTS 1
543554 #if CUTLASS_ENABLE_DEBUG_PRINTS
544555 #define PRINT(x) print(#x ": "); print(x); print("\n");
@@ -571,9 +582,10 @@ class kgemm_4bit_inference_cutlass_dequant {
571582 }
572583 #undef PRINT
573584 #endif
585+ #endif
574586
575587 // crd2idx: 将多维逻辑坐标转换为线性索引
576- const int k_start_idx = crd2idx ((*k_tile_iter), make_shape (K));
588+ const int k_start_idx = crd2idx ((*k_tile_iter), make_shape (K / 2 ));
577589 int prefetch_k = 0 ;
578590
579591 CUTLASS_PRAGMA_UNROLL
@@ -593,7 +605,7 @@ class kgemm_4bit_inference_cutlass_dequant {
593605
594606 copy (tiled_copy_scale, copy_iter_s (_, _, _, k_start_idx + (k_tile / k_reload_factor)), copy_tCrS);
595607 // dequant(quant_frag, mma_B_expanded, fragment_scale_input, quant_map);
596- dequant (quant_frag, mma_B, fragment_scale_input, quant_map);
608+ dequant (quant_frag, mma_B, mma_A, fragment_scale_input, quant_map);
597609
598610 if (prefetch_k < k_tile_count) {
599611 prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
0 commit comments