@@ -149,30 +149,33 @@ using GmemTiledCopyC = CopyOpG2R;
149149using GmemTiledCopyD = cute::conditional_t <not cute::is_void_v<ElementD> && not cute::is_void_v<CopyOpR2G>,
150150 CopyOpR2G, XE_2D_U32x8x16_ST_N>;
151151
152- // Calculate subgroup_tile_shape (reminder: not the same thing with "subgroup_size" in sycl!!)
153- static constexpr auto BLK_M = get<0 >(WorkgroupTileShape{});
154- static constexpr auto BLK_N = get<1 >(WorkgroupTileShape{});
155- static constexpr auto BLK_K = get<2 >(WorkgroupTileShape{});
156-
157- static constexpr auto ATOM_M = get<1 >(typename TiledMma::ThrLayoutVMNK{}.shape());
158- static constexpr auto ATOM_N = get<2 >(typename TiledMma::ThrLayoutVMNK{}.shape());
159- static constexpr auto ATOM_K = get<3 >(typename TiledMma::ThrLayoutVMNK{}.shape());
160-
161- static_assert (BLK_M % TiledMma{}.template tile_size_mnk<0 >() == 0 , " TiledMma permutation size must match block size." );
162- static_assert (BLK_N % TiledMma{}.template tile_size_mnk<1 >() == 0 , " TiledMma permutation size must match block size." );
163- static_assert (BLK_K % TiledMma{}.template tile_size_mnk<2 >() == 0 , " TiledMma permutation size must match block size." );
164-
165- static constexpr auto SG_M = ceil_div(BLK_M , ATOM_M );
166- static constexpr auto SG_N = ceil_div(BLK_N , ATOM_N );
167- static constexpr auto SG_K = ceil_div(BLK_K , ATOM_K );
168- using SubgroupTileShape = Shape<decltype (SG_M ), decltype (SG_N ), decltype (SG_K )>;
169-
170- static constexpr auto Num_SGs = ATOM_N * ATOM_M * ATOM_K ; // 32
171- static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{});
172-
173152template <typename T, int BITS >
174153class kgemm_4bit_inference_cutlass_dequant {
175154public:
155+ // Calculate subgroup_tile_shape (reminder: not the same thing with "subgroup_size" in sycl!!)
156+ static constexpr auto BLK_M = get<0 >(WorkgroupTileShape{});
157+ static constexpr auto BLK_N = get<1 >(WorkgroupTileShape{});
158+ static constexpr auto BLK_K = get<2 >(WorkgroupTileShape{});
159+
160+ // Threads number
161+ static constexpr auto ATOM_M = get<1 >(typename TiledMma::ThrLayoutVMNK{}.shape());
162+ static constexpr auto ATOM_N = get<2 >(typename TiledMma::ThrLayoutVMNK{}.shape());
163+ static constexpr auto ATOM_K = get<3 >(typename TiledMma::ThrLayoutVMNK{}.shape());
164+
165+ static_assert (BLK_M % TiledMma{}.template tile_size_mnk<0 >() == 0 , " TiledMma permutation size must match block size." );
166+ static_assert (BLK_N % TiledMma{}.template tile_size_mnk<1 >() == 0 , " TiledMma permutation size must match block size." );
167+ static_assert (BLK_K % TiledMma{}.template tile_size_mnk<2 >() == 0 , " TiledMma permutation size must match block size." );
168+
169+ // sub-tile shape
170+ static constexpr auto SG_M = ceil_div(BLK_M , ATOM_M );
171+ static constexpr auto SG_N = ceil_div(BLK_N , ATOM_N );
172+ static constexpr auto SG_K = ceil_div(BLK_K , ATOM_K );
173+ using SubgroupTileShape = Shape<decltype (SG_M ), decltype (SG_N ), decltype (SG_K )>;
174+
175+ // Total Threads number
176+ static constexpr auto Num_SGs = ATOM_N * ATOM_M * ATOM_K ; // 32
177+ static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{});
178+
176179 // Kernel level shared memory storage
177180 struct SharedStorage {
178181 using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
@@ -374,6 +377,8 @@ class kgemm_4bit_inference_cutlass_dequant {
374377 Tensor mB_nkl = cute::get_pvc_tensor (make_shape (N,K,L)); // coordinate tensor: 0,1,2....
375378 Tensor mB_nkl_4bit = cute::get_pvc_tensor (make_shape (N,K/2 ,L)); // coordinate tensor: 0,1,2....
376379
380+ // local_tile: 从全局张量中提取线程块(CTA)级别的局部子块
381+ // gA: 逻辑视图(无实际内存分配)
377382 Tensor gA = local_tile (mA_mkl , select<0 ,2 >(blk_shape), make_coord (m_coord,_,l_coord));
378383 Tensor gB = local_tile (mB_nkl , select<1 ,2 >(blk_shape), make_coord (n_coord,_,l_coord));
379384 Tensor gB_4bit = local_tile (mB_nkl_4bit , select<1 ,2 >(blk_shape_4bit), make_coord (n_coord,_,l_coord / 2 ));
@@ -417,16 +422,21 @@ class kgemm_4bit_inference_cutlass_dequant {
417422 // thr_mma:线程的 MMA(矩阵乘累加)分片
418423 // gA:矩阵 A 的全局或共享内存分块
419424 // tCgA,一个逻辑张量,表示当前线程负责的寄存器片段, 形状由 TiledMMA 策略决定
425+ // tCgA: t(tensor) C(compute) gA(globaleA);
426+ // tCsA: s (shared memory)
427+ // tCrA: r (register)
420428 Tensor tCgA = thr_mma.partition_A (gA );
421429 Tensor tCgB = thr_mma.partition_B (gB );
422430 Tensor tCgB_4bit = thr_mma.partition_B (gB_4bit );
423431
424- // Create fragments
432+ // // Create fragments:将全局或共享内存中的数据分块转换为适合硬件加速器(如 Tensor Core)计算的寄存器格式
433+ // make_fragment_layout: 为寄存器片段(Fragment)创建内存布局(Layout),确保数据在寄存器中的排布符合硬件指令(如 Tensor Core)的要求
434+ // 提取分块形状(tCgA) → 生成寄存器布局(make_fragment_layout) → 创建逻辑张量(make_tensor)
425435 Tensor mma_A = make_tensor<ElementMMA>(make_fragment_layout (tiled_copy_a, tCgA (_,_,_,0 ).shape ()));
426436 Tensor mma_B = make_tensor<ElementMMA>(make_fragment_layout (tiled_copy_b, tCgB (_,_,_,0 ).shape ()));
427437
428- using FragScaleLayout = Layout<Shape<_2, _2, _1>>;
429- Tensor fragment_scale_input = make_tensor<NonVoidElementScale>(FragScaleLayout{});
438+ using FragScaleLayout = Layout<Shape<_2, _2, _1>>; // scale 寄存器分布
439+ Tensor fragment_scale_input = make_tensor<NonVoidElementScale>(FragScaleLayout{}); // 创建scale 寄存器张量
430440
431441 // narrow input fragment
432442 Tensor mma_B_4bit = make_tensor<ElementMMA>(make_fragment_layout (tiled_copy_b_4bit, tCgB_4bit (_,_,_,0 ).shape ()));
@@ -436,46 +446,90 @@ class kgemm_4bit_inference_cutlass_dequant {
436446 static_assert (std::is_same_v<typename decltype (mma_A)::value_type, ElementMMA>);
437447 static_assert (std::is_same_v<typename decltype (mma_B)::value_type, ElementMMA>);
438448
439- // Retile for copy
449+ // // Retile for copy
450+ // retile_D: 将数据从一种布局(如共享内存)转换为另一种布局(如寄存器片段),确保数据在寄存器中的排列符合硬件指令(如 Tensor Core)的要求
451+ // 为什么需要 retile_D?共享内存的布局(如行主序 Stride<_1,_128>)可能与硬件指令(如 Tensor Core 的 8x8 分块)不兼容, 通过 retile_D 将数据重排为寄存器需要的布局(如 Stride<_1,_8>)
452+ // D(Destination):数据最终需要适配的布局(通常是寄存器布局)
453+ // thr_copy_A.retile_D(mma_A): 将线程分片的数据(thr_copy_A)从原始布局(共享内存的行主序)重映射为目标布局(mma_A 的寄存器布局)。
454+ // frag_copy_A: 数据按 mma_A 的布局重新排列后的寄存器片段
455+ // code Analyze:
456+ // (1) Lambda 表达式 [&](){ ... }()
457+ // [&]:捕获当前作用域的所有变量(按引用)。
458+ // std::make_pair:返回 frag_copy_A 和 frag_copy_B 的元组。
459+ // 立即执行:() 表示直接调用该 Lambda。
460+ // (2) thr_copy_A.retile_D(mma_A)
461+ // 作用:将 thr_copy_A 的数据按 mma_A 的布局重排到寄存器。
462+ // 底层操作:
463+ // 从共享内存读取数据。
464+ // 按 mma_A.layout() 的步长(如 Stride<_1,_8>)重新排列。
465+ // 写入寄存器片段 frag_copy_A。
466+ // (3) thr_copy_B_4bit.retile_D(quant_frag)
467+ // 作用:将 4-bit 量化的 thr_copy_B_4bit 数据解压并按 quant_frag 布局重排。
468+ // 特殊处理:
469+ // 4-bit 解压:将每字节的 2 个 4-bit 数值解压为 2 个 8-bit 数值。
470+ // 布局适配:确保解压后的数据符合 MMA 指令的输入要求(如 int8 或 fp16)。
471+ // (4) 为什么需要 make_pair?: C++ 函数(或 Lambda)只能返回一个值,无法直接返回多个独立对象。 std::pair 或 std::tuple 将多个值封装为单个对象。允许 Lambda 函数通过单一 return 返回多个值。
440472 auto [frag_copy_A, frag_copy_B] = [&](){
441473 return std::make_pair (thr_copy_A.retile_D (mma_A), thr_copy_B_4bit.retile_D (quant_frag));
442474 }();
443475
444476 Tensor copy_tCrS = thr_copy_scale.retile_D (fragment_scale_input);
445477
446- // Retile global counting tensors for copies
478+ // // Retile global counting tensors for copies:
479+ // retile_D:将数据 物理复制到目标布局(如寄存器)。
480+ // retile_S:仅生成一个 逻辑视图,不实际移动数据(类似 reinterpret_cast)
481+ // 生成一个逻辑视图 tAgA,其形状和步长与 tCgA 相同,但数据仍存储在原始位置(共享内存)
482+ // 共享内存 → retile_S → 逻辑视图 (next step later → 寄存器 (实际复制))
447483 Tensor tAgA = thr_copy_A.retile_S (tCgA);
448484 Tensor tBgB = thr_copy_B_4bit.retile_S (tCgB);
449-
485+
486+ // // Prepare for prefetch
487+ // BLK_M, BLK_N, BLK_K, Num_SGs: Gemm Tile Atom information.
488+ // tiled_copy_a: Copy Atom information
489+ // prefetch_selector: 选择适合硬件架构的预取策略
450490 auto tiled_prefetch_a = cute::prefetch_selector<Shape<Int<BLK_M >,Int<BLK_K >>, Num_SGs>(tiled_copy_a);
451491 auto tiled_prefetch_b = cute::prefetch_selector<Shape<Int<BLK_N >,Int<BLK_K >>, Num_SGs>(tiled_copy_b_4bit);
492+ // get_slice: 获取当前线程负责的预取分片
452493 auto thr_prefetch_A = tiled_prefetch_a.get_slice (thread_idx);
453494 auto thr_prefetch_B = tiled_prefetch_b.get_slice (thread_idx);
454495
455496 // Partition global tile for prefetch
497+ // partition_S:将全局数据划分为预取分片,生成逻辑视图(不实际移动数据)
498+ // pAgA 和 pBgB:线程私有的全局内存分片视图,用于后续预取操作
499+ // code analyze:
500+ // (1) 预取(Prefetch)的作用: 隐藏延迟:在计算当前分块时,异步预取下一个分块的数据到缓存或共享内存。
501+ // (2) partition_S vs partition_D:
502+ // partition_S: 生成逻辑视图(源布局),不实际移动数据
503+ // partition_D: 实际复制数据到目标布局(如共享内存→寄存器)
456504 auto pAgA = thr_prefetch_A.partition_S (gA );
457505 auto pBgB = thr_prefetch_B.partition_S (gB );
458506
459- //
460- // Mainloop
461- //
507+ // //
508+ // // Mainloop
509+ // //
510+ // 在矩阵乘法(GEMM)中动态计算每个线程块(CTA)需要处理的数据分块位置
462511 auto [m_idx, n_idx, k_idx, l_idx] = blk_coord_mnkl;
463- m_coord = m_idx * BLK_M + (get_sub_group_id () / ATOM_N ) * SG_M ;
464- n_coord = n_idx * BLK_N + (get_sub_group_id () % ATOM_N ) * SG_N ;
512+ m_coord = m_idx * BLK_M + (get_sub_group_id () / ATOM_N ) * SG_M ; // m_idx * BLK_M:分块在 M 维度的起始全局坐标; get_sub_group_id() / ATOM_N) * SG_M:子组在 M 维度的偏移(用于细粒度并行)
513+ n_coord = n_idx * BLK_N + (get_sub_group_id () % ATOM_N ) * SG_N ; // n_idx * BLK_N:分块在 N 维度的起始全局坐标; (get_sub_group_id() % ATOM_N) * SG_N:子组在 N 维度的偏移
465514 l_coord = l_idx;
466515
467516 Tensor copy_iter_s = [&](){
468- return make_tensor (make_inttuple_iter (make_coord (n_coord, 0 , l_coord)),
469- make_layout (make_shape (_2{}, _2{}, _1{}, k_tile_count),
470- make_stride (E<0 >{} * _16{}, E<0 >{} * _32{}, _0{}, E<1 >{} * _1{})));
517+ return make_tensor (make_inttuple_iter (make_coord (n_coord, 0 , l_coord)), // 初始坐标:(n_coord, 0, l_coord),表示从 N 维的 n_coord 开始,K 维从 0 开始
518+ make_layout (make_shape (_2{}, _2{}, _1{}, k_tile_count), // 迭代器的逻辑形状:[2, 2, 1, k_tile_count],表示每次迭代生成 2x2x1 的坐标块,共 k_tile_count 次
519+ make_stride (E<0 >{} * _16{}, E<0 >{} * _32{}, _0{}, E<1 >{} * _1{}))); // 步长 [16, 32, 0, 1]:
520+ // E<0>{} * _16{}: 第一维度(N)的步长为 16;
521+ // E<0>{} * _32{}:第二维度(K)的步长为 32;
522+ // 0{}:第三维度(L)的步长为 0(固定);
523+ // E<1>{} * _1{}:第四维度(迭代次数)的步长为 1.
524+ // E<0>{} 是一个编译时表达式,用于表示步长(Stride)或布局(Layout)中的占位符或动态值
525+ // E<N>:一个模板类,表示第 N 维的步长或索引,通常用于动态形状或步长的占位符
526+ // E<0>{}:表示第 0 维(最内层维度)的动态步长或索引值,具体值在运行时确定
471527 }();
472528
473- // #define LOG_GROUP 1
474- // #define LOG_THREAD 1
475529 #define CUTLASS_ENABLE_DEBUG_PRINTS 1
476530 #if CUTLASS_ENABLE_DEBUG_PRINTS
477531 #define PRINT (x ) print(#x " : " ); print(x); print(" \n " );
478- if (cute::thread0 ()){ // (cutlass::thread(LOG_THREAD, LOG_GROUP)) {
532+ if (cute::thread0 ()){
479533 print (" ======================= A: \n " );
480534 print (" gA : " ); print (gA ); print (" \n " );
481535 print (" tCgA : " ); print (tCgA); print (" \n " );
@@ -504,7 +558,7 @@ class kgemm_4bit_inference_cutlass_dequant {
504558 }
505559 #undef PRINT
506560 #endif
507-
561+ // crd2idx: 将多维逻辑坐标转换为线性索引
508562 const int k_start_idx = crd2idx ((*k_tile_iter), make_shape (K));
509563 int prefetch_k = 0 ;
510564
0 commit comments