Skip to content

Commit 3e41231

Browse files
committed
clean code
1 parent 5312f0a commit 3e41231

1 file changed

Lines changed: 0 additions & 14 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -496,11 +496,9 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
496496

497497
};
498498
#else
499-
//auto dequant = [&] (float* quant_map){
500499
auto dequant = [&] (int start_lut_id){
501500
constexpr int N = decltype(cute::size<1>(mma_B))::value;
502501
constexpr int K = decltype(cute::size(mma_B))::value / N;
503-
//if(cute::thread0) printf("scale num = %d\n", decltype(cute::size(fragment_scale))::value);
504502

505503
using src_compress_type = uint64_t;
506504
using dst_compress_type = uint64_t;
@@ -518,7 +516,6 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
518516

519517
#pragma unroll
520518
for (int n = 0; n < N; n++) {
521-
//float scale_value = fragment_scale(0);
522519
#pragma unroll
523520
for (int l = 0; l < src_loop_num; l++) {
524521
reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[0] = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[n*src_loop_num + l];
@@ -531,20 +528,9 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
531528
for (int c = 0; c < src_compress_size; c++) {
532529
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
533530
float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
534-
//dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map[bit_value + (dst_base_idx + c) % 4 * 16] * scale_value);
535531
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
536-
//printf("sg_idx = %d, thread_idx = %d, dst_id = %d, start_lut_id = %d, lut_id = %d\n", sg_idx, thread_idx, dst_base_idx + c, start_lut_id, lut_id);
537532
lut_id = (lut_id + 1) % LUT_NUM;
538-
//dst[dst_base_idx + c] = static_cast<ElementMMA>(params.quant_map_const[bit_value] * scale_value);
539-
540-
// uint8_t high = (src_value >> (4 * (c * 2 + 1))) & 0xf;
541-
// uint8_t low = (src_value >> (4 * (c * 2))) & 0xf;
542-
// float ts_high = fragment_scale((n * BLK_K + dst_base_idx + 2 * c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));;
543-
// float ts_low = fragment_scale((n * BLK_K + dst_base_idx + 2 * c + 1) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));;
544-
// dst[dst_base_idx + 2 * c] = static_cast<ElementMMA>(quant_map[high] * ts_high);
545-
// dst[dst_base_idx + 2 * c + 1] = static_cast<ElementMMA>(quant_map[low] * ts_low);
546533
}
547-
//lut_id = (lut_id + 1) % LUT_NUM;
548534
}
549535
}
550536

0 commit comments

Comments
 (0)