Skip to content

Commit be85daa

Browse files
committed
save code
1 parent 51ebf6e commit be85daa

1 file changed

Lines changed: 10 additions & 4 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -468,10 +468,16 @@ if(thread_idx==0 && m_coord==0 && n_coord==0 && l_coord==0) {
468468
src_compress_type src_value = src[v];
469469
int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
470470
#pragma unroll
471-
for (int c = 0; c < src_compress_size; c++) {
472-
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
473-
float scale_value = fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + c) / GROUP_SIZE);
474-
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
471+
for (int c = 0; c < src_compress_size/2; c++) {
472+
//uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
473+
//float scale_value = fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + c) / GROUP_SIZE);
474+
//dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
475+
uint8_t high = (src_value >> (4 * (c * 2 + 1))) & 0xf;
476+
uint8_t low = (src_value >> (4 * (c * 2))) & 0xf;
477+
float ts_high = fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE);
478+
float ts_low = fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE);
479+
dst[dst_base_idx + 2 * c] = static_cast<ElementMMA>(quant_map[high] * ts_high);
480+
dst[dst_base_idx + 2 * c + 1] = static_cast<ElementMMA>(quant_map[low] * ts_low);
475481
#if 0
476482
//dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_alias(bit_value) * scale_value);
477483

0 commit comments

Comments
 (0)