Skip to content

Commit 98194fb

Browse files
committed
refine code
1 parent 1ea5d07 commit 98194fb

1 file changed

Lines changed: 93 additions & 39 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 93 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -149,30 +149,33 @@ using GmemTiledCopyC = CopyOpG2R;
149149
using GmemTiledCopyD = cute::conditional_t<not cute::is_void_v<ElementD> && not cute::is_void_v<CopyOpR2G>,
150150
CopyOpR2G, XE_2D_U32x8x16_ST_N>;
151151

152-
// Calculate subgroup_tile_shape (reminder: not the same thing with "subgroup_size" in sycl!!)
153-
static constexpr auto BLK_M = get<0>(WorkgroupTileShape{});
154-
static constexpr auto BLK_N = get<1>(WorkgroupTileShape{});
155-
static constexpr auto BLK_K = get<2>(WorkgroupTileShape{});
156-
157-
static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape());
158-
static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape());
159-
static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape());
160-
161-
static_assert(BLK_M % TiledMma{}.template tile_size_mnk<0>() == 0, "TiledMma permutation size must match block size.");
162-
static_assert(BLK_N % TiledMma{}.template tile_size_mnk<1>() == 0, "TiledMma permutation size must match block size.");
163-
static_assert(BLK_K % TiledMma{}.template tile_size_mnk<2>() == 0, "TiledMma permutation size must match block size.");
164-
165-
static constexpr auto SG_M = ceil_div(BLK_M, ATOM_M);
166-
static constexpr auto SG_N = ceil_div(BLK_N, ATOM_N);
167-
static constexpr auto SG_K = ceil_div(BLK_K, ATOM_K);
168-
using SubgroupTileShape = Shape<decltype(SG_M), decltype(SG_N), decltype(SG_K)>;
169-
170-
static constexpr auto Num_SGs = ATOM_N * ATOM_M * ATOM_K; //32
171-
static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{});
172-
173152
template <typename T, int BITS>
174153
class kgemm_4bit_inference_cutlass_dequant {
175154
public:
155+
// Calculate subgroup_tile_shape (reminder: not the same thing with "subgroup_size" in sycl!!)
156+
static constexpr auto BLK_M = get<0>(WorkgroupTileShape{});
157+
static constexpr auto BLK_N = get<1>(WorkgroupTileShape{});
158+
static constexpr auto BLK_K = get<2>(WorkgroupTileShape{});
159+
160+
//Threads number
161+
static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape());
162+
static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape());
163+
static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape());
164+
165+
static_assert(BLK_M % TiledMma{}.template tile_size_mnk<0>() == 0, "TiledMma permutation size must match block size.");
166+
static_assert(BLK_N % TiledMma{}.template tile_size_mnk<1>() == 0, "TiledMma permutation size must match block size.");
167+
static_assert(BLK_K % TiledMma{}.template tile_size_mnk<2>() == 0, "TiledMma permutation size must match block size.");
168+
169+
//sub-tile shape
170+
static constexpr auto SG_M = ceil_div(BLK_M, ATOM_M);
171+
static constexpr auto SG_N = ceil_div(BLK_N, ATOM_N);
172+
static constexpr auto SG_K = ceil_div(BLK_K, ATOM_K);
173+
using SubgroupTileShape = Shape<decltype(SG_M), decltype(SG_N), decltype(SG_K)>;
174+
175+
//Total Threads number
176+
static constexpr auto Num_SGs = ATOM_N * ATOM_M * ATOM_K; //32
177+
static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{});
178+
176179
// Kernel level shared memory storage
177180
struct SharedStorage {
178181
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
@@ -374,6 +377,8 @@ class kgemm_4bit_inference_cutlass_dequant {
374377
Tensor mB_nkl = cute::get_pvc_tensor(make_shape(N,K,L)); //coordinate tensor: 0,1,2....
375378
Tensor mB_nkl_4bit = cute::get_pvc_tensor(make_shape(N,K/2,L)); //coordinate tensor: 0,1,2....
376379

380+
// local_tile: 从全局张量中提取线程块(CTA)级别的局部子块
381+
// gA: 逻辑视图(无实际内存分配)
377382
Tensor gA = local_tile(mA_mkl, select<0,2>(blk_shape), make_coord(m_coord,_,l_coord));
378383
Tensor gB = local_tile(mB_nkl, select<1,2>(blk_shape), make_coord(n_coord,_,l_coord));
379384
Tensor gB_4bit = local_tile(mB_nkl_4bit, select<1,2>(blk_shape_4bit), make_coord(n_coord,_,l_coord / 2));
@@ -417,16 +422,21 @@ class kgemm_4bit_inference_cutlass_dequant {
417422
// thr_mma:线程的 MMA(矩阵乘累加)分片
418423
// gA:矩阵 A 的全局或共享内存分块
419424
// tCgA,一个逻辑张量,表示当前线程负责的寄存器片段, 形状由 TiledMMA 策略决定
425+
// tCgA: t(tensor) C(compute) gA(globaleA);
426+
// tCsA: s (shared memory)
427+
// tCrA: r (register)
420428
Tensor tCgA = thr_mma.partition_A(gA);
421429
Tensor tCgB = thr_mma.partition_B(gB);
422430
Tensor tCgB_4bit = thr_mma.partition_B(gB_4bit);
423431

424-
// Create fragments
432+
//// Create fragments:将全局或共享内存中的数据分块转换为适合硬件加速器(如 Tensor Core)计算的寄存器格式
433+
// make_fragment_layout: 为寄存器片段(Fragment)创建内存布局(Layout),确保数据在寄存器中的排布符合硬件指令(如 Tensor Core)的要求
434+
// 提取分块形状(tCgA) → 生成寄存器布局(make_fragment_layout) → 创建逻辑张量(make_tensor)
425435
Tensor mma_A = make_tensor<ElementMMA>(make_fragment_layout(tiled_copy_a, tCgA(_,_,_,0).shape()));
426436
Tensor mma_B = make_tensor<ElementMMA>(make_fragment_layout(tiled_copy_b, tCgB(_,_,_,0).shape()));
427437

428-
using FragScaleLayout = Layout<Shape<_2, _2, _1>>;
429-
Tensor fragment_scale_input = make_tensor<NonVoidElementScale>(FragScaleLayout{});
438+
using FragScaleLayout = Layout<Shape<_2, _2, _1>>; // scale 寄存器分布
439+
Tensor fragment_scale_input = make_tensor<NonVoidElementScale>(FragScaleLayout{}); // 创建scale 寄存器张量
430440

431441
// narrow input fragment
432442
Tensor mma_B_4bit = make_tensor<ElementMMA>(make_fragment_layout(tiled_copy_b_4bit, tCgB_4bit(_,_,_,0).shape()));
@@ -436,46 +446,90 @@ class kgemm_4bit_inference_cutlass_dequant {
436446
static_assert(std::is_same_v<typename decltype(mma_A)::value_type, ElementMMA>);
437447
static_assert(std::is_same_v<typename decltype(mma_B)::value_type, ElementMMA>);
438448

439-
// Retile for copy
449+
//// Retile for copy
450+
// retile_D: 将数据从一种布局(如共享内存)转换为另一种布局(如寄存器片段),确保数据在寄存器中的排列符合硬件指令(如 Tensor Core)的要求
451+
// 为什么需要 retile_D?共享内存的布局(如行主序 Stride<_1,_128>)可能与硬件指令(如 Tensor Core 的 8x8 分块)不兼容, 通过 retile_D 将数据重排为寄存器需要的布局(如 Stride<_1,_8>)
452+
// D(Destination):数据最终需要适配的布局(通常是寄存器布局)
453+
// thr_copy_A.retile_D(mma_A): 将线程分片的数据(thr_copy_A)从原始布局(共享内存的行主序)重映射为目标布局(mma_A 的寄存器布局)。
454+
// frag_copy_A: 数据按 mma_A 的布局重新排列后的寄存器片段
455+
// code Analyze:
456+
// (1) Lambda 表达式 [&](){ ... }()
457+
// [&]:捕获当前作用域的所有变量(按引用)。
458+
// std::make_pair:返回 frag_copy_A 和 frag_copy_B 的元组。
459+
// 立即执行:() 表示直接调用该 Lambda。
460+
// (2) thr_copy_A.retile_D(mma_A)
461+
// 作用:将 thr_copy_A 的数据按 mma_A 的布局重排到寄存器。
462+
// 底层操作:
463+
// 从共享内存读取数据。
464+
// 按 mma_A.layout() 的步长(如 Stride<_1,_8>)重新排列。
465+
// 写入寄存器片段 frag_copy_A。
466+
// (3) thr_copy_B_4bit.retile_D(quant_frag)
467+
// 作用:将 4-bit 量化的 thr_copy_B_4bit 数据解压并按 quant_frag 布局重排。
468+
// 特殊处理:
469+
// 4-bit 解压:将每字节的 2 个 4-bit 数值解压为 2 个 8-bit 数值。
470+
// 布局适配:确保解压后的数据符合 MMA 指令的输入要求(如 int8 或 fp16)。
471+
// (4) 为什么需要 make_pair?: C++ 函数(或 Lambda)只能返回一个值,无法直接返回多个独立对象。 std::pair 或 std::tuple 将多个值封装为单个对象。允许 Lambda 函数通过单一 return 返回多个值。
440472
auto [frag_copy_A, frag_copy_B] = [&](){
441473
return std::make_pair(thr_copy_A.retile_D(mma_A), thr_copy_B_4bit.retile_D(quant_frag));
442474
}();
443475

444476
Tensor copy_tCrS = thr_copy_scale.retile_D(fragment_scale_input);
445477

446-
// Retile global counting tensors for copies
478+
//// Retile global counting tensors for copies:
479+
// retile_D:将数据 物理复制到目标布局(如寄存器)。
480+
// retile_S:仅生成一个 逻辑视图,不实际移动数据(类似 reinterpret_cast)
481+
// 生成一个逻辑视图 tAgA,其形状和步长与 tCgA 相同,但数据仍存储在原始位置(共享内存)
482+
// 共享内存 → retile_S → 逻辑视图 (next step later → 寄存器 (实际复制))
447483
Tensor tAgA = thr_copy_A.retile_S(tCgA);
448484
Tensor tBgB = thr_copy_B_4bit.retile_S(tCgB);
449-
485+
486+
//// Prepare for prefetch
487+
// BLK_M, BLK_N, BLK_K, Num_SGs: Gemm Tile Atom information.
488+
// tiled_copy_a: Copy Atom information
489+
// prefetch_selector: 选择适合硬件架构的预取策略
450490
auto tiled_prefetch_a = cute::prefetch_selector<Shape<Int<BLK_M>,Int<BLK_K>>, Num_SGs>(tiled_copy_a);
451491
auto tiled_prefetch_b = cute::prefetch_selector<Shape<Int<BLK_N>,Int<BLK_K>>, Num_SGs>(tiled_copy_b_4bit);
492+
// get_slice: 获取当前线程负责的预取分片
452493
auto thr_prefetch_A = tiled_prefetch_a.get_slice(thread_idx);
453494
auto thr_prefetch_B = tiled_prefetch_b.get_slice(thread_idx);
454495

455496
// Partition global tile for prefetch
497+
// partition_S:将全局数据划分为预取分片,生成逻辑视图(不实际移动数据)
498+
// pAgA 和 pBgB:线程私有的全局内存分片视图,用于后续预取操作
499+
// code analyze:
500+
// (1) 预取(Prefetch)的作用: 隐藏延迟:在计算当前分块时,异步预取下一个分块的数据到缓存或共享内存。
501+
// (2) partition_S vs partition_D:
502+
// partition_S: 生成逻辑视图(源布局),不实际移动数据
503+
// partition_D: 实际复制数据到目标布局(如共享内存→寄存器)
456504
auto pAgA = thr_prefetch_A.partition_S(gA);
457505
auto pBgB = thr_prefetch_B.partition_S(gB);
458506

459-
//
460-
// Mainloop
461-
//
507+
////
508+
//// Mainloop
509+
////
510+
// 在矩阵乘法(GEMM)中动态计算每个线程块(CTA)需要处理的数据分块位置
462511
auto [m_idx, n_idx, k_idx, l_idx] = blk_coord_mnkl;
463-
m_coord = m_idx * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M;
464-
n_coord = n_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N;
512+
m_coord = m_idx * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M; // m_idx * BLK_M:分块在 M 维度的起始全局坐标; get_sub_group_id() / ATOM_N) * SG_M:子组在 M 维度的偏移(用于细粒度并行)
513+
n_coord = n_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N; // n_idx * BLK_N:分块在 N 维度的起始全局坐标; (get_sub_group_id() % ATOM_N) * SG_N:子组在 N 维度的偏移
465514
l_coord = l_idx;
466515

467516
Tensor copy_iter_s = [&](){
468-
return make_tensor(make_inttuple_iter(make_coord(n_coord, 0, l_coord)),
469-
make_layout(make_shape(_2{}, _2{}, _1{}, k_tile_count),
470-
make_stride(E<0>{} * _16{}, E<0>{} * _32{}, _0{}, E<1>{} * _1{})));
517+
return make_tensor(make_inttuple_iter(make_coord(n_coord, 0, l_coord)), // 初始坐标:(n_coord, 0, l_coord),表示从 N 维的 n_coord 开始,K 维从 0 开始
518+
make_layout(make_shape(_2{}, _2{}, _1{}, k_tile_count), // 迭代器的逻辑形状:[2, 2, 1, k_tile_count],表示每次迭代生成 2x2x1 的坐标块,共 k_tile_count 次
519+
make_stride(E<0>{} * _16{}, E<0>{} * _32{}, _0{}, E<1>{} * _1{}))); // 步长 [16, 32, 0, 1]:
520+
// E<0>{} * _16{}: 第一维度(N)的步长为 16;
521+
// E<0>{} * _32{}:第二维度(K)的步长为 32;
522+
// 0{}:第三维度(L)的步长为 0(固定);
523+
// E<1>{} * _1{}:第四维度(迭代次数)的步长为 1.
524+
// E<0>{} 是一个编译时表达式,用于表示步长(Stride)或布局(Layout)中的占位符或动态值
525+
// E<N>:一个模板类,表示第 N 维的步长或索引,通常用于动态形状或步长的占位符
526+
// E<0>{}:表示第 0 维(最内层维度)的动态步长或索引值,具体值在运行时确定
471527
}();
472528

473-
// #define LOG_GROUP 1
474-
// #define LOG_THREAD 1
475529
#define CUTLASS_ENABLE_DEBUG_PRINTS 1
476530
#if CUTLASS_ENABLE_DEBUG_PRINTS
477531
#define PRINT(x) print(#x ": "); print(x); print("\n");
478-
if (cute::thread0()){ //(cutlass::thread(LOG_THREAD, LOG_GROUP)) {
532+
if (cute::thread0()){
479533
print("======================= A: \n");
480534
print(" gA : "); print(gA); print("\n");
481535
print(" tCgA : "); print(tCgA); print("\n");
@@ -504,7 +558,7 @@ class kgemm_4bit_inference_cutlass_dequant {
504558
}
505559
#undef PRINT
506560
#endif
507-
561+
// crd2idx: 将多维逻辑坐标转换为线性索引
508562
const int k_start_idx = crd2idx((*k_tile_iter), make_shape(K));
509563
int prefetch_k = 0;
510564

0 commit comments

Comments
 (0)