Skip to content

Commit 050a138

Browse files
committed
save code
1 parent 4ef320c commit 050a138

1 file changed

Lines changed: 30 additions & 18 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -500,17 +500,20 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
500500
constexpr int N = decltype(cute::size<1>(mma_B))::value;
501501
constexpr int K = decltype(cute::size(mma_B))::value / N;
502502

503-
using src_compress_type = uint16_t;
504-
using dst_compress_type = uint64_t;
503+
using src_compress_type = uint8_t;
504+
using dst_compress_type = uint32_t;
505505
constexpr int src_compress_size = cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; //16
506506
constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; //4
507-
constexpr int src_vec_size = 2; //(K / src_compress_size) >= 16 ? 16 : K / src_compress_size; //4, 16 -> max vec_size of sycl::vec
508-
constexpr int dst_vec_size = 2; //(K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; //16, 16 -> max vec_size of sycl::vec
507+
constexpr int src_vec_size = 4; //(K / src_compress_size) >= 16 ? 16 : K / src_compress_size; //4, 16 -> max vec_size of sycl::vec
508+
constexpr int dst_vec_size = 4; //(K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; //16, 16 -> max vec_size of sycl::vec
509509
constexpr int src_loop_num = K / src_vec_size / src_compress_size;
510510
constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
511511
src_compress_type src[src_vec_size];
512512
ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
513513

514+
//if(cute::thread0()) {
515+
//printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_vec_size = %d, src_loop_num = %d, dst_loop_num = %d\n", src_compress_size, dst_compress_size, src_vec_size, dst_vec_size, src_loop_num, dst_loop_num);
516+
//}
514517
int lut_id = start_lut_id;
515518
//if(sg_idx == 0){
516519
// for (int i = 0; i < 64; i++){
@@ -585,34 +588,43 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
585588
uint16_t low_bits = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
586589
reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c] = (static_cast<uint32_t>(low_bits) << 16) | high_bits;
587590
}
588-
#elif 0
591+
#elif 1
592+
#pragma unroll
589593
for (int c = 0; c < src_compress_size; c++) {
590594
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
591595
float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
592596
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
593597
lut_id = (lut_id + 1) % LUT_NUM;
594598
}
595-
reinterpret_cast<sycl::vec<dst_compress_type, 4>*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num + l * src_vec_size + v] = reinterpret_cast<sycl::vec<dst_compress_type, 4>*>(dst)[v];
596-
}
597-
}
598-
#else
599-
#pragma unroll
600-
for (int c = 0; c < src_compress_size; c++) {
601-
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
602-
float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
603-
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
604-
lut_id = (lut_id + 1) % LUT_NUM;
605-
}
599+
//reinterpret_cast<sycl::vec<dst_compress_type, 1>*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num + l * src_vec_size + v] = reinterpret_cast<sycl::vec<dst_compress_type, 1>*>(dst)[l * src_vec_size + v];
606600
}
601+
//reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
607602
}
608-
609603
#pragma unroll
610604
for (int l = 0; l < dst_loop_num; l++) {
611605
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
612606
//reinterpret_cast<dst_compress_type*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<dst_compress_type*>(dst)[l];
613607
//reinterpret_cast<sycl::vec<dst_compress_type, 2>*>(cute::raw_pointer_cast(mma_B.data()))[n*dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, 2>*>(dst)[l];
614608

615-
}
609+
}
610+
//#else
611+
// #pragma unroll
612+
// for (int c = 0; c < src_compress_size; c++) {
613+
// uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
614+
// float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
615+
// dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
616+
// lut_id = (lut_id + 1) % LUT_NUM;
617+
// }
618+
// }
619+
// }
620+
//
621+
// #pragma unroll
622+
// for (int l = 0; l < dst_loop_num; l++) {
623+
// reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
624+
// //reinterpret_cast<dst_compress_type*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<dst_compress_type*>(dst)[l];
625+
// //reinterpret_cast<sycl::vec<dst_compress_type, 2>*>(cute::raw_pointer_cast(mma_B.data()))[n*dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, 2>*>(dst)[l];
626+
//
627+
// }
616628
#endif
617629
}
618630
};

0 commit comments

Comments
 (0)