@@ -496,11 +496,9 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
496496
497497 };
498498#else
499- // auto dequant = [&] (float* quant_map){
500499 auto dequant = [&] (int start_lut_id){
501500 constexpr int N = decltype (cute::size<1 >(mma_B))::value;
502501 constexpr int K = decltype (cute::size (mma_B))::value / N;
503- // if(cute::thread0) printf("scale num = %d\n", decltype(cute::size(fragment_scale))::value);
504502
505503 using src_compress_type = uint64_t ;
506504 using dst_compress_type = uint64_t ;
@@ -518,7 +516,6 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
518516
519517 #pragma unroll
520518 for (int n = 0 ; n < N; n++) {
521- // float scale_value = fragment_scale(0);
522519 #pragma unroll
523520 for (int l = 0 ; l < src_loop_num; l++) {
524521 reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(src)[0 ] = reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast (dequant_frag.data ()))[n*src_loop_num + l];
@@ -531,20 +528,9 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
531528 for (int c = 0 ; c < src_compress_size; c++) {
532529 uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
533530 float scale_value = fragment_scale ((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int >(GROUP_SIZE )));
534- // dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map[bit_value + (dst_base_idx + c) % 4 * 16] * scale_value);
535531 dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
536- // printf("sg_idx = %d, thread_idx = %d, dst_id = %d, start_lut_id = %d, lut_id = %d\n", sg_idx, thread_idx, dst_base_idx + c, start_lut_id, lut_id);
537532 lut_id = (lut_id + 1 ) % LUT_NUM ;
538- // dst[dst_base_idx + c] = static_cast<ElementMMA>(params.quant_map_const[bit_value] * scale_value);
539-
540- // uint8_t high = (src_value >> (4 * (c * 2 + 1))) & 0xf;
541- // uint8_t low = (src_value >> (4 * (c * 2))) & 0xf;
542- // float ts_high = fragment_scale((n * BLK_K + dst_base_idx + 2 * c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));;
543- // float ts_low = fragment_scale((n * BLK_K + dst_base_idx + 2 * c + 1) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));;
544- // dst[dst_base_idx + 2 * c] = static_cast<ElementMMA>(quant_map[high] * ts_high);
545- // dst[dst_base_idx + 2 * c + 1] = static_cast<ElementMMA>(quant_map[low] * ts_low);
546533 }
547- // lut_id = (lut_id + 1) % LUT_NUM;
548534 }
549535 }
550536
0 commit comments