@@ -253,24 +253,25 @@ class kgemm_4bit_inference_cutlass_dequant {
253253
254254#if 1
255255 auto const & src = tCrA_load (_, _, _);
256+ // auto src = src_(_, cute::take(src_.size(1)/2), _);
257+ // auto src = src_(_, _0{size_t(src_.size(1)/2)}, _);
256258 auto const & dst = tCrA_mma (_, _, _);
257259 auto pSrc = const_cast <SrcType*>(raw_pointer_cast (src.data ()));
258260 auto pDst = const_cast <DstType*>(raw_pointer_cast (dst.data ()));
259261 constexpr int num_elements = decltype (size (src))::value / 2 ;
260262
261263 // TODO(Codeplay): (perf) consider replacing `pack` with `num_elements` here - See xe_flash_attn_mma.hpp
262264 constexpr int pack = decltype (select_packing<SrcType, DstType, num_elements>::value ())::value;
263- // if(cute::thread0()) printf("Cosize, sizeof_bits_v<SrcType> = %d, sizeof_bits_v<DstType> = %d, cute::min(sizeof_bits_v<SrcType>, sizeof_bits_v<DstType>) = %d, 32 / cute::min(sizeof_bits_v<SrcType>, sizeof_bits_v<DstType>) = %d\n", num_elements, sizeof_bits_v<SrcType>, sizeof_bits_v<DstType>, cute::min(sizeof_bits_v<SrcType>, sizeof_bits_v<DstType>), 32 / cute::min(sizeof_bits_v<SrcType>, sizeof_bits_v<DstType>));
264265 int src_size = sizeof_bits_v<SrcType>;
265266 int dst_size = sizeof_bits_v<DstType>;
266267 if (cute::thread0 ()) printf (" Cosize = %d, src_size = %d, dst_size = %d\n " , num_elements, src_size, dst_size);
267268 // using Converter = cutlass::NumericArrayConverter<DstType, SrcType, pack, cutlass::FloatRoundStyle::round_to_nearest>;
268269 using SrcArray = cutlass::Array<SrcType, pack>;
269- using DstArray = cutlass::Array<DstType, pack * 2 >;
270+ using DstArray = cutlass::Array<DstType, pack>;
270271 constexpr int iters = num_elements / pack;
271272
272273 CUTLASS_PRAGMA_UNROLL
273- for (int i = 0 ; i < iters; ++i) {
274+ for (int i = 0 ; i < iters / 2 ; ++i) {
274275 SrcArray const * pSrcArr = reinterpret_cast <SrcArray const *>(pSrc) + i;
275276 DstArray* pDstArr = reinterpret_cast <DstArray*>(pDst) + i * 2 ;
276277 // *pDstArr = Converter::convert(*pSrcArr);
@@ -425,6 +426,7 @@ class kgemm_4bit_inference_cutlass_dequant {
425426 // tCgA: t(tensor) C(compute) gA(globaleA);
426427 // tCsA: s (shared memory)
427428 // tCrA: r (register)
429+ // 虽然每个线程参与多个 Atom 的计算,但 tCgB 的 shape 是针对单个Atom 的线程分片
428430 Tensor tCgA = thr_mma.partition_A (gA );
429431 Tensor tCgB = thr_mma.partition_B (gB );
430432 Tensor tCgB_4bit = thr_mma.partition_B (gB_4bit );
@@ -502,7 +504,7 @@ class kgemm_4bit_inference_cutlass_dequant {
502504 // partition_S: 生成逻辑视图(源布局),不实际移动数据
503505 // partition_D: 实际复制数据到目标布局(如共享内存→寄存器)
504506 auto pAgA = thr_prefetch_A.partition_S (gA );
505- auto pBgB = thr_prefetch_B.partition_S (gB );
507+ auto pBgB = thr_prefetch_B.partition_S (gB_4bit );
506508
507509// //
508510// // Mainloop
@@ -515,7 +517,7 @@ class kgemm_4bit_inference_cutlass_dequant {
515517
516518 Tensor copy_iter_s = [&](){
517519 return make_tensor (make_inttuple_iter (make_coord (n_coord, 0 , l_coord)), // 初始坐标:(n_coord, 0, l_coord),表示从 N 维的 n_coord 开始,K 维从 0 开始
518- make_layout (make_shape (_2{}, _2{}, _1{}, k_tile_count), // 迭代器的逻辑形状:[2, 2, 1, k_tile_count],表示每次迭代生成 2x2x1 的坐标块,共 k_tile_count 次
520+ make_layout (make_shape (_2{}, _2{}, _1{}, k_tile_count/ 2 ), // 迭代器的逻辑形状:[2, 2, 1, k_tile_count],表示每次迭代生成 2x2x1 的坐标块,共 k_tile_count 次
519521 make_stride (E<0 >{} * _16{}, E<0 >{} * _32{}, _0{}, E<1 >{} * _1{}))); // 步长 [16, 32, 0, 1]:
520522 // E<0>{} * _16{}: 第一维度(N)的步长为 16;
521523 // E<0>{} * _32{}:第二维度(K)的步长为 32;
@@ -529,7 +531,7 @@ class kgemm_4bit_inference_cutlass_dequant {
529531 #define CUTLASS_ENABLE_DEBUG_PRINTS 1
530532 #if CUTLASS_ENABLE_DEBUG_PRINTS
531533 #define PRINT (x ) print(#x " : " ); print(x); print(" \n " );
532- if (cute::thread0 ()){
534+ if (cute::thread0 ()){
533535 print (" ======================= A: \n " );
534536 print (" gA : " ); print (gA ); print (" \n " );
535537 print (" tCgA : " ); print (tCgA); print (" \n " );
@@ -558,6 +560,7 @@ class kgemm_4bit_inference_cutlass_dequant {
558560 }
559561 #undef PRINT
560562 #endif
563+
561564 // crd2idx: 将多维逻辑坐标转换为线性索引
562565 const int k_start_idx = crd2idx ((*k_tile_iter), make_shape (K));
563566 int prefetch_k = 0 ;
@@ -639,12 +642,13 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
639642 auto mA_mkl = make_tensor (make_gmem_ptr (A), make_layout (make_shape (m, k, l), stride_A));
640643 Copy_A tiled_copy_a{Copy_A{}.with (mA_mkl )};
641644
645+ // make_cute_packed_stride: 根据张量形状自动生成内存步长(Stride)的关键函数,其核心目标是优化内存访问模式以适配硬件指令
642646 StrideB stride_B = cutlass::make_cute_packed_stride (StrideB{}, cute::make_shape (n, k, l));
643647 auto mB_nkl = make_tensor (make_gmem_ptr (B), make_layout (make_shape (n, k, l), stride_B));
644648 Copy_B tiled_copy_b{Copy_B{}.with (mB_nkl )};
645649
646650 StrideB stride_B_4bit = cutlass::make_cute_packed_stride (StrideB{}, cute::make_shape (n, k/2 , l));
647- auto mB_nkl_4bit = make_tensor (make_gmem_ptr (B), make_layout (make_shape (n, k/2 , l), stride_B ));
651+ auto mB_nkl_4bit = make_tensor (make_gmem_ptr (B), make_layout (make_shape (n, k/2 , l), stride_B_4bit ));
648652 Copy_B tiled_copy_b_4bit{Copy_B{}.with (mB_nkl_4bit )};
649653
650654 params.tiled_copy_a = tiled_copy_a;
0 commit comments