Skip to content

Commit 7fce59f

Browse files
committed
delete epilogue
1 parent 8baa41c commit 7fce59f

2 files changed

Lines changed: 129 additions & 31 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ constexpr int PipelineStages = 2;
8585
constexpr int PipelineStages = 4;
8686
#endif
8787

88+
using MmaAtomShape = typename TiledMma::AtomShape_MNK;
8889
using WorkgroupTileShape = TileShape;
8990
static constexpr auto BLK_M = get<0>(WorkgroupTileShape{}); //256 //16
9091
static constexpr auto BLK_N = get<1>(WorkgroupTileShape{}); //256 //64
@@ -118,11 +119,13 @@ static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); //8*4*1*16=512
118119
using DispatchPolicy = cutlass::gemm::MainloopIntelPVCMixedPrecision<PipelineStages>;
119120
static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; // 16
120121

122+
#if 0
121123
// Design Epilogue
122124
using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue;
123125
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementAccumulator, ElementComputeEpilogue, ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
124126
using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape, decltype(tile_shape(TiledMma()))>;
125127
using SharedStorage = FusionCallBacks::SharedStorage;
128+
#endif
126129

127130
// Design Scheduler
128131
using TileScheduler_ = PersistentScheduler;
@@ -132,6 +135,7 @@ using TileScheduler = typename cutlass::gemm::kernel::detail::TileSchedulerSelec
132135
using TileSchedulerArguments = typename TileScheduler::Arguments;
133136
using TileSchedulerParams = typename TileScheduler::Params;
134137

138+
#if 0
135139
// Define Epilogue
136140
using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
137141
EpilogueDispatchPolicy,
@@ -146,6 +150,7 @@ using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
146150
XE_2D_U32x4x16_ST_N, // The copy atom used to store matrix D
147151
void, void>;
148152
using EpilogueParams = typename CollectiveEpilogue::Params;
153+
#endif
149154

150155
using ClusterShape = typename DispatchPolicy::ClusterShape;
151156

@@ -179,18 +184,24 @@ using atom_load_scale = Copy_Atom<traits_load_scale, ElementScale>;
179184
using val_layout_load_scale = decltype(make_layout(shape_div(typename traits_load_scale::BlockShape{}, CopyThreadShapeRev{})));
180185
using Copy_Scale = decltype(make_tiled_copy(atom_load_scale{}, Layout<CopyThreadShapeRev>{}, val_layout_load_scale{})); //group-wise scale
181186

182-
using StrideC = cutlass::gemm::TagToStrideC_t<cutlass::layout::RowMajor>;
187+
//using StrideC = cutlass::gemm::TagToStrideC_t<cutlass::layout::RowMajor>;
183188
using StrideD = cutlass::gemm::TagToStrideC_t<cutlass::layout::RowMajor>;
184189

