Skip to content

Commit 4ef320c

Browse files
committed
save code, better perf
1 parent 6f64ec9 commit 4ef320c

1 file changed

Lines changed: 6 additions & 6 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -500,12 +500,12 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
500500
constexpr int N = decltype(cute::size<1>(mma_B))::value;
501501
constexpr int K = decltype(cute::size(mma_B))::value / N;
502502

503-
using src_compress_type = uint64_t;
503+
using src_compress_type = uint16_t;
504504
using dst_compress_type = uint64_t;
505505
constexpr int src_compress_size = cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; //16
506506
constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; //4
507-
constexpr int src_vec_size = (K / src_compress_size) >= 16 ? 16 : K / src_compress_size; //4, 16 -> max vec_size of sycl::vec
508-
constexpr int dst_vec_size = 1; //(K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; //16, 16 -> max vec_size of sycl::vec
507+
constexpr int src_vec_size = 2; //(K / src_compress_size) >= 16 ? 16 : K / src_compress_size; //4, 16 -> max vec_size of sycl::vec
508+
constexpr int dst_vec_size = 2; //(K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; //16, 16 -> max vec_size of sycl::vec
509509
constexpr int src_loop_num = K / src_vec_size / src_compress_size;
510510
constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
511511
src_compress_type src[src_vec_size];
@@ -607,10 +607,10 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
607607
}
608608

609609
#pragma unroll
610-
for (int l = 0; l < dst_loop_num / 4; l++) {
611-
//reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
610+
for (int l = 0; l < dst_loop_num; l++) {
611+
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
612612
//reinterpret_cast<dst_compress_type*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<dst_compress_type*>(dst)[l];
613-
reinterpret_cast<sycl::vec<dst_compress_type, 4>*>(cute::raw_pointer_cast(mma_B.data()))[n*dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, 4>*>(dst)[l];
613+
//reinterpret_cast<sycl::vec<dst_compress_type, 2>*>(cute::raw_pointer_cast(mma_B.data()))[n*dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, 2>*>(dst)[l];
614614

615615
}
616616
#endif

0 commit comments

Comments
 (0)