@@ -504,24 +504,16 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
504504 using dst_compress_type = uint32_t ;
505505 constexpr int src_compress_size = cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; // 16
506506 constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; // 4
507- constexpr int src_vec_size = 4 ; // (K / src_compress_size) >= 16 ? 16 : K / src_compress_size; //4, 16 -> max vec_size of sycl::vec
508- constexpr int dst_vec_size = 4 ; // (K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; //16, 16 -> max vec_size of sycl::vec
507+ constexpr int src_vec_size = 16 ;
509508 constexpr int src_loop_num = K / src_vec_size / src_compress_size;
510- constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
511509 src_compress_type src[src_vec_size];
510+
511+ #if 1
512+ constexpr int dst_vec_size = src_vec_size;
513+ constexpr int dst_loop_num = src_loop_num;
512514 ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
513515
514- // if(cute::thread0()) {
515- // printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_vec_size = %d, src_loop_num = %d, dst_loop_num = %d\n", src_compress_size, dst_compress_size, src_vec_size, dst_vec_size, src_loop_num, dst_loop_num);
516- // }
517516 int lut_id = start_lut_id;
518- // if(sg_idx == 0){
519- // for (int i = 0; i < 64; i++){
520- // printf("tid = %d, dequant_frag ptr[%d] = %x, mma_B ptr[%d] = %x\n",thread_idx, i, cute::raw_pointer_cast(dequant_frag.data()+i),i, cute::raw_pointer_cast(mma_B.data()+i));
521- // }
522- // }
523-
524- int pre_num = 1 ;
525517 #pragma unroll
526518 for (int n = 0 ; n < N; n++) {
527519 #pragma unroll
@@ -533,81 +525,29 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
533525 // src_compress_type src_value = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[n*src_loop_num + l][v]; //src[v];
534526 src_compress_type src_value = src[v];
535527 int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
536- #if 0
537- int c = 0;
538- uint16_t high_bits_1 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
539- lut_id = (lut_id + 1) % LUT_NUM;
540- uint16_t low_bits_1 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
541-
542- c++;
543- reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-pre_num] = (static_cast<uint32_t>(low_bits_1) << 16) | high_bits_1;
544- uint16_t high_bits_2 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
545- lut_id = (lut_id + 1) % LUT_NUM;
546- uint16_t low_bits_2 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
547-
548- c++;
549- reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-pre_num] = (static_cast<uint32_t>(low_bits_2) << 16) | high_bits_2;
550- uint16_t high_bits_3 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
551- lut_id = (lut_id + 1) % LUT_NUM;
552- uint16_t low_bits_3 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
553-
554- c++;
555- reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-pre_num] = (static_cast<uint32_t>(low_bits_3) << 16) | high_bits_3;
556- uint16_t high_bits_4 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
557- lut_id = (lut_id + 1) % LUT_NUM;
558- uint16_t low_bits_4 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
559-
560- c++;
561- reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-pre_num] = (static_cast<uint32_t>(low_bits_4) << 16) | high_bits_4;
562- uint16_t high_bits_5 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
563- lut_id = (lut_id + 1) % LUT_NUM;
564- uint16_t low_bits_5 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
565-
566- c++;
567- reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-pre_num] = (static_cast<uint32_t>(low_bits_5) << 16) | high_bits_5;
568- uint16_t high_bits_6 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
569- lut_id = (lut_id + 1) % LUT_NUM;
570- uint16_t low_bits_6 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
571-
572- c++;
573- reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-pre_num] = (static_cast<uint32_t>(low_bits_6) << 16) | high_bits_6;
574- uint16_t high_bits_7 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
575- lut_id = (lut_id + 1) % LUT_NUM;
576- uint16_t low_bits_7 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
577-
578- c++;
579- reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c-1] = (static_cast<uint32_t>(low_bits_7) << 16) | high_bits_7;
580- uint16_t high_bits_8 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
581- lut_id = (lut_id + 1) % LUT_NUM;
582- uint16_t low_bits_8 = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
583- reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c] = (static_cast<uint32_t>(low_bits_8) << 16) | high_bits_8;
584- #elif 0
585- for (int c = 0; c < src_compress_size/2; c++) {
586- uint16_t high_bits = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2 + 1))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c) / GROUP_SIZE)));
587- lut_id = (lut_id + 1) % LUT_NUM;
588- uint16_t low_bits = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
589- reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c] = (static_cast<uint32_t>(low_bits) << 16) | high_bits;
590- }
591- #elif 0
592528 #pragma unroll
593529 for (int c = 0 ; c < src_compress_size; c++) {
594530 uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
595531 float scale_value = fragment_scale ((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int >(GROUP_SIZE )));
596532 dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
597533 lut_id = (lut_id + 1 ) % LUT_NUM ;
598534 }
599- // reinterpret_cast<sycl::vec<dst_compress_type, 1 >*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num + l * src_vec_size + v] = reinterpret_cast<sycl::vec<dst_compress_type, 1 >*>(dst)[l * src_vec_size + v];
535+ reinterpret_cast <sycl::vec<dst_compress_type, src_compress_size / dst_compress_size >*>(cute::raw_pointer_cast (mma_B.data ()))[n*src_loop_num + l * src_vec_size + v] = reinterpret_cast <sycl::vec<dst_compress_type, src_compress_size / dst_compress_size >*>(dst)[l * src_vec_size + v];
600536 }
601- //reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size >*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size >*>(dst)[l];
537+ // reinterpret_cast<sycl::vec<dst_compress_type, src_vec_size * src_compress_size / dst_compress_size >*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, src_vec_size * src_compress_size / dst_compress_size >*>(dst)[l];
602538 }
603- #pragma unroll
604- for (int l = 0; l < dst_loop_num; l++) {
605- reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
606- //reinterpret_cast<dst_compress_type*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<dst_compress_type*>(dst)[l];
607- //reinterpret_cast<sycl::vec<dst_compress_type, 2>*>(cute::raw_pointer_cast(mma_B.data()))[n*dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, 2>*>(dst)[l];
608-
609- }
610539#else
540+ constexpr int dst_vec_size = 4 ;
541+ constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
542+ ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
543+
544+ int lut_id = start_lut_id;
545+ #pragma unroll
546+ for (int n = 0 ; n < N; n++) {
547+ #pragma unroll
548+ for (int l = 0 ; l < src_loop_num; l++) {
549+ reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(src)[0 ] = reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast (dequant_frag.data ()))[n*src_loop_num + l];
550+
611551 #pragma unroll
612552 for (int c = 0 ; c < src_compress_size; c++) {
613553 uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
0 commit comments