190+
using GmemTiledCopyD = XE_2D_U32x4x16_ST_N;
191+
using Trait_D = Copy_Traits<GmemTiledCopyD, StrideD>;
192+
using val_layout_store_D = decltype(make_layout(shape_div(typename Trait_D::BlockShape{}, CopyThreadShape{})));
193+
using XE_Copy_D = decltype(make_tiled_copy(Copy_Atom<Trait_D, ElementOutput>{}, Layout<CopyThreadShape>{}, val_layout_store_D{}));
194+
185195
template <typename T, int BITS>
186196
class gemm_4bit_cutlass_kernel {
187197
public:
188198
// Kernel level shared memory storage
199+
#if 0
189200
struct SharedStorage {
190201
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
191202
EpilogueTensorStorage epilogue;
192203
};
193-
204+
#endif
194205
struct Params {
195206
int m, n, k, l;
196207
T* A;
@@ -206,7 +217,8 @@ class gemm_4bit_cutlass_kernel {
206217
Copy_B tiled_copy_b;
207218
Copy_Scale tiled_copy_scale;
208219

209-
EpilogueParams epilogue{};
220+
//EpilogueParams epilogue{};
221+
XE_Copy_D xe_store_d;
210222
KernelHardwareInfo hw_info{};
211223
TileSchedulerParams scheduler{};
212224
};
@@ -362,7 +374,7 @@ class gemm_4bit_cutlass_kernel {
362374

363375
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
364376
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
365-
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
377+
//static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
366378
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
367379

368380
int thread_idx = int(ThreadIdxX());
@@ -557,7 +569,8 @@ class gemm_4bit_cutlass_kernel {
557569
cute::gemm(tiled_mma, mma_A, mma_B, accumulators);
558570
barrier_wait(3);
559571
}
560-
572+
573+
#if 0
561574
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>((char*)nullptr);
562575
CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue};
563576
auto problem_shape_MNKL = append<4>(problem_size, 1);
@@ -569,6 +582,36 @@ class gemm_4bit_cutlass_kernel {
569582
tiled_mma,
570583
thread_idx
571584
);
585+
#else
586+
587+
static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // A frags per sub_group
588+
static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group
589+
590+
static constexpr int FragmentSize = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize;
591+
592+
auto m_sg = get_sub_group_id() / ATOM_N;
593+
auto n_sg = get_sub_group_id() % ATOM_N;
594+
595+
// Represent the full output tensor
596+
Tensor mD_mnl = cute::get_pvc_tensor(make_shape(M,N,L));
597+
598+
// Tile the output tensor per WG and select the tile for current WG
599+
Tensor g_wg_D = local_tile(mD_mnl, take<0,2>(WorkgroupTileShape{}), make_coord(m_coord,n_coord,l_coord)); // (BLK_M,BLK_N)
600+
601+
// Tile the output tensor per SG and select tile for the current SG
602+
Tensor gD = local_tile(g_wg_D, take<0,2>(SubgroupTileShape{}), make_coord(m_sg,n_sg)); // (SG_M,SG_N)
603+
604+
auto thread_xe_store_d = params.xe_store_d.get_thread_slice(thread_idx);
605+
Tensor tCgD = thread_xe_store_d.partition_D(gD);
606+
607+
//CUTLASS_PRAGMA_UNROLL
608+
for (int epi_n = 0; epi_n < FragsN; ++epi_n) {
609+
//CUTLASS_PRAGMA_UNROLL
610+
for (int epi_m = 0; epi_m < FragsM; ++epi_m) {
611+
copy(params.xe_store_d, accumulators(_, epi_m, epi_n), tCgD(_, epi_m, epi_n));
612+
}
613+
}
614+
#endif
572615
}
573616
};
574617

@@ -618,9 +661,12 @@ void gemm_4bit_cutlass(int m, int n, int k, int l, T *A, unsigned char *B,
618661
cutlass::KernelHardwareInfo hw_info;
619662
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
620663
auto problem_shape_MNKL = problem_size;
664+
665+
#if 0
621666
float alpha=1.0f;
622667
float beta=0.f;
623668
StrideC stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, l));
669+
#endif
624670
StrideD stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, l));
625671

