@@ -85,6 +85,7 @@ constexpr int PipelineStages = 2;
8585 constexpr int PipelineStages = 4 ;
8686#endif
8787
88+ using MmaAtomShape = typename TiledMma::AtomShape_MNK;
8889using WorkgroupTileShape = TileShape;
8990static constexpr auto BLK_M = get<0 >(WorkgroupTileShape{}); // 256 //16
9091static constexpr auto BLK_N = get<1 >(WorkgroupTileShape{}); // 256 //64
@@ -118,11 +119,13 @@ static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); //8*4*1*16=512
118119using DispatchPolicy = cutlass::gemm::MainloopIntelPVCMixedPrecision<PipelineStages>;
119120static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; // 16
120121
122+ #if 0
121123// Design Epilogue
122124using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue;
123125using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementAccumulator, ElementComputeEpilogue, ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
124126using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape, decltype(tile_shape(TiledMma()))>;
125127using SharedStorage = FusionCallBacks::SharedStorage;
128+ #endif
126129
127130// Design Scheduler
128131using TileScheduler_ = PersistentScheduler;
@@ -132,6 +135,7 @@ using TileScheduler = typename cutlass::gemm::kernel::detail::TileSchedulerSelec
132135using TileSchedulerArguments = typename TileScheduler::Arguments;
133136using TileSchedulerParams = typename TileScheduler::Params;
134137
138+ #if 0
135139// Define Epilogue
136140using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
137141 EpilogueDispatchPolicy,
@@ -146,6 +150,7 @@ using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
146150 XE_2D_U32x4x16_ST_N, // The copy atom used to store matrix D
147151 void, void>;
148152using EpilogueParams = typename CollectiveEpilogue::Params;
153+ #endif
149154
150155using ClusterShape = typename DispatchPolicy::ClusterShape;
151156
@@ -179,18 +184,24 @@ using atom_load_scale = Copy_Atom<traits_load_scale, ElementScale>;
179184using val_layout_load_scale = decltype (make_layout(shape_div(typename traits_load_scale::BlockShape{}, CopyThreadShapeRev{})));
180185using Copy_Scale = decltype (make_tiled_copy(atom_load_scale{}, Layout<CopyThreadShapeRev>{}, val_layout_load_scale{})); // group-wise scale
181186
182- using StrideC = cutlass::gemm::TagToStrideC_t<cutlass::layout::RowMajor>;
187+ // using StrideC = cutlass::gemm::TagToStrideC_t<cutlass::layout::RowMajor>;
183188using StrideD = cutlass::gemm::TagToStrideC_t<cutlass::layout::RowMajor>;
184189
190+ using GmemTiledCopyD = XE_2D_U32x4x16_ST_N;
191+ using Trait_D = Copy_Traits<GmemTiledCopyD, StrideD>;
192+ using val_layout_store_D = decltype (make_layout(shape_div(typename Trait_D::BlockShape{}, CopyThreadShape{})));
193+ using XE_Copy_D = decltype (make_tiled_copy(Copy_Atom<Trait_D, ElementOutput>{}, Layout<CopyThreadShape>{}, val_layout_store_D{}));
194+
185195template <typename T, int BITS >
186196class gemm_4bit_cutlass_kernel {
187197public:
188198 // Kernel level shared memory storage
199+ #if 0
189200 struct SharedStorage {
190201 using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
191202 EpilogueTensorStorage epilogue;
192203 };
193-
204+ # endif
194205 struct Params {
195206 int m, n, k, l;
196207 T* A;
@@ -206,7 +217,8 @@ class gemm_4bit_cutlass_kernel {
206217 Copy_B tiled_copy_b;
207218 Copy_Scale tiled_copy_scale;
208219
209- EpilogueParams epilogue{};
220+ // EpilogueParams epilogue{};
221+ XE_Copy_D xe_store_d;
210222 KernelHardwareInfo hw_info{};
211223 TileSchedulerParams scheduler{};
212224 };
@@ -362,7 +374,7 @@ class gemm_4bit_cutlass_kernel {
362374
363375 static_assert (cute::rank (StrideA{}) == 3 , " StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>." );
364376 static_assert (cute::rank (StrideB{}) == 3 , " StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>." );
365- static_assert (cute::rank (StrideC{}) == 3 , " StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>." );
377+ // static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
366378 static_assert (cute::rank (StrideD{}) == 3 , " StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>." );
367379
368380 int thread_idx = int (ThreadIdxX ());
@@ -557,7 +569,8 @@ class gemm_4bit_cutlass_kernel {
557569 cute::gemm (tiled_mma, mma_A, mma_B, accumulators);
558570 barrier_wait (3 );
559571 }
560-
572+
573+ #if 0
561574 SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>((char*)nullptr);
562575 CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue};
563576 auto problem_shape_MNKL = append<4>(problem_size, 1);
@@ -569,6 +582,36 @@ class gemm_4bit_cutlass_kernel {
569582 tiled_mma,
570583 thread_idx
571584 );
585+ #else
586+
587+ static constexpr int FragsM = get<0 >(SubgroupTileShape{}) / get<0 >(MmaAtomShape ()); // A frags per sub_group
588+ static constexpr int FragsN = get<1 >(SubgroupTileShape{}) / get<1 >(MmaAtomShape ()); // B frags per sub_group
589+
590+ static constexpr int FragmentSize = (get<0 >(MmaAtomShape ()) * get<1 >(MmaAtomShape ())) / SubgroupSize;
591+
592+ auto m_sg = get_sub_group_id () / ATOM_N ;
593+ auto n_sg = get_sub_group_id () % ATOM_N ;
594+
595+ // Represent the full output tensor
596+ Tensor mD_mnl = cute::get_pvc_tensor (make_shape (M,N,L));
597+
598+ // Tile the output tensor per WG and select the tile for current WG
599+ Tensor g_wg_D = local_tile (mD_mnl , take<0 ,2 >(WorkgroupTileShape{}), make_coord (m_coord,n_coord,l_coord)); // (BLK_M,BLK_N)
600+
601+ // Tile the output tensor per SG and select tile for the current SG
602+ Tensor gD = local_tile (g_wg_D, take<0 ,2 >(SubgroupTileShape{}), make_coord (m_sg,n_sg)); // (SG_M,SG_N)
603+
604+ auto thread_xe_store_d = params.xe_store_d .get_thread_slice (thread_idx);
605+ Tensor tCgD = thread_xe_store_d.partition_D (gD );
606+
607+ // CUTLASS_PRAGMA_UNROLL
608+ for (int epi_n = 0 ; epi_n < FragsN; ++epi_n) {
609+ // CUTLASS_PRAGMA_UNROLL
610+ for (int epi_m = 0 ; epi_m < FragsM; ++epi_m) {
611+ copy (params.xe_store_d , accumulators (_, epi_m, epi_n), tCgD (_, epi_m, epi_n));
612+ }
613+ }
614+ #endif
572615 }
573616};
574617
@@ -618,9 +661,12 @@ void gemm_4bit_cutlass(int m, int n, int k, int l, T *A, unsigned char *B,
618661 cutlass::KernelHardwareInfo hw_info;
619662 hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count (hw_info.device_id );
620663 auto problem_shape_MNKL = problem_size;
664+
665+ #if 0
621666 float alpha=1.0f;
622667 float beta=0.f;
623668 StrideC stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, l));
669+ #endif
624670 StrideD stride_D = cutlass::make_cute_packed_stride (StrideD{}, cute::make_shape (m, n, l));
625671
626672#if 0
@@ -643,7 +689,15 @@ void gemm_4bit_cutlass(int m, int n, int k, int l, T *A, unsigned char *B,
643689#endif
644690
645691 params.hw_info = hw_info;
692+
693+ #if 0
646694 params.epilogue = CollectiveEpilogue::to_underlying_arguments(problem_size, {{alpha, beta}, nullptr, stride_C, out, stride_D}, nullptr);
695+ #else
696+ XE_Copy_D xe_store_d = {};
697+ auto mD = make_tensor (make_gmem_ptr (out), make_layout (make_shape (m, n, l), stride_D));
698+ xe_store_d = {xe_store_d.with (mD )};
699+ params.xe_store_d = xe_store_d;
700+ #endif
647701
648702 TileSchedulerArguments scheduler{};
649703 params.scheduler = TileScheduler::to_underlying_arguments (
0 commit comments