Skip to content

Commit 9526a54

Browse files
committed
refine code
1 parent 617d605 commit 9526a54

2 files changed

Lines changed: 29 additions & 17 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ using TiledMma =
6969
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
7070

7171
// Define Mainloop dispatch policy
72-
constexpr int PipelineStages = 3;
72+
constexpr int PipelineStages = 1;
7373
using DispatchPolicy = cutlass::gemm::MainloopIntelPVCMixedPrecision<PipelineStages>;
7474
static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; // sub_group size
7575

@@ -231,15 +231,18 @@ class kgemm_4bit_inference_cutlass_dequant {
231231
/// Utilities to transform A.
232232
template <class EngineIn,
233233
class EngineOut,
234+
class EngineRef,
234235
class EngineScales,
235236
class LayoutIn,
236237
class LayoutOut,
238+
class LayoutRef,
237239
class LayoutScales,
238240
class... Ts>
239241
CUTLASS_DEVICE
240242
void dequant(
241243
Tensor<EngineIn, LayoutIn> const& tCrA_load,
242244
Tensor<EngineOut, LayoutOut>& tCrA_mma,
245+
Tensor<EngineRef, LayoutRef>& A_ref, //mma_A for debug
243246
Tensor<EngineScales, LayoutScales>& tCrS_input,
244247
float* quant_map
245248
) {
@@ -258,21 +261,27 @@ class kgemm_4bit_inference_cutlass_dequant {
258261
auto const& dst = tCrA_mma(_, _, _);
259262
auto pSrc = const_cast<SrcType*>(raw_pointer_cast(src.data()));
260263
auto pDst = const_cast<DstType*>(raw_pointer_cast(dst.data()));
264+
auto pA = const_cast<DstType*>(raw_pointer_cast(A_ref.data()));
261265
constexpr int num_elements = decltype(size(src))::value / 2;
266+
for(int i=0; i<num_elements * 2; i++){
267+
if(cute::thread0())
268+
printf("ThreadIdxX() = %d, i = %d, *(pSrc + i) = %d, *(pA + i*2) = %f, *(pA + i*2+1) = %f\n", ThreadIdxX(), i, static_cast<int>(*(pSrc + i)), static_cast<int>(*(pA + i*2)), static_cast<int>(*(pA + i*2+1)));
269+
}
262270

263271
// TODO(Codeplay): (perf) consider replacing `pack` with `num_elements` here - See xe_flash_attn_mma.hpp
264272
constexpr int pack = 1; //decltype(select_packing<SrcType, DstType, num_elements>::value())::value;
265273
int src_size = sizeof_bits_v<SrcType>;
266274
int dst_size = sizeof_bits_v<DstType>;
267-
if(cute::thread0()) printf("Cosize = %d, src_size = %d, dst_size = %d\n", num_elements, src_size, dst_size);
275+
//if(cute::thread0()) printf("Cosize = %d, src_size = %d, dst_size = %d\n", num_elements, src_size, dst_size);
268276
//using Converter = cutlass::NumericArrayConverter<DstType, SrcType, pack, cutlass::FloatRoundStyle::round_to_nearest>;
269277
#if 1
270278
for(int i=0; i<num_elements; i++){
271279
auto src_value = *(pSrc + i);
272-
if(cute::thread0()) printf("*(pSrc + i) = %d, src_value = %d\n",static_cast<int>(*(pSrc + i)), static_cast<int>(src_value));
280+
//if(cute::thread0()) printf("*(pSrc + i) = %d, src_value = %d\n",static_cast<int>(*(pSrc + i)), static_cast<int>(src_value));
273281
*(pDst + (2 * i)) = static_cast<DstType>(quant_map[src_value >> 4]);
274282
*(pDst + (2 * i + 1)) = static_cast<DstType>(quant_map[src_value & 0x0f]);
275-
if(cute::thread0()) printf("num_elements = %d, i = %d, *(pSrc + i) = %d, *(pSrc + i) >> 4= %d, *(pSrc + i) & 0x0f, quant_map[*(pSrc + i) >> 4] = %f, quant_map[src_value & 0x0f] = %f \n", num_elements, i, static_cast<int>(*(pSrc + i)), static_cast<int>(*(pSrc + i) >> 4), static_cast<int>(*(pSrc + i) & 0x0f), static_cast<int>(quant_map[*(pSrc + i) >> 4]), static_cast<int>(quant_map[src_value & 0x0f]), static_cast<int>(*(pDst + (2 * i))), static_cast<int>(*(pDst + (2 * i + 1))));
283+
if(cute::thread0())
284+
printf("num_elements = %d, i = %d, *(pSrc + i) = %d, *(pSrc + i) >> 4= %d, *(pSrc + i) & 0x0f, quant_map[*(pSrc + i) >> 4] = %f, quant_map[src_value & 0x0f] = %f \n", num_elements, i, static_cast<int>(*(pSrc + i)), static_cast<int>(*(pSrc + i) >> 4), static_cast<int>(*(pSrc + i) & 0x0f), static_cast<int>(quant_map[*(pSrc + i) >> 4]), static_cast<int>(quant_map[src_value & 0x0f]), static_cast<int>(*(pDst + (2 * i))), static_cast<int>(*(pDst + (2 * i + 1))));
276285
}
277286
#else
278287
using SrcArray = cutlass::Array<SrcType, pack>;
@@ -306,7 +315,9 @@ class kgemm_4bit_inference_cutlass_dequant {
306315
for (int i = 0; i < 4; ++i) {
307316
CUTLASS_PRAGMA_UNROLL
308317
for (int j = 0; j < 32; ++j) {
318+
if(cute::thread0()) printf("tCrA_mma(_, i, _)[j] = %f, i = %d, j = %d, tCrS_input(i) = %f\n",tCrA_mma(_, i, _)[j], i, j, tCrS_input(i));
309319
tCrA_mma(_, i, _)[j] *= tCrS_input(i);
320+
//if(cute::thread0()) printf("after scaling tCrA_mma(_, i, _)[j] = %f\n", tCrA_mma(_, i, _)[j]);
310321
}
311322
}
312323
#else
@@ -366,7 +377,7 @@ class kgemm_4bit_inference_cutlass_dequant {
366377

367378
//// Get the block level coordinate(indexing) for current block
368379
auto blk_shape = TileShape{}; //256,256,32
369-
auto blk_shape_4bit = Shape<_256, _256, _16>{}; //TileShape{}; //256,256,32
380+
//auto blk_shape_4bit = Shape<_256, _256, _16>{}; //TileShape{}; //256,256,32
370381
int m_coord, n_coord, l_coord; //block index
371382
if (params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN) {
372383
if(cute::thread0()) printf("AlongN !!\n");
@@ -393,7 +404,7 @@ class kgemm_4bit_inference_cutlass_dequant {
393404
// gA: 逻辑视图(无实际内存分配)
394405
Tensor gA = local_tile(mA_mkl, select<0,2>(blk_shape), make_coord(m_coord,_,l_coord));
395406
Tensor gB = local_tile(mB_nkl, select<1,2>(blk_shape), make_coord(n_coord,_,l_coord));
396-
Tensor gB_4bit = local_tile(mB_nkl_4bit, select<1,2>(blk_shape_4bit), make_coord(n_coord,_,l_coord / 2));
407+
//Tensor gB_4bit = local_tile(mB_nkl_4bit, select<1,2>(blk_shape_4bit), make_coord(n_coord,_,l_coord / 2));
397408

398409
//// Allocate the tiled_mma and the accumulators for the (M,N) subgroup_tile_shape
399410
TiledMma tiled_mma;
@@ -412,8 +423,8 @@ class kgemm_4bit_inference_cutlass_dequant {
412423
// 对于单维度,坐标直接等于索引值。
413424
// 使用方式:int k = get<0>(coord); // k = 0
414425
// cute::make_coord_iterator(A, B): 生成起始坐标A,步长B的迭代器
415-
auto k_tile_iter = cute::make_coord_iterator(idx2crd(0, make_shape(K)), make_shape(K));
416-
int k_tile_count = ceil_div(K, get<2>(workgroup_shape));
426+
auto k_tile_iter = cute::make_coord_iterator(idx2crd(0, make_shape(K / 2)), make_shape(K / 2));
427+
int k_tile_count = ceil_div(K / 2, get<2>(workgroup_shape));
417428
if(cute::thread0()) printf("k_tile_count = %d\n", k_tile_count);
418429

419430
//////Run MainLoop//////
@@ -440,7 +451,7 @@ class kgemm_4bit_inference_cutlass_dequant {
440451
// 虽然每个线程参与多个 Atom 的计算,但 tCgB 的 shape 是针对单个Atom 的线程分片
441452
Tensor tCgA = thr_mma.partition_A(gA);
442453
Tensor tCgB = thr_mma.partition_B(gB);
443-
Tensor tCgB_4bit = thr_mma.partition_B(gB_4bit);
454+
//Tensor tCgB_4bit = thr_mma.partition_B(gB_4bit);
444455

445456
//// Create fragments:将全局或共享内存中的数据分块转换为适合硬件加速器(如 Tensor Core)计算的寄存器格式
446457
// make_fragment_layout: 为寄存器片段(Fragment)创建内存布局(Layout),确保数据在寄存器中的排布符合硬件指令(如 Tensor Core)的要求
@@ -452,7 +463,7 @@ class kgemm_4bit_inference_cutlass_dequant {
452463
Tensor fragment_scale_input = make_tensor<NonVoidElementScale>(FragScaleLayout{}); // 创建scale 寄存器张量
453464

454465
// narrow input fragment
455-
Tensor mma_B_4bit = make_tensor<ElementMMA>(make_fragment_layout(tiled_copy_b_4bit, tCgB_4bit(_,_,_,0).shape()));
466+
Tensor mma_B_4bit = make_tensor<ElementMMA>(make_fragment_layout(tiled_copy_b_4bit, tCgB(_,_,_,0).shape()));
456467
Tensor quant_frag = make_tensor<ElementQuant>(decltype(mma_B_4bit.layout()){});
457468

458469
static_assert(std::is_same_v<typename decltype(quant_frag)::value_type, ElementQuant>);
@@ -515,7 +526,7 @@ class kgemm_4bit_inference_cutlass_dequant {
515526
// partition_S: 生成逻辑视图(源布局),不实际移动数据
516527
// partition_D: 实际复制数据到目标布局(如共享内存→寄存器)
517528
auto pAgA = thr_prefetch_A.partition_S(gA);
518-
auto pBgB = thr_prefetch_B.partition_S(gB_4bit);
529+
auto pBgB = thr_prefetch_B.partition_S(gB);
519530

520531
////
521532
//// Mainloop
@@ -528,7 +539,7 @@ class kgemm_4bit_inference_cutlass_dequant {
528539

529540
Tensor copy_iter_s = [&](){
530541
return make_tensor(make_inttuple_iter(make_coord(n_coord, 0, l_coord)), // 初始坐标:(n_coord, 0, l_coord),表示从 N 维的 n_coord 开始,K 维从 0 开始
531-
make_layout(make_shape(_2{}, _2{}, _1{}, k_tile_count/2), // 迭代器的逻辑形状:[2, 2, 1, k_tile_count],表示每次迭代生成 2x2x1 的坐标块,共 k_tile_count 次
542+
make_layout(make_shape(_2{}, _2{}, _1{}, k_tile_count), // 迭代器的逻辑形状:[2, 2, 1, k_tile_count],表示每次迭代生成 2x2x1 的坐标块,共 k_tile_count 次
532543
make_stride(E<0>{} * _16{}, E<0>{} * _32{}, _0{}, E<1>{} * _1{}))); // 步长 [16, 32, 0, 1]:
533544
// E<0>{} * _16{}: 第一维度(N)的步长为 16;
534545
// E<0>{} * _32{}:第二维度(K)的步长为 32;
@@ -538,7 +549,7 @@ class kgemm_4bit_inference_cutlass_dequant {
538549
// E<N>:一个模板类,表示第 N 维的步长或索引,通常用于动态形状或步长的占位符
539550
// E<0>{}:表示第 0 维(最内层维度)的动态步长或索引值,具体值在运行时确定
540551
}();
541-
552+
#if 0
542553
#define CUTLASS_ENABLE_DEBUG_PRINTS 1
543554
#if CUTLASS_ENABLE_DEBUG_PRINTS
544555
#define PRINT(x) print(#x ": "); print(x); print("\n");
@@ -571,9 +582,10 @@ class kgemm_4bit_inference_cutlass_dequant {
571582
}
572583
#undef PRINT
573584
#endif
585+
#endif
574586

575587
// crd2idx: 将多维逻辑坐标转换为线性索引
576-
const int k_start_idx = crd2idx((*k_tile_iter), make_shape(K));
588+
const int k_start_idx = crd2idx((*k_tile_iter), make_shape(K / 2));
577589
int prefetch_k = 0;
578590

579591
CUTLASS_PRAGMA_UNROLL
@@ -593,7 +605,7 @@ class kgemm_4bit_inference_cutlass_dequant {
593605

594606
copy(tiled_copy_scale, copy_iter_s(_, _, _, k_start_idx + (k_tile / k_reload_factor)), copy_tCrS);
595607
//dequant(quant_frag, mma_B_expanded, fragment_scale_input, quant_map);
596-
dequant(quant_frag, mma_B, fragment_scale_input, quant_map);
608+
dequant(quant_frag, mma_B, mma_A, fragment_scale_input, quant_map);
597609

598610
if(prefetch_k < k_tile_count) {
599611
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));

tests/test_xpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class TestXPU:
4747
[torch.uint8],
4848
ids=describe_dtype,
4949
)
50-
@pytest.mark.parametrize("dim", [32], ids=id_formatter("dim"))
50+
@pytest.mark.parametrize("dim", [256], ids=id_formatter("dim"))
5151
def test_gemm_4bit(self, device, dim, dtype, storage_type, quant_storage, double_quant, kind):
5252
errs1 = []
5353
errs2 = []
@@ -86,7 +86,7 @@ def test_gemm_4bit(self, device, dim, dtype, storage_type, quant_storage, double
8686
)
8787
##pdb.set_trace()
8888
C3 = torch.matmul(A, B.t())
89-
#pdb.set_trace()
89+
pdb.set_trace()
9090
C2 = F.gemv_4bit(A, qB.t(), state=state)
9191
#pdb.set_trace()
9292
print(C3[0])

0 commit comments

Comments
 (0)