@@ -313,9 +313,9 @@ class gemm_4bit_cutlass_kernel {
313313 int start_lut_id = sg_idx % LUT_NUM ;
314314
315315#if 1
316- auto dequant = [&](int start_lut_id, const int buffer_idx ) {
317- constexpr int N = decltype (cute::size<1 >(*mma_B[buffer_idx] ))::value;
318- constexpr int K = decltype (cute::size (*mma_B[buffer_idx] ))::value / N;
316+ auto dequant = [&](decltype (dequant_frag_a)* dequant_frag_, decltype (fragment_scale_a)* fragment_scale_, decltype (mma_B_a)* mma_B_ ) {
317+ constexpr int N = decltype (cute::size<1 >(*mma_B_ ))::value;
318+ constexpr int K = decltype (cute::size (*mma_B_ ))::value / N;
319319
320320 using src_compress_type = uint32_t ;
321321 using dst_compress_type = uint32_t ;
@@ -340,13 +340,13 @@ class gemm_4bit_cutlass_kernel {
340340 #pragma unroll
341341 for (int v = 0 ; v < src_vec_size; v++) {
342342 // src_compress_type src_value = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag[buffer_idx]->data()))[n*src_loop_num + l][v];
343- src_compress_type src_value = reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast ((*dequant_frag[buffer_idx]). data ()))[n*src_loop_num + l][v];
343+ src_compress_type src_value = reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast (dequant_frag_-> data ()))[n*src_loop_num + l][v];
344344 int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
345345
346346 #pragma unroll
347347 for (int c = 0 ; c < src_compress_size; c++) {
348348 uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
349- float scale_value = (*fragment_scale[buffer_idx] )((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int >(GROUP_SIZE )));
349+ float scale_value = (*fragment_scale_ )((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int >(GROUP_SIZE )));
350350
351351 dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
352352 lut_id = (lut_id + 1 ) % LUT_NUM ;
@@ -356,7 +356,7 @@ class gemm_4bit_cutlass_kernel {
356356
357357 #pragma unroll
358358 for (int l = 0 ; l < dst_loop_num; l++) {
359- reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast ((*mma_B[buffer_idx]). data ()))[n * dst_loop_num + l] = reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
359+ reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast (mma_B_-> data ()))[n * dst_loop_num + l] = reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
360360 }
361361 }
362362 };
@@ -375,11 +375,12 @@ class gemm_4bit_cutlass_kernel {
375375 const int buf_idx = k_tile % 2 ;
376376
377377 // dequant(start_lut_id, 1 - buf_idx);
378- if (buf_idx == 1 ) {
379- dequant (start_lut_id, 0 );
380- } else {
381- dequant (start_lut_id, 1 );
382- }
378+ // if(buf_idx == 1) {
379+ // dequant(start_lut_id, 0);
380+ // } else {
381+ // dequant(start_lut_id, 1);
382+ // }
383+ dequant (dequant_frag[1 - buf_idx], fragment_scale[1 - buf_idx], mma_B[1 - buf_idx]);
383384
384385 copy (params.tiled_copy_b , tBgB (_,_,_,k_tile), *frag_copy_B[buf_idx]);
385386 copy (params.tiled_copy_scale , tSgS (_,_,_,(k_start_idx+k_s)*BLK_K /params.group_size ), *frag_copy_Scale[buf_idx]);
0 commit comments