Skip to content

Commit e71b808

Browse files
committed
save code
1 parent bf5a345 commit e71b808

1 file changed

Lines changed: 8 additions & 2 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,13 +413,15 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
413413
constexpr int src_compress_size = 8; //cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; //16
414414
constexpr int dst_compress_size = 4; //cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; //4
415415
constexpr int src_vec_size = 8; //(K / src_compress_size) >= 16 ? 16 : K / src_compress_size; //4, 16 -> max vec_size of sycl::vec
416-
constexpr int dst_vec_size = 16; //(K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; //16, 16 -> max vec_size of sycl::vec
416+
constexpr int dst_vec_size = 8; //(K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; //16, 16 -> max vec_size of sycl::vec
417417
constexpr int src_loop_num = 1; //K / src_vec_size / src_compress_size;
418-
constexpr int dst_loop_num = 1; //K / dst_vec_size / dst_compress_size;
418+
constexpr int dst_loop_num = 2; //K / dst_vec_size / dst_compress_size;
419419

420420
//src_compress_type src[src_loop_num * src_vec_size];
421421
ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
422422

423+
//ElementMMA dst[dst_loop_num * dst_compress_size];
424+
423425
//reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[0] = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[0];
424426
float scale_value = fragment_scale(0);//(dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
425427

@@ -446,6 +448,8 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
446448
dst[dst_base_idx + c-1] = static_cast<ElementMMA>(converted_value_2 * scale_value);
447449
}
448450
dst[dst_base_idx + c] = static_cast<ElementMMA>(converted_value_1 * scale_value);
451+
reinterpret_cast<dst_compress_type*>(cute::raw_pointer_cast(mma_B.data()))[2*(v-1)] = reinterpret_cast<dst_compress_type*>(dst)[2*(v-1)];
452+
reinterpret_cast<dst_compress_type*>(cute::raw_pointer_cast(mma_B.data()))[2*(v-1)+1] = reinterpret_cast<dst_compress_type*>(dst)[2*(v-1)+1];
449453
}
450454
src_2 = src_1;
451455
int dst_base_idx = v * src_compress_size;
@@ -462,6 +466,8 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
462466
dst[dst_base_idx + c-1] = static_cast<ElementMMA>(converted_value_2 * scale_value);
463467
}
464468
dst[dst_base_idx + c] = static_cast<ElementMMA>(converted_value_1 * scale_value);
469+
reinterpret_cast<dst_compress_type*>(cute::raw_pointer_cast(mma_B.data()))[2*v] = reinterpret_cast<dst_compress_type*>(dst)[2*v];
470+
reinterpret_cast<dst_compress_type*>(cute::raw_pointer_cast(mma_B.data()))[2*v+1] = reinterpret_cast<dst_compress_type*>(dst)[2*v+1];
465471

466472
// reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[1] = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[1];
467473
// scale_value = fragment_scale(1);

0 commit comments

Comments
 (0)