Skip to content

Commit 125db28

Browse files
committed
save code, double buffer decltype(mma_A_a)* mma_A[]
1 parent bc45998 commit 125db28

1 file changed

Lines changed: 31 additions & 25 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -262,15 +262,22 @@ class gemm_4bit_cutlass_kernel {
262262
Tensor frag_copy_B_b = thr_copy_B.retile_D(dequant_frag_b);
263263
Tensor frag_copy_Scale_b = thr_copy_scale.retile_D(fragment_scale_b);
264264

265-
cute::tuple<decltype(mma_A_a), decltype(mma_A_b)> mma_A(mma_A_a, mma_A_b);
266-
cute::tuple<decltype(mma_B_a), decltype(mma_B_b)> mma_B(mma_B_a, mma_B_b);
267-
cute::tuple<decltype(dequant_frag_a), decltype(dequant_frag_b)> dequant_frag(dequant_frag_a, dequant_frag_b);
268-
cute::tuple<decltype(fragment_scale_a), decltype(fragment_scale_b)> fragment_scale(fragment_scale_a, fragment_scale_b);
269-
cute::tuple<decltype(frag_copy_A_a), decltype(frag_copy_A_b)> frag_copy_A(frag_copy_A_a, frag_copy_A_b);
270-
cute::tuple<decltype(frag_copy_B_a), decltype(frag_copy_B_b)> frag_copy_B(frag_copy_B_a, frag_copy_B_b);
271-
cute::tuple<decltype(frag_copy_Scale_a), decltype(frag_copy_Scale_b)> frag_copy_Scale(frag_copy_Scale_a, frag_copy_Scale_b);
272-
//auto& mma_A_0 = cute::get<0>(mma_A_tuple); // 引用 mma_A_a
273-
//auto& mma_A_1 = cute::get<1>(mma_A_tuple); // 引用 mma_A_b
265+
// cute::tuple<decltype(mma_A_a), decltype(mma_A_b)> mma_A(mma_A_a, mma_A_b);
266+
// cute::tuple<decltype(mma_B_a), decltype(mma_B_b)> mma_B(mma_B_a, mma_B_b);
267+
// cute::tuple<decltype(dequant_frag_a), decltype(dequant_frag_b)> dequant_frag(dequant_frag_a, dequant_frag_b);
268+
// cute::tuple<decltype(fragment_scale_a), decltype(fragment_scale_b)> fragment_scale(fragment_scale_a, fragment_scale_b);
269+
// cute::tuple<decltype(frag_copy_A_a), decltype(frag_copy_A_b)> frag_copy_A(frag_copy_A_a, frag_copy_A_b);
270+
// cute::tuple<decltype(frag_copy_B_a), decltype(frag_copy_B_b)> frag_copy_B(frag_copy_B_a, frag_copy_B_b);
271+
// cute::tuple<decltype(frag_copy_Scale_a), decltype(frag_copy_Scale_b)> frag_copy_Scale(frag_copy_Scale_a, frag_copy_Scale_b);
272+
////auto& mma_A_0 = cute::get<0>(mma_A_tuple); // 引用 mma_A_a
273+
////auto& mma_A_1 = cute::get<1>(mma_A_tuple); // 引用 mma_A_b
274+
decltype(mma_A_a)* mma_A[] = {&mma_A_a, &mma_A_b};
275+
decltype(mma_B_a)* mma_B[] = {&mma_B_a, &mma_B_b};
276+
decltype(dequant_frag_a)* dequant_frag[] = {&dequant_frag_a, &dequant_frag_b};
277+
decltype(fragment_scale_a)* fragment_scale[] = {&fragment_scale_a, &fragment_scale_b};
278+
decltype(frag_copy_A_a)* frag_copy_A[] = {&frag_copy_A_a, &frag_copy_A_b};
279+
decltype(frag_copy_B_a)* frag_copy_B[] = {&frag_copy_B_a, &frag_copy_B_b};
280+
decltype(frag_copy_Scale_a)* frag_copy_Scale[] = {&frag_copy_Scale_a, &frag_copy_Scale_b};
274281

275282
#endif
276283
Tensor tAgA = thr_copy_A.retile_S(tCgA);
@@ -308,8 +315,8 @@ class gemm_4bit_cutlass_kernel {
308315

309316
#if 1
310317
auto dequant = [&](int start_lut_id, const int buffer_idx) {
311-
constexpr int N = decltype(cute::size<1>(mma_B))::value;
312-
constexpr int K = decltype(cute::size(mma_B))::value / N;
318+
constexpr int N = decltype(cute::size<1>(*mma_B[buffer_idx]))::value;
319+
constexpr int K = decltype(cute::size(*mma_B[buffer_idx]))::value / N;
313320

314321
using src_compress_type = uint32_t;
315322
using dst_compress_type = uint32_t;
@@ -333,13 +340,13 @@ class gemm_4bit_cutlass_kernel {
333340

334341
#pragma unroll
335342
for (int v = 0; v < src_vec_size; v++) {
336-
src_compress_type src_value = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(cute::get<buffer_idx>(dequant_frag).data()))[n*src_loop_num + l][v];
343+
src_compress_type src_value = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag[buffer_idx]->data()))[n*src_loop_num + l][v];
337344
int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
338345

339346
#pragma unroll
340347
for (int c = 0; c < src_compress_size; c++) {
341348
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
342-
float scale_value = cute::get<buffer_idx>(fragment_scale_a)((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
349+
float scale_value = (*fragment_scale[buffer_idx])((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
343350

344351
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
345352
lut_id = (lut_id + 1) % LUT_NUM;
@@ -349,14 +356,14 @@ class gemm_4bit_cutlass_kernel {
349356

350357
#pragma unroll
351358
for (int l = 0; l < dst_loop_num; l++) {
352-
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(cute::get<buffer_idx>(mma_B_a).data()))[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
359+
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B[buffer_idx]->data()))[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
353360
}
354361
}
355362
};
356363

357-
copy(params.tiled_copy_b, tBgB(_,_,_,k_start_idx), cute::get<0>(frag_copy_B));
358-
copy(params.tiled_copy_scale, tSgS(_,_,_,k_start_idx * BLK_K/params.group_size), cute::get<0>(frag_copy_Scale));
359-
copy(params.tiled_copy_a, tAgA(_,_,_,k_start_idx), cute::get<0>(frag_copy_A));
364+
copy(params.tiled_copy_b, tBgB(_,_,_,k_start_idx), *frag_copy_B[0]);
365+
copy(params.tiled_copy_scale, tSgS(_,_,_,k_start_idx * BLK_K/params.group_size), *frag_copy_Scale[0]);
366+
copy(params.tiled_copy_a, tAgA(_,_,_,k_start_idx), *frag_copy_A[0]);
360367

361368
if (prefetch_k < k_tile_count) {
362369
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
@@ -365,24 +372,23 @@ class gemm_4bit_cutlass_kernel {
365372
prefetch_k++;
366373

367374
for (int k_tile = k_start_idx + 1, k_s = 1; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++) {
368-
constexpr int buf_idx = k_tile % 2;
375+
const int buf_idx = k_tile % 2;
369376

370-
dequant(start_lut_id, buf_idx);
377+
dequant(start_lut_id, 1 - buf_idx);
371378

372-
copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), cute::get<buf_idx>(frag_copy_B));
373-
copy(params.tiled_copy_scale, tSgS(_,_,_,(k_start_idx+k_s)*BLK_K/params.group_size), cute::get<buf_idx>(frag_copy_Scale));
374-
copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), cute::get<buf_idx>(frag_copy_A));
379+
copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), *frag_copy_B[buf_idx]);
380+
copy(params.tiled_copy_scale, tSgS(_,_,_,(k_start_idx+k_s)*BLK_K/params.group_size), *frag_copy_Scale[buf_idx]);
381+
copy(params.tiled_copy_a, tAgA(_,_,_,k_tile), *frag_copy_A[buf_idx]);
375382

376383
if (prefetch_k < k_tile_count) {
377384
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
378385
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
379386
}
380387

381-
constexpr int idx = 1 - buf_idx;
382-
cute::gemm(tiled_mma, cute::get<idx>(frag_copy_A), cute::get<idx>(frag_copy_B), accumulators);
388+
cute::gemm(tiled_mma, *mma_A[1 - buf_idx], *mma_B[1 - buf_idx], accumulators);
383389
barrier_wait(3);
384390
}
385-
cute::gemm(tiled_mma, cute::get<1>(frag_copy_A), cute::get<1>(frag_copy_B), accumulators);
391+
cute::gemm(tiled_mma, *mma_A[1], *mma_B[1], accumulators);
386392

387393
#else
388394

0 commit comments

Comments
 (0)