@@ -241,11 +241,11 @@ class kgemm_4bit_inference_cutlass_dequant {
241241 // if(syclcompat::global_id::x() == 2 && syclcompat::global_id::y() ==0 && syclcompat::global_id::z() ==0 )
242242 // printf("syclcompat::global_id::x() = %d, syclcompat::global_id::y() = %d, syclcompat::global_id::z() = %d, thread_idx = %d, i = %d, in[i].ptr_ = %x, in[i].idx_=%x, value_bit = %x, value = %d, quant_map[value] = %f, out[i] = %f\n",syclcompat::global_id::x(), syclcompat::global_id::y(), syclcompat::global_id::z(), thread_idx, i, in[i].ptr_, in[i].idx_, value, static_cast<int>(value), quant_map[value], static_cast<float>(out[i]));
243243 }
244- int scale_number = decltype ( size (tCrS_input))::value;
245- for ( int i= 0 ; i< scale_number; i++){
246- auto s_value = tCrS_input (i);
247- if ( cute::thread0 ()) printf ( " scale_number = %d, tCrS_input[%d] = %f \n " ,scale_number, i, s_value );
248- }
244+ }
245+ int scale_number = decltype ( size (tCrS_input))::value;
246+ for ( int i= 0 ; i<scale_number; i++){
247+ auto s_value = tCrS_input (i );
248+ if ( cute::thread0 ()) printf ( " scale_number = %d, tCrS_input[%d] = %f \n " ,scale_number, i, s_value);
249249 }
250250#else
251251 static constexpr auto N = decltype(size<1>(in))::value;
@@ -297,6 +297,7 @@ class kgemm_4bit_inference_cutlass_dequant {
297297 int N = params.n ;
298298 int K = params.k ;
299299 int L = 1 ;
300+ static constexpr int BLK_K = 64 ;
300301
301302 T* A = params.A ;
302303 uint8_t * B = params.B ;
@@ -383,11 +384,11 @@ class kgemm_4bit_inference_cutlass_dequant {
383384
384385 Tensor dequant_frag = make_tensor<ElementB>(mma_B.layout ());
385386
386- static constexpr auto scale_traits_size = decltype ( size ( typename GmemTiledCopyScale::BlockShape{}))::value / SubgroupSize;
387- static constexpr auto scale_traits_num = SG_QNT_WIDTH / size< 1 >( typename GmemTiledCopyScale::BlockShape{}) ;
387+ static constexpr auto scale_traits_size = 16 / SubgroupSize;
388+ static constexpr auto scale_traits_num = 64 / 16 ;
388389 using FragScaleLayout = Layout<Shape<Int<scale_traits_size>, Int<scale_traits_num>, _1>>;
389390 Tensor fragment_scale = make_tensor<ElementScale>(FragScaleLayout{});
390- if (cute::thread0 ()) printf (" scale_traits_size = %d, scale_traits_num = %d, SG_QNT_WIDTH = %d \n " , scale_traits_size, scale_traits_num, SG_QNT_WIDTH );
391+ if (cute::thread0 ()) printf (" scale_traits_size = %d, scale_traits_num = %d\n " , scale_traits_size, scale_traits_num);
391392
392393 static_assert (std::is_same_v<typename decltype (dequant_frag)::value_type, ElementQuant>);
393394 static_assert (std::is_same_v<typename decltype (mma_A)::value_type, ElementMMA>);
@@ -414,15 +415,15 @@ class kgemm_4bit_inference_cutlass_dequant {
414415
415416// Run mainloop
416417 auto [m_idx, n_idx, k_idx, l_idx] = blk_coord_mnkl;
417- const int n_coord_s = n_idx * BLK_N + (get_sub_group_id () % ATOM_N ) * SG_N ;
418+ const int n_coord_s = n_idx * 64 + (get_sub_group_id () % 2 ) * 32 ;
418419 const int l_coord_s = l_idx;
419420
420421 if (cute::thread0 ()) printf (" m_idx = %d, n_idx = %d, k_idx = %d, l_idx = %d, n_coord_s = %d, l_coord_s = %d\n " ,m_idx, n_idx, k_idx, l_idx, n_coord_s, l_coord_s);
421422
422423 auto copy_iter_s = [&](){
423424 return make_tensor (make_inttuple_iter (make_coord (n_coord_s, 0 , l_coord_s)),
424425 make_layout (make_shape (Int<scale_traits_size>{}, Int<scale_traits_num>{}, _1{}, k_tile_count),
425- make_stride (E<0 >{} * _16{}, E<0 >{} * size< 1 >( typename GmemTiledCopyScale::BlockShape{}) , _0{}, E<1 >{} * _1{})));
426+ make_stride (E<0 >{} * _16{}, E<0 >{} * 16 , _0{}, E<1 >{} * _1{})));
426427
427428 }();
428429#if 1
0 commit comments