Skip to content

Commit 9cc08c0

Browse files
committed
save code
1 parent 050a138 commit 9cc08c0

1 file changed

Lines changed: 16 additions & 19 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,7 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
588588
uint16_t low_bits = sycl::bit_cast<uint16_t>(static_cast<ElementMMA>(quant_map_[lut_id][(src_value >> (4 * (c * 2))) & 0xf] * fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + 2 * c + 1) / GROUP_SIZE)));
589589
reinterpret_cast<uint32_t*>(cute::raw_pointer_cast(mma_B.data()))[n*src_loop_num*src_compress_size/2 + l * src_vec_size*src_compress_size/2 + v*src_compress_size/2 + c] = (static_cast<uint32_t>(low_bits) << 16) | high_bits;
590590
}
591-
#elif 1
591+
#elif 0
592592
#pragma unroll
593593
for (int c = 0; c < src_compress_size; c++) {
594594
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
@@ -607,24 +607,21 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
607607
//reinterpret_cast<sycl::vec<dst_compress_type, 2>*>(cute::raw_pointer_cast(mma_B.data()))[n*dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, 2>*>(dst)[l];
608608

609609
}
610-
//#else
611-
// #pragma unroll
612-
// for (int c = 0; c < src_compress_size; c++) {
613-
// uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
614-
// float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
615-
// dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
616-
// lut_id = (lut_id + 1) % LUT_NUM;
617-
// }
618-
// }
619-
// }
620-
//
621-
// #pragma unroll
622-
// for (int l = 0; l < dst_loop_num; l++) {
623-
// reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
624-
// //reinterpret_cast<dst_compress_type*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<dst_compress_type*>(dst)[l];
625-
// //reinterpret_cast<sycl::vec<dst_compress_type, 2>*>(cute::raw_pointer_cast(mma_B.data()))[n*dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, 2>*>(dst)[l];
626-
//
627-
// }
610+
#else
611+
#pragma unroll
612+
for (int c = 0; c < src_compress_size; c++) {
613+
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
614+
float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
615+
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
616+
lut_id = (lut_id + 1) % LUT_NUM;
617+
}
618+
}
619+
}
620+
621+
#pragma unroll
622+
for (int l = 0; l < dst_loop_num; l++) {
623+
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
624+
}
628625
#endif
629626
}
630627
};

0 commit comments

Comments
 (0)