Skip to content

Commit ee3fbef

Browse files
committed
save code
1 parent 14f83eb commit ee3fbef

1 file changed

Lines changed: 7 additions & 4 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,8 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
495495

496496
};
497497
#else
498-
auto dequant = [&] (float* quant_map){
498+
//auto dequant = [&] (float* quant_map){
499+
auto dequant = [&] (int start_lut_id){
499500
constexpr int N = decltype(cute::size<1>(mma_B))::value;
500501
constexpr int K = decltype(cute::size(mma_B))::value / N;
501502
//if(cute::thread0) printf("scale num = %d\n", decltype(cute::size(fragment_scale))::value);
@@ -528,7 +529,7 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
528529
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
529530
float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
530531
//dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map[bit_value + (dst_base_idx + c) % 4 * 16] * scale_value);
531-
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
532+
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[start_lut_id][bit_value] * scale_value);
532533
//dst[dst_base_idx + c] = static_cast<ElementMMA>(params.quant_map_const[bit_value] * scale_value);
533534

534535
// uint8_t high = (src_value >> (4 * (c * 2 + 1))) & 0xf;
@@ -559,7 +560,8 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
559560

560561
//int map_offset = 16 * (sg_idx % 4);
561562
//int map_offset = 16 * ((sg_idx ^ (sg_idx >> 2)) % 4);
562-
int lut_id = sg_idx % 4;
563+
//int lut_id = sg_idx % 4;
564+
int start_lut_id = sg_idx % 4;
563565

564566
for (int k_tile = k_start_idx, k_s = 0; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++) {
565567
#if 1 //SLM: 0, register: 1
@@ -568,7 +570,8 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
568570
copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
569571
//dequant((sg_idx % 4 ) < 2 ? quant_map_1 : quant_map_2);
570572
//dequant(quant_map_ + map_offset);
571-
dequant(quant_map_[lut_id]);
573+
//dequant(quant_map_[lut_id]);
574+
dequant(start_lut_id);
572575
#else
573576
copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) * BLK_K/params.group_size), frag_copy_Scale);
574577
copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);

0 commit comments

Comments
 (0)