Skip to content

Commit 6f64ec9

Browse files
committed
save code
1 parent 71f3b6e commit 6f64ec9

1 file changed

Lines changed: 21 additions & 5 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -578,26 +578,42 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
578578
lut_id = (lut_id + 1) % LUT_NUM;
579579
uint16_t low_bits_8 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
580580
reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c] = (static_cast<uint32_t>(low_bits_8) << 16) | high_bits_8;
581-
#else
582-
#if 0
581+
#elif 0
583582
for (int c = 0; c < src_compress_size/2; c++) {
584583
uint16_t high_bits = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
585584
lut_id = (lut_id + 1) % LUT_NUM;
586585
uint16_t low_bits = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
587586
reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c] = (static_cast<uint32_t>(low_bits) << 16) | high_bits;
588587
}
589-
#else
588+
#elif 0
590589
for (int c = 0; c < src_compress_size; c++) {
591590
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
592591
float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
593592
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
594593
lut_id = (lut_id + 1) % LUT_NUM;
595594
}
596595
reinterpret_cast<sycl::vec<dst_compress_type, 4>*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num + l * src_vec_size + v] = reinterpret_cast<sycl::vec<dst_compress_type, 4>*>(dst)[v];
597-
#endif
598-
#endif
599596
}
600597
}
598+
#else
599+
#pragma unroll
600+
for (int c = 0; c < src_compress_size; c++) {
601+
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
602+
float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
603+
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
604+
lut_id = (lut_id + 1) % LUT_NUM;
605+
}
606+
}
607+
}
608+
609+
#pragma unroll
610+
for (int l = 0; l < dst_loop_num / 4; l++) {
611+
//reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
612+
//reinterpret_cast<dst_compress_type*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<dst_compress_type*>(dst)[l];
613+
reinterpret_cast<sycl::vec<dst_compress_type, 4>*>(cute::raw_pointer_cast(mma_B.data()))[n*dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, 4>*>(dst)[l];
614+
615+
}
616+
#endif
601617
}
602618
};
603619
#endif

0 commit comments

Comments
 (0)