@@ -233,31 +233,7 @@ class kgemm_4bit_inference_cutlass_dequant {
233233 using SrcType = typename EngineIn::value_type;
234234 using DstType = typename EngineOut::value_type;
235235 using ScaleType = typename EngineScales::value_type;
236- #if 0
237- static constexpr auto N = decltype(size<1>(in))::value;
238- static constexpr auto loop_cnt = decltype(size(out))::value / N;
239- for (int n = 0; n < N; n++) {
240- auto s_value = tCrS_input(i);
241- for (int l = 0; s < loop_cnt; l++) {
242-
243- // int numbers = decltype(size(in))::value;
244- // for(int i=0; i<numbers / N; i++){
245- // //auto in_ptr_8 = (uint8_t*)(raw_pointer_cast(in.data()));
246- // //out[i] = static_cast<DstType>(quant_map[in_ptr_8[i].data()]);
247- // uint8_t value = in[i].get();
248- // out[i] = static_cast<DstType>(quant_map[value]);
249- // int thread_idx = int(ThreadIdxX());
250- // if(cute::thread0()){
251- // //if(syclcompat::global_id::x() == 2 && syclcompat::global_id::y() ==0 && syclcompat::global_id::z() ==0 )
252- // //printf("syclcompat::global_id::x() = %d, syclcompat::global_id::y() = %d, syclcompat::global_id::z() = %d, thread_idx = %d, i = %d, in[i].ptr_ = %x, in[i].idx_=%x, value_bit = %x, value = %d, quant_map[value] = %f, out[i] = %f\n",syclcompat::global_id::x(), syclcompat::global_id::y(), syclcompat::global_id::z(), thread_idx, i, in[i].ptr_, in[i].idx_, value, static_cast<int>(value), quant_map[value], static_cast<float>(out[i]));
253- // }
254- // }
255- // int scale_number = decltype(size(tCrS_input))::value;
256- // for(int i=0; i<scale_number; i++){
257- // auto s_value = tCrS_input(i);
258- // if(cute::thread0()) printf("scale_number = %d, tCrS_input[%d] = %f\n",scale_number, i, static_cast<float>(s_value));
259- // }
260- #else
236+
261237 static constexpr auto N = decltype (size<1 >(in))::value;
262238
263239 using format_type = ushort; // 16
@@ -275,13 +251,6 @@ class kgemm_4bit_inference_cutlass_dequant {
275251 auto s_tensor = make_tensor ((format_type*)(raw_pointer_cast (in.data ())), Shape<Int<loop_cnt / scalar>, Int<N>>{});
276252 auto d_tensor = make_tensor (out.data (), Shape<Int<vec_size>, Int<splits>, Int<N>>{});
277253
278- int scale_number = decltype (size (tCrS_input))::value;
279- for (int i=0 ; i<scale_number; i++){
280- auto s_value = tCrS_input (i);
281- if (cute::thread0 ()) printf (" scale_number = %d, tCrS_input[%d] = %f\n " ,scale_number, i, static_cast <float >(s_value));
282- }
283- // printf("thread_idx = %d, decltype(size(in))::value = %d, K = %d, N = %d, L = %d, src_bits = %d, sizeof_bits_v<format_type> = %d, scalar = %d, decltype(size(out))::value = %d, loop_cnt = %d, splits = %d\n",int(ThreadIdxX()), decltype(size(in))::value, decltype(size<0>(in))::value, N, decltype(size<2>(in))::value, src_bits, sizeof_bits_v<format_type>, scalar, decltype(size(out))::value, loop_cnt, splits);
284-
285254 for (int n = 0 ; n < N; n++) {
286255 const auto ts = tCrS_input (n);
287256
@@ -300,17 +269,14 @@ for(int i=0; i<scale_number; i++){
300269 } else {
301270 dst[i+1 ] = static_cast <DstType>(quant_map[value] * static_cast <float >(ts));
302271 }
303- if (cute::thread0 ())
304- printf (" tid = %d, n = %d, s = %d, i = %d, format_data = %d, value = %d, quant_map[value] = %f, ts = %f, dst = %f\n " ,ThreadIdxX (), n, s, i, static_cast <int >(format_data), static_cast <int >(value), quant_map[value], static_cast <float >(ts), static_cast <float >(dst[i]));
305272 }
306273 }
307274 }
308- #endif
309275 }
310276
311277 CUTLASS_DEVICE
312278 void operator ()(Params const & params, char * smem_buf) {
313- if (cute::thread0 ()) printf (" this is fusion kernel...........\n " );
279+ // if(cute::thread0()) printf("this is fusion kernel...........\n");
314280
315281 int M = params.m ;
316282 int N = params.n ;
@@ -363,21 +329,17 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
363329 auto blk_shape = TileShape{}; // 16,64,64
364330 int m_coord, n_coord, l_coord; // block index
365331 if (params.scheduler .raster_order_ == TileScheduler::RasterOrder::AlongN) {
366- if (cute::thread0 ()) printf (" AlongN !!\n " );
332+ // if(cute::thread0()) printf("AlongN !!\n");
367333 m_coord = BlockIdxY ();
368334 n_coord = BlockIdxX ();
369335 l_coord = BlockIdxZ ();
370336 } else {
371- if (cute::thread0 ()) printf (" not AlongN !!\n " );
337+ // if(cute::thread0()) printf("not AlongN !!\n");
372338 m_coord = BlockIdxX ();
373339 n_coord = BlockIdxY ();
374340 l_coord = BlockIdxZ ();
375341 }
376342 auto blk_coord_mnkl = make_coord (m_coord, n_coord, _, l_coord);
377- if (cute::thread0 ()) {
378- printf (" M = %d, N=%d, K=%d, L=%d\n " , M, N, K, L);
379- printf (" thread_idx = %d, m_coord = %d, n_coord = %d, l_coord = %d, BlockIdxX() = %d, BlockIdxY() = %d, BlockIdxZ() = %d\n " ,thread_idx, m_coord, n_coord, l_coord, BlockIdxX (), BlockIdxY (), BlockIdxZ ());
380- }
381343 constexpr auto workgroup_shape = WorkgroupTileShape{}; // 256, 256, 32
382344 constexpr auto subgroup_tile_shape = SubgroupTileShape{}; // 32, 64, 32 (number of atom level workgroup: 256/8=32, 256/4=64, 32/2=32)
383345
@@ -395,7 +357,6 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
395357// // Create K slicing tiling iterator and count
396358 auto k_tile_iter = cute::make_coord_iterator (idx2crd (0 , make_shape (K)), make_shape (K));
397359 int k_tile_count = ceil_div (K, get<2 >(workgroup_shape)); // inner_loop number
398- if (cute::thread0 ()) printf (" k_tile_count = %d\n " , k_tile_count);
399360
400361
401362// //// MainLoop //////
@@ -417,13 +378,10 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
417378
418379 Tensor dequant_frag = make_tensor<ElementB>(mma_B.layout ());
419380
420- // const int SubgroupSize = 16;
421381 static constexpr auto scale_traits_size = decltype (size (typename GmemTiledCopyScale::BlockShape{}))::value / DispatchPolicy::SubgroupSize; // SubgroupSize;
422382 static constexpr auto scale_traits_num = SG_QNT_WIDTH / decltype (size<1 >(typename GmemTiledCopyScale::BlockShape{}))::value;
423383 using FragScaleLayout = Layout<Shape<Int<scale_traits_size>, Int<scale_traits_num>, _1>>;
424- // using FragScaleLayout = Layout<Shape<Int<scale_traits_size>, Int<scale_traits_num>, _1>, Stride<_1,_1,_0>>;
425384 Tensor fragment_scale = make_tensor<ElementScale>(FragScaleLayout{});
426- if (cute::thread0 ()) printf (" scale_traits_size = %d, scale_traits_num = %d, SG_QNT_WIDTH = %d, BlockShape = %d, BlockShape_1= %d\n " , scale_traits_size, scale_traits_num, SG_QNT_WIDTH , decltype (size (typename GmemTiledCopyScale::BlockShape{}))::value, decltype (size<1 >(typename GmemTiledCopyScale::BlockShape{}))::value);
427385
428386 static_assert (std::is_same_v<typename decltype (dequant_frag)::value_type, ElementQuant>);
429387 static_assert (std::is_same_v<typename decltype (mma_A)::value_type, ElementMMA>);
@@ -433,15 +391,6 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
433391 Tensor frag_copy_A = thr_copy_A.retile_D (mma_A);
434392 Tensor frag_copy_B = thr_copy_B.retile_D (dequant_frag);
435393 Tensor frag_copy_Scale = thr_copy_scale.retile_D (fragment_scale);
436- // auto frag_layout = make_layout(
437- // make_shape(_2{}, _1{}, _1{}), // 形状 (_2, _1, _1)
438- // make_stride(_1{}, _1{}, _0{}) // 步长 (_1, _1, _0)
439- // );
440- // Tensor frag_copy_Scale = thr_copy_scale.retile_D(make_tensor(fragment_scale.data(), frag_layout));
441-
442- // using FragLayout = Layout<Shape<_2,_1,_1>, Stride<_1,_1,_0>>;
443- // Tensor fragment_scale = make_tensor<ElementScale>(FragLayout{});
444- // Tensor frag_copy_Scale = thr_copy_scale.retile_D(fragment_scale);
445394
446395// // Retile global counting tensors for copies:
447396 Tensor tAgA = thr_copy_A.retile_S (tCgA);
@@ -458,26 +407,14 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
458407 auto pBgB = thr_prefetch_B.partition_S (gB );
459408
460409// Run mainloop
461- // auto [m_idx, n_idx, k_idx, l_idx] = blk_coord_mnkl;
462- // const int n_coord_s = n_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N;
463- // const int l_coord_s = l_idx;
464-
465- // if(cute::thread0()) printf("get_sub_group_id() = %d, m_idx = %d, n_idx = %d, k_idx = %d, l_idx = %d, n_coord_s = %d, l_coord_s = %d\n",get_sub_group_id(), m_idx, n_idx, k_idx, l_idx, n_coord_s, l_coord_s);
466-
467410 auto copy_iter_s = [&](){
468411 return make_tensor (make_inttuple_iter (make_coord (n_coord, 0 , l_coord)),
469412 make_layout (make_shape (Int<scale_traits_size>{}, Int<scale_traits_num>{}, _1{}, k_tile_count),
470413 make_stride (E<0 >{} * _16{}, E<0 >{} * decltype (size<1 >(typename GmemTiledCopyScale::BlockShape{}))::value, _0{}, E<1 >{} * _1{})));
471414
472415 }();
473416
474- // auto copy_iter_s = [&](){
475- // return make_tensor(make_inttuple_iter(make_coord(n_coord, 0, l_coord)),
476- // make_layout(make_shape(Int<decltype(size<0>(typename GmemTiledCopyScale::BlockShape{}))::value>{}, Int<decltype(size<1>(typename GmemTiledCopyScale::BlockShape{}))::value>{}, _1{}, k_tile_count),
477- // make_stride(_16{}, _32{}, _0{}, _1{})));
478- // }();
479-
480- #if 1
417+ #if 0
481418 #define PRINT(x) print(#x ": "); print(x); print("\n");
482419 if (cutlass::thread(LOG_THREAD, LOG_GROUP)) {
483420 print("\n\n======================= A: \n");
@@ -518,10 +455,9 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
518455 const int k_start_idx = crd2idx ((*k_tile_iter), make_shape (K));
519456 int prefetch_k = k_start_idx;
520457
521- #if 1
522458 const int k_reload_factor = ceil_div (params.group_size , BLK_K );
523- if (cute::thread0 ()) printf (" params.group_size = %d, BLK_K = %d, k_reload_factor = %f\n " ,params.group_size , BLK_K , k_reload_factor);
524- # endif
459+ // if(cute::thread0()) printf("params.group_size = %d, BLK_K = %d, k_reload_factor = %f\n",params.group_size, BLK_K, k_reload_factor);
460+
525461 CUTLASS_PRAGMA_UNROLL
526462 for (int i = 0 ; i < DispatchPolicy::Stages; i++, prefetch_k++) {
527463 prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
@@ -534,19 +470,11 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
534470 // Copy gmem to rmem for the first k_tile
535471 copy (tiled_copy_a, tAgA (_,_,_,k_tile), frag_copy_A);
536472 copy (tiled_copy_b, tBgB (_,_,_,k_tile), frag_copy_B);
537- #if 1
538- const int s_step = k_start_idx + (k_s / k_reload_factor); // 1 + k_tile / k_reload_factor;
539- if (cute::thread0 ()) printf (" k_start_idx = %d, k_s = %d, k_reload_factor = %f, s_step = %d\n " ,k_start_idx, k_s, k_reload_factor, s_step);
540- copy (tiled_copy_scale, copy_iter_s (_, _, _, s_step), frag_copy_Scale);
541- #else
542- const int k_reload_factor = ceil_div(params.group_size, BLK_K);
543- //const int k_reload_factor = params.group_size / BLK_K;
544473
545- //if(cute::thread0())
546- printf("params.group_size = %d, BLK_K = %d, k_reload_factor = %d\n",params.group_size, BLK_K, k_reload_factor);
474+ const int s_step = k_start_idx + (k_s / k_reload_factor);
475+ // if(cute::thread0()) printf("k_start_idx = %d, k_s = %d, k_reload_factor = %f, s_step = %d\n",k_start_idx, k_s, k_reload_factor, s_step);
476+ copy (tiled_copy_scale, copy_iter_s (_, _, _, s_step), frag_copy_Scale);
547477
548- copy(tiled_copy_scale, copy_iter_s(_, _, _, k_tile / k_reload_factor), frag_copy_Scale);
549- #endif
550478 if (prefetch_k < k_tile_count) {
551479 prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
552480 }
@@ -617,41 +545,16 @@ template <typename T, int BITS>
617545void gemm_4bit_inference_cutlass_dequant (int m, int n, int k, T *A, unsigned char *B,
618546 T *absmax_, float *datatype, float *out, int lda,
619547 int ldb, int ldc, int blocksize, sycl::queue *stream) {
620- std::cout<<" this is gemm_4bit_inference_cutlass_dequant ......................!!!!!!\n " ;
548+ // std::cout<<"this is gemm_4bit_inference_cutlass_dequant ......................!!!!!!\n";
621549
622550 sycl::queue q = *stream;
623551 using GemmKernel = kgemm_4bit_inference_cutlass_dequant<T, BITS >;
624552
625553 static constexpr int smem_size= 512 ; // (16 * 32) for quant_map
626554 int l = 1 ;
627555
628- // TODO(Xiaoli): FIX ME?? auto problem_size = ProblemShape{m, n, k};
629556 auto problem_size = ProblemShape{m, n, k, l};
630- // TODO(Xiaoli): FIX ME
631- // T* absmax = (T*)absmax_;
632- // T* absmax = (T*)absmax_;
633557
634- // std::vector<T> host_data(n * k / blocksize);
635- #if 0
636- int element_size_A = m * k;
637- auto scale_host_A = sycl::aligned_alloc_host<T>(512, element_size_A, q);
638- q.memcpy(scale_host_A, A, element_size_A * sizeof(T)).wait();
639- for (int i = 0; i < element_size_A; ++i) {
640- //std::cout << scale_host[i] << " ";
641- printf("%f ",static_cast<float>(scale_host_A[i]));
642- }
643- std::cout << std::endl;
644-
645- int element_size = n * k / blocksize;
646- auto scale_host = sycl::aligned_alloc_host<T>(512, element_size, q);
647- q.memcpy(scale_host, absmax_, element_size * sizeof(T)).wait();
648- for (int i = 0; i < element_size; ++i) {
649- //std::cout << scale_host[i] << " ";
650- printf("%f ",static_cast<float>(scale_host[i]));
651- }
652- std::cout << std::endl;
653- #endif
654- #if 1
655558 // Init Params
656559 using Params = GemmKernel::Params;
657560 Params params;
@@ -678,7 +581,7 @@ std::cout << std::endl;
678581
679582 const int scale_k = cute::ceil_div (k, blocksize);
680583 StrideScale stride_S = cutlass::make_cute_packed_stride (StrideScale{}, cute::make_shape (n, scale_k, l));
681- std::cout<<" n = " <<n<<" k = " <<k<<" blocksize = " <<blocksize<<" scale_k = " <<scale_k<<std::endl;
584+ // std::cout<<"n = "<<n<<" k = "<<k<<" blocksize = "<<blocksize<<" scale_k = "<<scale_k<<std::endl;
682585 auto mScale = make_tensor (
683586 make_gmem_ptr (absmax_),
684587 make_layout (make_shape (n, scale_k, l), stride_S));
@@ -694,6 +597,7 @@ std::cout << std::endl;
694597 StrideC stride_C = cutlass::make_cute_packed_stride (StrideC{}, cute::make_shape (m, n, l));
695598 StrideD stride_D = cutlass::make_cute_packed_stride (StrideD{}, cute::make_shape (m, n, l));
696599
600+ #if 0
697601 #define PRINT(x) print(#x ": "); print(x); print("\n");
698602 if (cutlass::thread(LOG_THREAD, LOG_GROUP)) {
699603 print("===================== stride :\n");
@@ -705,6 +609,7 @@ std::cout << std::endl;
705609 print("===================== stride :\n");
706610 }
707611 #undef PRINT
612+ #endif
708613
709614 params.hw_info = hw_info;
710615 params.epilogue = CollectiveEpilogue::to_underlying_arguments (problem_size, {{alpha, beta}, nullptr , stride_C, out, stride_D}, nullptr );
@@ -721,8 +626,8 @@ std::cout << std::endl;
721626
722627 const syclcompat::dim3 sycl_block (block.x , block.y , block.z ); // workgroup_size: 1*2*1*16, 1, 1
723628 const syclcompat::dim3 sycl_grid (grid.x , grid.y , grid.z ); // workgroup_number (problem_size / tile_size): N/64, M/16, 1
724- printf (" Host Grid: (%d, %d, %d)\n " , grid.x , grid.y , grid.z );
725- printf (" Host Block: (%d, %d, %d)\n " , block.x , block.y , block.z );
629+ // printf("Host Grid: (%d, %d, %d)\n", grid.x, grid.y, grid.z);
630+ // printf("Host Block: (%d, %d, %d)\n", block.x, block.y, block.z);
726631
727632 auto kernel_props = [] {
728633 return syclcompat::experimental::kernel_properties{
@@ -739,7 +644,6 @@ std::cout << std::endl;
739644 auto event = syclcompat::experimental::launch<device_kernel<GemmKernel>>(policy, q, params);
740645 EventManager::getInstance ().addEvent (event);
741646 // syclcompat::wait();
742- #endif
743647}
744648
745649template void gemm_4bit_inference_cutlass_dequant<sycl::ext::oneapi::bfloat16, 16 >(
0 commit comments