Skip to content

Commit 140e602

Browse files
committed
save code, cute::array double buffer
1 parent 125db28 commit 140e602

1 file changed

Lines changed: 16 additions & 16 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -262,22 +262,21 @@ class gemm_4bit_cutlass_kernel {
262262
Tensor frag_copy_B_b = thr_copy_B.retile_D(dequant_frag_b);
263263
Tensor frag_copy_Scale_b = thr_copy_scale.retile_D(fragment_scale_b);
264264

265-
// cute::tuple<decltype(mma_A_a), decltype(mma_A_b)> mma_A(mma_A_a, mma_A_b);
266-
// cute::tuple<decltype(mma_B_a), decltype(mma_B_b)> mma_B(mma_B_a, mma_B_b);
267-
// cute::tuple<decltype(dequant_frag_a), decltype(dequant_frag_b)> dequant_frag(dequant_frag_a, dequant_frag_b);
268-
// cute::tuple<decltype(fragment_scale_a), decltype(fragment_scale_b)> fragment_scale(fragment_scale_a, fragment_scale_b);
269-
// cute::tuple<decltype(frag_copy_A_a), decltype(frag_copy_A_b)> frag_copy_A(frag_copy_A_a, frag_copy_A_b);
270-
// cute::tuple<decltype(frag_copy_B_a), decltype(frag_copy_B_b)> frag_copy_B(frag_copy_B_a, frag_copy_B_b);
271-
// cute::tuple<decltype(frag_copy_Scale_a), decltype(frag_copy_Scale_b)> frag_copy_Scale(frag_copy_Scale_a, frag_copy_Scale_b);
272-
////auto& mma_A_0 = cute::get<0>(mma_A_tuple); // 引用 mma_A_a
273-
////auto& mma_A_1 = cute::get<1>(mma_A_tuple); // 引用 mma_A_b
274-
decltype(mma_A_a)* mma_A[] = {&mma_A_a, &mma_A_b};
275-
decltype(mma_B_a)* mma_B[] = {&mma_B_a, &mma_B_b};
276-
decltype(dequant_frag_a)* dequant_frag[] = {&dequant_frag_a, &dequant_frag_b};
277-
decltype(fragment_scale_a)* fragment_scale[] = {&fragment_scale_a, &fragment_scale_b};
278-
decltype(frag_copy_A_a)* frag_copy_A[] = {&frag_copy_A_a, &frag_copy_A_b};
279-
decltype(frag_copy_B_a)* frag_copy_B[] = {&frag_copy_B_a, &frag_copy_B_b};
280-
decltype(frag_copy_Scale_a)* frag_copy_Scale[] = {&frag_copy_Scale_a, &frag_copy_Scale_b};
265+
//decltype(mma_A_a)* mma_A[] = {&mma_A_a, &mma_A_b};
266+
//decltype(mma_B_a)* mma_B[] = {&mma_B_a, &mma_B_b};
267+
//decltype(dequant_frag_a)* dequant_frag[] = {&dequant_frag_a, &dequant_frag_b};
268+
//decltype(fragment_scale_a)* fragment_scale[] = {&fragment_scale_a, &fragment_scale_b};
269+
//decltype(frag_copy_A_a)* frag_copy_A[] = {&frag_copy_A_a, &frag_copy_A_b};
270+
//decltype(frag_copy_B_a)* frag_copy_B[] = {&frag_copy_B_a, &frag_copy_B_b};
271+
//decltype(frag_copy_Scale_a)* frag_copy_Scale[] = {&frag_copy_Scale_a, &frag_copy_Scale_b};
272+
273+
cute::array<decltype(mma_A_a)*, 2> mma_A = {&mma_A_a, &mma_A_b};
274+
cute::array<decltype(mma_B_a)*, 2> mma_B = {&mma_B_a, &mma_B_b};
275+
cute::array<decltype(dequant_frag_a)*, 2> dequant_frag = {&dequant_frag_a, &dequant_frag_b};
276+
cute::array<decltype(fragment_scale_a)*, 2> fragment_scale = {&fragment_scale_a, &fragment_scale_b};
277+
cute::array<decltype(frag_copy_A_a)*, 2> frag_copy_A = {&frag_copy_A_a, &frag_copy_A_b};
278+
cute::array<decltype(frag_copy_B_a)*, 2> frag_copy_B = {&frag_copy_B_a, &frag_copy_B_b};
279+
cute::array<decltype(frag_copy_Scale_a)*, 2> frag_copy_Scale = {&frag_copy_Scale_a, &frag_copy_Scale_b};
281280

282281
#endif
283282
Tensor tAgA = thr_copy_A.retile_S(tCgA);
@@ -340,6 +339,7 @@ class gemm_4bit_cutlass_kernel {
340339

341340
#pragma unroll
342341
for (int v = 0; v < src_vec_size; v++) {
342+
//src_compress_type src_value = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag[buffer_idx]->data()))[n*src_loop_num + l][v];
343343
src_compress_type src_value = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag[buffer_idx]->data()))[n*src_loop_num + l][v];
344344
int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
345345

0 commit comments

Comments
 (0)