Skip to content

Commit d55ce8c

Browse files
committed
refine code
1 parent 98194fb commit d55ce8c

1 file changed

Lines changed: 11 additions & 7 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -253,24 +253,25 @@ class kgemm_4bit_inference_cutlass_dequant {
253253

254254
#if 1
255255
auto const& src = tCrA_load(_, _, _);
256+
//auto src = src_(_, cute::take(src_.size(1)/2), _);
257+
//auto src = src_(_, _0{size_t(src_.size(1)/2)}, _);
256258
auto const& dst = tCrA_mma(_, _, _);
257259
auto pSrc = const_cast<SrcType*>(raw_pointer_cast(src.data()));
258260
auto pDst = const_cast<DstType*>(raw_pointer_cast(dst.data()));
259261
constexpr int num_elements = decltype(size(src))::value / 2;
260262

261263
// TODO(Codeplay): (perf) consider replacing `pack` with `num_elements` here - See xe_flash_attn_mma.hpp
262264
constexpr int pack = decltype(select_packing<SrcType, DstType, num_elements>::value())::value;
263-
//if(cute::thread0()) printf("Cosize, sizeof_bits_v<SrcType> = %d, sizeof_bits_v<DstType> = %d, cute::min(sizeof_bits_v<SrcType>, sizeof_bits_v<DstType>) = %d, 32 / cute::min(sizeof_bits_v<SrcType>, sizeof_bits_v<DstType>) = %d\n", num_elements, sizeof_bits_v<SrcType>, sizeof_bits_v<DstType>, cute::min(sizeof_bits_v<SrcType>, sizeof_bits_v<DstType>), 32 / cute::min(sizeof_bits_v<SrcType>, sizeof_bits_v<DstType>));
264265
int src_size = sizeof_bits_v<SrcType>;
265266
int dst_size = sizeof_bits_v<DstType>;
266267
if(cute::thread0()) printf("Cosize = %d, src_size = %d, dst_size = %d\n", num_elements, src_size, dst_size);
267268
//using Converter = cutlass::NumericArrayConverter<DstType, SrcType, pack, cutlass::FloatRoundStyle::round_to_nearest>;
268269
using SrcArray = cutlass::Array<SrcType, pack>;
269-
using DstArray = cutlass::Array<DstType, pack * 2>;
270+
using DstArray = cutlass::Array<DstType, pack>;
270271
constexpr int iters = num_elements / pack;
271272

272273
CUTLASS_PRAGMA_UNROLL
273-
for (int i = 0; i < iters; ++i) {
274+
for (int i = 0; i < iters / 2 ; ++i) {
274275
SrcArray const* pSrcArr = reinterpret_cast<SrcArray const*>(pSrc) + i;
275276
DstArray* pDstArr = reinterpret_cast<DstArray*>(pDst) + i * 2;
276277
//*pDstArr = Converter::convert(*pSrcArr);
@@ -425,6 +426,7 @@ class kgemm_4bit_inference_cutlass_dequant {
425426
// tCgA: t(tensor) C(compute) gA(globaleA);
426427
// tCsA: s (shared memory)
427428
// tCrA: r (register)
429+
// 虽然每个线程参与多个 Atom 的计算,但 tCgB 的 shape 是针对单个Atom 的线程分片
428430
Tensor tCgA = thr_mma.partition_A(gA);
429431
Tensor tCgB = thr_mma.partition_B(gB);
430432
Tensor tCgB_4bit = thr_mma.partition_B(gB_4bit);
@@ -502,7 +504,7 @@ class kgemm_4bit_inference_cutlass_dequant {
502504
// partition_S: 生成逻辑视图(源布局),不实际移动数据
503505
// partition_D: 实际复制数据到目标布局(如共享内存→寄存器)
504506
auto pAgA = thr_prefetch_A.partition_S(gA);
505-
auto pBgB = thr_prefetch_B.partition_S(gB);
507+
auto pBgB = thr_prefetch_B.partition_S(gB_4bit);
506508

507509
////
508510
//// Mainloop
@@ -515,7 +517,7 @@ class kgemm_4bit_inference_cutlass_dequant {
515517

516518
Tensor copy_iter_s = [&](){
517519
return make_tensor(make_inttuple_iter(make_coord(n_coord, 0, l_coord)), // 初始坐标:(n_coord, 0, l_coord),表示从 N 维的 n_coord 开始,K 维从 0 开始
518-
make_layout(make_shape(_2{}, _2{}, _1{}, k_tile_count), // 迭代器的逻辑形状:[2, 2, 1, k_tile_count],表示每次迭代生成 2x2x1 的坐标块,共 k_tile_count 次
520+
make_layout(make_shape(_2{}, _2{}, _1{}, k_tile_count/2), // 迭代器的逻辑形状:[2, 2, 1, k_tile_count],表示每次迭代生成 2x2x1 的坐标块,共 k_tile_count 次
519521
make_stride(E<0>{} * _16{}, E<0>{} * _32{}, _0{}, E<1>{} * _1{}))); // 步长 [16, 32, 0, 1]:
520522
// E<0>{} * _16{}: 第一维度(N)的步长为 16;
521523
// E<0>{} * _32{}:第二维度(K)的步长为 32;
@@ -529,7 +531,7 @@ class kgemm_4bit_inference_cutlass_dequant {
529531
#define CUTLASS_ENABLE_DEBUG_PRINTS 1
530532
#if CUTLASS_ENABLE_DEBUG_PRINTS
531533
#define PRINT(x) print(#x ": "); print(x); print("\n");
532-
if (cute::thread0()){
534+
if (cute::thread0()){
533535
print("======================= A: \n");
534536
print(" gA : "); print(gA); print("\n");
535537
print(" tCgA : "); print(tCgA); print("\n");
@@ -558,6 +560,7 @@ class kgemm_4bit_inference_cutlass_dequant {
558560
}
559561
#undef PRINT
560562
#endif
563+
561564
// crd2idx: 将多维逻辑坐标转换为线性索引
562565
const int k_start_idx = crd2idx((*k_tile_iter), make_shape(K));
563566
int prefetch_k = 0;
@@ -639,12 +642,13 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
639642
auto mA_mkl = make_tensor(make_gmem_ptr(A), make_layout(make_shape(m, k, l), stride_A));
640643
Copy_A tiled_copy_a{Copy_A{}.with(mA_mkl)};
641644

645+
// make_cute_packed_stride: 根据张量形状自动生成内存步长(Stride)的关键函数,其核心目标是优化内存访问模式以适配硬件指令
642646
StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, l));
643647
auto mB_nkl = make_tensor(make_gmem_ptr(B), make_layout(make_shape(n, k, l), stride_B));
644648
Copy_B tiled_copy_b{Copy_B{}.with(mB_nkl)};
645649

646650
StrideB stride_B_4bit = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k/2, l));
647-
auto mB_nkl_4bit = make_tensor(make_gmem_ptr(B), make_layout(make_shape(n, k/2, l), stride_B));
651+
auto mB_nkl_4bit = make_tensor(make_gmem_ptr(B), make_layout(make_shape(n, k/2, l), stride_B_4bit));
648652
Copy_B tiled_copy_b_4bit{Copy_B{}.with(mB_nkl_4bit)};
649653

650654
params.tiled_copy_a = tiled_copy_a;

0 commit comments

Comments
 (0)