@@ -500,17 +500,20 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
500500 constexpr int N = decltype (cute::size<1 >(mma_B))::value;
501501 constexpr int K = decltype (cute::size (mma_B))::value / N;
502502
503- using src_compress_type = uint16_t ;
504- using dst_compress_type = uint64_t ;
503+ using src_compress_type = uint8_t ;
504+ using dst_compress_type = uint32_t ;
505505 constexpr int src_compress_size = cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; // 16
506506 constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; // 4
507- constexpr int src_vec_size = 2 ; // (K / src_compress_size) >= 16 ? 16 : K / src_compress_size; //4, 16 -> max vec_size of sycl::vec
508- constexpr int dst_vec_size = 2 ; // (K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; //16, 16 -> max vec_size of sycl::vec
507+ constexpr int src_vec_size = 4 ; // (K / src_compress_size) >= 16 ? 16 : K / src_compress_size; //4, 16 -> max vec_size of sycl::vec
508+ constexpr int dst_vec_size = 4 ; // (K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; //16, 16 -> max vec_size of sycl::vec
509509 constexpr int src_loop_num = K / src_vec_size / src_compress_size;
510510 constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
511511 src_compress_type src[src_vec_size];
512512 ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
513513
514+ // if(cute::thread0()) {
515+ // printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_vec_size = %d, src_loop_num = %d, dst_loop_num = %d\n", src_compress_size, dst_compress_size, src_vec_size, dst_vec_size, src_loop_num, dst_loop_num);
516+ // }
514517 int lut_id = start_lut_id;
515518// if(sg_idx == 0){
516519// for (int i = 0; i < 64; i++){
@@ -585,34 +588,43 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
585588 uint16_t low_bits = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
586589 reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c] = (static_cast<uint32_t>(low_bits) << 16) | high_bits;
587590 }
588- #elif 0
591+ #elif 1
592+ #pragma unroll
589593 for (int c = 0; c < src_compress_size; c++) {
590594 uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
591595 float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
592596 dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
593597 lut_id = (lut_id + 1) % LUT_NUM;
594598 }
595- reinterpret_cast<sycl::vec<dst_compress_type, 4>*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num + l * src_vec_size + v] = reinterpret_cast<sycl::vec<dst_compress_type, 4>*>(dst)[v];
596- }
597- }
598- #else
599- #pragma unroll
600- for (int c = 0 ; c < src_compress_size; c++) {
601- uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
602- float scale_value = fragment_scale ((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int >(GROUP_SIZE )));
603- dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
604- lut_id = (lut_id + 1 ) % LUT_NUM ;
605- }
599+ //reinterpret_cast<sycl::vec<dst_compress_type, 1>*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num + l * src_vec_size + v] = reinterpret_cast<sycl::vec<dst_compress_type, 1>*>(dst)[l * src_vec_size + v];
606600 }
601+ //reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
607602 }
608-
609603 #pragma unroll
610604 for (int l = 0; l < dst_loop_num; l++) {
611605 reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
612606 //reinterpret_cast<dst_compress_type*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<dst_compress_type*>(dst)[l];
613607 //reinterpret_cast<sycl::vec<dst_compress_type, 2>*>(cute::raw_pointer_cast(mma_B.data()))[n*dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, 2>*>(dst)[l];
614608
615- }
609+ }
610+ //#else
611+ // #pragma unroll
612+ // for (int c = 0; c < src_compress_size; c++) {
613+ // uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
614+ // float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
615+ // dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
616+ // lut_id = (lut_id + 1) % LUT_NUM;
617+ // }
618+ // }
619+ // }
620+ //
621+ // #pragma unroll
622+ // for (int l = 0; l < dst_loop_num; l++) {
623+ // reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
624+ // //reinterpret_cast<dst_compress_type*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<dst_compress_type*>(dst)[l];
625+ // //reinterpret_cast<sycl::vec<dst_compress_type, 2>*>(cute::raw_pointer_cast(mma_B.data()))[n*dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, 2>*>(dst)[l];
626+ //
627+ // }
616628#endif
617629 }
618630 };
0 commit comments