Skip to content

Commit c442182

Browse files
committed
refine code
1 parent 08fe237 commit c442182

1 file changed

Lines changed: 11 additions & 10 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -241,11 +241,11 @@ class kgemm_4bit_inference_cutlass_dequant {
241241
//if(syclcompat::global_id::x() == 2 && syclcompat::global_id::y() ==0 && syclcompat::global_id::z() ==0 )
242242
//printf("syclcompat::global_id::x() = %d, syclcompat::global_id::y() = %d, syclcompat::global_id::z() = %d, thread_idx = %d, i = %d, in[i].ptr_ = %x, in[i].idx_=%x, value_bit = %x, value = %d, quant_map[value] = %f, out[i] = %f\n",syclcompat::global_id::x(), syclcompat::global_id::y(), syclcompat::global_id::z(), thread_idx, i, in[i].ptr_, in[i].idx_, value, static_cast<int>(value), quant_map[value], static_cast<float>(out[i]));
243243
}
244-
int scale_number = decltype(size(tCrS_input))::value;
245-
for(int i=0; i<scale_number; i++){
246-
auto s_value = tCrS_input(i);
247-
if(cute::thread0()) printf("scale_number = %d, tCrS_input[%d] = %f\n",scale_number, i, s_value);
248-
}
244+
}
245+
int scale_number = decltype(size(tCrS_input))::value;
246+
for(int i=0; i<scale_number; i++){
247+
auto s_value = tCrS_input(i);
248+
if(cute::thread0()) printf("scale_number = %d, tCrS_input[%d] = %f\n",scale_number, i, s_value);
249249
}
250250
#else
251251
static constexpr auto N = decltype(size<1>(in))::value;
@@ -297,6 +297,7 @@ class kgemm_4bit_inference_cutlass_dequant {
297297
int N = params.n;
298298
int K = params.k;
299299
int L = 1;
300+
static constexpr int BLK_K = 64;
300301

301302
T* A = params.A;
302303
uint8_t* B = params.B;
@@ -383,11 +384,11 @@ class kgemm_4bit_inference_cutlass_dequant {
383384

384385
Tensor dequant_frag = make_tensor<ElementB>(mma_B.layout());
385386

386-
static constexpr auto scale_traits_size = decltype(size(typename GmemTiledCopyScale::BlockShape{}))::value / SubgroupSize;
387-
static constexpr auto scale_traits_num = SG_QNT_WIDTH / size<1>(typename GmemTiledCopyScale::BlockShape{});
387+
static constexpr auto scale_traits_size = 16 / SubgroupSize;
388+
static constexpr auto scale_traits_num = 64 / 16;
388389
using FragScaleLayout = Layout<Shape<Int<scale_traits_size>, Int<scale_traits_num>, _1>>;
389390
Tensor fragment_scale = make_tensor<ElementScale>(FragScaleLayout{});
390-
if(cute::thread0()) printf("scale_traits_size = %d, scale_traits_num = %d, SG_QNT_WIDTH = %d\n", scale_traits_size, scale_traits_num, SG_QNT_WIDTH);
391+
if(cute::thread0()) printf("scale_traits_size = %d, scale_traits_num = %d\n", scale_traits_size, scale_traits_num);
391392

392393
static_assert(std::is_same_v<typename decltype(dequant_frag)::value_type, ElementQuant>);
393394
static_assert(std::is_same_v<typename decltype(mma_A)::value_type, ElementMMA>);
@@ -414,15 +415,15 @@ class kgemm_4bit_inference_cutlass_dequant {
414415

415416
// Run mainloop
416417
auto [m_idx, n_idx, k_idx, l_idx] = blk_coord_mnkl;
417-
const int n_coord_s = n_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N;
418+
const int n_coord_s = n_idx * 64 + (get_sub_group_id() % 2) * 32;
418419
const int l_coord_s = l_idx;
419420

420421
if(cute::thread0()) printf("m_idx = %d, n_idx = %d, k_idx = %d, l_idx = %d, n_coord_s = %d, l_coord_s = %d\n",m_idx, n_idx, k_idx, l_idx, n_coord_s, l_coord_s);
421422

422423
auto copy_iter_s = [&](){
423424
return make_tensor(make_inttuple_iter(make_coord(n_coord_s, 0, l_coord_s)),
424425
make_layout(make_shape(Int<scale_traits_size>{}, Int<scale_traits_num>{}, _1{}, k_tile_count),
425-
make_stride(E<0>{} * _16{}, E<0>{} * size<1>(typename GmemTiledCopyScale::BlockShape{}), _0{}, E<1>{} * _1{})));
426+
make_stride(E<0>{} * _16{}, E<0>{} * 16, _0{}, E<1>{} * _1{})));
426427

427428
}();
428429
#if 1

0 commit comments

Comments
 (0)