Skip to content

Commit c3edbbc

Browse files
committed
clean code
1 parent 0276ced commit c3edbbc

1 file changed

Lines changed: 5 additions & 229 deletions

File tree

csrc/xpu_cutlass-cute.cpp

Lines changed: 5 additions & 229 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ using namespace cute;
3434
using namespace cutlass;
3535
using namespace cutlass::gemm;
3636

37-
#if 1
3837
using ElementA = bfloat16_t;
3938
using ElementB = bfloat16_t;
4039
using ElementC = float;
@@ -65,14 +64,9 @@ using TiledMma =
6564
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
6665
static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{});
6766
using DispatchPolicy = MainloopIntelPVC<Stages, KernelPVC /*Schedule*/>;
68-
#endif
69-
//using TiledMma =
70-
// typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
71-
// Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
72-
//static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{});
73-
//using TileShape = Shape<_256, _256, _32>;
74-
//using ProblemShape = Shape<int, int, int, int>;
75-
#if 1
67+
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<float /*data_type of GEMM output*/, ElementComputeEpilogue, ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
68+
using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::IntelPVCEpilogue, EpilogueOp, TileShape, decltype(tile_shape(TiledMma()))>;
69+
7670
static dim3
7771
get_block_shape() {
7872
return dim3(MaxThreadsPerBlock, 1, 1);
@@ -91,122 +85,10 @@ get_tiled_cta_shape_mnl(ProblemShape problem_shape) {
9185
};
9286
}
9387

