@@ -413,13 +413,15 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
413413 constexpr int src_compress_size = 8 ; // cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; //16
414414 constexpr int dst_compress_size = 4 ; // cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; //4
415415 constexpr int src_vec_size = 8 ; // (K / src_compress_size) >= 16 ? 16 : K / src_compress_size; //4, 16 -> max vec_size of sycl::vec
416- constexpr int dst_vec_size = 16 ; // (K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; //16, 16 -> max vec_size of sycl::vec
416+ constexpr int dst_vec_size = 8 ; // (K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; //16, 16 -> max vec_size of sycl::vec
417417 constexpr int src_loop_num = 1 ; // K / src_vec_size / src_compress_size;
418- constexpr int dst_loop_num = 1 ; // K / dst_vec_size / dst_compress_size;
418+ constexpr int dst_loop_num = 2 ; // K / dst_vec_size / dst_compress_size;
419419
420420 // src_compress_type src[src_loop_num * src_vec_size];
421421 ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
422422
423+ // ElementMMA dst[dst_loop_num * dst_compress_size];
424+
423425 // reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[0] = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[0];
424426 float scale_value = fragment_scale (0 );// (dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
425427
@@ -446,6 +448,8 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
446448 dst[dst_base_idx + c-1 ] = static_cast <ElementMMA>(converted_value_2 * scale_value);
447449 }
448450 dst[dst_base_idx + c] = static_cast <ElementMMA>(converted_value_1 * scale_value);
451+ reinterpret_cast <dst_compress_type*>(cute::raw_pointer_cast (mma_B.data ()))[2 *(v-1 )] = reinterpret_cast <dst_compress_type*>(dst)[2 *(v-1 )];
452+ reinterpret_cast <dst_compress_type*>(cute::raw_pointer_cast (mma_B.data ()))[2 *(v-1 )+1 ] = reinterpret_cast <dst_compress_type*>(dst)[2 *(v-1 )+1 ];
449453 }
450454 src_2 = src_1;
451455 int dst_base_idx = v * src_compress_size;
@@ -462,6 +466,8 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
462466 dst[dst_base_idx + c-1 ] = static_cast <ElementMMA>(converted_value_2 * scale_value);
463467 }
464468 dst[dst_base_idx + c] = static_cast <ElementMMA>(converted_value_1 * scale_value);
469+ reinterpret_cast <dst_compress_type*>(cute::raw_pointer_cast (mma_B.data ()))[2 *v] = reinterpret_cast <dst_compress_type*>(dst)[2 *v];
470+ reinterpret_cast <dst_compress_type*>(cute::raw_pointer_cast (mma_B.data ()))[2 *v+1 ] = reinterpret_cast <dst_compress_type*>(dst)[2 *v+1 ];
465471
466472// reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[1] = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[1];
467473// scale_value = fragment_scale(1);
0 commit comments