Skip to content

Commit 4248edf

Browse files
committed
save code
1 parent ee3fbef commit 4248edf

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,8 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
512512
src_compress_type src[src_vec_size];
513513
ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
514514

515+
int lut_id = start_lut_id;
516+
515517

516518
#pragma unroll
517519
for (int n = 0; n < N; n++) {
@@ -529,7 +531,8 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
529531
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
530532
float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
531533
//dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map[bit_value + (dst_base_idx + c) % 4 * 16] * scale_value);
532-
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[start_lut_id][bit_value] * scale_value);
534+
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
535+
//lut_id = (lut_id + 1) % 4;
533536
//dst[dst_base_idx + c] = static_cast<ElementMMA>(params.quant_map_const[bit_value] * scale_value);
534537

535538
// uint8_t high = (src_value >> (4 * (c * 2 + 1))) & 0xf;
@@ -539,6 +542,7 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
539542
// dst[dst_base_idx + 2 * c] = static_cast<ElementMMA>(quant_map[high] * ts_high);
540543
// dst[dst_base_idx + 2 * c + 1] = static_cast<ElementMMA>(quant_map[low] * ts_low);
541544
}
545+
lut_id = (lut_id + 1) % 4;
542546
}
543547
}
544548

0 commit comments

Comments
 (0)