94-
//struct Arguments {
95-
// GemmUniversalMode mode{};
96-
// ProblemShape problem_shape{};
97-
// //MainloopArguments mainloop{};
98-
// ElementA const* ptr_A;
99-
// StrideA dA;
100-
// ElementB const* ptr_B;
101-
// StrideB dB;
102-
//
103-
// //EpilogueArguments epilogue{};
104-
// //typename FusionCallbacks::Arguments thread{};
105-
// ElementC const* ptr_C;
106-
// StrideC dC;
107-
// ElementD* ptr_D;
108-
// StrideD dD;
109-
//
110-
// cutlass::KernelHardwareInfo hw_info{};
111-
// //TileSchedulerArguments scheduler{};
112-
//};
113-
114-
//static size_t get_workspace_size(Arguments const& args) {
115-
// size_t workspace_bytes = 0;
116-
// if (args.mode == GemmUniversalMode::kGemmSplitKParallel) {
117-
// workspace_bytes += sizeof(int) * size_t(cute::size<0>(TileShape{})) * size_t(cute::size<1>(TileShape{}));
118-
// }
119-
//
120-
// //TODO: Check it!!
121-
// workspace_bytes += 0; //GemmKernel::get_workspace_size(args);
122-
//
123-
// CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
124-
//
125-
// return workspace_bytes;
126-
//}
127-
#endif
128-
12988
template <typename T, size_t GROUP_SIZE, size_t NUM_PER_THREAD,
13089
size_t SUBG_SIZE, int BITS>
13190
class kgemv_4bit_inference_cutlass {
13291
public:
133-
// kgemv_4bit_inference_cutlass(int M_, int N_, int K_, T *A_, T *B_,
134-
// float *absmax_, const float *datatype_, float *out_,
135-
// int lda_, int ldb_, int ldc_, int blocksize_)
136-
// : M(M_), N(N_), K(K_), A(A_), B(B_),
137-
// absmax(absmax_), out(out_), datatype(datatype_),
138-
// lda(lda_), ldb(ldb_), ldc(ldc_), blocksize(blocksize_) {}
139-
140-
//SYCL_EXTERNAL
141-
//void operator()(sycl::nd_item<1> item) const {
142-
143-
// Specific setting
144-
#if 1
145-
using ElementA = bfloat16_t;
146-
using ElementB = bfloat16_t;
147-
using ElementC = float;
148-
using ElementD = float;
149-
using ElementAccumulator = float; // data_type of accumulator
150-
using ElementComputeEpilogue = float; // data_type of epilogue operations
151-
using ElementOutput = float;
152-
static constexpr int Stages = 2;
153-
154-
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
155-
using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;
156-
using CopyOpG2R = XE_2D_U32x8x16_LD_N;
157-
using CopyOpR2G = XE_2D_U32x8x16_ST_N;
158-
using GmemTiledCopyC = CopyOpG2R;
159-
using GmemTiledCopyD = cute::conditional_t<not cute::is_void_v<ElementD> && not cute::is_void_v<CopyOpR2G>,
160-
CopyOpR2G, XE_2D_U32x8x16_ST_N>;
161-
162-
using StrideA = cutlass::gemm::TagToStrideA_t<cutlass::layout::RowMajor>;
163-
using StrideB = cutlass::gemm::TagToStrideB_t<cutlass::layout::RowMajor>;
164-
using StrideC = cutlass::gemm::TagToStrideC_t<cutlass::layout::RowMajor>;
165-
using StrideD = cutlass::gemm::TagToStrideC_t<cutlass::layout::RowMajor>;
166-
using ProblemShape = Shape<int, int, int, int>;
167-
168-
// int L = 1;
169-
// StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
170-
// StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
171-
// StrideC stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
172-
//cutlass::KernelHardwareInfo hw_info;
173-
//hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
174-
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<float /*data_type of GEMM output*/, ElementComputeEpilogue, ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
175-
using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::IntelPVCEpilogue, EpilogueOp, TileShape, decltype(tile_shape(TiledMma()))>;
176-
177-
using TileShape = Shape<_256, _256, _32>;
178-
using WorkgroupTileShape = TileShape;
179-
using TiledMma =
180-
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
181-
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
182-
static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{});
183-
using DispatchPolicy = MainloopIntelPVC<Stages, KernelPVC /*Schedule*/>;
184-
#endif
185-
#if 1
186-
//using ClusterShape = cutlass::gemm::GemmShape<
187-
// cute::size<0>(typename DispatchPolicy::ClusterShape{}),
188-
// cute::size<1>(typename DispatchPolicy::ClusterShape{}),
189-
// cute::size<2>(typename DispatchPolicy::ClusterShape{})>;
190-
191-
//new Functions
192-
//static dim3
193-
// get_block_shape() {
194-
// return dim3(MaxThreadsPerBlock, 1, 1);
195-
//}
196-
//
197-
//static dim3
198-
//get_tiled_cta_shape_mnl(ProblemShape problem_shape) {
199-
// using cta_shape = TileShape;
200-
// auto cta_m = (get<0>(problem_shape) + get<0>(cta_shape{}) - 1) / get<0>(cta_shape{});
201-
// auto cta_n = (get<1>(problem_shape) + get<1>(cta_shape{}) - 1) / get<1>(cta_shape{});
202-
//
203-
// return {
204-
// static_cast<uint32_t>(cta_m),
205-
// static_cast<uint32_t>(cta_n),
206-
// static_cast<uint32_t>(get<3>(problem_shape))
207-
// };
208-
//}
209-
//SYCL_EXTERNAL
21092
struct Arguments {
21193
GemmUniversalMode mode{};
21294
ProblemShape problem_shape{};
@@ -247,11 +129,6 @@ static size_t get_workspace_size(Arguments const& args) {
247129
// CudaHostAdapter* cuda_adapter = nullptr) {
248130
// return Status::kSuccess;
249131
//}
250-
#endif
251-
//template <typename T, size_t GROUP_SIZE, size_t NUM_PER_THREAD,
252-
// size_t SUBG_SIZE, int BITS>
253-
//class kgemv_4bit_inference_cutlass {
254-
//public:
255132
#if 1
256133
kgemv_4bit_inference_cutlass(int M_, int N_, int K_, T *A_, T *B_,
257134
float *absmax_, const float *datatype_, float *out_,
@@ -276,11 +153,9 @@ static size_t get_workspace_size(Arguments const& args) {
276153
sycl::local_accessor<T> quant_map;
277154
int SharedStorageSize = 0;
278155
public:
279-
// SYCL_EXTERNAL
280156
CUTLASS_DEVICE
281157
void operator()(sycl::nd_item<1> item) const {
282158

283-
//std::cout<<"this is kgemv_4bit_inference cutlass fusion path !!!\n";
284159
#else
285160
CUTLASS_DEVICE
286161
void operator()(int M, int N, int K, T *A, T *B,
@@ -693,91 +568,7 @@ CUTLASS_DEVICE
693568

694569
cst_callbacks.end();
695570
}
696-
697-
//private:
698-
// int M;
699-
// int N;
700-
// int K;
701-
// T *A;
702-
// T *B;
703-
// float *absmax;
704-
// const float *datatype;
705-
// float *out;
706-
// int lda;
707-
// int ldb;
708-
// int ldc;
709-
// int blocksize;
710-
// sycl::local_accessor<T> quant_map;
711-
// int SharedStorageSize = 0;
712571
};
713-
#if 0
714-
//TODO: replace with private kernel submit ???
715-
//Launch Kernel
716-
dim3 const block = GemmKernel::get_block_shape();
717-
dim3 const grid = get_grid_shape(params);
718-
719-
const syclcompat::dim3 sycl_block(block.x, block.y, block.z);
720-
const syclcompat::dim3 sycl_grid(grid.x, grid.y, grid.z);
721-
722-
// configure smem size and carveout
723-
int smem_size = GemmKernel::SharedStorageSize;
724-
725-
Status launch_result{ Status::kSuccess };
726-
launch_result = Status::kSuccess;
727-
cutlass::arch::synclog_setup();
728-
729-
sycl::queue q = *stream; //stream ? *stream : syclcompat::get_default_queue();
730-
using namespace syclcompat::experimental;
731-
if constexpr (cute::is_same_v<DispatchPolicy, MainloopDeviceAgnostic>) {
732-
auto event = launch<device_kernel<GemmKernel>>(launch_policy{
733-
sycl_grid, sycl_block, local_mem_size{static_cast<std::size_t>(smem_size)}
734-
}, q, params);
735-
EventManager::getInstance().addEvent(event);
736-
} else {
737-
auto event = launch<device_kernel<GemmKernel>>(launch_policy{
738-
sycl_grid, sycl_block, local_mem_size{static_cast<std::size_t>(smem_size)}
739-
, kernel_properties{sycl_exp::sub_group_size<DispatchPolicy::SubgroupSize>}
740-
}, q, params);
741-
EventManager::getInstance().addEvent(event);
742-
}
743-
#endif
744-
745-
#if 0
746-
template <typename T, size_t GROUP_SIZE, size_t NUM_PER_THREAD,
747-
size_t SUBG_SIZE, int BITS>
748-
class kgemv_4bit_inference_cutlass {
749-
public:
750-
SYCL_EXTERNAL void operator()(sycl::nd_item<1> item) const;
751-
752-
kgemv_4bit_inference_cutlass(int M_, int N_, int K_, T *A_, unsigned char *B_,
753-
float *absmax_, const float *datatype_, T *out_,
754-
int lda_, int ldb_, int ldc_, int blocksize_)
755-
: M(M_), N(N_), K(K_), A(A_), B(B_), absmax(absmax_), datatype(datatype_),
756-
out(out_), lda(lda_), ldb(ldb_), ldc(ldc_), blocksize(blocksize_),
757-
quant_map(), SharedStorageSize() {}
758-
759-
void sycl_ker_local_memory_creation(sycl::handler &cgh) {
760-
quant_map = sycl::local_accessor<T>(16, cgh);
761-
}
762-
763-
private:
764-
int M;
765-
int N;
766-
int K;
767-
T *A;
768-
unsigned char *B;
769-
float *absmax;
770-
const float *datatype;
771-
T *out;
772-
int lda;
773-
int ldb;
774-
int ldc;
775-
int blocksize;
776-
sycl::local_accessor<T> quant_map;
777-
int SharedStorageSize = 0;
778-
};
779-
780-
#endif
781572

782573
//template class kgemv_4bit_inference_cutlass<sycl::ext::oneapi::bfloat16, 128, 4, 32, 16>;
783574

@@ -794,6 +585,8 @@ void gemv_4bit_inference_cutlass(int m, int n, int k, T *A, T *B,
794585
size_t workgroup_num = (n + NUM_PER_THREAD - 1) / NUM_PER_THREAD;
795586

796587
auto problem_shape = ProblemShape{m, n, k, 1};
588+
589+
#if 1
797590
dim3 const block = get_block_shape();
798591
//dim3 const grid = get_grid_shape(params);
799592
dim3 grid = get_tiled_cta_shape_mnl(problem_shape); //, TileShape{}); //, ClusterShape{});
@@ -802,7 +595,6 @@ void gemv_4bit_inference_cutlass(int m, int n, int k, T *A, T *B,
802595
const syclcompat::dim3 sycl_block(block.x, block.y, block.z);
803596
const syclcompat::dim3 sycl_grid(grid.x, grid.y, grid.z);
804597

805-
#if 1
806598
auto &queue = *stream;
807599
kgemv_4bit_inference_cutlass<T, GROUP_SIZE, NUM_PER_THREAD, SUBG_SIZE, BITS> kfn(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
808600
sycl_kernel_submit<decltype(kfn), 1, 32>(
@@ -848,19 +640,3 @@ template void gemv_4bit_inference_cutlass<sycl::ext::oneapi::bfloat16, 16>(
848640
float *absmax, float *datatype, float *out, int lda,
849641
int ldb, int ldc, int blocksize, sycl::queue *stream);
850642

851-
//template class kgemv_4bit_inference_cutlass<sycl::ext::oneapi::bfloat16, 128, 4, 32, 16>;
852-
//template void gemv_4bit_fusion<sycl::half, 16>(
853-
// int m, int n, int k, sycl::half *A, unsigned char *B, float *absmax,
854-
// float *datatype, sycl::half *out, int lda, int ldb, int ldc, int blocksize,
855-
// sycl::queue *stream);
856-
//template void gemv_4bit_inference<sycl::ext::oneapi::bfloat16, 16>(
857-
// int m, int n, int k, sycl::ext::oneapi::bfloat16 *A, sycl::ext::oneapi::bfloat16 *B,
858-
// float *absmax, float *datatype, float *out, int lda,
859-
// int ldb, int ldc, int blocksize, sycl::queue *stream);
860-
//template void gemv_4bit_inference<float, 32>(int m, int n, int k, float *A,
861-
// unsigned char *B, float *absmax,
862-
// float *datatype, float *out,
863-
// int lda, int ldb, int ldc,
864-
// int blocksize,
865-
// sycl::queue *stream);
866-
//

0 commit comments

Comments
 (0)