@@ -146,7 +146,6 @@ using val_layout_load_B = decltype(make_layout(shape_div(typename traits_load_B:
146146using Copy_B = decltype (make_tiled_copy(atom_load_B{}, Layout<CopyThreadShape>{}, val_layout_load_B{}));
147147
148148using GmemTiledCopyScale = XE_2D_U16x1x32_LD_N; // XE_2D_U16x1x16_LD_N;
149- // using GmemTiledCopyScale = XE_2D_U16x1x16_LD_N;
150149static constexpr auto SG_QNT_WIDTH = Int<SG_N >{};
151150using StrideScale = cute::Stride<_1, int64_t , int64_t >; // dynamic stride
152151using traits_load_scale = Copy_Traits<GmemTiledCopyScale, StrideScale>;
@@ -245,7 +244,7 @@ class kgemm_4bit_inference_cutlass_dequant {
245244 int scale_number = decltype (size (tCrS_input))::value;
246245 for (int i=0 ; i<scale_number; i++){
247246 auto s_value = tCrS_input (i);
248- if (cute::thread0 ()) printf (" scale_number = %d, tCrS_input[%d] = %f\n " ,scale_number, i, s_value);
247+ if (cute::thread0 ()) printf (" scale_number = %d, tCrS_input[%d] = %f\n " ,scale_number, i, static_cast < float >( s_value) );
249248 }
250249#else
251250 static constexpr auto N = decltype(size<1>(in))::value;
@@ -298,6 +297,21 @@ class kgemm_4bit_inference_cutlass_dequant {
298297 int K = params.k ;
299298 int L = 1 ;
300299
300+ const int BLK_M = 16 ;
301+ const int BLK_N = 64 ;
302+ const int BLK_K = 64 ;
303+
304+ const int ATOM_M = 1 ;
305+ const int ATOM_N = 2 ;
306+ const int ATOM_K = 1 ;
307+
308+ const int SG_M = ceil_div (BLK_M , ATOM_M );
309+ const int SG_N = ceil_div (BLK_N , ATOM_N );
310+ const int SG_K = ceil_div (BLK_K , ATOM_K );
311+
312+ const int Num_SGs = ATOM_N * ATOM_M * ATOM_K ;
313+ static constexpr auto SG_QNT_WIDTH = Int<SG_N >{};
314+
301315 T* A = params.A ;
302316 uint8_t * B = params.B ;
303317 float * out = params.out ;
@@ -383,9 +397,8 @@ class kgemm_4bit_inference_cutlass_dequant {
383397
384398 Tensor dequant_frag = make_tensor<ElementB>(mma_B.layout ());
385399
386- const int SubgroupSize = 16 ;
387- const int SG_QNT_WIDTH = 32 ;
388- static constexpr auto scale_traits_size = decltype (size (typename GmemTiledCopyScale::BlockShape{}))::value / SubgroupSize;
400+ // const int SubgroupSize = 16;
401+ static constexpr auto scale_traits_size = decltype (size (typename GmemTiledCopyScale::BlockShape{}))::value / DispatchPolicy::SubgroupSize; // SubgroupSize;
389402 static constexpr auto scale_traits_num = SG_QNT_WIDTH / decltype (size<1 >(typename GmemTiledCopyScale::BlockShape{}))::value;
390403 using FragScaleLayout = Layout<Shape<Int<scale_traits_size>, Int<scale_traits_num>, _1>>;
391404 Tensor fragment_scale = make_tensor<ElementScale>(FragScaleLayout{});
@@ -405,7 +418,6 @@ class kgemm_4bit_inference_cutlass_dequant {
405418 Tensor tBgB = thr_copy_B.retile_S (tCgB);
406419
407420// // Prepare for prefetch
408- const int BLK_K = 64 ;
409421 auto tiled_prefetch_a = cute::prefetch_selector<Shape<Int<BLK_M >,Int<BLK_K >>, Num_SGs>(tiled_copy_a);;
410422 auto tiled_prefetch_b = cute::prefetch_selector<Shape<Int<BLK_N >,Int<BLK_K >>, Num_SGs>(tiled_copy_b);;
411423 auto thr_prefetch_A = tiled_prefetch_a.get_slice (thread_idx);
@@ -416,21 +428,22 @@ class kgemm_4bit_inference_cutlass_dequant {
416428 auto pBgB = thr_prefetch_B.partition_S (gB );
417429
418430// Run mainloop
419- const int BLK_N = 64 ;
420- const int ATOM_N = 2 ;
421- const int SG_N = 32 ;
422- auto [m_idx, n_idx, k_idx, l_idx] = blk_coord_mnkl;
423- const int n_coord_s = n_idx * BLK_N + (get_sub_group_id () % ATOM_N ) * SG_N ;
424- const int l_coord_s = l_idx;
431+ // auto [m_idx, n_idx, k_idx, l_idx] = blk_coord_mnkl;
432+ // const int n_coord_s = n_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N;
433+ // const int l_coord_s = l_idx;
425434
426- if (cute::thread0 ()) printf (" get_sub_group_id() = %d, m_idx = %d, n_idx = %d, k_idx = %d, l_idx = %d, n_coord_s = %d, l_coord_s = %d\n " ,get_sub_group_id (), m_idx, n_idx, k_idx, l_idx, n_coord_s, l_coord_s);
435+ // if(cute::thread0()) printf("get_sub_group_id() = %d, m_idx = %d, n_idx = %d, k_idx = %d, l_idx = %d, n_coord_s = %d, l_coord_s = %d\n",get_sub_group_id(), m_idx, n_idx, k_idx, l_idx, n_coord_s, l_coord_s);
427436
428437 auto copy_iter_s = [&](){
429- return make_tensor (make_inttuple_iter (make_coord (n_coord_s , 0 , l_coord_s )),
438+ return make_tensor (make_inttuple_iter (make_coord (n_coord , 0 , l_coord )),
430439 make_layout (make_shape (Int<scale_traits_size>{}, Int<scale_traits_num>{}, _1{}, k_tile_count),
431440 make_stride (E<0 >{} * _16{}, E<0 >{} * decltype (size<1 >(typename GmemTiledCopyScale::BlockShape{}))::value, _0{}, E<1 >{} * _1{})));
432441
433442 }();
443+
444+ // using ExpectedLayout = typename decltype(tiled_copy_scale)::TiledLayout::dst_layout; //decltype(tiled_copy_scale.dst_layout()); //decltype(tiled_copy_scale.atom_layout_dst());
445+ // static_assert(is_same<decltype(frag_copy_Scale.layout()), ExpectedLayout>::value, "布局不匹配");
446+
434447#if 1
435448 #define PRINT (x ) print(#x " : " ); print(x); print(" \n " );
436449 if (cutlass::thread (LOG_THREAD , LOG_GROUP )) {
@@ -450,7 +463,8 @@ class kgemm_4bit_inference_cutlass_dequant {
450463 print (" dequant_frag : " ); print (dequant_frag); print (" \n " );
451464
452465 print (" ===================== D :\n " );
453- print (" frag_copy_ScaleB : " ); print (frag_copy_Scale); print (" \n " );
466+ print (" tiled_copy_scale : " ); print (tiled_copy_scale); print (" \n " );
467+ print (" frag_copy_Scale : " ); print (frag_copy_Scale); print (" \n " );
454468 print (" copy_iter_s: " ); print (copy_iter_s); print (" \n " );
455469
456470 print (" ===================== D :\n " );
@@ -484,6 +498,7 @@ class kgemm_4bit_inference_cutlass_dequant {
484498 copy (tiled_copy_b, tBgB (_,_,_,k_tile), frag_copy_B);
485499
486500 const int k_reload_factor = ceil_div (params.group_size , BLK_K );
501+ // const int k_reload_factor = params.group_size / BLK_K;
487502
488503 if (cute::thread0 ()) printf (" params.group_size = %d, BLK_K = %d, k_reload_factor = %d\n " ,params.group_size , BLK_K , k_reload_factor);
489504
@@ -575,6 +590,26 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
575590// T* absmax = (T*)absmax_;
576591// T* absmax = (T*)absmax_;
577592
593+ // std::vector<T> host_data(n * k / blocksize);
594+ #if 0
595+ int element_size_A = m * k;
596+ auto scale_host_A = sycl::aligned_alloc_host<T>(512, element_size_A, q);
597+ q.memcpy(scale_host_A, A, element_size_A * sizeof(T)).wait();
598+ for (int i = 0; i < element_size_A; ++i) {
599+ //std::cout << scale_host[i] << " ";
600+ printf("%f ",static_cast<float>(scale_host_A[i]));
601+ }
602+ std::cout << std::endl;
603+
604+ int element_size = n * k / blocksize;
605+ auto scale_host = sycl::aligned_alloc_host<T>(512, element_size, q);
606+ q.memcpy(scale_host, absmax_, element_size * sizeof(T)).wait();
607+ for (int i = 0; i < element_size; ++i) {
608+ //std::cout << scale_host[i] << " ";
609+ printf("%f ",static_cast<float>(scale_host[i]));
610+ }
611+ std::cout << std::endl;
612+ #endif
578613#if 1
579614 // Init Params
580615 using Params = GemmKernel::Params;
0 commit comments