File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -468,10 +468,16 @@ if(thread_idx==0 && m_coord==0 && n_coord==0 && l_coord==0) {
468468 src_compress_type src_value = src[v];
469469 int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
470470 #pragma unroll
471- for (int c = 0 ; c < src_compress_size; c++) {
472- uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
473- float scale_value = fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + c) / GROUP_SIZE );
474- dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map[bit_value] * scale_value);
471+ for (int c = 0 ; c < src_compress_size/2 ; c++) {
472+ // uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
473+ // float scale_value = fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + c) / GROUP_SIZE);
474+ // dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
475+ uint8_t high = (src_value >> (4 * (c * 2 + 1 ))) & 0xf ;
476+ uint8_t low = (src_value >> (4 * (c * 2 ))) & 0xf ;
477+ float ts_high = fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c) / GROUP_SIZE );
478+ float ts_low = fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + 2 * c + 1 ) / GROUP_SIZE );
479+ dst[dst_base_idx + 2 * c] = static_cast <ElementMMA>(quant_map[high] * ts_high);
480+ dst[dst_base_idx + 2 * c + 1 ] = static_cast <ElementMMA>(quant_map[low] * ts_low);
475481#if 0
476482 //dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_alias(bit_value) * scale_value);
477483
You can’t perform that action at this time.
0 commit comments