Skip to content

Commit de7bc02

Browse files
committed
save code
1 parent 9d3a0ba commit de7bc02

1 file changed

Lines changed: 6 additions & 6 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -383,8 +383,8 @@ inline float dDequantizeNF4(unsigned char val) {
383383

384384
//if(cute::thread0()) printf("params.group_size = %d, k_reload_factor = %d, k_tile_count = %d, N = %d, K = %d, src_compress_size = %d, src_vec_size = %d, dst_compress_size = %d, dst_vec_size = %d\n",params.group_size, k_reload_factor, k_tile_count, N, K, src_compress_size, src_vec_size, dst_compress_size, dst_vec_size);
385385

386-
src_compress_type src[N*src_loop_num*src_vec_size];
387-
ElementMMA dst[N*K];
386+
src_compress_type src[src_vec_size];
387+
ElementMMA dst[dst_compress_size * dst_vec_size];
388388

389389
#pragma unroll
390390
for (int n = 0; n < N; n++) {
@@ -393,11 +393,11 @@ inline float dDequantizeNF4(unsigned char val) {
393393
for (int l = 0; l < src_loop_num; l++) {
394394
//src_compress_type src[src_vec_size];
395395
//ElementMMA dst[K/dst_loop_num];
396-
reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[n*src_loop_num + l] = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[n*src_loop_num + l];
396+
reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[0] = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[n*src_loop_num + l];
397397
#pragma unroll
398398
for (int v = 0; v < src_vec_size; v++) {
399-
src_compress_type src_value = src[(n*src_loop_num + l)*src_vec_size + v];
400-
int dst_idx = ((n*src_loop_num + l)* src_vec_size + v) * src_compress_size;
399+
src_compress_type src_value = src[v];
400+
int dst_idx = v * src_compress_size;
401401
#pragma unroll
402402
for (int c = 0; c < src_compress_size; c++) {
403403
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
@@ -408,7 +408,7 @@ inline float dDequantizeNF4(unsigned char val) {
408408

409409
#pragma unroll
410410
for (int l = 0; l < dst_loop_num; l++) {
411-
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n*dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[n*dst_loop_num + l];
411+
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n*dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[0];
412412
}
413413
}
414414
};

0 commit comments

Comments
 (0)