@@ -468,16 +468,18 @@ if(thread_idx==0 && m_coord==0 && n_coord==0 && l_coord==0) {
468468 src_compress_type src_value = src[v];
469469 int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
470470 #pragma unroll
471- for (int c = 0 ; c < src_compress_size/2 ; c++) {
472- // uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
473- // float scale_value = fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + c) / GROUP_SIZE);
474- // dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
475- uint8_t high = (src_value >> (4 * (c * 2 + 1 ))) & 0xf ;
476- uint8_t low = (src_value >> (4 * (c * 2 ))) & 0xf ;
477- float ts_high = fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c) / GROUP_SIZE );
478- float ts_low = fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c + 1 ) / GROUP_SIZE );
479- dst[dst_base_idx + 2 * c] = static_cast <ElementMMA>(quant_map[high] * ts_high);
480- dst[dst_base_idx + 2 * c + 1 ] = static_cast <ElementMMA>(quant_map[low] * ts_low);
471+ for (int c = 0 ; c < src_compress_size; c++) {
472+ uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
473+ // float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) / GROUP_SIZE);
474+ float scale_value = fragment_scale ((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int >(GROUP_SIZE )));
475+ dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map[bit_value] * scale_value);
476+
477+ // uint8_t high = (src_value >> (4 * (c * 2 + 1))) & 0xf;
478+ // uint8_t low = (src_value >> (4 * (c * 2))) & 0xf;
479+ // float ts_high = fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE);
480+ // float ts_low = fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE);
481+ // dst[dst_base_idx + 2 * c] = static_cast<ElementMMA>(quant_map[high] * ts_high);
482+ // dst[dst_base_idx + 2 * c + 1] = static_cast<ElementMMA>(quant_map[low] * ts_low);
481483#if 0
482484 //dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_alias(bit_value) * scale_value);
483485
0 commit comments