Skip to content

Commit 5a48421

Browse files
committed
save code
1 parent 3f63567 commit 5a48421

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -417,24 +417,24 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
417417
constexpr int src_loop_num = 1; //K / src_vec_size / src_compress_size;
418418
constexpr int dst_loop_num = 1; //K / dst_vec_size / dst_compress_size;
419419

420-
src_compress_type src[src_loop_num * src_vec_size];
420+
//src_compress_type src[src_loop_num * src_vec_size];
421421
ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
422422

423-
reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[0] = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[0];
423+
//reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[0] = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[0];
424424
float scale_value = fragment_scale(0);//(dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
425425

426426
#pragma unroll
427427
for (int v = 0; v < src_vec_size; v++) {
428428
int dst_base_idx = v * src_compress_size;
429429
int c = 0;
430-
uint8_t bit_value = (src[v] >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
430+
uint8_t bit_value = (reinterpret_cast<src_compress_type*>(cute::raw_pointer_cast(dequant_frag.data()))[v] >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
431431
float converted_value_1 = quant_map[bit_value];
432432
float converted_value_2 = 0.f;
433433
#pragma unroll
434434
for (; c < src_compress_size-1;) {
435435
converted_value_2 = converted_value_1;
436436
c++;
437-
bit_value = (src[v] >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
437+
bit_value = (reinterpret_cast<src_compress_type*>(cute::raw_pointer_cast(dequant_frag.data()))[v] >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
438438
converted_value_1 = quant_map[bit_value];
439439
dst[dst_base_idx + c-1] = static_cast<ElementMMA>(converted_value_2 * scale_value);
440440
}

0 commit comments

Comments
 (0)