@@ -495,7 +495,8 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
495495
496496 };
497497#else
498- auto dequant = [&] (float * quant_map){
498+ // auto dequant = [&] (float* quant_map){
499+ auto dequant = [&] (int start_lut_id){
499500 constexpr int N = decltype (cute::size<1 >(mma_B))::value;
500501 constexpr int K = decltype (cute::size (mma_B))::value / N;
501502 // if(cute::thread0) printf("scale num = %d\n", decltype(cute::size(fragment_scale))::value);
@@ -528,7 +529,7 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
528529 uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
529530 float scale_value = fragment_scale ((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int >(GROUP_SIZE )));
530531 // dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map[bit_value + (dst_base_idx + c) % 4 * 16] * scale_value);
531- dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map [bit_value] * scale_value);
532+ dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map_[start_lut_id] [bit_value] * scale_value);
532533 // dst[dst_base_idx + c] = static_cast<ElementMMA>(params.quant_map_const[bit_value] * scale_value);
533534
534535// uint8_t high = (src_value >> (4 * (c * 2 + 1))) & 0xf;
@@ -559,7 +560,8 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
559560
560561 // int map_offset = 16 * (sg_idx % 4);
561562 // int map_offset = 16 * ((sg_idx ^ (sg_idx >> 2)) % 4);
562- int lut_id = sg_idx % 4 ;
563+ // int lut_id = sg_idx % 4;
564+ int start_lut_id = sg_idx % 4 ;
563565
564566 for (int k_tile = k_start_idx, k_s = 0 ; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++) {
565567#if 1 // SLM: 0, register: 1
@@ -568,7 +570,8 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
568570 copy (params.tiled_copy_a , tAgA (_,_,_,k_tile), frag_copy_A);
569571 // dequant((sg_idx % 4 ) < 2 ? quant_map_1 : quant_map_2);
570572 // dequant(quant_map_ + map_offset);
571- dequant (quant_map_[lut_id]);
573+ // dequant(quant_map_[lut_id]);
574+ dequant (start_lut_id);
572575#else
573576 copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) * BLK_K/params.group_size), frag_copy_Scale);
574577 copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
0 commit comments