Skip to content

Commit dd2c366

Browse files
committed
save code
1 parent 876606c commit dd2c366

1 file changed

Lines changed: 11 additions & 4 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -408,12 +408,19 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
408408
#pragma unroll
409409
for (int v = 0; v < src_vec_size; v++) {
410410
int dst_base_idx = v * src_compress_size;
411+
int c = 0;
412+
uint8_t bit_value = (src[v] >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
413+
float scale_value = fragment_scale((dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
414+
float converted_value = quant_map[bit_value];
411415
#pragma unroll
412-
for (int c = 0; c < src_compress_size; c++) {
413-
uint8_t bit_value = (src[v] >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
414-
float scale_value = fragment_scale((dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
415-
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
416+
for (; c < src_compress_size-1;) {
417+
c++;
418+
bit_value = (src[v] >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
419+
scale_value = fragment_scale((dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
420+
dst[dst_base_idx + c-1] = static_cast<ElementMMA>(converted_value * scale_value);
421+
converted_value = quant_map[bit_value];
416422
}
423+
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
417424
}
418425

419426
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[0] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[0];

0 commit comments

Comments
 (0)