Skip to content

Commit d84f1d8

Browse files
committed
refine code
1 parent d2c1023 commit d84f1d8

1 file changed

Lines changed: 23 additions & 10 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ using ElementOutput = float;
5959
using ProblemShape = Shape<int, int, int, int>;
6060

6161
using TileShape = Shape<_16, _64, _64>;
62+
using TileShape_half = Shape<_16, _64, _32>;
6263
using TiledMma =
6364
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32F16F16F32_TT>, Layout<TileShape>,
6465
Layout<Shape<_1, _2, _1>, Stride<_2, _1, _0>>>::TiledMMA;
@@ -138,6 +139,7 @@ using Copy_A = decltype(make_tiled_copy(atom_load_A{}, Layout<CopyThreadShape>{}
138139

139140
using GmemTiledCopyB = XE_2D_U4x32x16_LD_T;
140141
using StrideB = cutlass::gemm::TagToStrideB_t<cutlass::layout::ColumnMajor>;
142+
//using StrideB = Stride<int64_t, int64_t, int64_t>;
141143
//using Copy_B = typename Copy_Traits<GmemTiledCopyB, StrideB>::template DefaultTiledCopy<ElementB>;
142144
using traits_load_B = Copy_Traits<GmemTiledCopyB, StrideB>;
143145
using atom_load_B = Copy_Atom<traits_load_B, ElementB>;
@@ -234,7 +236,10 @@ class kgemm_4bit_inference_cutlass_dequant {
234236
//out[i] = static_cast<DstType>(quant_map[in_ptr_8[i].data()]);
235237
uint8_t value = in[i].get();
236238
out[i] = static_cast<DstType>(quant_map[value]);
237-
if(cute::thread0()) printf("thread_idx = %d, i = %d, value_bit = %x, value = %d, quant_map[value] = %f, out[i] = %f\n",int(ThreadIdxX()), i, value, static_cast<int>(value), quant_map[value], static_cast<float>(out[i]));
239+
int thread_idx = int(ThreadIdxX());
240+
//if(thread_idx == 0)
241+
if(syclcompat::global_id::x() == 2 && syclcompat::global_id::y() ==0 && syclcompat::global_id::z() ==0 )
242+
printf("thread_idx = %d, i = %d, value_bit = %x, value = %d, quant_map[value] = %f, out[i] = %f\n",thread_idx, i, value, static_cast<int>(value), quant_map[value], static_cast<float>(out[i]));
238243
}
239244
#else
240245
static constexpr auto N = decltype(size<1>(in))::value;
@@ -330,7 +335,7 @@ if(cute::thread0())
330335
l_coord = BlockIdxZ();
331336
}
332337
auto blk_coord_mnkl = make_coord(m_coord, n_coord, _, l_coord);
333-
if(cute::thread0()) {
338+
if(0){//cute::thread0()) {
334339
printf("M = %d, N=%d, K=%d, L=%d\n", M, N, K, L);
335340
//}
336341
printf("thread_idx = %d, m_coord = %d, n_coord = %d, l_coord = %d, BlockIdxX() = %d, BlockIdxY() = %d, BlockIdxZ() = %d\n",thread_idx, m_coord, n_coord, l_coord, BlockIdxX(), BlockIdxY(), BlockIdxZ());
@@ -414,7 +419,7 @@ if(cute::thread0())
414419
// make_stride(E<0>{} * _16{}, E<0>{} * size<1>(typename GmemTiledCopyScale::BlockShape{}), _0{}, E<1>{} * _1{})));
415420
//
416421
// }();
417-
422+
#if 0
418423
#define PRINT(x) print(#x ": "); print(x); print("\n");
419424
if (cutlass::thread(LOG_THREAD, LOG_GROUP)) {
420425
print("======================= A: \n");
@@ -442,7 +447,7 @@ if(cute::thread0())
442447
print(" pBgB : "); print(pBgB); print("\n");
443448
}
444449
#undef PRINT
445-
450+
#endif
446451
const int k_start_idx = crd2idx((*k_tile_iter), make_shape(K));
447452
int prefetch_k = k_start_idx;
448453

@@ -466,7 +471,7 @@ if(cute::thread0())
466471
if(prefetch_k < k_tile_count) {
467472
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
468473
}
469-
if(prefetch_k < k_tile_count / 2) {
474+
if(prefetch_k < k_tile_count) {
470475
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
471476
}
472477

@@ -517,14 +522,11 @@ if(cute::thread0())
517522
};
518523

519524
template <typename T, int BITS>
520-
void gemm_4bit_inference_cutlass_dequant(int m, int n, int k_, T *A, unsigned char *B,
525+
void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned char *B,
521526
T *absmax_, float *datatype, float *out, int lda,
522527
int ldb, int ldc, int blocksize, sycl::queue *stream) {
523528
std::cout<<"this is gemm_4bit_inference_cutlass_dequant ......................!!!!!!\n";
524529

525-
int k = k_;
526-
527-
528530
sycl::queue q = *stream;
529531
using GemmKernel = kgemm_4bit_inference_cutlass_dequant<T, BITS>;
530532

@@ -555,7 +557,7 @@ int k = k_;
555557
auto mA_mkl = make_tensor(make_gmem_ptr(A), make_layout(make_shape(m, k, l), stride_A));
556558
Copy_A tiled_copy_a{Copy_A{}.with(mA_mkl)};
557559

558-
StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, l));
560+
//StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n/2, k, l));
559561
// auto stride_B_custom = cute::make_stride(
560562
// cute::Int<1>{}, // 连续维度步幅(字节)
561563
// (n * 4 + 7) / 8, // pitch = ceil(n * 4bit / 8bit)
@@ -568,9 +570,20 @@ int k = k_;
568570
// (n * 4 ) / 8,
569571
// (n * k * 4 ) / 8
570572
// );
573+
//int k_half = k/2;
574+
//StrideB stride_B = make_stride(int64_t{1}, int64_t{n}, int64_t{n * k});
575+
StrideB stride_B = make_stride(int64_t{n}, cute::Int<1>{}, int64_t{0});
571576
auto mB_nkl = make_tensor(cute::subbyte_iterator<uint4_t>(B), make_layout(make_shape(n, k, l), stride_B));
572577
Copy_B tiled_copy_b{Copy_B{}.with(mB_nkl)};
573578

579+
#define PRINT(x) print(#x ": "); print(x); print("\n");
580+
if (cutlass::thread(LOG_THREAD, LOG_GROUP)) {
581+
print("===================== B :\n");
582+
print(" stride_B : "); print(stride_B); print("\n");
583+
print("===================== B :\n");
584+
}
585+
#undef PRINT
586+
574587
params.tiled_copy_a = tiled_copy_a;
575588
params.tiled_copy_b = tiled_copy_b;
576589

0 commit comments

Comments
 (0)