@@ -69,7 +69,7 @@ using TiledMma =
6969 Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
7070
7171// Define Mainloop dispatch policy
72- constexpr int PipelineStages = 3 ;
72+ constexpr int PipelineStages = 0 ;
7373using DispatchPolicy = cutlass::gemm::MainloopIntelPVCMixedPrecision<PipelineStages>;
7474static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; // sub_group size
7575
@@ -140,7 +140,7 @@ using GmemTiledCopyC = CopyOpG2R;
140140using GmemTiledCopyD = cute::conditional_t <not cute::is_void_v<ElementD> && not cute::is_void_v<CopyOpR2G>,
141141 CopyOpR2G, XE_2D_U32x8x16_ST_N>;
142142
143- // TODO(Xiaoli): Maybe legacy, refine me.
143+ // Calculate subgroup_tile_shape (reminder: not the same thing with "subgroup_size" in sycl!!)
144144static constexpr auto BLK_M = get<0 >(WorkgroupTileShape{});
145145static constexpr auto BLK_N = get<1 >(WorkgroupTileShape{});
146146static constexpr auto BLK_K = get<2 >(WorkgroupTileShape{});
@@ -174,16 +174,14 @@ class kgemm_4bit_inference_cutlass_dequant {
174174 int m, n, k;
175175 T* A;
176176 uint8_t * B;
177- float *absmax; // TODO(Xiaoli): FIX ME
178177 float * out;
179- float *datatype;
178+ float *datatype; // LUT
180179
181- // GemmUniversalMode mode{};
182180 ProblemShape problem_shape{};
183-
184- // inloopParams mainloop{};
181+
185182 Copy_A tiled_copy_a;
186183 Copy_B tiled_copy_b;
184+ Copy_B tiled_copy_b_4bit;
187185 Copy_Scale tiled_copy_scale;
188186 int group_size;
189187
@@ -309,45 +307,41 @@ class kgemm_4bit_inference_cutlass_dequant {
309307
310308 CUTLASS_DEVICE
311309 void operator ()(Params const & params, char * smem_buf) {
310+ if (cute::thread0 ()) printf (" this is fusion kernel...........\n " );
311+
312312 int M = params.m ;
313313 int N = params.n ;
314314 int K = params.k ;
315315 T* A = params.A ;
316316 uint8_t * B = params.B ;
317317 float * out = params.out ;
318318 float * datatype = params.datatype ;
319- // int blocksize = params.blocksize;
320319 auto tiled_copy_a = params.tiled_copy_a ;
321320 auto tiled_copy_b = params.tiled_copy_b ;
322- auto tiled_copy_scale = params.tiled_copy_scale ;
323- if ( cute::thread0 ())
324- printf ( " this is fusion kernel........... \n " );
321+ auto tiled_copy_b_4bit = params.tiled_copy_b_4bit ;
322+ auto tiled_copy_scale = params. tiled_copy_scale ;
323+
325324 int L = 1 ;
326325 auto problem_size = ProblemShape{M, N, K, L};
327-
328- // TODO(Xiaoli): FIX ME
329- SharedStorage& shared_storage = *reinterpret_cast <SharedStorage*>(smem_buf);
330326
331- float * quant_map = reinterpret_cast <float *>(smem_buf);
332327 // Preconditions
333328 static_assert (cute::rank (StrideA{}) == 3 , " StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>." );
334329 static_assert (cute::rank (StrideB{}) == 3 , " StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>." );
335330 static_assert (cute::rank (StrideC{}) == 3 , " StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>." );
336331 static_assert (cute::rank (StrideD{}) == 3 , " StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>." );
337332
338- // Get the appropriate blocks for this sub_group -- potential for sub_group locality
339333 int thread_idx = int (ThreadIdxX ());
340- // #if 0
341- // Load Dequat table
334+
335+ // Load Dequatize LUT and save to SLM, 16 for 4bits
336+ float * quant_map = reinterpret_cast <float *>(smem_buf);
342337 if (thread_idx < 16 ) {
343- quant_map[thread_idx] = datatype[thread_idx]; // T(datatype[thread_idx]);
338+ quant_map[thread_idx] = datatype[thread_idx];
344339 printf (" quant_map[thread_idx] = %f\n " , quant_map[thread_idx]);
345340 }
346341 barrier_wait (1 );
347342
348- #if 1
349- auto blk_shape = TileShape{};
350- int m_coord, n_coord, l_coord;
343+ auto blk_shape = TileShape{}; // 256,256,32
344+ int m_coord, n_coord, l_coord; // block index
351345 if (params.scheduler .raster_order_ == TileScheduler::RasterOrder::AlongN) {
352346 if (cute::thread0 ()) printf (" AlongN !!\n " );
353347 m_coord = BlockIdxY ();
@@ -359,25 +353,23 @@ class kgemm_4bit_inference_cutlass_dequant {
359353 n_coord = BlockIdxY ();
360354 l_coord = BlockIdxZ ();
361355 }
356+ auto blk_coord_mnkl = make_coord (m_coord, n_coord, _, l_coord);
362357 if (cute::thread0 ()) printf (" M = %d, N=%d, K=%d, L=%d, m_coord = %d, n_coord = %d, l_coord = %d, BlockIdxX() = %d, BlockIdxY() = %d, BlockIdxZ() = %d\n " ,M, N, K, L, m_coord, n_coord, l_coord, BlockIdxX (), BlockIdxY (), BlockIdxZ ());
363358
364- auto blk_coord_mnkl = make_coord (m_coord, n_coord, _, l_coord);
365- constexpr auto workgroup_shape = WorkgroupTileShape{};
366- constexpr auto subgroup_shape = SubgroupTileShape{};
367- if (cute::thread0 ())
368- printf (" BLK_M = %d, BLK_N = %d, BLK_K = %d, ATOM_M = %d, ATOM_N = %d, ATOM_K = %d, SG_M = %d, SG_N = %d, SG_K = %d\n " , BLK_M , BLK_N , BLK_K , ATOM_M , ATOM_N , ATOM_K , SG_M , SG_N , SG_K );
359+ constexpr auto workgroup_shape = WorkgroupTileShape{}; // 256, 256, 32
360+ constexpr auto subgroup_tile_shape = SubgroupTileShape{}; // 256/8=32, 256/16=16, 32/16=2
369361
370- Tensor mA_mkl = cute::get_pvc_tensor (make_shape (M,K,L)); // (m,k,l)
371- Tensor mB_nkl = cute::get_pvc_tensor (make_shape (N,K,L)); // (n,k,l)
362+ Tensor mA_mkl = cute::get_pvc_tensor (make_shape (M,K,L)); // coordinate tensor: 0,1,2....
363+ Tensor mB_nkl = cute::get_pvc_tensor (make_shape (N,K,L)); // coordinate tensor: 0,1,2....
372364
373365 Tensor gA = local_tile (mA_mkl , select<0 ,2 >(blk_shape), make_coord (m_coord,_,l_coord));
374366 Tensor gB = local_tile (mB_nkl , select<1 ,2 >(blk_shape), make_coord (n_coord,_,l_coord));
375367
376- // Allocate the tiled_mma and the accumulators for the (M,N) subgroup_shape
368+ // Allocate the tiled_mma and the accumulators for the (M,N) subgroup_tile_shape
377369 TiledMma tiled_mma;
378370
379- auto expanded_shape = replace<1 >(blk_shape, cute::C<2 >{} * get<1 >(blk_shape));
380- Tensor accumulators = partition_fragment_C (tiled_mma, take<0 ,2 >(expanded_shape ));
371+ // auto expanded_shape = replace<1>(blk_shape, cute::C<2>{} * get<1>(blk_shape));
372+ Tensor accumulators = partition_fragment_C (tiled_mma, take<0 ,2 >(blk_shape ));
381373 clear (accumulators);
382374
383375 auto k_tile_iter = cute::make_coord_iterator (idx2crd (0 , make_shape (K)), make_shape (K));
@@ -387,6 +379,7 @@ class kgemm_4bit_inference_cutlass_dequant {
387379// Run MainLoop
388380 auto thr_copy_A = tiled_copy_a.get_slice (thread_idx);
389381 auto thr_copy_B = tiled_copy_b.get_slice (thread_idx);
382+ auto thr_copy_B_4bit = tiled_copy_b_4bit.get_slice (thread_idx);
390383 auto thr_copy_scale = tiled_copy_scale.get_slice (thread_idx);
391384
392385 auto sg = syclcompat::get_nd_item<1 >().get_sub_group ();
@@ -397,39 +390,39 @@ class kgemm_4bit_inference_cutlass_dequant {
397390 Tensor tCgA = thr_mma.partition_A (gA );
398391 Tensor tCgB = thr_mma.partition_B (gB );
399392
400- // Create fragments
393+ // Create fragments
401394 Tensor mma_A = make_tensor<ElementMMA>(make_fragment_layout (tiled_copy_a, tCgA (_,_,_,0 ).shape ()));
402395 Tensor mma_B = make_tensor<ElementMMA>(make_fragment_layout (tiled_copy_b, tCgB (_,_,_,0 ).shape ()));
403396
404397 using FragScaleLayout = Layout<Shape<_2, _2, _1>>;
405398 Tensor fragment_scale_input = make_tensor<NonVoidElementScale>(FragScaleLayout{});
406399
407400 // narrow input fragment
408- Tensor quant_frag = make_tensor<ElementQuant>(decltype (mma_B.layout ()){});
401+ Tensor mma_B_4bit = make_tensor<ElementMMA>(make_fragment_layout (tiled_copy_b_4bit, tCgB (_,_,_,0 ).shape ()));
402+ Tensor quant_frag = make_tensor<ElementQuant>(decltype (mma_B_4bit.layout ()){});
409403
410- auto original_shape = tCgB (_,_,_,0 ).shape ();
411- auto expanded_shape_2 = make_shape (cute::get<0 >(original_shape), cute::C<2 >{} * cute::get<1 >(original_shape),cute::get<2 >(original_shape));
412- auto expanded_layout = make_fragment_layout (tiled_copy_b, expanded_shape_2);
413- Tensor mma_B_expanded = make_tensor<ElementMMA>(expanded_layout);
404+ // auto original_shape = tCgB(_,_,_,0).shape();
405+ // auto expanded_shape_2 = make_shape(cute::get<0>(original_shape), cute::C<2>{} * cute::get<1>(original_shape),cute::get<2>(original_shape));
406+ // auto expanded_layout = make_fragment_layout(tiled_copy_b, expanded_shape_2);
407+ // Tensor mma_B_expanded = make_tensor<ElementMMA>(expanded_layout);
414408
415409 static_assert (std::is_same_v<typename decltype (quant_frag)::value_type, ElementQuant>);
416410 static_assert (std::is_same_v<typename decltype (mma_A)::value_type, ElementMMA>);
417411 static_assert (std::is_same_v<typename decltype (mma_B)::value_type, ElementMMA>);
418412
419413 // Retile for copy
420414 auto [frag_copy_A, frag_copy_B] = [&](){
421- return std::make_pair (thr_copy_A.retile_D (mma_A), thr_copy_B .retile_D (quant_frag));
415+ return std::make_pair (thr_copy_A.retile_D (mma_A), thr_copy_B_4bit .retile_D (quant_frag));
422416 }();
423417
424418 Tensor copy_tCrS = thr_copy_scale.retile_D (fragment_scale_input);
425- // Tensor copy_tCrZ = thr_copy_zero.retile_D(fragment_zero_input);
426419
427420 // Retile global counting tensors for copies
428421 Tensor tAgA = thr_copy_A.retile_S (tCgA);
429- Tensor tBgB = thr_copy_B .retile_S (tCgB);
422+ Tensor tBgB = thr_copy_B_4bit .retile_S (tCgB);
430423
431424 auto tiled_prefetch_a = cute::prefetch_selector<Shape<Int<BLK_M >,Int<BLK_K >>, Num_SGs>(tiled_copy_a);
432- auto tiled_prefetch_b = cute::prefetch_selector<Shape<Int<BLK_N >,Int<BLK_K >>, Num_SGs>(tiled_copy_b );
425+ auto tiled_prefetch_b = cute::prefetch_selector<Shape<Int<BLK_N >,Int<BLK_K >>, Num_SGs>(tiled_copy_b_4bit );
433426 auto thr_prefetch_A = tiled_prefetch_a.get_slice (thread_idx);
434427 auto thr_prefetch_B = tiled_prefetch_b.get_slice (thread_idx);
435428
@@ -460,37 +453,39 @@ class kgemm_4bit_inference_cutlass_dequant {
460453 prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k));
461454 }
462455
463- const int k_reload_factor = params.group_size / BLK_K ;
456+ const int k_reload_factor = params.group_size / BLK_K / 2 ;
464457 if (cute::thread0 ()) printf (" k_reload_factor = %d\n " , k_reload_factor);
465458
466459 CUTLASS_PRAGMA_UNROLL
467460 for (int k_tile = 0 , k = k_start_idx; k_tile < k_tile_count; ++k_tile, ++k, ++prefetch_k) {
468461 // Copy gmem to rmem for the first k_tile
469462 copy (tiled_copy_a, tAgA (_,_,_,k), frag_copy_A);
470- copy (tiled_copy_b , tBgB (_,_,_,k), frag_copy_B);
463+ copy (tiled_copy_b_4bit , tBgB (_,_,_,k), frag_copy_B);
471464
472465 copy (tiled_copy_scale, copy_iter_s (_, _, _, k_start_idx + (k_tile / k_reload_factor)), copy_tCrS);
473- dequant (quant_frag, mma_B_expanded, fragment_scale_input, quant_map);
466+ // dequant(quant_frag, mma_B_expanded, fragment_scale_input, quant_map);
467+ dequant (quant_frag, mma_B, fragment_scale_input, quant_map);
474468
475469 if (prefetch_k < k_tile_count) {
476470 prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
477471 prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k));
478472 }
479473
480- cute::gemm (tiled_mma, mma_A, mma_B_expanded , accumulators);
474+ cute::gemm (tiled_mma, mma_A, mma_B , accumulators);
481475 }
476+
477+ SharedStorage& shared_storage = *reinterpret_cast <SharedStorage*>((char *)nullptr );
482478 CollectiveEpilogue epilogue{params.epilogue , shared_storage.epilogue };
483- auto expanded_problem_size = ProblemShape{M, 2 * N, K, 1 };
484- auto problem_shape_MNKL = append<4 >(expanded_problem_size , 1 );
479+ // auto expanded_problem_size = ProblemShape{M, 2 * N, K, 1};
480+ auto problem_shape_MNKL = append<4 >(problem_size , 1 );
485481 epilogue (
486482 problem_shape_MNKL,
487- subgroup_shape, // TODO(codeplay): Inconsistency here w/ blk_coord_mnkl
483+ subgroup_tile_shape,
488484 blk_coord_mnkl,
489485 accumulators,
490486 tiled_mma,
491487 thread_idx
492488 );
493- #endif
494489 }
495490};
496491
@@ -532,9 +527,14 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
532527 StrideB stride_B = cutlass::make_cute_packed_stride (StrideB{}, cute::make_shape (n, k, l));
533528 auto mB_nkl = make_tensor (make_gmem_ptr (B), make_layout (make_shape (n, k, l), stride_B));
534529 Copy_B tiled_copy_b{Copy_B{}.with (mB_nkl )};
530+
531+ StrideB stride_B_4bit = cutlass::make_cute_packed_stride (StrideB{}, cute::make_shape (n, k/2 , l));
532+ auto mB_nkl_4bit = make_tensor (make_gmem_ptr (B), make_layout (make_shape (n, k/2 , l), stride_B));
533+ Copy_B tiled_copy_b_4bit{Copy_B{}.with (mB_nkl_4bit )};
535534
536535 params.tiled_copy_a = tiled_copy_a;
537536 params.tiled_copy_b = tiled_copy_b;
537+ params.tiled_copy_b_4bit = tiled_copy_b_4bit;
538538
539539 const int scale_k = cute::ceil_div (k, blocksize);
540540 const int dq_mn_size = n;
0 commit comments