@@ -176,7 +176,6 @@ class kgemm_4bit_inference_cutlass_dequant {
176176 T* A;
177177 uint8_t * B;
178178 float * out;
179- // T *absmax;
180179 float *datatype; // LUT
181180 int group_size;
182181
@@ -206,11 +205,6 @@ class kgemm_4bit_inference_cutlass_dequant {
206205 }
207206 }
208207
209- /* float bfloat16_to_float(uint16_t bf16_bits) {
210- uint32_t float_bits = (bf16_bits << 16); // 将 bfloat16 左移16位转为 float
211- return reinterpret_cast<float&>(float_bits);
212- }*/
213-
214208 // / Utilities to transform A.
215209 template <class EngineIn ,
216210 class EngineOut ,
@@ -233,31 +227,7 @@ class kgemm_4bit_inference_cutlass_dequant {
233227 using SrcType = typename EngineIn::value_type;
234228 using DstType = typename EngineOut::value_type;
235229 using ScaleType = typename EngineScales::value_type;
236- #if 0
237- static constexpr auto N = decltype(size<1>(in))::value;
238- static constexpr auto loop_cnt = decltype(size(out))::value / N;
239- for (int n = 0; n < N; n++) {
240- auto s_value = tCrS_input(i);
241- for (int l = 0; s < loop_cnt; l++) {
242-
243- // int numbers = decltype(size(in))::value;
244- // for(int i=0; i<numbers / N; i++){
245- // //auto in_ptr_8 = (uint8_t*)(raw_pointer_cast(in.data()));
246- // //out[i] = static_cast<DstType>(quant_map[in_ptr_8[i].data()]);
247- // uint8_t value = in[i].get();
248- // out[i] = static_cast<DstType>(quant_map[value]);
249- // int thread_idx = int(ThreadIdxX());
250- // if(cute::thread0()){
251- // //if(syclcompat::global_id::x() == 2 && syclcompat::global_id::y() ==0 && syclcompat::global_id::z() ==0 )
252- // //printf("syclcompat::global_id::x() = %d, syclcompat::global_id::y() = %d, syclcompat::global_id::z() = %d, thread_idx = %d, i = %d, in[i].ptr_ = %x, in[i].idx_=%x, value_bit = %x, value = %d, quant_map[value] = %f, out[i] = %f\n",syclcompat::global_id::x(), syclcompat::global_id::y(), syclcompat::global_id::z(), thread_idx, i, in[i].ptr_, in[i].idx_, value, static_cast<int>(value), quant_map[value], static_cast<float>(out[i]));
253- // }
254- // }
255- // int scale_number = decltype(size(tCrS_input))::value;
256- // for(int i=0; i<scale_number; i++){
257- // auto s_value = tCrS_input(i);
258- // if(cute::thread0()) printf("scale_number = %d, tCrS_input[%d] = %f\n",scale_number, i, static_cast<float>(s_value));
259- // }
260- #else
230+
261231 static constexpr auto N = decltype (size<1 >(in))::value;
262232
263233 using format_type = ushort; // 16
@@ -275,13 +245,6 @@ class kgemm_4bit_inference_cutlass_dequant {
275245 auto s_tensor = make_tensor ((format_type*)(raw_pointer_cast (in.data ())), Shape<Int<loop_cnt / scalar>, Int<N>>{});
276246 auto d_tensor = make_tensor (out.data (), Shape<Int<vec_size>, Int<splits>, Int<N>>{});
277247
278- int scale_number = decltype (size (tCrS_input))::value;
279- for (int i=0 ; i<scale_number; i++){
280- auto s_value = tCrS_input (i);
281- if (cute::thread0 ()) printf (" scale_number = %d, tCrS_input[%d] = %f\n " ,scale_number, i, static_cast <float >(s_value));
282- }
283- // printf("thread_idx = %d, decltype(size(in))::value = %d, K = %d, N = %d, L = %d, src_bits = %d, sizeof_bits_v<format_type> = %d, scalar = %d, decltype(size(out))::value = %d, loop_cnt = %d, splits = %d\n",int(ThreadIdxX()), decltype(size(in))::value, decltype(size<0>(in))::value, N, decltype(size<2>(in))::value, src_bits, sizeof_bits_v<format_type>, scalar, decltype(size(out))::value, loop_cnt, splits);
284-
285248 for (int n = 0 ; n < N; n++) {
286249 const auto ts = tCrS_input (n);
287250
@@ -300,17 +263,14 @@ for(int i=0; i<scale_number; i++){
300263 } else {
301264 dst[i+1 ] = static_cast <DstType>(quant_map[value] * static_cast <float >(ts));
302265 }
303- if (cute::thread0 ())
304- printf (" tid = %d, n = %d, s = %d, i = %d, format_data = %d, value = %d, quant_map[value] = %f, ts = %f, dst = %f\n " ,ThreadIdxX (), n, s, i, static_cast <int >(format_data), static_cast <int >(value), quant_map[value], static_cast <float >(ts), static_cast <float >(dst[i]));
305266 }
306267 }
307268 }
308- #endif
309269 }
310270
311271 CUTLASS_DEVICE
312272 void operator ()(Params const & params, char * smem_buf) {
313- if (cute::thread0 ()) printf (" this is fusion kernel...........\n " );
273+ // if(cute::thread0()) printf("this is fusion kernel...........\n");
314274
315275 int M = params.m ;
316276 int N = params.n ;
@@ -363,21 +323,17 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
363323 auto blk_shape = TileShape{}; // 16,64,64
364324 int m_coord, n_coord, l_coord; // block index
365325 if (params.scheduler .raster_order_ == TileScheduler::RasterOrder::AlongN) {
366- if (cute::thread0 ()) printf (" AlongN !!\n " );
326+ // if(cute::thread0()) printf("AlongN !!\n");
367327 m_coord = BlockIdxY ();
368328 n_coord = BlockIdxX ();
369329 l_coord = BlockIdxZ ();
370330 } else {
371- if (cute::thread0 ()) printf (" not AlongN !!\n " );
331+ // if(cute::thread0()) printf("not AlongN !!\n");
372332 m_coord = BlockIdxX ();
373333 n_coord = BlockIdxY ();
374334 l_coord = BlockIdxZ ();
375335 }
376336 auto blk_coord_mnkl = make_coord (m_coord, n_coord, _, l_coord);
377- if (cute::thread0 ()) {
378- printf (" M = %d, N=%d, K=%d, L=%d\n " , M, N, K, L);
379- printf (" thread_idx = %d, m_coord = %d, n_coord = %d, l_coord = %d, BlockIdxX() = %d, BlockIdxY() = %d, BlockIdxZ() = %d\n " ,thread_idx, m_coord, n_coord, l_coord, BlockIdxX (), BlockIdxY (), BlockIdxZ ());
380- }
381337 constexpr auto workgroup_shape = WorkgroupTileShape{}; // 256, 256, 32
382338 constexpr auto subgroup_tile_shape = SubgroupTileShape{}; // 32, 64, 32 (number of atom level workgroup: 256/8=32, 256/4=64, 32/2=32)
383339
@@ -395,7 +351,6 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
395351// // Create K slicing tiling iterator and count
396352 auto k_tile_iter = cute::make_coord_iterator (idx2crd (0 , make_shape (K)), make_shape (K));
397353 int k_tile_count = ceil_div (K, get<2 >(workgroup_shape)); // inner_loop number
398- if (cute::thread0 ()) printf (" k_tile_count = %d\n " , k_tile_count);
399354
400355
401356// //// MainLoop //////
@@ -417,13 +372,10 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
417372
418373 Tensor dequant_frag = make_tensor<ElementB>(mma_B.layout ());
419374
420- // const int SubgroupSize = 16;
421375 static constexpr auto scale_traits_size = decltype (size (typename GmemTiledCopyScale::BlockShape{}))::value / DispatchPolicy::SubgroupSize; // SubgroupSize;
422376 static constexpr auto scale_traits_num = SG_QNT_WIDTH / decltype (size<1 >(typename GmemTiledCopyScale::BlockShape{}))::value;
423377 using FragScaleLayout = Layout<Shape<Int<scale_traits_size>, Int<scale_traits_num>, _1>>;
424- // using FragScaleLayout = Layout<Shape<Int<scale_traits_size>, Int<scale_traits_num>, _1>, Stride<_1,_1,_0>>;
425378 Tensor fragment_scale = make_tensor<ElementScale>(FragScaleLayout{});
426- if (cute::thread0 ()) printf (" scale_traits_size = %d, scale_traits_num = %d, SG_QNT_WIDTH = %d, BlockShape = %d, BlockShape_1= %d\n " , scale_traits_size, scale_traits_num, SG_QNT_WIDTH , decltype (size (typename GmemTiledCopyScale::BlockShape{}))::value, decltype (size<1 >(typename GmemTiledCopyScale::BlockShape{}))::value);
427379
428380 static_assert (std::is_same_v<typename decltype (dequant_frag)::value_type, ElementQuant>);
429381 static_assert (std::is_same_v<typename decltype (mma_A)::value_type, ElementMMA>);
@@ -433,15 +385,6 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
433385 Tensor frag_copy_A = thr_copy_A.retile_D (mma_A);
434386 Tensor frag_copy_B = thr_copy_B.retile_D (dequant_frag);
435387 Tensor frag_copy_Scale = thr_copy_scale.retile_D (fragment_scale);
436- // auto frag_layout = make_layout(
437- // make_shape(_2{}, _1{}, _1{}), // 形状 (_2, _1, _1)
438- // make_stride(_1{}, _1{}, _0{}) // 步长 (_1, _1, _0)
439- // );
440- // Tensor frag_copy_Scale = thr_copy_scale.retile_D(make_tensor(fragment_scale.data(), frag_layout));
441-
442- // using FragLayout = Layout<Shape<_2,_1,_1>, Stride<_1,_1,_0>>;
443- // Tensor fragment_scale = make_tensor<ElementScale>(FragLayout{});
444- // Tensor frag_copy_Scale = thr_copy_scale.retile_D(fragment_scale);
445388
446389// // Retile global counting tensors for copies:
447390 Tensor tAgA = thr_copy_A.retile_S (tCgA);
@@ -458,26 +401,14 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
458401 auto pBgB = thr_prefetch_B.partition_S (gB );
459402
460403// Run mainloop
461- // auto [m_idx, n_idx, k_idx, l_idx] = blk_coord_mnkl;
462- // const int n_coord_s = n_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N;
463- // const int l_coord_s = l_idx;
464-
465- // if(cute::thread0()) printf("get_sub_group_id() = %d, m_idx = %d, n_idx = %d, k_idx = %d, l_idx = %d, n_coord_s = %d, l_coord_s = %d\n",get_sub_group_id(), m_idx, n_idx, k_idx, l_idx, n_coord_s, l_coord_s);
466-
467404 auto copy_iter_s = [&](){
468405 return make_tensor (make_inttuple_iter (make_coord (n_coord, 0 , l_coord)),
469406 make_layout (make_shape (Int<scale_traits_size>{}, Int<scale_traits_num>{}, _1{}, k_tile_count),
470407 make_stride (E<0 >{} * _16{}, E<0 >{} * decltype (size<1 >(typename GmemTiledCopyScale::BlockShape{}))::value, _0{}, E<1 >{} * _1{})));
471408
472409 }();
473410
474- // auto copy_iter_s = [&](){
475- // return make_tensor(make_inttuple_iter(make_coord(n_coord, 0, l_coord)),
476- // make_layout(make_shape(Int<decltype(size<0>(typename GmemTiledCopyScale::BlockShape{}))::value>{}, Int<decltype(size<1>(typename GmemTiledCopyScale::BlockShape{}))::value>{}, _1{}, k_tile_count),
477- // make_stride(_16{}, _32{}, _0{}, _1{})));
478- // }();
479-
480- #if 1
411+ #if 0
481412 #define PRINT(x) print(#x ": "); print(x); print("\n");
482413 if (cutlass::thread(LOG_THREAD, LOG_GROUP)) {
483414 print("\n\n======================= A: \n");
@@ -518,10 +449,9 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
518449 const int k_start_idx = crd2idx ((*k_tile_iter), make_shape (K));
519450 int prefetch_k = k_start_idx;
520451
521- #if 1
522452 const int k_reload_factor = ceil_div (params.group_size , BLK_K );
523- if (cute::thread0 ()) printf (" params.group_size = %d, BLK_K = %d, k_reload_factor = %f\n " ,params.group_size , BLK_K , k_reload_factor);
524- # endif
453+ // if(cute::thread0()) printf("params.group_size = %d, BLK_K = %d, k_reload_factor = %f\n",params.group_size, BLK_K, k_reload_factor);
454+
525455 CUTLASS_PRAGMA_UNROLL
526456 for (int i = 0 ; i < DispatchPolicy::Stages; i++, prefetch_k++) {
527457 prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
@@ -534,19 +464,11 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
534464 // Copy gmem to rmem for the first k_tile
535465 copy (tiled_copy_a, tAgA (_,_,_,k_tile), frag_copy_A);
536466 copy (tiled_copy_b, tBgB (_,_,_,k_tile), frag_copy_B);
537- #if 1
538- const int s_step = k_start_idx + (k_s / k_reload_factor); // 1 + k_tile / k_reload_factor;
539- if (cute::thread0 ()) printf (" k_start_idx = %d, k_s = %d, k_reload_factor = %f, s_step = %d\n " ,k_start_idx, k_s, k_reload_factor, s_step);
540- copy (tiled_copy_scale, copy_iter_s (_, _, _, s_step), frag_copy_Scale);
541- #else
542- const int k_reload_factor = ceil_div(params.group_size, BLK_K);
543- //const int k_reload_factor = params.group_size / BLK_K;
544467
545- //if(cute::thread0())
546- printf("params.group_size = %d, BLK_K = %d, k_reload_factor = %d\n",params.group_size, BLK_K, k_reload_factor);
468+ const int s_step = k_start_idx + (k_s / k_reload_factor);
469+ // if(cute::thread0()) printf("k_start_idx = %d, k_s = %d, k_reload_factor = %f, s_step = %d\n",k_start_idx, k_s, k_reload_factor, s_step);
470+ copy (tiled_copy_scale, copy_iter_s (_, _, _, s_step), frag_copy_Scale);
547471
548- copy(tiled_copy_scale, copy_iter_s(_, _, _, k_tile / k_reload_factor), frag_copy_Scale);
549- #endif
550472 if (prefetch_k < k_tile_count) {
551473 prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
552474 }
@@ -617,41 +539,16 @@ template <typename T, int BITS>
617539void gemm_4bit_inference_cutlass_dequant (int m, int n, int k, T *A, unsigned char *B,
618540 T *absmax_, float *datatype, float *out, int lda,
619541 int ldb, int ldc, int blocksize, sycl::queue *stream) {
620- std::cout<<" this is gemm_4bit_inference_cutlass_dequant ......................!!!!!!\n " ;
542+ // std::cout<<"this is gemm_4bit_inference_cutlass_dequant ......................!!!!!!\n";
621543
622544 sycl::queue q = *stream;
623545 using GemmKernel = kgemm_4bit_inference_cutlass_dequant<T, BITS >;
624546
625547 static constexpr int smem_size= 512 ; // (16 * 32) for quant_map
626548 int l = 1 ;
627549
628- // TODO(Xiaoli): FIX ME?? auto problem_size = ProblemShape{m, n, k};
629550 auto problem_size = ProblemShape{m, n, k, l};
630- // TODO(Xiaoli): FIX ME
631- // T* absmax = (T*)absmax_;
632- // T* absmax = (T*)absmax_;
633551
634- // std::vector<T> host_data(n * k / blocksize);
635- #if 0
636- int element_size_A = m * k;
637- auto scale_host_A = sycl::aligned_alloc_host<T>(512, element_size_A, q);
638- q.memcpy(scale_host_A, A, element_size_A * sizeof(T)).wait();
639- for (int i = 0; i < element_size_A; ++i) {
640- //std::cout << scale_host[i] << " ";
641- printf("%f ",static_cast<float>(scale_host_A[i]));
642- }
643- std::cout << std::endl;
644-
645- int element_size = n * k / blocksize;
646- auto scale_host = sycl::aligned_alloc_host<T>(512, element_size, q);
647- q.memcpy(scale_host, absmax_, element_size * sizeof(T)).wait();
648- for (int i = 0; i < element_size; ++i) {
649- //std::cout << scale_host[i] << " ";
650- printf("%f ",static_cast<float>(scale_host[i]));
651- }
652- std::cout << std::endl;
653- #endif
654- #if 1
655552 // Init Params
656553 using Params = GemmKernel::Params;
657554 Params params;
@@ -678,7 +575,7 @@ std::cout << std::endl;
678575
679576 const int scale_k = cute::ceil_div (k, blocksize);
680577 StrideScale stride_S = cutlass::make_cute_packed_stride (StrideScale{}, cute::make_shape (n, scale_k, l));
681- std::cout<<" n = " <<n<<" k = " <<k<<" blocksize = " <<blocksize<<" scale_k = " <<scale_k<<std::endl;
578+ // std::cout<<"n = "<<n<<" k = "<<k<<" blocksize = "<<blocksize<<" scale_k = "<<scale_k<<std::endl;
682579 auto mScale = make_tensor (
683580 make_gmem_ptr (absmax_),
684581 make_layout (make_shape (n, scale_k, l), stride_S));
@@ -694,6 +591,7 @@ std::cout << std::endl;
694591 StrideC stride_C = cutlass::make_cute_packed_stride (StrideC{}, cute::make_shape (m, n, l));
695592 StrideD stride_D = cutlass::make_cute_packed_stride (StrideD{}, cute::make_shape (m, n, l));
696593
594+ #if 0
697595 #define PRINT(x) print(#x ": "); print(x); print("\n");
698596 if (cutlass::thread(LOG_THREAD, LOG_GROUP)) {
699597 print("===================== stride :\n");
@@ -705,6 +603,7 @@ std::cout << std::endl;
705603 print("===================== stride :\n");
706604 }
707605 #undef PRINT
606+ #endif
708607
709608 params.hw_info = hw_info;
710609 params.epilogue = CollectiveEpilogue::to_underlying_arguments (problem_size, {{alpha, beta}, nullptr , stride_C, out, stride_D}, nullptr );
@@ -721,8 +620,8 @@ std::cout << std::endl;
721620
722621 const syclcompat::dim3 sycl_block (block.x , block.y , block.z ); // workgroup_size: 1*2*1*16, 1, 1
723622 const syclcompat::dim3 sycl_grid (grid.x , grid.y , grid.z ); // workgroup_number (problem_size / tile_size): N/64, M/16, 1
724- printf (" Host Grid: (%d, %d, %d)\n " , grid.x , grid.y , grid.z );
725- printf (" Host Block: (%d, %d, %d)\n " , block.x , block.y , block.z );
623+ // printf("Host Grid: (%d, %d, %d)\n", grid.x, grid.y, grid.z);
624+ // printf("Host Block: (%d, %d, %d)\n", block.x, block.y, block.z);
726625
727626 auto kernel_props = [] {
728627 return syclcompat::experimental::kernel_properties{
@@ -739,7 +638,6 @@ std::cout << std::endl;
739638 auto event = syclcompat::experimental::launch<device_kernel<GemmKernel>>(policy, q, params);
740639 EventManager::getInstance ().addEvent (event);
741640 // syclcompat::wait();
742- #endif
743641}
744642
745643template void gemm_4bit_inference_cutlass_dequant<sycl::ext::oneapi::bfloat16, 16 >(
0 commit comments