@@ -262,15 +262,22 @@ class gemm_4bit_cutlass_kernel {
262262 Tensor frag_copy_B_b = thr_copy_B.retile_D (dequant_frag_b);
263263 Tensor frag_copy_Scale_b = thr_copy_scale.retile_D (fragment_scale_b);
264264
265- cute::tuple<decltype (mma_A_a), decltype (mma_A_b)> mma_A (mma_A_a, mma_A_b);
266- cute::tuple<decltype (mma_B_a), decltype (mma_B_b)> mma_B (mma_B_a, mma_B_b);
267- cute::tuple<decltype (dequant_frag_a), decltype (dequant_frag_b)> dequant_frag (dequant_frag_a, dequant_frag_b);
268- cute::tuple<decltype (fragment_scale_a), decltype (fragment_scale_b)> fragment_scale (fragment_scale_a, fragment_scale_b);
269- cute::tuple<decltype (frag_copy_A_a), decltype (frag_copy_A_b)> frag_copy_A (frag_copy_A_a, frag_copy_A_b);
270- cute::tuple<decltype (frag_copy_B_a), decltype (frag_copy_B_b)> frag_copy_B (frag_copy_B_a, frag_copy_B_b);
271- cute::tuple<decltype (frag_copy_Scale_a), decltype (frag_copy_Scale_b)> frag_copy_Scale (frag_copy_Scale_a, frag_copy_Scale_b);
272- // auto& mma_A_0 = cute::get<0>(mma_A_tuple); // 引用 mma_A_a
273- // auto& mma_A_1 = cute::get<1>(mma_A_tuple); // 引用 mma_A_b
265+ // cute::tuple<decltype(mma_A_a), decltype(mma_A_b)> mma_A(mma_A_a, mma_A_b);
266+ // cute::tuple<decltype(mma_B_a), decltype(mma_B_b)> mma_B(mma_B_a, mma_B_b);
267+ // cute::tuple<decltype(dequant_frag_a), decltype(dequant_frag_b)> dequant_frag(dequant_frag_a, dequant_frag_b);
268+ // cute::tuple<decltype(fragment_scale_a), decltype(fragment_scale_b)> fragment_scale(fragment_scale_a, fragment_scale_b);
269+ // cute::tuple<decltype(frag_copy_A_a), decltype(frag_copy_A_b)> frag_copy_A(frag_copy_A_a, frag_copy_A_b);
270+ // cute::tuple<decltype(frag_copy_B_a), decltype(frag_copy_B_b)> frag_copy_B(frag_copy_B_a, frag_copy_B_b);
271+ // cute::tuple<decltype(frag_copy_Scale_a), decltype(frag_copy_Scale_b)> frag_copy_Scale(frag_copy_Scale_a, frag_copy_Scale_b);
272+ // //auto& mma_A_0 = cute::get<0>(mma_A_tuple); // 引用 mma_A_a
273+ // //auto& mma_A_1 = cute::get<1>(mma_A_tuple); // 引用 mma_A_b
274+ decltype (mma_A_a)* mma_A[] = {&mma_A_a, &mma_A_b};
275+ decltype (mma_B_a)* mma_B[] = {&mma_B_a, &mma_B_b};
276+ decltype (dequant_frag_a)* dequant_frag[] = {&dequant_frag_a, &dequant_frag_b};
277+ decltype (fragment_scale_a)* fragment_scale[] = {&fragment_scale_a, &fragment_scale_b};
278+ decltype (frag_copy_A_a)* frag_copy_A[] = {&frag_copy_A_a, &frag_copy_A_b};
279+ decltype (frag_copy_B_a)* frag_copy_B[] = {&frag_copy_B_a, &frag_copy_B_b};
280+ decltype (frag_copy_Scale_a)* frag_copy_Scale[] = {&frag_copy_Scale_a, &frag_copy_Scale_b};
274281
275282#endif
276283 Tensor tAgA = thr_copy_A.retile_S (tCgA);
@@ -308,8 +315,8 @@ class gemm_4bit_cutlass_kernel {
308315
309316#if 1
310317 auto dequant = [&](int start_lut_id, const int buffer_idx) {
311- constexpr int N = decltype (cute::size<1 >(mma_B))::value;
312- constexpr int K = decltype (cute::size (mma_B))::value / N;
318+ constexpr int N = decltype (cute::size<1 >(* mma_B[buffer_idx] ))::value;
319+ constexpr int K = decltype (cute::size (* mma_B[buffer_idx] ))::value / N;
313320
314321 using src_compress_type = uint32_t ;
315322 using dst_compress_type = uint32_t ;
@@ -333,13 +340,13 @@ class gemm_4bit_cutlass_kernel {
333340
334341 #pragma unroll
335342 for (int v = 0 ; v < src_vec_size; v++) {
336- src_compress_type src_value = reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast (cute::get< buffer_idx>(dequant_frag). data ()))[n*src_loop_num + l][v];
343+ src_compress_type src_value = reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast (dequant_frag[ buffer_idx]-> data ()))[n*src_loop_num + l][v];
337344 int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
338345
339346 #pragma unroll
340347 for (int c = 0 ; c < src_compress_size; c++) {
341348 uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
342- float scale_value = cute::get<buffer_idx>(fragment_scale_a )((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int >(GROUP_SIZE )));
349+ float scale_value = (*fragment_scale[buffer_idx] )((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int >(GROUP_SIZE )));
343350
344351 dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
345352 lut_id = (lut_id + 1 ) % LUT_NUM ;
@@ -349,14 +356,14 @@ class gemm_4bit_cutlass_kernel {
349356
350357 #pragma unroll
351358 for (int l = 0 ; l < dst_loop_num; l++) {
352- reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast (cute::get< buffer_idx>(mma_B_a). data ()))[n * dst_loop_num + l] = reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
359+ reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast (mma_B[ buffer_idx]-> data ()))[n * dst_loop_num + l] = reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
353360 }
354361 }
355362 };
356363
357- copy (params.tiled_copy_b , tBgB (_,_,_,k_start_idx), cute::get< 0 >( frag_copy_B) );
358- copy (params.tiled_copy_scale , tSgS (_,_,_,k_start_idx * BLK_K /params.group_size ), cute::get< 0 >( frag_copy_Scale) );
359- copy (params.tiled_copy_a , tAgA (_,_,_,k_start_idx), cute::get< 0 >( frag_copy_A) );
364+ copy (params.tiled_copy_b , tBgB (_,_,_,k_start_idx), * frag_copy_B[ 0 ] );
365+ copy (params.tiled_copy_scale , tSgS (_,_,_,k_start_idx * BLK_K /params.group_size ), * frag_copy_Scale[ 0 ] );
366+ copy (params.tiled_copy_a , tAgA (_,_,_,k_start_idx), * frag_copy_A[ 0 ] );
360367
361368 if (prefetch_k < k_tile_count) {
362369 prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
@@ -365,24 +372,23 @@ class gemm_4bit_cutlass_kernel {
365372 prefetch_k++;
366373
367374 for (int k_tile = k_start_idx + 1 , k_s = 1 ; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++) {
368- constexpr int buf_idx = k_tile % 2 ;
375+ const int buf_idx = k_tile % 2 ;
369376
370- dequant (start_lut_id, buf_idx);
377+ dequant (start_lut_id, 1 - buf_idx);
371378
372- copy (params.tiled_copy_b , tBgB (_,_,_,k_tile), cute::get<buf_idx>( frag_copy_B) );
373- copy (params.tiled_copy_scale , tSgS (_,_,_,(k_start_idx+k_s)*BLK_K /params.group_size ), cute::get<buf_idx>( frag_copy_Scale) );
374- copy (params.tiled_copy_a , tAgA (_,_,_,k_tile), cute::get<buf_idx>( frag_copy_A) );
379+ copy (params.tiled_copy_b , tBgB (_,_,_,k_tile), * frag_copy_B[buf_idx] );
380+ copy (params.tiled_copy_scale , tSgS (_,_,_,(k_start_idx+k_s)*BLK_K /params.group_size ), * frag_copy_Scale[buf_idx] );
381+ copy (params.tiled_copy_a , tAgA (_,_,_,k_tile), * frag_copy_A[buf_idx] );
375382
376383 if (prefetch_k < k_tile_count) {
377384 prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
378385 prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k));
379386 }
380387
381- constexpr int idx = 1 - buf_idx;
382- cute::gemm (tiled_mma, cute::get<idx>(frag_copy_A), cute::get<idx>(frag_copy_B), accumulators);
388+ cute::gemm (tiled_mma, *mma_A[1 - buf_idx], *mma_B[1 - buf_idx], accumulators);
383389 barrier_wait (3 );
384390 }
385- cute::gemm (tiled_mma, cute::get< 1 >(frag_copy_A), cute::get< 1 >(frag_copy_B) , accumulators);
391+ cute::gemm (tiled_mma, *mma_A[ 1 ], *mma_B[ 1 ] , accumulators);
386392
387393#else
388394
0 commit comments