Skip to content

Commit 0214749

Browse files
committed
fix build
1 parent 107cc09 commit 0214749

1 file changed

Lines changed: 95 additions & 11 deletions

File tree

csrc/xpu_cutlass-cute.cpp

Lines changed: 95 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,24 @@ using DispatchPolicy = MainloopIntelPVC<Stages, KernelPVC /*Schedule*/>;
6767
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<float /*data_type of GEMM output*/, ElementComputeEpilogue, ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
6868
using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::IntelPVCEpilogue, EpilogueOp, TileShape, decltype(tile_shape(TiledMma()))>;
6969

70+
// struct TensorStorageImpl: cute::tuple<SmemCStorage, SmemDStorage> {
71+
// using FusionStorage = typename FusionCallbacks::SharedStorage;
72+
// FusionStorage thread;
73+
// };
74+
//
75+
// struct SharedStorage {
76+
// using TensorStorage = TensorStorageImpl;
77+
//
78+
// TensorStorage tensors;
79+
// };
80+
// using TensorStorage = typename SharedStorage::TensorStorage;
81+
//
82+
// // Kernel level shared memory storage
83+
// struct SharedStorage {
84+
// using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
85+
// EpilogueTensorStorage epilogue;
86+
// };
87+
using SharedStorage = FusionCallBacks::SharedStorage;
7088
static dim3
7189
get_block_shape() {
7290
return dim3(MaxThreadsPerBlock, 1, 1);
@@ -89,6 +107,15 @@ template <typename T, size_t GROUP_SIZE, size_t NUM_PER_THREAD,
89107
size_t SUBG_SIZE, int BITS>
90108
class kgemv_4bit_inference_cutlass {
91109
public:
110+
struct Params {
111+
int m, n, k;
112+
T *A, *B;
113+
float *absmax, *out;
114+
const float *datatype;
115+
int lda, ldb, ldc;
116+
int blocksize;
117+
};
118+
92119
struct Arguments {
93120
GemmUniversalMode mode{};
94121
ProblemShape problem_shape{};
@@ -130,7 +157,7 @@ class kgemv_4bit_inference_cutlass {
130157
// return Status::kSuccess;
131158
//}
132159

133-
#if 1
160+
#if 0
134161
kgemv_4bit_inference_cutlass(int M_, int N_, int K_, T *A_, T *B_,
135162
float *absmax_, const float *datatype_, float *out_,
136163
int lda_, int ldb_, int ldc_, int blocksize_)
@@ -151,19 +178,35 @@ class kgemv_4bit_inference_cutlass {
151178
int ldb;
152179
int ldc;
153180
int blocksize;
154-
sycl::local_accessor<T> quant_map;
155181
int SharedStorageSize = 0;
156182

157183
public:
158184
CUTLASS_DEVICE
159185
void operator()(sycl::nd_item<1> item) const {
160186

161-
#else
187+
#elif 0
162188
CUTLASS_DEVICE
163189
void operator()(int M, int N, int K, T *A, T *B,
164190
float *absmax, const float *datatype, float *out,
165191
int lda, int ldb, int ldc, int blocksize) {//(sycl::nd_item<1> item) const {
192+
#else
193+
CUTLASS_DEVICE
194+
void operator()(Params const& params, char* smem_buf) const {
195+
//SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
196+
auto M = params.m;
197+
auto N = params.n;
198+
auto K = params.k;
199+
auto A = params.A;
200+
auto B = params.B;
201+
auto out = params.out;
202+
auto absmax = params.absmax;
203+
auto datatype = params.datatype;
204+
auto lda = params.lda;
205+
auto ldb = params.ldb;
206+
auto ldc = params.ldc;
207+
auto blocksize = params.blocksize;
166208
#endif
209+
#if 0
167210
int L = 1;
168211
StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
169212
StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
@@ -568,7 +611,7 @@ class kgemv_4bit_inference_cutlass {
568611
}
569612

570613
cst_callbacks.end();
571-
614+
#endif
572615
}
573616
};
574617

@@ -588,7 +631,7 @@ void gemv_4bit_inference_cutlass(int m, int n, int k, T *A, T *B,
588631

589632
auto problem_shape = ProblemShape{m, n, k, 1};
590633

591-
#if 1
634+
#if 0
592635
dim3 const block = get_block_shape();
593636
//dim3 const grid = get_grid_shape(params);
594637
dim3 grid = get_tiled_cta_shape_mnl(problem_shape); //, TileShape{}); //, ClusterShape{});
@@ -605,7 +648,7 @@ void gemv_4bit_inference_cutlass(int m, int n, int k, T *A, T *B,
605648
queue, kfn);
606649
queue.wait();
607650
#else
608-
using GemmKernel = kgemv_4bit_inference_cutlass<T, GROUP_SIZE, NUM_PER_THREAD, SUBG_SIZE, BITS>;
651+
using GemmKernel = kgemv_4bit_inference_cutlass<T, GROUP_SIZE, NUM_PER_THREAD, SUBG_SIZE, BITS>;//(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
609652
using GemmKernel_t = GetUnderlyingKernel_t<GemmKernel>;
610653

611654
dim3 const block = get_block_shape();
@@ -617,21 +660,62 @@ void gemv_4bit_inference_cutlass(int m, int n, int k, T *A, T *B,
617660
const syclcompat::dim3 sycl_grid(grid.x, grid.y, grid.z);
618661

619662
// configure smem size and carveout
620-
const int smem_size = 0; //GemmKernel::SharedStorageSize;
663+
//const int smem_size = 0; //GemmKernel::SharedStorageSize;
664+
static constexpr int smem_size= 1;
621665

622666
//Status launch_result{ Status::kSuccess };
623667
// launch_result = Status::kSuccess;
624-
cutlass::arch::synclog_setup();
668+
//cutlass::arch::synclog_setup();
625669

626670
sycl::queue q = *stream; //stream ? *stream : syclcompat::get_default_queue();
671+
672+
using Params = GemmKernel_t::Params;
673+
#if 0
674+
cutlass::kernel_launch<GemmKernel, Params>(
675+
grid, block, smem_size, stream, Params{m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize}, false);
676+
#else
627677
using namespace syclcompat::experimental;
628678

629-
auto event = syclcompat::experimental::launch<device_kernel<GemmKernel>>(launch_policy{
630-
sycl_grid, sycl_block, local_mem_size{static_cast<std::size_t>(smem_size)}
679+
// Params params{
680+
// .M = m, .N = n, .K = k,
681+
// .A = A, .B = B,
682+
// .out = out,
683+
// .lda = lda, .ldb = ldb, .ldc = ldc
684+
// };
685+
Params params;
686+
params.m = m;
687+
params.n = n;
688+
params.k = k;
689+
params.A = A;
690+
params.B = B;
691+
params.out = out;
692+
params.lda = lda;
693+
params.ldb = ldb;
694+
params.ldc = ldc;
695+
params.absmax = absmax;
696+
params.datatype = datatype;
697+
params.blocksize = blocksize;
698+
auto event = launch<device_kernel<GemmKernel_t>>(launch_policy{
699+
sycl_grid, sycl_block//, local_mem_size{static_cast<std::size_t>(smem_size)}
631700
, kernel_properties{sycl_exp::sub_group_size<DispatchPolicy::SubgroupSize>}
632-
}, q);//, params);
701+
}, q, params);
702+
// 计算执行范围
703+
//size_t local_size = 256;
704+
//size_t global_size = (m + local_size - 1) / local_size * local_size;
705+
//
706+
//// 启动内核
707+
//auto event = syclcompat::experimental::launch<
708+
// GemmKernel>(
709+
// launch_policy{
710+
// sycl_grid, sycl_block, local_mem_size{static_cast<std::size_t>(smem_size)},//sycl::nd_range<1>(global_size, local_size),
711+
// kernel_properties{sycl_exp::sub_group_size<DispatchPolicy::SubgroupSize>}
712+
// },
713+
// q,
714+
// params
715+
//);
633716
EventManager::getInstance().addEvent(event);
634717
#endif
718+
#endif
635719
}
636720

637721
//template class kgemv_4bit_inference_cutlass<sycl::ext::oneapi::bfloat16, 128, 4, 32, 16>;

0 commit comments

Comments
 (0)