Skip to content

Commit 998f482

Browse files
committed
clean code
1 parent 7fce59f commit 998f482

1 file changed

Lines changed: 6 additions & 61 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 6 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,6 @@ static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); //8*4*1*16=512
119119
using DispatchPolicy = cutlass::gemm::MainloopIntelPVCMixedPrecision<PipelineStages>;
120120
static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; // 16
121121

122-
#if 0
123-
// Design Epilogue
124-
using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue;
125-
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementAccumulator, ElementComputeEpilogue, ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
126-
using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape, decltype(tile_shape(TiledMma()))>;
127-
using SharedStorage = FusionCallBacks::SharedStorage;
128-
#endif
129-
130122
// Design Scheduler
131123
using TileScheduler_ = PersistentScheduler;
132124
static_assert(cute::is_void_v<TileScheduler_> or cute::is_same_v<TileScheduler_, PersistentScheduler>, "Intel PVC does not support specializing the tile scheduler.");
@@ -135,23 +127,6 @@ using TileScheduler = typename cutlass::gemm::kernel::detail::TileSchedulerSelec
135127
using TileSchedulerArguments = typename TileScheduler::Arguments;
136128
using TileSchedulerParams = typename TileScheduler::Params;
137129

138-
#if 0
139-
// Define Epilogue
140-
using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
141-
EpilogueDispatchPolicy,
142-
TileShape,
143-
ElementAccumulator,
144-
cutlass::gemm::TagToStrideC_t<cutlass::layout::RowMajor>, // Convert CUTLASS 2.x to CUTLASS 3.x representation
145-
ElementOutput,
146-
cutlass::gemm::TagToStrideC_t<cutlass::layout::RowMajor>, // Convert CUTLASS 2.x to CUTLASS 3.x representation
147-
FusionCallBacks,
148-
XE_2D_U32x4x16_LD_N, // The copy atom used to load matrix C
149-
void, void,
150-
XE_2D_U32x4x16_ST_N, // The copy atom used to store matrix D
151-
void, void>;
152-
using EpilogueParams = typename CollectiveEpilogue::Params;
153-
#endif
154-
155130
using ClusterShape = typename DispatchPolicy::ClusterShape;
156131

157132
// Define Copy
@@ -196,12 +171,6 @@ template <typename T, int BITS>
196171
class gemm_4bit_cutlass_kernel {
197172
public:
198173
// Kernel level shared memory storage
199-
#if 0
200-
struct SharedStorage {
201-
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
202-
EpilogueTensorStorage epilogue;
203-
};
204-
#endif
205174
struct Params {
206175
int m, n, k, l;
207176
T* A;
@@ -413,10 +382,12 @@ class gemm_4bit_cutlass_kernel {
413382
auto blk_shape = TileShape{};
414383
int m_coord, n_coord, l_coord;
415384
if (params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN) {
385+
if(cute::thread0()) printf("log1 ....\n");
416386
m_coord = BlockIdxY();
417387
n_coord = BlockIdxX();
418388
l_coord = BlockIdxZ();
419389
} else {
390+
if(cute::thread0()) printf("log2 ....\n");
420391
m_coord = BlockIdxX();
421392
n_coord = BlockIdxY();
422393
l_coord = BlockIdxZ();
@@ -570,20 +541,6 @@ class gemm_4bit_cutlass_kernel {
570541
barrier_wait(3);
571542
}
572543

573-
#if 0
574-
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>((char*)nullptr);
575-
CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue};
576-
auto problem_shape_MNKL = append<4>(problem_size, 1);
577-
epilogue(
578-
problem_shape_MNKL,
579-
subgroup_tile_shape,
580-
blk_coord_mnkl,
581-
accumulators,
582-
tiled_mma,
583-
thread_idx
584-
);
585-
#else
586-
587544
static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // A frags per sub_group
588545
static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group
589546

@@ -611,7 +568,6 @@ class gemm_4bit_cutlass_kernel {
611568
copy(params.xe_store_d, accumulators(_, epi_m, epi_n), tCgD(_, epi_m, epi_n));
612569
}
613570
}
614-
#endif
615571
}
616572
};
617573

@@ -662,19 +618,17 @@ void gemm_4bit_cutlass(int m, int n, int k, int l, T *A, unsigned char *B,
662618
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
663619
auto problem_shape_MNKL = problem_size;
664620

665-
#if 0
666-
float alpha=1.0f;
667-
float beta=0.f;
668-
StrideC stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, l));
669-
#endif
670621
StrideD stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, l));
622+
XE_Copy_D xe_store_d = {};
623+
auto mD = make_tensor(make_gmem_ptr(out), make_layout(make_shape(m, n, l), stride_D));
624+
xe_store_d = {xe_store_d.with(mD)};
625+
params.xe_store_d = xe_store_d;
671626

672627
#if 0
673628
if (cutlass::thread(LOG_THREAD, LOG_GROUP)) {
674629
print("===================== stride :\n");
675630
print(" stride_A : "); print(stride_A); print("\n");
676631
print(" stride_B : "); print(stride_B); print("\n");
677-
print(" stride_C : "); print(stride_C); print("\n");
678632
print(" stride_D : "); print(stride_D); print("\n");
679633
print(" stride_S : "); print(stride_S); print("\n");
680634
print("===================== mScale :\n");
@@ -690,15 +644,6 @@ void gemm_4bit_cutlass(int m, int n, int k, int l, T *A, unsigned char *B,
690644

691645
params.hw_info = hw_info;
692646

693-
#if 0
694-
params.epilogue = CollectiveEpilogue::to_underlying_arguments(problem_size, {{alpha, beta}, nullptr, stride_C, out, stride_D}, nullptr);
695-
#else
696-
XE_Copy_D xe_store_d = {};
697-
auto mD = make_tensor(make_gmem_ptr(out), make_layout(make_shape(m, n, l), stride_D));
698-
xe_store_d = {xe_store_d.with(mD)};
699-
params.xe_store_d = xe_store_d;
700-
#endif
701-
702647
TileSchedulerArguments scheduler{};
703648
params.scheduler = TileScheduler::to_underlying_arguments(
704649
problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, scheduler, nullptr);

0 commit comments

Comments
 (0)