@@ -262,22 +262,21 @@ class gemm_4bit_cutlass_kernel {
262262 Tensor frag_copy_B_b = thr_copy_B.retile_D (dequant_frag_b);
263263 Tensor frag_copy_Scale_b = thr_copy_scale.retile_D (fragment_scale_b);
264264
265- // cute::tuple<decltype(mma_A_a), decltype(mma_A_b)> mma_A(mma_A_a, mma_A_b);
266- // cute::tuple<decltype(mma_B_a), decltype(mma_B_b)> mma_B(mma_B_a, mma_B_b);
267- // cute::tuple<decltype(dequant_frag_a), decltype(dequant_frag_b)> dequant_frag(dequant_frag_a, dequant_frag_b);
268- // cute::tuple<decltype(fragment_scale_a), decltype(fragment_scale_b)> fragment_scale(fragment_scale_a, fragment_scale_b);
269- // cute::tuple<decltype(frag_copy_A_a), decltype(frag_copy_A_b)> frag_copy_A(frag_copy_A_a, frag_copy_A_b);
270- // cute::tuple<decltype(frag_copy_B_a), decltype(frag_copy_B_b)> frag_copy_B(frag_copy_B_a, frag_copy_B_b);
271- // cute::tuple<decltype(frag_copy_Scale_a), decltype(frag_copy_Scale_b)> frag_copy_Scale(frag_copy_Scale_a, frag_copy_Scale_b);
272- // //auto& mma_A_0 = cute::get<0>(mma_A_tuple); // 引用 mma_A_a
273- // //auto& mma_A_1 = cute::get<1>(mma_A_tuple); // 引用 mma_A_b
274- decltype (mma_A_a)* mma_A[] = {&mma_A_a, &mma_A_b};
275- decltype (mma_B_a)* mma_B[] = {&mma_B_a, &mma_B_b};
276- decltype (dequant_frag_a)* dequant_frag[] = {&dequant_frag_a, &dequant_frag_b};
277- decltype (fragment_scale_a)* fragment_scale[] = {&fragment_scale_a, &fragment_scale_b};
278- decltype (frag_copy_A_a)* frag_copy_A[] = {&frag_copy_A_a, &frag_copy_A_b};
279- decltype (frag_copy_B_a)* frag_copy_B[] = {&frag_copy_B_a, &frag_copy_B_b};
280- decltype (frag_copy_Scale_a)* frag_copy_Scale[] = {&frag_copy_Scale_a, &frag_copy_Scale_b};
265+ // decltype(mma_A_a)* mma_A[] = {&mma_A_a, &mma_A_b};
266+ // decltype(mma_B_a)* mma_B[] = {&mma_B_a, &mma_B_b};
267+ // decltype(dequant_frag_a)* dequant_frag[] = {&dequant_frag_a, &dequant_frag_b};
268+ // decltype(fragment_scale_a)* fragment_scale[] = {&fragment_scale_a, &fragment_scale_b};
269+ // decltype(frag_copy_A_a)* frag_copy_A[] = {&frag_copy_A_a, &frag_copy_A_b};
270+ // decltype(frag_copy_B_a)* frag_copy_B[] = {&frag_copy_B_a, &frag_copy_B_b};
271+ // decltype(frag_copy_Scale_a)* frag_copy_Scale[] = {&frag_copy_Scale_a, &frag_copy_Scale_b};
272+
273+ cute::array<decltype (mma_A_a), 2 > mma_A = {mma_A_a, mma_A_b};
274+ cute::array<decltype (mma_B_a), 2 > mma_B = {mma_B_a, mma_B_b};
275+ cute::array<decltype (dequant_frag_a), 2 > dequant_frag = {dequant_frag_a, dequant_frag_b};
276+ cute::array<decltype (fragment_scale_a), 2 > fragment_scale = {fragment_scale_a, fragment_scale_b};
277+ cute::array<decltype (frag_copy_A_a), 2 > frag_copy_A = {frag_copy_A_a, frag_copy_A_b};
278+ cute::array<decltype (frag_copy_B_a), 2 > frag_copy_B = {frag_copy_B_a, frag_copy_B_b};
279+ cute::array<decltype (frag_copy_Scale_a), 2 > frag_copy_Scale = {frag_copy_Scale_a, frag_copy_Scale_b};
281280
282281#endif
283282 Tensor tAgA = thr_copy_A.retile_S (tCgA);
@@ -315,8 +314,8 @@ class gemm_4bit_cutlass_kernel {
315314
316315#if 1
317316 auto dequant = [&](int start_lut_id, const int buffer_idx) {
318- constexpr int N = decltype (cute::size<1 >(* mma_B[buffer_idx]))::value;
319- constexpr int K = decltype (cute::size (* mma_B[buffer_idx]))::value / N;
317+ constexpr int N = decltype (cute::size<1 >(mma_B[buffer_idx]))::value;
318+ constexpr int K = decltype (cute::size (mma_B[buffer_idx]))::value / N;
320319
321320 using src_compress_type = uint32_t ;
322321 using dst_compress_type = uint32_t ;
@@ -340,13 +339,14 @@ class gemm_4bit_cutlass_kernel {
340339
341340 #pragma unroll
342341 for (int v = 0 ; v < src_vec_size; v++) {
343- src_compress_type src_value = reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast (dequant_frag[buffer_idx]->data ()))[n*src_loop_num + l][v];
342+ // src_compress_type src_value = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag[buffer_idx]->data()))[n*src_loop_num + l][v];
343+ src_compress_type src_value = reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast (dequant_frag[buffer_idx].data ()))[n*src_loop_num + l][v];
344344 int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
345345
346346 #pragma unroll
347347 for (int c = 0 ; c < src_compress_size; c++) {
348348 uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
349- float scale_value = (* fragment_scale[buffer_idx])((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int >(GROUP_SIZE )));
349+ float scale_value = (fragment_scale[buffer_idx])((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int >(GROUP_SIZE )));
350350
351351 dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
352352 lut_id = (lut_id + 1 ) % LUT_NUM ;
@@ -356,14 +356,14 @@ class gemm_4bit_cutlass_kernel {
356356
357357 #pragma unroll
358358 for (int l = 0 ; l < dst_loop_num; l++) {
359- reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast (mma_B[buffer_idx]-> data ()))[n * dst_loop_num + l] = reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
359+ reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast (mma_B[buffer_idx]. data ()))[n * dst_loop_num + l] = reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
360360 }
361361 }
362362 };
363363
364- copy (params.tiled_copy_b , tBgB (_,_,_,k_start_idx), * frag_copy_B[0 ]);
365- copy (params.tiled_copy_scale , tSgS (_,_,_,k_start_idx * BLK_K /params.group_size ), * frag_copy_Scale[0 ]);
366- copy (params.tiled_copy_a , tAgA (_,_,_,k_start_idx), * frag_copy_A[0 ]);
364+ copy (params.tiled_copy_b , tBgB (_,_,_,k_start_idx), frag_copy_B[0 ]);
365+ copy (params.tiled_copy_scale , tSgS (_,_,_,k_start_idx * BLK_K /params.group_size ), frag_copy_Scale[0 ]);
366+ copy (params.tiled_copy_a , tAgA (_,_,_,k_start_idx), frag_copy_A[0 ]);
367367
368368 if (prefetch_k < k_tile_count) {
369369 prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
@@ -376,19 +376,19 @@ class gemm_4bit_cutlass_kernel {
376376
377377 dequant (start_lut_id, 1 - buf_idx);
378378
379- copy (params.tiled_copy_b , tBgB (_,_,_,k_tile), * frag_copy_B[buf_idx]);
380- copy (params.tiled_copy_scale , tSgS (_,_,_,(k_start_idx+k_s)*BLK_K /params.group_size ), * frag_copy_Scale[buf_idx]);
381- copy (params.tiled_copy_a , tAgA (_,_,_,k_tile), * frag_copy_A[buf_idx]);
379+ copy (params.tiled_copy_b , tBgB (_,_,_,k_tile), frag_copy_B[buf_idx]);
380+ copy (params.tiled_copy_scale , tSgS (_,_,_,(k_start_idx+k_s)*BLK_K /params.group_size ), frag_copy_Scale[buf_idx]);
381+ copy (params.tiled_copy_a , tAgA (_,_,_,k_tile), frag_copy_A[buf_idx]);
382382
383383 if (prefetch_k < k_tile_count) {
384384 prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
385385 prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k));
386386 }
387387
388- cute::gemm (tiled_mma, * mma_A[1 - buf_idx], * mma_B[1 - buf_idx], accumulators);
388+ cute::gemm (tiled_mma, mma_A[1 - buf_idx], mma_B[1 - buf_idx], accumulators);
389389 barrier_wait (3 );
390390 }
391- cute::gemm (tiled_mma, * mma_A[1 ], * mma_B[1 ], accumulators);
391+ cute::gemm (tiled_mma, mma_A[1 ], mma_B[1 ], accumulators);
392392
393393#else
394394
0 commit comments