Skip to content

Commit 3b170ef

Browse files
committed
save code
1 parent e71b808 commit 3b170ef

1 file changed

Lines changed: 7 additions & 5 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -448,8 +448,9 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
448448
dst[dst_base_idx + c-1] = static_cast<ElementMMA>(converted_value_2 * scale_value);
449449
}
450450
dst[dst_base_idx + c] = static_cast<ElementMMA>(converted_value_1 * scale_value);
451-
reinterpret_cast<dst_compress_type*>(cute::raw_pointer_cast(mma_B.data()))[2*(v-1)] = reinterpret_cast<dst_compress_type*>(dst)[2*(v-1)];
452-
reinterpret_cast<dst_compress_type*>(cute::raw_pointer_cast(mma_B.data()))[2*(v-1)+1] = reinterpret_cast<dst_compress_type*>(dst)[2*(v-1)+1];
451+
//reinterpret_cast<dst_compress_type*>(cute::raw_pointer_cast(mma_B.data()))[2*(v-1)] = reinterpret_cast<dst_compress_type*>(dst)[2*(v-1)];
452+
//reinterpret_cast<dst_compress_type*>(cute::raw_pointer_cast(mma_B.data()))[2*(v-1)+1] = reinterpret_cast<dst_compress_type*>(dst)[2*(v-1)+1];
453+
reinterpret_cast<sycl::vec<dst_compress_type, dst_loop_num>*>(cute::raw_pointer_cast(mma_B.data()))[v-1] = reinterpret_cast<sycl::vec<dst_compress_type, dst_loop_num>*>(dst)[v-1];
453454
}
454455
src_2 = src_1;
455456
int dst_base_idx = v * src_compress_size;
@@ -466,12 +467,13 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
466467
dst[dst_base_idx + c-1] = static_cast<ElementMMA>(converted_value_2 * scale_value);
467468
}
468469
dst[dst_base_idx + c] = static_cast<ElementMMA>(converted_value_1 * scale_value);
469-
reinterpret_cast<dst_compress_type*>(cute::raw_pointer_cast(mma_B.data()))[2*v] = reinterpret_cast<dst_compress_type*>(dst)[2*v];
470-
reinterpret_cast<dst_compress_type*>(cute::raw_pointer_cast(mma_B.data()))[2*v+1] = reinterpret_cast<dst_compress_type*>(dst)[2*v+1];
470+
//reinterpret_cast<dst_compress_type*>(cute::raw_pointer_cast(mma_B.data()))[2*v] = reinterpret_cast<dst_compress_type*>(dst)[2*v];
471+
//reinterpret_cast<dst_compress_type*>(cute::raw_pointer_cast(mma_B.data()))[2*v+1] = reinterpret_cast<dst_compress_type*>(dst)[2*v+1];
472+
reinterpret_cast<sycl::vec<dst_compress_type, dst_loop_num>*>(cute::raw_pointer_cast(mma_B.data()))[v] = reinterpret_cast<sycl::vec<dst_compress_type, dst_loop_num>*>(dst)[v];
471473

472474
// reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[1] = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[1];
473475
// scale_value = fragment_scale(1);
474-
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[0] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[0];
476+
//reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[0] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[0];
475477

476478
// #pragma unroll
477479
// for (int v = src_vec_size; v < src_loop_num * src_vec_size; v++) {

0 commit comments

Comments
 (0)