@@ -448,8 +448,9 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
448448 dst[dst_base_idx + c-1 ] = static_cast <ElementMMA>(converted_value_2 * scale_value);
449449 }
450450 dst[dst_base_idx + c] = static_cast <ElementMMA>(converted_value_1 * scale_value);
451- reinterpret_cast <dst_compress_type*>(cute::raw_pointer_cast (mma_B.data ()))[2 *(v-1 )] = reinterpret_cast <dst_compress_type*>(dst)[2 *(v-1 )];
452- reinterpret_cast <dst_compress_type*>(cute::raw_pointer_cast (mma_B.data ()))[2 *(v-1 )+1 ] = reinterpret_cast <dst_compress_type*>(dst)[2 *(v-1 )+1 ];
451+ // reinterpret_cast<dst_compress_type*>(cute::raw_pointer_cast(mma_B.data()))[2*(v-1)] = reinterpret_cast<dst_compress_type*>(dst)[2*(v-1)];
452+ // reinterpret_cast<dst_compress_type*>(cute::raw_pointer_cast(mma_B.data()))[2*(v-1)+1] = reinterpret_cast<dst_compress_type*>(dst)[2*(v-1)+1];
453+ reinterpret_cast <sycl::vec<dst_compress_type, dst_loop_num>*>(cute::raw_pointer_cast (mma_B.data ()))[v-1 ] = reinterpret_cast <sycl::vec<dst_compress_type, dst_loop_num>*>(dst)[v-1 ];
453454 }
454455 src_2 = src_1;
455456 int dst_base_idx = v * src_compress_size;
@@ -466,12 +467,13 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
466467 dst[dst_base_idx + c-1 ] = static_cast <ElementMMA>(converted_value_2 * scale_value);
467468 }
468469 dst[dst_base_idx + c] = static_cast <ElementMMA>(converted_value_1 * scale_value);
469- reinterpret_cast <dst_compress_type*>(cute::raw_pointer_cast (mma_B.data ()))[2 *v] = reinterpret_cast <dst_compress_type*>(dst)[2 *v];
470- reinterpret_cast <dst_compress_type*>(cute::raw_pointer_cast (mma_B.data ()))[2 *v+1 ] = reinterpret_cast <dst_compress_type*>(dst)[2 *v+1 ];
470+ // reinterpret_cast<dst_compress_type*>(cute::raw_pointer_cast(mma_B.data()))[2*v] = reinterpret_cast<dst_compress_type*>(dst)[2*v];
471+ // reinterpret_cast<dst_compress_type*>(cute::raw_pointer_cast(mma_B.data()))[2*v+1] = reinterpret_cast<dst_compress_type*>(dst)[2*v+1];
472+ reinterpret_cast <sycl::vec<dst_compress_type, dst_loop_num>*>(cute::raw_pointer_cast (mma_B.data ()))[v] = reinterpret_cast <sycl::vec<dst_compress_type, dst_loop_num>*>(dst)[v];
471473
472474// reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[1] = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[1];
473475// scale_value = fragment_scale(1);
474- reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast (mma_B.data ()))[0 ] = reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[0 ];
476+ // reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[0] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[0];
475477
476478// #pragma unroll
477479// for (int v = src_vec_size; v < src_loop_num * src_vec_size; v++) {
0 commit comments