Skip to content

Commit 47ae47e

Browse files
committed
save code
1 parent 9cc08c0 commit 47ae47e

1 file changed

Lines changed: 18 additions & 78 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 18 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -504,24 +504,16 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
504504
using dst_compress_type = uint32_t;
505505
constexpr int src_compress_size = cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; //16
506506
constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; //4
507-
constexpr int src_vec_size = 4; //(K / src_compress_size) >= 16 ? 16 : K / src_compress_size; //4, 16 -> max vec_size of sycl::vec
508-
constexpr int dst_vec_size = 4; //(K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; //16, 16 -> max vec_size of sycl::vec
507+
constexpr int src_vec_size = 16;
509508
constexpr int src_loop_num = K / src_vec_size / src_compress_size;
510-
constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
511509
src_compress_type src[src_vec_size];
510+
511+
#if 1
512+
constexpr int dst_vec_size = src_vec_size;
513+
constexpr int dst_loop_num = src_loop_num;
512514
ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
513515

514-
//if(cute::thread0()) {
515-
//printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_vec_size = %d, src_loop_num = %d, dst_loop_num = %d\n", src_compress_size, dst_compress_size, src_vec_size, dst_vec_size, src_loop_num, dst_loop_num);
516-
//}
517516
int lut_id = start_lut_id;
518-
//if(sg_idx == 0){
519-
// for (int i = 0; i < 64; i++){
520-
// printf("tid = %d, dequant_frag ptr[%d] = %x, mma_B ptr[%d] = %x\n",thread_idx, i, cute::raw_pointer_cast(dequant_frag.data()+i),i, cute::raw_pointer_cast(mma_B.data()+i));
521-
// }
522-
//}
523-
524-
int pre_num = 1;
525517
#pragma unroll
526518
for (int n = 0; n < N; n++) {
527519
#pragma unroll
@@ -533,81 +525,29 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
533525
//src_compress_type src_value = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[n*src_loop_num + l][v]; //src[v];
534526
src_compress_type src_value = src[v];
535527
int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
536-
#if 0
537-
int c = 0;
538-
uint16_t high_bits_1 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
539-
lut_id = (lut_id + 1) % LUT_NUM;
540-
uint16_t low_bits_1 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
541-
542-
c++;
543-
reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-pre_num] = (static_cast<uint32_t>(low_bits_1) << 16) | high_bits_1;
544-
uint16_t high_bits_2 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
545-
lut_id = (lut_id + 1) % LUT_NUM;
546-
uint16_t low_bits_2 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
547-
548-
c++;
549-
reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-pre_num] = (static_cast<uint32_t>(low_bits_2) << 16) | high_bits_2;
550-
uint16_t high_bits_3 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
551-
lut_id = (lut_id + 1) % LUT_NUM;
552-
uint16_t low_bits_3 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
553-
554-
c++;
555-
reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-pre_num] = (static_cast<uint32_t>(low_bits_3) << 16) | high_bits_3;
556-
uint16_t high_bits_4 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
557-
lut_id = (lut_id + 1) % LUT_NUM;
558-
uint16_t low_bits_4 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
559-
560-
c++;
561-
reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-pre_num] = (static_cast<uint32_t>(low_bits_4) << 16) | high_bits_4;
562-
uint16_t high_bits_5 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
563-
lut_id = (lut_id + 1) % LUT_NUM;
564-
uint16_t low_bits_5 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
565-
566-
c++;
567-
reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-pre_num] = (static_cast<uint32_t>(low_bits_5) << 16) | high_bits_5;
568-
uint16_t high_bits_6 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
569-
lut_id = (lut_id + 1) % LUT_NUM;
570-
uint16_t low_bits_6 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
571-
572-
c++;
573-
reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-pre_num] = (static_cast<uint32_t>(low_bits_6) << 16) | high_bits_6;
574-
uint16_t high_bits_7 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
575-
lut_id = (lut_id + 1) % LUT_NUM;
576-
uint16_t low_bits_7 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
577-
578-
c++;
579-
reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-1] = (static_cast<uint32_t>(low_bits_7) << 16) | high_bits_7;
580-
uint16_t high_bits_8 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
581-
lut_id = (lut_id + 1) % LUT_NUM;
582-
uint16_t low_bits_8 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
583-
reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c] = (static_cast<uint32_t>(low_bits_8) << 16) | high_bits_8;
584-
#elif 0
585-
for (int c = 0; c < src_compress_size/2; c++) {
586-
uint16_t high_bits = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
587-
lut_id = (lut_id + 1) % LUT_NUM;
588-
uint16_t low_bits = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
589-
reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c] = (static_cast<uint32_t>(low_bits) << 16) | high_bits;
590-
}
591-
#elif 0
592528
#pragma unroll
593529
for (int c = 0; c < src_compress_size; c++) {
594530
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
595531
float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
596532
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
597533
lut_id = (lut_id + 1) % LUT_NUM;
598534
}
599-
//reinterpret_cast<sycl::vec<dst_compress_type, 1>*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num + l * src_vec_size + v] = reinterpret_cast<sycl::vec<dst_compress_type, 1>*>(dst)[l * src_vec_size + v];
535+
reinterpret_cast<sycl::vec<dst_compress_type, src_compress_size / dst_compress_size>*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num + l * src_vec_size + v] = reinterpret_cast<sycl::vec<dst_compress_type, src_compress_size / dst_compress_size>*>(dst)[l * src_vec_size + v];
600536
}
601-
//reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
537+
//reinterpret_cast<sycl::vec<dst_compress_type, src_vec_size * src_compress_size / dst_compress_size>*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, src_vec_size * src_compress_size / dst_compress_size>*>(dst)[l];
602538
}
603-
#pragma unroll
604-
for (int l = 0; l < dst_loop_num; l++) {
605-
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
606-
//reinterpret_cast<dst_compress_type*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<dst_compress_type*>(dst)[l];
607-
//reinterpret_cast<sycl::vec<dst_compress_type, 2>*>(cute::raw_pointer_cast(mma_B.data()))[n*dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, 2>*>(dst)[l];
608-
609-
}
610539
#else
540+
constexpr int dst_vec_size = 4;
541+
constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
542+
ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
543+
544+
int lut_id = start_lut_id;
545+
#pragma unroll
546+
for (int n = 0; n < N; n++) {
547+
#pragma unroll
548+
for (int l = 0; l < src_loop_num; l++) {
549+
reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[0] = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[n*src_loop_num + l];
550+
611551
#pragma unroll
612552
for (int c = 0; c < src_compress_size; c++) {
613553
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;

0 commit comments

Comments
 (0)