Skip to content

Commit 51ebf6e

Browse files
committed
save code
1 parent 678eeaa commit 51ebf6e

1 file changed

Lines changed: 32 additions & 0 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,8 @@ if(cute::thread0()) printf("N = %d, K = %d, src_compress_size = %d, dst_compress
446446
src_compress_type src[src_vec_size];
447447
ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
448448

449+
sycl::vec<float, 16> loaded = *(sycl::vec<float, 16>*)&quant_map[0];
450+
449451
#pragma unroll
450452
for (int n = 0; n < N; n++) {
451453
#pragma unroll
@@ -470,7 +472,37 @@ if(thread_idx==0 && m_coord==0 && n_coord==0 && l_coord==0) {
470472
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
471473
float scale_value = fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + c) / GROUP_SIZE);
472474
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
475+
#if 0
473476
//dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_alias(bit_value) * scale_value);
477+
478+
constexpr uint8_t VEC_WIDTH = 4;
479+
uint8_t base_offset = (bit_value / VEC_WIDTH) * VEC_WIDTH;
480+
sycl::vec<float, VEC_WIDTH> loaded = *(sycl::vec<float, VEC_WIDTH>*)&quant_map[base_offset];
481+
//auto mask = (sycl::vec<int, VEC_WIDTH>(0,1,2,3) == (bit_value % VEC_WIDTH));
482+
//float convert_value = loaded[0] * static_cast<float>(mask[0]) +
483+
// loaded[1] * static_cast<float>(mask[1]) +
484+
// loaded[2] * static_cast<float>(mask[2]) +
485+
// loaded[3] * static_cast<float>(mask[3]);
486+
auto lane = bit_value % VEC_WIDTH;
487+
float convert_value = loaded[lane];
488+
dst[dst_base_idx + c] = static_cast<ElementMMA>(convert_value * scale_value);
489+
//#endif
490+
auto mask = (sycl::vec<uint8_t, 16>(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15) == sycl::vec<uint8_t, 16>(bit_value));
491+
float convert_value = 0.0f;
492+
#pragma unroll
493+
for (int i = 0; i < 16; ++i) {
494+
convert_value += loaded[i] * static_cast<float>(mask[i]);
495+
}
496+
497+
//auto sg = sycl::ext::oneapi::experimental::this_sub_group();
498+
//float convert_value = sycl::select_from_group(
499+
// sg,
500+
// loaded,
501+
// bit_value // 直接使用bit_value作为索引
502+
//);
503+
dst[dst_base_idx + c] = static_cast<ElementMMA>(convert_value * scale_value);
504+
#endif
505+
474506
#if 0
475507
if(thread_idx==60 && m_coord==0 && n_coord==0 && l_coord==0){
476508
printf("tid = %d, m_coord = %d, n_coord = %d, l_coord = %d, n = %d, src_l = %d, dst_dx = %d, scale_idx = %d, scale_value = %f\n", thread_idx, m_coord, n_coord, l_coord, n, l, dst_base_idx+c, n * (BLK_K / GROUP_SIZE) + (dst_base_idx+c)/GROUP_SIZE, scale_value);

0 commit comments

Comments
 (0)