@@ -52,7 +52,7 @@ using ElementOutput = float;
5252
5353using ProblemShape = Shape<int , int , int , int >;
5454
55- using TileShape = Shape<_64, _128, _64 >;
55+ using TileShape = Shape<_64, _128, _32 >;
5656using TiledMma =
5757 typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
5858 Layout<Shape<_2, _8, _1>, Stride<_8, _1, _0>>>::TiledMMA;
@@ -310,10 +310,10 @@ class gemm_4bit_cutlass_kernel {
310310 prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k));
311311 }
312312
313- int start_lut_id = sg_idx % LUT_NUM ;
313+ // int start_lut_id = sg_idx % LUT_NUM;
314314
315315#if 1
316- auto dequant = [& ](decltype (dequant_frag_a)* dequant_frag_, decltype (fragment_scale_a)* fragment_scale_, decltype (mma_B_a)* mma_B_) {
316+ auto dequant = [](decltype (dequant_frag_a)* dequant_frag_, decltype (fragment_scale_a)* fragment_scale_, decltype (mma_B_a)* mma_B_, float (*quant_map)[ 16 ] ) {
317317 constexpr int N = decltype (cute::size<1 >(*mma_B_))::value;
318318 constexpr int K = decltype (cute::size (*mma_B_))::value / N;
319319
@@ -330,7 +330,7 @@ class gemm_4bit_cutlass_kernel {
330330
331331 ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
332332
333- int lut_id = start_lut_id;
333+ int lut_id = syclcompat::get_nd_item< 1 >(). get_sub_group (). get_group_linear_id () % LUT_NUM ; // start_lut_id;
334334 #pragma unroll
335335 for (int n = 0 ; n < N; n++) {
336336
@@ -339,7 +339,6 @@ class gemm_4bit_cutlass_kernel {
339339
340340 #pragma unroll
341341 for (int v = 0 ; v < src_vec_size; v++) {
342- // src_compress_type src_value = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag[buffer_idx]->data()))[n*src_loop_num + l][v];
343342 src_compress_type src_value = reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast (dequant_frag_->data ()))[n*src_loop_num + l][v];
344343 int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
345344
@@ -348,7 +347,7 @@ class gemm_4bit_cutlass_kernel {
348347 uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
349348 float scale_value = (*fragment_scale_)((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int >(GROUP_SIZE )));
350349
351- dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map_ [lut_id][bit_value] * scale_value);
350+ dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map [lut_id][bit_value] * scale_value);
352351 lut_id = (lut_id + 1 ) % LUT_NUM ;
353352 }
354353 }
@@ -371,16 +370,24 @@ class gemm_4bit_cutlass_kernel {
371370 }
372371 prefetch_k++;
373372
373+ int buf_idx = 0 ;
374+
374375 for (int k_tile = k_start_idx + 1 , k_s = 1 ; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++) {
375- const int buf_idx = k_tile % 2 ;
376+ buf_idx ^= 1 ; // k_tile % 2;
376377
377378 // dequant(start_lut_id, 1 - buf_idx);
378379 // if(buf_idx == 1) {
379380 // dequant(start_lut_id, 0);
380381 // } else {
381382 // dequant(start_lut_id, 1);
382383 // }
383- dequant (dequant_frag[1 - buf_idx], fragment_scale[1 - buf_idx], mma_B[1 - buf_idx]);
384+
385+ dequant (dequant_frag[1 - buf_idx], fragment_scale[1 - buf_idx], mma_B[1 - buf_idx], quant_map_);
386+ // if(buf_idx == 1) {
387+ // dequant(dequant_frag[0], fragment_scale[0], mma_B[0]);
388+ // } else {
389+ // dequant(dequant_frag[1], fragment_scale[1], mma_B[1]);
390+ // }
384391
385392 copy (params.tiled_copy_b , tBgB (_,_,_,k_tile), *frag_copy_B[buf_idx]);
386393 copy (params.tiled_copy_scale , tSgS (_,_,_,(k_start_idx+k_s)*BLK_K /params.group_size ), *frag_copy_Scale[buf_idx]);
@@ -392,6 +399,7 @@ class gemm_4bit_cutlass_kernel {
392399 }
393400
394401 cute::gemm (tiled_mma, *mma_A[1 - buf_idx], *mma_B[1 - buf_idx], accumulators);
402+
395403 barrier_wait (3 );
396404 }
397405 cute::gemm (tiled_mma, *mma_A[1 ], *mma_B[1 ], accumulators);
0 commit comments