Skip to content

Commit b148d2a

Browse files
committed
add more debug logs
1 parent 06f8ab2 commit b148d2a

4 files changed

Lines changed: 113 additions & 42 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -129,16 +129,13 @@ using CopyThreadShapeRev = decltype(cute::reverse(CopyThreadShape{}));
129129

130130
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; //XE_2D_U16x16x32_LD_N;
131131
using StrideA = cutlass::gemm::TagToStrideA_t<cutlass::layout::RowMajor>;
132-
//using Copy_A = typename Copy_Traits<GmemTiledCopyA, StrideA>::template DefaultTiledCopy<ElementA>;
133132
using traits_load_A = Copy_Traits<GmemTiledCopyA, StrideA>;
134133
using atom_load_A = Copy_Atom<traits_load_A, ElementA>;
135134
using val_layout_load_A = decltype(make_layout(shape_div(typename traits_load_A::BlockShape{}, CopyThreadShape{})));
136135
using Copy_A = decltype(make_tiled_copy(atom_load_A{}, Layout<CopyThreadShape>{}, val_layout_load_A{}));
137136

138137
using GmemTiledCopyB = XE_2D_U4x32x16_LD_T;
139138
using StrideB = cutlass::gemm::TagToStrideB_t<cutlass::layout::ColumnMajor>;
140-
//using StrideB = Stride<int64_t, int64_t, int64_t>;
141-
//using Copy_B = typename Copy_Traits<GmemTiledCopyB, StrideB>::template DefaultTiledCopy<ElementB>;
142139
using traits_load_B = Copy_Traits<GmemTiledCopyB, StrideB>;
143140
using atom_load_B = Copy_Atom<traits_load_B, ElementB>;
144141
using val_layout_load_B = decltype(make_layout(shape_div(typename traits_load_B::BlockShape{}, CopyThreadShape{})));
@@ -148,12 +145,6 @@ using Copy_B = decltype(make_tiled_copy(atom_load_B{}, Layout<CopyThreadShape>{}
148145
using GmemTiledCopyScale = XE_2D_U16x1x16_LD_N;
149146
using StrideScale = cute::Stride<_1, int64_t, int64_t>; //dynamic stride
150147
using traits_load_scale = Copy_Traits<GmemTiledCopyScale, StrideScale>;
151-
//using AtomLayout = Layout<
152-
// Shape<_16, _2>, // 匹配 XE_2D_U16x1x32_LD_N 的 BlockShape
153-
// Stride<_1, _16> // 连续存储,步长 16
154-
//>;
155-
//using atom_load_scale = Copy_Atom<traits_load_scale, ElementScale, AtomLayout>;
156-
//using Copy_Scale = decltype(make_tiled_copy(atom_load_scale{}, Layout<CopyThreadShapeRev>{}, AtomLayout{})); //group-wise scale
157148
using atom_load_scale = Copy_Atom<traits_load_scale, ElementScale>;
158149
using val_layout_load_scale = decltype(make_layout(shape_div(typename traits_load_scale::BlockShape{}, CopyThreadShapeRev{})));
159150
using Copy_Scale = decltype(make_tiled_copy(atom_load_scale{}, Layout<CopyThreadShapeRev>{}, val_layout_load_scale{})); //group-wise scale
@@ -245,17 +236,20 @@ class kgemm_4bit_inference_cutlass_dequant {
245236
auto s_tensor = make_tensor((format_type*)(raw_pointer_cast(in.data())), Shape<Int<loop_cnt / scalar>, Int<N>>{});
246237
auto d_tensor = make_tensor(out.data(), Shape<Int<vec_size>, Int<splits>, Int<N>>{});
247238

239+
CUTLASS_PRAGMA_UNROLL
248240
for (int n = 0; n < N; n++) {
249241
const auto ts = tCrS_input(n);
250242

251243
auto& src = *(cute::array<format_type, loop_cnt / scalar>*)(s_tensor(_, n).data());
252244

245+
CUTLASS_PRAGMA_UNROLL
253246
for (int s = 0; s < splits; s++) {
254247
auto idx = vec_size * s / scalar;
255248
auto format_data = src[idx];
256249

257250
auto& dst = *(cute::array<DstType, vec_size>*)(d_tensor(_, s, n).data());
258251

252+
CUTLASS_PRAGMA_UNROLL
259253
for (int i = 0; i < vec_size; i++) {
260254
uint8_t value = (format_data >> (src_bits * i)) & 0xf;
261255
if(i % 2 != 0) { //1,3, high_4bit
@@ -271,27 +265,18 @@ class kgemm_4bit_inference_cutlass_dequant {
271265
CUTLASS_DEVICE
272266
void operator()(Params const& params, char* smem_buf) {
273267
//if(cute::thread0()) printf("this is fusion kernel...........\n");
274-
275268
int M = params.m;
276269
int N = params.n;
277270
int K = params.k;
278271
int L = 1;
279-
280-
const int BLK_M = 256;
281-
const int BLK_N = 256;
282-
const int BLK_K = 32;
283272

284-
const int ATOM_M = 8;
285-
const int ATOM_N = 4;
286-
const int ATOM_K = 1;
287-
288-
const int SG_M = ceil_div(BLK_M, ATOM_M);
289-
const int SG_N = ceil_div(BLK_N, ATOM_N);
290-
const int SG_K = ceil_div(BLK_K, ATOM_K);
291-
292-
const int Num_SGs = ATOM_N * ATOM_M * ATOM_K;
273+
//Total Threads number
274+
static constexpr auto Num_SGs = ATOM_N * ATOM_M * ATOM_K; //32 //2
275+
293276
static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
294277

278+
if(cute::thread0()) printf("BLK_M = %d, BLK_N = %d, BLK_K = %d, ATOM_M = %d, ATOM_N = %d, ATOM_K = %d, SG_M = %d, SG_N = %d, SG_K = %d, Num_SGs = %d, SG_QNT_WIDTH = %d\n", static_cast<int>(BLK_M), static_cast<int>(BLK_N), static_cast<int>(BLK_K), static_cast<int>(ATOM_M), static_cast<int>(ATOM_N), static_cast<int>(ATOM_K), static_cast<int>(SG_M), static_cast<int>(SG_N), static_cast<int>(SG_K), static_cast<int>(Num_SGs), static_cast<int>(SG_QNT_WIDTH));
279+
295280
T* A = params.A;
296281
uint8_t* B = params.B;
297282
float* out = params.out;
@@ -401,14 +386,14 @@ class kgemm_4bit_inference_cutlass_dequant {
401386
auto pBgB = thr_prefetch_B.partition_S(gB);
402387

403388
// Run mainloop
404-
auto copy_iter_s = [&](){
389+
auto tSgS = [&](){
405390
return make_tensor(make_inttuple_iter(make_coord(n_coord, 0, l_coord)),
406391
make_layout(make_shape(Int<scale_traits_size>{}, Int<scale_traits_num>{}, _1{}, k_tile_count),
407392
make_stride(E<0>{} * _16{}, E<0>{} * decltype(size<1>(typename GmemTiledCopyScale::BlockShape{}))::value, _0{}, E<1>{} * _1{})));
408393

409394
}();
410395

411-
#if 0
396+
#if 1
412397
#define PRINT(x) print(#x ": "); print(x); print("\n");
413398
if (cutlass::thread(LOG_THREAD, LOG_GROUP)) {
414399
print("\n\n======================= A: \n");
@@ -426,11 +411,17 @@ class kgemm_4bit_inference_cutlass_dequant {
426411
print(" frag_copy_B : "); print(frag_copy_B); print("\n");
427412
print(" dequant_frag : "); print(dequant_frag); print("\n");
428413

429-
print("===================== D :\n");
430-
print(" tiled_copy_scale : "); print(tiled_copy_scale); print("\n");
414+
print("===================== Scale :\n");
415+
//print(" traits_load_scale::BlockShape{} : "); print(traits_load_scale::BlockShape{}); print("\n");
416+
//print(" CopyThreadShapeRev{} : "); print(CopyThreadShapeRev{}); print("\n");
417+
//print(" val_layout_load_scale{} : "); print(val_layout_load_scale{}); print("\n");
418+
//print(" atom_load_scale{} : "); print(atom_load_scale{}); print("\n");
419+
//print(" Layout<CopyThreadShapeRev>{} : "); print(Layout<CopyThreadShapeRev>{}); print("\n");
420+
//print(" Copy_Scale{} : "); print(Copy_Scale{}); print("\n");
421+
//print(" tiled_copy_scale : "); print(tiled_copy_scale); print("\n");
431422
print(" fragment_scale : "); print(fragment_scale); print("\n");
432423
print(" frag_copy_Scale : "); print(frag_copy_Scale); print("\n");
433-
print(" copy_iter_s: "); print(copy_iter_s); print("\n");
424+
print(" tSgS : "); print(tSgS); print("\n");
434425

435426
print("===================== D :\n");
436427
print(" accumulators : "); print(accumulators); print("\n");
@@ -439,9 +430,25 @@ class kgemm_4bit_inference_cutlass_dequant {
439430
print(" threads per workgroup : "); print(MaxThreadsPerBlock); print("\n");
440431
print(" SubgroupTileShape : "); print(SubgroupTileShape{}); print("\n");
441432

433+
print("===================== Config: \n");
434+
print(" tiled_mma : "); print(tiled_mma); print("\n");
435+
436+
print("===================== Config: \n");
437+
print(" SubgroupTileShape : "); print(SubgroupTileShape{}); print("\n");
438+
439+
print("===================== Config: \n");
440+
print(" thr_mma : "); print(thr_mma); print("\n");
441+
442+
print("===================== Config: \n");
442443
print(" tiled_prefetch_a : "); print(tiled_prefetch_a); print("\n");
444+
445+
print("===================== Config: \n");
443446
print(" tiled_prefetch_b : "); print(tiled_prefetch_b); print("\n");
447+
448+
print("===================== Config: \n");
444449
print(" pAgA : "); print(pAgA); print("\n");
450+
451+
print("===================== Config: \n");
445452
print(" pBgB : "); print(pBgB); print("\n\n\n");
446453
}
447454
#undef PRINT
@@ -450,7 +457,7 @@ class kgemm_4bit_inference_cutlass_dequant {
450457
int prefetch_k = k_start_idx;
451458

452459
const int k_reload_factor = ceil_div(params.group_size, BLK_K);
453-
//if(cute::thread0()) printf("params.group_size = %d, BLK_K = %d, k_reload_factor = %f\n",params.group_size, BLK_K, k_reload_factor);
460+
if(cute::thread0()) printf("params.group_size = %d, BLK_K = %d, k_reload_factor = %d\n",params.group_size, static_cast<int>(BLK_K), k_reload_factor);
454461

455462
CUTLASS_PRAGMA_UNROLL
456463
for (int i = 0; i < DispatchPolicy::Stages; i++, prefetch_k++) {
@@ -465,9 +472,9 @@ class kgemm_4bit_inference_cutlass_dequant {
465472
copy(tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
466473
copy(tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B);
467474

468-
const int s_step = k_start_idx + (k_s / k_reload_factor);
469-
//if(cute::thread0()) printf("k_start_idx = %d, k_s = %d, k_reload_factor = %f, s_step = %d\n",k_start_idx, k_s, k_reload_factor, s_step);
470-
copy(tiled_copy_scale, copy_iter_s(_, _, _, s_step), frag_copy_Scale);
475+
const int s_idx = (k_start_idx + k_s) / k_reload_factor;
476+
if(cute::thread0()) printf("k_start_idx = %d, k_s = %d, k_reload_factor = %d, s_idx = %d\n",k_start_idx, k_s, k_reload_factor, s_idx);
477+
copy(tiled_copy_scale, tSgS(_, _, _, s_idx), frag_copy_Scale);
471478

472479
if(prefetch_k < k_tile_count) {
473480
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
@@ -591,7 +598,7 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
591598
StrideC stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, l));
592599
StrideD stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, l));
593600

594-
#if 0
601+
#if 1
595602
#define PRINT(x) print(#x ": "); print(x); print("\n");
596603
if (cutlass::thread(LOG_THREAD, LOG_GROUP)) {
597604
print("===================== stride :\n");

include/cute/atom/copy_traits_xe.hpp

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,15 +210,13 @@ struct XE_2D_LD_Unpack {
210210
// It mean (M, N):(N, 1) convention if 'is_convention_MN' is true, (N, M):(1, N) convention otherwise.
211211
static constexpr bool is_convention_MN = !(is_need_reversed ^ is_column_major);
212212

213-
// 2d copy parameters
213+
// 2d copy parameters
214214
const void *base_ptr;
215215
uint32_t width;
216216
uint32_t height;
217217
uint32_t pitch;
218218
uint32_t stride_l = 0;
219219

220-
221-
222220
XE_2D_LD_Unpack(const void *ptr, uint32_t y,
223221
uint32_t x, uint32_t p = 0) : base_ptr(ptr) {
224222
if constexpr (is_need_reversed) {
@@ -235,6 +233,15 @@ struct XE_2D_LD_Unpack {
235233

236234
template <class... TensorArgs>
237235
XE_2D_LD_Unpack(Tensor<TensorArgs...> const &tensor) {
236+
#if 1
237+
if(cute::thread0()){
238+
print("===============================\n");
239+
print("is_column_major : "); print(is_column_major); print("\n");
240+
print("is_need_reversed : "); print(is_need_reversed); print("\n");
241+
print("is_convention_MN : "); print(is_convention_MN); print("\n");
242+
print("===============================\n");
243+
}
244+
#endif
238245
base_ptr = raw_pointer_cast(tensor.data());
239246

240247
if constexpr (is_need_reversed)
@@ -430,7 +437,25 @@ CUTE_HOST_DEVICE constexpr auto make_fragment_layout(TiledCopy &tiled_copy,
430437
auto order = std::conditional_t<TiledCopy::is_convention_MN,
431438
Step<Step<_0, _1>, Step<_2, _4>, Step<_3, _5>>,
432439
Step<Step<_0, _1>, Step<_3, _5>, Step<_2, _4>>>{};
433-
440+
#if 1
441+
if(cute::thread0()){
442+
print("========================make_fragment_layout: \n");
443+
print("fragment_top_level_shape: "); print(fragment_top_level_shape); print("\n");
444+
print("mma_atom_shape: "); print(mma_atom_shape); print("\n");
445+
print("total_mma_atom_iters_M: "); print(total_mma_atom_iters_M); print("\n");
446+
print("total_mma_atom_iters_N: "); print(total_mma_atom_iters_N); print("\n");
447+
print("ThreadLayout_: "); print(ThreadLayout_{}); print("\n");
448+
print("ThreadLayout: "); print(ThreadLayout{}); print("\n");
449+
print("thread_copy_shape: "); print(thread_copy_shape); print("\n");
450+
print("mma_atom_iters_in_copy_M: "); print(mma_atom_iters_in_copy_M); print("\n");
451+
print("mma_atom_iters_in_copy_N: "); print(mma_atom_iters_in_copy_N); print("\n");
452+
print("copy_iters_M: "); print(copy_iters_M); print("\n");
453+
print("copy_iters_N: "); print(copy_iters_N); print("\n");
454+
print("order: "); print(order); print("\n");
455+
print("mma_atom_shape_2d: "); print(mma_atom_shape_2d); print("\n");
456+
print("============================================== \n");
457+
}
458+
#endif
434459
return make_ordered_layout(make_shape(mma_atom_shape_2d,
435460
make_shape(mma_atom_iters_in_copy_M, copy_iters_M),
436461
make_shape(mma_atom_iters_in_copy_N, copy_iters_N)),

include/cute/atom/mma_atom.hpp

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,25 @@ struct TiledMMA : MMA_Atom
313313
make_tile(make_layout(size<1>(thr_layout_vmnk_)),
314314
make_layout(size<3>(thr_layout_vmnk_))));
315315
auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK)))
316-
316+
#if 1
317+
if(cute::thread0()){
318+
print("========================thrfrg_A: \n");
319+
print("atensor: "); print(atensor); print("\n");
320+
print("permutation_mnk<0>: "); print(permutation_mnk<0>()); print("\n");
321+
print("permutation_mnk<2>: "); print(permutation_mnk<2>()); print("\n");
322+
print("t_tile: "); print(t_tile); print("\n");
323+
print("t_tensor: "); print(t_tensor); print("\n");
324+
print("AtomShape_MNK: "); print(AtomShape_MNK{}); print("\n");
325+
print("a_tile: "); print(a_tile); print("\n");
326+
print("a_tensor: "); print(a_tensor); print("\n");
327+
print("AtomLayoutA_TV: "); print(AtomLayoutA_TV{}); print("\n");
328+
print("tv_tensor: "); print(tv_tensor); print("\n");
329+
print("thr_layout_vmnk_: "); print(thr_layout_vmnk_); print("\n");
330+
print("thr_tile: "); print(thr_tile); print("\n");
331+
print("thr_tensor: "); print(thr_tensor); print("\n");
332+
print("==================================== \n");
333+
}
334+
#endif
317335
return thr_tensor;
318336
}
319337

@@ -352,7 +370,25 @@ struct TiledMMA : MMA_Atom
352370
make_tile(make_layout(size<2>(thr_layout_vmnk_)),
353371
make_layout(size<3>(thr_layout_vmnk_))));
354372
auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK)))
355-
373+
#if 1
374+
if(cute::thread0()){
375+
print("========================thrfrg_B: \n");
376+
print("permutation_mnk<1>: "); print(permutation_mnk<1>()); print("\n");
377+
print("permutation_mnk<2>: "); print(permutation_mnk<2>()); print("\n");
378+
print("t_tile: "); print(t_tile); print("\n");
379+
print("btensor: "); print(btensor); print("\n");
380+
print("t_tensor: "); print(t_tensor); print("\n");
381+
print("AtomShape_MNK: "); print(AtomShape_MNK{}); print("\n");
382+
print("b_tile: "); print(b_tile); print("\n");
383+
print("b_tensor: "); print(b_tensor); print("\n");
384+
print("AtomLayoutB_TV: "); print(AtomLayoutB_TV{}); print("\n");
385+
print("tv_tensor: "); print(tv_tensor); print("\n");
386+
print("thr_layout_vmnk_: "); print(thr_layout_vmnk_); print("\n");
387+
print("thr_tile: "); print(thr_tile); print("\n");
388+
print("thr_tensor: "); print(thr_tensor); print("\n");
389+
print("==================================== \n");
390+
}
391+
#endif
356392
return thr_tensor;
357393
}
358394

@@ -523,6 +559,7 @@ struct ThrMMA : TiledMMA
523559
auto thr_tensor = make_tensor(static_cast<CTensor&&>(ctensor).data(), this->thrfrg_C(ctensor.layout()));
524560

525561
auto thr_vmn = make_coord(get<0>(thr_vmnk_), make_coord(get<1>(thr_vmnk_), get<2>(thr_vmnk_)));
562+
//if(cute::thread0()) printf("partition_C: get<0>(thr_vmnk_) = %d, get<1>(thr_vmnk_) = %d, get<2>(thr_vmnk_) = %d\n", static_cast<int>(get<0>(thr_vmnk_)),static_cast<int>(get<1>(thr_vmnk_)),static_cast<int>(get<2>(thr_vmnk_)));
526563
return thr_tensor(thr_vmn, make_coord(_, repeat<rank<1,1>(thr_tensor)>(_)));
527564
}
528565

@@ -534,6 +571,7 @@ struct ThrMMA : TiledMMA
534571
auto thr_tensor = make_tensor(static_cast<ATensor&&>(atensor).data(), this->thrfrg_A(atensor.layout()));
535572

536573
auto thr_vmk = make_coord(get<0>(thr_vmnk_), make_coord(get<1>(thr_vmnk_), get<3>(thr_vmnk_)));
574+
//if(cute::thread0()) printf("partition_A: get<0>(thr_vmnk_) = %d, get<1>(thr_vmnk_) = %d, get<3>(thr_vmnk_) = %d\n", static_cast<int>(get<0>(thr_vmnk_)),static_cast<int>(get<1>(thr_vmnk_)),static_cast<int>(get<3>(thr_vmnk_)));
537575
return thr_tensor(thr_vmk, make_coord(_, repeat<rank<1,1>(thr_tensor)>(_)));
538576
}
539577

@@ -545,6 +583,7 @@ struct ThrMMA : TiledMMA
545583
auto thr_tensor = make_tensor(static_cast<BTensor&&>(btensor).data(), this->thrfrg_B(btensor.layout()));
546584

547585
auto thr_vnk = make_coord(get<0>(thr_vmnk_), make_coord(get<2>(thr_vmnk_), get<3>(thr_vmnk_)));
586+
//if(cute::thread0()) printf("partition_B: get<0>(thr_vmnk_) = %d, get<2>(thr_vmnk_) = %d, get<3>(thr_vmnk_) = %d\n", static_cast<int>(get<0>(thr_vmnk_)),static_cast<int>(get<2>(thr_vmnk_)),static_cast<int>(get<3>(thr_vmnk_)));
548587
return thr_tensor(thr_vnk, make_coord(_, repeat<rank<1,1>(thr_tensor)>(_)));
549588
}
550589

tests/test_xpu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,12 @@ def test_gemm_4bit(self, device, dim, dtype, storage_type, quant_storage, double
8383
double_quant=False
8484
block_size = 16
8585
elif kind == "fc1":
86-
dim=4096
86+
dim=256
8787
A = torch.randn(64, dim, dtype=dtype, device=device)
8888
#A = torch.arange(1, 32 * 256 + 1).reshape(32, 256).bfloat16().xpu()
8989
B = torch.randn(dim , dim, dtype=dtype, device=device) / math.sqrt(dim)
9090
double_quant=False
91-
block_size = 32
91+
block_size = 64
9292
elif kind == "fc2":
9393
A = torch.randn(1, 4 * dim, dtype=dtype, device=device)
9494
B = torch.randn(dim, 4 * dim, dtype=dtype, device=device) / math.sqrt(dim)
@@ -144,7 +144,7 @@ def test_gemm_4bit(self, device, dim, dtype, storage_type, quant_storage, double
144144
#print("C3.sum() = ", C3.sum())
145145
#print("C2.sum() = ", C2.sum())
146146
diff = abs(C2-C3.bfloat16())
147-
print("diff = ", diff[0])
147+
print("diff/C2 = ", diff[0]/C2[0])
148148
print(C3[0])
149149
print(C2[0])
150150
#print(C3)

0 commit comments

Comments
 (0)