Skip to content

Commit d361323

Browse files
committed
refine code
1 parent fb9106d commit d361323

1 file changed

Lines changed: 29 additions & 44 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 29 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -209,16 +209,16 @@ class kgemm_4bit_inference_cutlass_dequant {
209209
/// Utilities to transform A.
210210
template <class EngineIn,
211211
class EngineOut,
212-
//class EngineScales,
212+
class EngineScales,
213213
class LayoutIn,
214214
class LayoutOut,
215-
//class LayoutScales,
215+
class LayoutScales,
216216
class... Ts>
217217
CUTLASS_DEVICE
218218
void dequant(
219219
Tensor<EngineIn, LayoutIn> const& in,
220220
Tensor<EngineOut, LayoutOut>& out,
221-
//Tensor<EngineScales, LayoutScales>& tCrS_input,
221+
Tensor<EngineScales, LayoutScales>& tCrS_input,
222222
float* quant_map
223223
) {
224224
static_assert(is_rmem<EngineIn>::value, "Input tensor for A conversion must come from registers");
@@ -227,7 +227,7 @@ class kgemm_4bit_inference_cutlass_dequant {
227227

228228
using SrcType = typename EngineIn::value_type;
229229
using DstType = typename EngineOut::value_type;
230-
//using ScaleType = typename EngineScales::value_type;
230+
using ScaleType = typename EngineScales::value_type;
231231
#if 0
232232
int numbers = decltype(size(in))::value;
233233
for(int i=0; i<numbers; i++){
@@ -263,7 +263,7 @@ class kgemm_4bit_inference_cutlass_dequant {
263263
// printf("thread_idx = %d, decltype(size(in))::value = %d, K = %d, N = %d, L = %d, src_bits = %d, sizeof_bits_v<format_type> = %d, scalar = %d, decltype(size(out))::value = %d, loop_cnt = %d, splits = %d\n",int(ThreadIdxX()), decltype(size(in))::value, decltype(size<0>(in))::value, N, decltype(size<2>(in))::value, src_bits, sizeof_bits_v<format_type>, scalar, decltype(size(out))::value, loop_cnt, splits);
264264

265265
for (int n = 0; n < N; n++) {
266-
//const auto ts = tCrS_input(n);
266+
const auto ts = tCrS_input(n);
267267

268268
auto& src = *(cute::array<format_type, loop_cnt / scalar>*)(s_tensor(_, n).data());
269269

@@ -299,7 +299,7 @@ class kgemm_4bit_inference_cutlass_dequant {
299299

300300
auto tiled_copy_a = params.tiled_copy_a;
301301
auto tiled_copy_b = params.tiled_copy_b;
302-
//auto tiled_copy_scale = params.tiled_copy_scale;
302+
auto tiled_copy_scale = params.tiled_copy_scale;
303303

304304
auto problem_size = ProblemShape{M, N, K, L};
305305

@@ -320,7 +320,7 @@ class kgemm_4bit_inference_cutlass_dequant {
320320
barrier_wait(1);
321321

322322
//// Get the block level coordinate(indexing) for current block
323-
auto blk_shape = TileShape{}; //256,256,32
323+
auto blk_shape = TileShape{}; //16,64,64
324324
int m_coord, n_coord, l_coord; //block index
325325
if (params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN) {
326326
if(cute::thread0()) printf("AlongN !!\n");
@@ -363,7 +363,7 @@ class kgemm_4bit_inference_cutlass_dequant {
363363
////// MainLoop //////
364364
auto thr_copy_A = tiled_copy_a.get_slice(thread_idx);
365365
auto thr_copy_B = tiled_copy_b.get_slice(thread_idx);
366-
//auto thr_copy_scale = tiled_copy_scale.get_slice(thread_idx);
366+
auto thr_copy_scale = tiled_copy_scale.get_slice(thread_idx);
367367

368368
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
369369
auto first_thread_in_sg_idx = sg.get_group_linear_id() * DispatchPolicy::SubgroupSize;
@@ -379,10 +379,10 @@ class kgemm_4bit_inference_cutlass_dequant {
379379

380380
Tensor dequant_frag = make_tensor<ElementB>(mma_B.layout());
381381

382-
//static constexpr auto scale_traits_size = decltype(size(typename GmemTiledCopyScale::BlockShape{}))::value / SubgroupSize;
383-
//static constexpr auto scale_traits_num = SG_QNT_WIDTH / size<1>(typename GmemTiledCopyScale::BlockShape{});
384-
//using FragScaleLayout = Layout<Shape<Int<scale_traits_size>, Int<scale_traits_num>, _1>>;
385-
//Tensor fragment_scale = make_tensor<ElementScale>(FragScaleLayout{});
382+
static constexpr auto scale_traits_size = decltype(size(typename GmemTiledCopyScale::BlockShape{}))::value / SubgroupSize;
383+
static constexpr auto scale_traits_num = SG_QNT_WIDTH / size<1>(typename GmemTiledCopyScale::BlockShape{});
384+
using FragScaleLayout = Layout<Shape<Int<scale_traits_size>, Int<scale_traits_num>, _1>>;
385+
Tensor fragment_scale = make_tensor<ElementScale>(FragScaleLayout{});
386386

387387
static_assert(std::is_same_v<typename decltype(dequant_frag)::value_type, ElementQuant>);
388388
static_assert(std::is_same_v<typename decltype(mma_A)::value_type, ElementMMA>);
@@ -391,7 +391,7 @@ class kgemm_4bit_inference_cutlass_dequant {
391391
//// Retile for copy
392392
Tensor frag_copy_A = thr_copy_A.retile_D(mma_A);
393393
Tensor frag_copy_B = thr_copy_B.retile_D(dequant_frag);
394-
//Tensor frag_copy_Scale = thr_copy_scale.retile_D(fragment_scale);
394+
Tensor frag_copy_Scale = thr_copy_scale.retile_D(fragment_scale);
395395

396396
//// Retile global counting tensors for copies:
397397
Tensor tAgA = thr_copy_A.retile_S(tCgA);
@@ -407,17 +407,17 @@ class kgemm_4bit_inference_cutlass_dequant {
407407
auto pAgA = thr_prefetch_A.partition_S(gA);
408408
auto pBgB = thr_prefetch_B.partition_S(gB);
409409

410-
//// Run mainloop
411-
// auto [m_idx, n_idx, k_idx, l_idx] = blk_coord_mnkl;
412-
// const int n_coord_s = n_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N;
413-
// const int l_coord_s = l_idx;
414-
//
415-
// auto copy_iter_s = [&](){
416-
// return make_tensor(make_inttuple_iter(make_coord(n_coord_s, 0, l_coord_s)),
417-
// make_layout(make_shape(Int<scale_traits_size>{}, Int<scale_traits_num>{}, _1{}, k_tile_count),
418-
// make_stride(E<0>{} * _16{}, E<0>{} * size<1>(typename GmemTiledCopyScale::BlockShape{}), _0{}, E<1>{} * _1{})));
419-
//
420-
// }();
410+
// Run mainloop
411+
auto [m_idx, n_idx, k_idx, l_idx] = blk_coord_mnkl;
412+
const int n_coord_s = n_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N;
413+
const int l_coord_s = l_idx;
414+
415+
auto copy_iter_s = [&](){
416+
return make_tensor(make_inttuple_iter(make_coord(n_coord_s, 0, l_coord_s)),
417+
make_layout(make_shape(Int<scale_traits_size>{}, Int<scale_traits_num>{}, _1{}, k_tile_count),
418+
make_stride(E<0>{} * _16{}, E<0>{} * size<1>(typename GmemTiledCopyScale::BlockShape{}), _0{}, E<1>{} * _1{})));
419+
420+
}();
421421
#if 1
422422
#define PRINT(x) print(#x ": "); print(x); print("\n");
423423
if (cutlass::thread(LOG_THREAD, LOG_GROUP)) {
@@ -458,7 +458,7 @@ class kgemm_4bit_inference_cutlass_dequant {
458458
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
459459
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
460460
}
461-
//k_tile_count=1;
461+
462462
for (int k_tile = k_start_idx; k_tile < k_tile_count + k_start_idx; k_tile++, prefetch_k++) {
463463
barrier_arrive(2);
464464

@@ -468,7 +468,7 @@ class kgemm_4bit_inference_cutlass_dequant {
468468

469469
const int k_reload_factor = params.group_size / BLK_K;
470470

471-
//copy(tiled_copy_scale, copy_iter_s(_, _, _, k_tile / k_reload_factor), frag_copy_Scale);
471+
copy(tiled_copy_scale, copy_iter_s(_, _, _, k_tile / k_reload_factor), frag_copy_Scale);
472472

473473
if(prefetch_k < k_tile_count) {
474474
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
@@ -477,9 +477,8 @@ class kgemm_4bit_inference_cutlass_dequant {
477477
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
478478
}
479479

480-
dequant(dequant_frag, mma_B, /*fragment_scale,*/ quant_map);
480+
dequant(dequant_frag, mma_B, fragment_scale, quant_map);
481481

482-
//barrier_wait(1);
483482

484483
cute::gemm(tiled_mma, mma_A, mma_B, accumulators);
485484
barrier_wait(2);
@@ -575,22 +574,8 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
575574
auto mA_mkl = make_tensor(make_gmem_ptr(A), make_layout(make_shape(m, k, l), stride_A));
576575
Copy_A tiled_copy_a{Copy_A{}.with(mA_mkl)};
577576

578-
//StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n/2, k, l));
579-
// auto stride_B_custom = cute::make_stride(
580-
// cute::Int<1>{}, // 连续维度步幅(字节)
581-
// (n * 4 + 7) / 8, // pitch = ceil(n * 4bit / 8bit)
582-
// cute::Int<0>{} // 无批量步幅(根据需求调整)
583-
//);
584-
// constexpr int stride_k = (n * 4 ) / 8;
585-
// constexpr int stride_l = (n * k * 4 ) / 8;
586-
// auto stride_B = cute::make_stride(
587-
// cute::Int<1>{},
588-
// (n * 4 ) / 8,
589-
// (n * k * 4 ) / 8
590-
// );
591-
//int k_half = k/2;
592-
//StrideB stride_B = make_stride(int64_t{1}, int64_t{n}, int64_t{n * k});
593-
StrideB stride_B = make_stride(int64_t{n}, cute::Int<1>{}, int64_t{0});
577+
//StrideB stride_B = make_stride(int64_t{n}, cute::Int<1>{}, int64_t{0});
578+
StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, l));
594579
auto mB_nkl = make_tensor(cute::subbyte_iterator<ElementB>(B), make_layout(make_shape(n, k, l), stride_B));
595580
Copy_B tiled_copy_b{Copy_B{}.with(mB_nkl)};
596581

0 commit comments

Comments
 (0)