Skip to content

Commit fcf0f8f

Browse files
committed
save code
1 parent 4576a87 commit fcf0f8f

1 file changed

Lines changed: 12 additions & 11 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -313,9 +313,9 @@ class gemm_4bit_cutlass_kernel {
313313
int start_lut_id = sg_idx % LUT_NUM;
314314

315315
#if 1
316-
auto dequant = [&](int start_lut_id, const int buffer_idx) {
317-
constexpr int N = decltype(cute::size<1>(*mma_B[buffer_idx]))::value;
318-
constexpr int K = decltype(cute::size(*mma_B[buffer_idx]))::value / N;
316+
auto dequant = [&](decltype(dequant_frag_a)* dequant_frag_, decltype(fragment_scale_a)* fragment_scale_, decltype(mma_B_a)* mma_B_) {
317+
constexpr int N = decltype(cute::size<1>(*mma_B_))::value;
318+
constexpr int K = decltype(cute::size(*mma_B_))::value / N;
319319

320320
using src_compress_type = uint32_t;
321321
using dst_compress_type = uint32_t;
@@ -340,13 +340,13 @@ class gemm_4bit_cutlass_kernel {
340340
#pragma unroll
341341
for (int v = 0; v < src_vec_size; v++) {
342342
//src_compress_type src_value = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag[buffer_idx]->data()))[n*src_loop_num + l][v];
343-
src_compress_type src_value = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast((*dequant_frag[buffer_idx]).data()))[n*src_loop_num + l][v];
343+
src_compress_type src_value = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag_->data()))[n*src_loop_num + l][v];
344344
int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
345345

346346
#pragma unroll
347347
for (int c = 0; c < src_compress_size; c++) {
348348
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
349-
float scale_value = (*fragment_scale[buffer_idx])((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
349+
float scale_value = (*fragment_scale_)((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
350350

351351
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
352352
lut_id = (lut_id + 1) % LUT_NUM;
@@ -356,7 +356,7 @@ class gemm_4bit_cutlass_kernel {
356356

357357
#pragma unroll
358358
for (int l = 0; l < dst_loop_num; l++) {
359-
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast((*mma_B[buffer_idx]).data()))[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
359+
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B_->data()))[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
360360
}
361361
}
362362
};
@@ -375,11 +375,12 @@ class gemm_4bit_cutlass_kernel {
375375
const int buf_idx = k_tile % 2;
376376

377377
//dequant(start_lut_id, 1 - buf_idx);
378-
if(buf_idx == 1) {
379-
dequant(start_lut_id, 0);
380-
} else {
381-
dequant(start_lut_id, 1);
382-
}
378+
//if(buf_idx == 1) {
379+
// dequant(start_lut_id, 0);
380+
//} else {
381+
// dequant(start_lut_id, 1);
382+
//}
383+
dequant(dequant_frag[1 - buf_idx], fragment_scale[1 - buf_idx], mma_B[1 - buf_idx]);
383384

384385
copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), *frag_copy_B[buf_idx]);
385386
copy(params.tiled_copy_scale, tSgS(_,_,_,(k_start_idx+k_s)*BLK_K/params.group_size), *frag_copy_Scale[buf_idx]);

0 commit comments

Comments
 (0)