626672
#if 0
@@ -643,7 +689,15 @@ void gemm_4bit_cutlass(int m, int n, int k, int l, T *A, unsigned char *B,
643689
#endif
644690

645691
params.hw_info = hw_info;
692+
693+
#if 0
646694
params.epilogue = CollectiveEpilogue::to_underlying_arguments(problem_size, {{alpha, beta}, nullptr, stride_C, out, stride_D}, nullptr);
695+
#else
696+
XE_Copy_D xe_store_d = {};
697+
auto mD = make_tensor(make_gmem_ptr(out), make_layout(make_shape(m, n, l), stride_D));
698+
xe_store_d = {xe_store_d.with(mD)};
699+
params.xe_store_d = xe_store_d;
700+
#endif
647701

648702
TileSchedulerArguments scheduler{};
649703
params.scheduler = TileScheduler::to_underlying_arguments(

include/cutlass/epilogue/collective/xe_epilogue.hpp

Lines changed: 70 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -346,8 +346,8 @@ class CollectiveEpilogue<
346346
thread_idx,
347347
};
348348
auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks<RefSrc>(cst_args);
349-
350-
cst_callbacks.begin();
349+
#if 1
350+
//cst_callbacks.begin();
351351

352352
auto acc_frag = recast<Array<ElementOutput, FragmentSize>>(accumulators);
353353
auto trD_frag = recast<Array<ElementOutput, FragmentSize>>(trD);
@@ -356,38 +356,82 @@ class CollectiveEpilogue<
356356
FragsM * FragsN * FragmentSize * SubgroupSize * ATOM_M * ATOM_N * ATOM_K;
357357
constexpr int MN = get<0>(CtaTileMNK{}) * get<1>(CtaTileMNK{});
358358
static_assert(ValuesLoaded == MN, "the total elements loaded by all threads should be the same as MxN" );
359+
360+
//copy(params.xe_store_d, accumulators, tCgD(_, FragsM, FragsN));
361+
CUTLASS_PRAGMA_UNROLL
362+
for (int epi_n = 0; epi_n < FragsN; ++epi_n) {
363+
CUTLASS_PRAGMA_UNROLL
364+
for (int epi_m = 0; epi_m < FragsM; ++epi_m) {
365+
// 拷贝当前分块到目标位置
366+
copy(params.xe_store_d,
367+
accumulators(_, epi_m, epi_n), // 源分块
368+
tCgD(_, epi_m, epi_n)); // 目标分块
369+
}
370+
}
371+
//
372+
// auto synchronize = [&] () {};
373+
// CUTLASS_PRAGMA_UNROLL
374+
// for (int epi_n = 0; epi_n < FragsN; epi_n++) {
375+
// CUTLASS_PRAGMA_UNROLL
376+
// for (int epi_m = 0; epi_m < FragsM; epi_m++) {
377+
//#if 1
378+
// //cst_callbacks.begin_loop(epi_m, epi_n);
379+
//
380+
//// if (is_C_load_needed) {
381+
//// //cordinates for C and D are the same
382+
//// copy(params.xe_load_c, tCgD(_, epi_m, epi_n), trC);
383+
//// }
384+
//
385+
// //cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed);
386+
//
387+
// auto acc_frag_mn = acc_frag(_, epi_m, epi_n);
388+
//
389+
// CUTLASS_PRAGMA_UNROLL
390+
// for (int epi_v = 0; epi_v < size<0>(trD_frag); ++epi_v) {
391+
// trD_frag(epi_v) = acc_frag_mn(epi_v); //cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n);
392+
// }
393+
// //cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, (epi_m == FragsM - 1 && epi_n == FragsN - 1), trD_frag);
394+
//#endif
395+
// if constexpr (is_destination_supported) {
396+
// copy(params.xe_store_d, trD, tCgD(_, epi_m, epi_n));
397+
// }
398+
//
399+
// //cst_callbacks.end_loop(epi_m, epi_n);
400+
// }
401+
// }
402+
403+
//cst_callbacks.end();
404+
#else
405+
using OutFragment = Array<float, FragmentSize>; // 根据实际类型调整
406+
407+
// 2. 移除所有回调相关逻辑,直接处理累加器
408+
auto acc_frag = recast<OutFragment>(accumulators);
409+
auto trD_frag = make_fragment_like<OutFragment>();
359410
360-
auto synchronize = [&] () {};
411+
// 3. 简化主循环(保留必要的分块逻辑)
361412
CUTLASS_PRAGMA_UNROLL
362-
for (int epi_n = 0; epi_n < FragsN; epi_n++) {
413+
for (int epi_n = 0; epi_n < FragsN; ++epi_n) {
363414
CUTLASS_PRAGMA_UNROLL
364-
for (int epi_m = 0; epi_m < FragsM; epi_m++) {
365-
cst_callbacks.begin_loop(epi_m, epi_n);
366-
367-
if (is_C_load_needed) {
368-
//cordinates for C and D are the same
369-
copy(params.xe_load_c, tCgD(_, epi_m, epi_n), trC);
370-
}
371-
372-
cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed);
373-
374-
auto acc_frag_mn = acc_frag(_, epi_m, epi_n);
375-
415+
for (int epi_m = 0; epi_m < FragsM; ++epi_m) {
416+
417+
// 直接规约操作(示例:求和)
418+
float reduce_sum = 0;
376419
CUTLASS_PRAGMA_UNROLL
377-
for (int epi_v = 0; epi_v < size<0>(trD_frag); ++epi_v) {
378-
trD_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n);
420+
for (int i = 0; i < FragmentSize; ++i) {
421+
reduce_sum += acc_frag(_, epi_m, epi_n)[i];
379422
}
380-
cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, (epi_m == FragsM - 1 && epi_n == FragsN - 1), trD_frag);
381-
382-
if constexpr (is_destination_supported) {
383-
copy(params.xe_store_d, trD, tCgD(_, epi_m, epi_n));
384-
}
385-
386-
cst_callbacks.end_loop(epi_m, epi_n);
423+
424+
// 存储结果(根据实际需求调整)
425+
trD_frag.fill(reduce_sum); // 或直接写入特定位置
426+
427+
// 直接存储到全局内存(跳过临时寄存器)
428+
copy(params.destination_ptr,
429+
trD_frag,
430+
tCgD(_, epi_m, epi_n)); // 需适配实际坐标计算
387431
}
388432
}
433+
#endif
389434

390-
cst_callbacks.end();
391435
}
392436

393437
private:

0 commit comments

Comments
 (0)