Skip to content

Commit a2bd43b

Browse files
committed
save code
1 parent be85daa commit a2bd43b

1 file changed

Lines changed: 12 additions & 10 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -468,16 +468,18 @@ if(thread_idx==0 && m_coord==0 && n_coord==0 && l_coord==0) {
468468
src_compress_type src_value = src[v];
469469
int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
470470
#pragma unroll
471-
for (int c = 0; c < src_compress_size/2; c++) {
472-
//uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
473-
//float scale_value = fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + c) / GROUP_SIZE);
474-
//dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
475-
uint8_t high = (src_value >> (4 * (c * 2 + 1))) & 0xf;
476-
uint8_t low = (src_value >> (4 * (c * 2))) & 0xf;
477-
float ts_high = fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE);
478-
float ts_low = fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE);
479-
dst[dst_base_idx + 2 * c] = static_cast<ElementMMA>(quant_map[high] * ts_high);
480-
dst[dst_base_idx + 2 * c + 1] = static_cast<ElementMMA>(quant_map[low] * ts_low);
471+
for (int c = 0; c < src_compress_size; c++) {
472+
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
473+
//float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) / GROUP_SIZE);
474+
float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
475+
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
476+
477+
//uint8_t high = (src_value >> (4 * (c * 2 + 1))) & 0xf;
478+
//uint8_t low = (src_value >> (4 * (c * 2))) & 0xf;
479+
//float ts_high = fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE);
480+
//float ts_low = fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE);
481+
//dst[dst_base_idx + 2 * c] = static_cast<ElementMMA>(quant_map[high] * ts_high);
482+
//dst[dst_base_idx + 2 * c + 1] = static_cast<ElementMMA>(quant_map[low] * ts_low);
481483
#if 0
482484
//dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_alias(bit_value) * scale_value);
483485

0 commit comments

Comments
 (0)