@@ -500,12 +500,12 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
500500 constexpr int N = decltype (cute::size<1 >(mma_B))::value;
501501 constexpr int K = decltype (cute::size (mma_B))::value / N;
502502
503- using src_compress_type = uint64_t ;
503+ using src_compress_type = uint16_t ;
504504 using dst_compress_type = uint64_t ;
505505 constexpr int src_compress_size = cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; // 16
506506 constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; // 4
507- constexpr int src_vec_size = (K / src_compress_size) >= 16 ? 16 : K / src_compress_size; // 4, 16 -> max vec_size of sycl::vec
508- constexpr int dst_vec_size = 1 ; // (K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; //16, 16 -> max vec_size of sycl::vec
507+ constexpr int src_vec_size = 2 ; // (K / src_compress_size) >= 16 ? 16 : K / src_compress_size; //4, 16 -> max vec_size of sycl::vec
508+ constexpr int dst_vec_size = 2 ; // (K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; //16, 16 -> max vec_size of sycl::vec
509509 constexpr int src_loop_num = K / src_vec_size / src_compress_size;
510510 constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
511511 src_compress_type src[src_vec_size];
@@ -607,10 +607,10 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
607607 }
608608
609609 #pragma unroll
610- for (int l = 0 ; l < dst_loop_num / 4 ; l++) {
611- // reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
610+ for (int l = 0 ; l < dst_loop_num; l++) {
611+ reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast (mma_B.data ()))[n * dst_loop_num + l] = reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
612612 // reinterpret_cast<dst_compress_type*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<dst_compress_type*>(dst)[l];
613- reinterpret_cast <sycl::vec<dst_compress_type, 4 >*>(cute::raw_pointer_cast (mma_B.data ()))[n*dst_loop_num + l] = reinterpret_cast <sycl::vec<dst_compress_type, 4 >*>(dst)[l];
613+ // reinterpret_cast<sycl::vec<dst_compress_type, 2 >*>(cute::raw_pointer_cast(mma_B.data()))[n*dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, 2 >*>(dst)[l];
614614
615615 }
616616#endif
0 commit comments