@@ -67,6 +67,24 @@ using DispatchPolicy = MainloopIntelPVC<Stages, KernelPVC /*Schedule*/>;
6767 using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<float /* data_type of GEMM output*/ , ElementComputeEpilogue, ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
6868 using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::IntelPVCEpilogue, EpilogueOp, TileShape, decltype (tile_shape(TiledMma()))>;
6969
70+ // struct TensorStorageImpl: cute::tuple<SmemCStorage, SmemDStorage> {
71+ // using FusionStorage = typename FusionCallbacks::SharedStorage;
72+ // FusionStorage thread;
73+ // };
74+ //
75+ // struct SharedStorage {
76+ // using TensorStorage = TensorStorageImpl;
77+ //
78+ // TensorStorage tensors;
79+ // };
80+ // using TensorStorage = typename SharedStorage::TensorStorage;
81+ //
82+ // // Kernel level shared memory storage
83+ // struct SharedStorage {
84+ // using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
85+ // EpilogueTensorStorage epilogue;
86+ // };
87+ using SharedStorage = FusionCallBacks::SharedStorage;
7088static dim3
7189 get_block_shape () {
7290 return dim3 (MaxThreadsPerBlock, 1 , 1 );
@@ -89,6 +107,15 @@ template <typename T, size_t GROUP_SIZE, size_t NUM_PER_THREAD,
89107 size_t SUBG_SIZE , int BITS >
90108class kgemv_4bit_inference_cutlass {
91109public:
110+ struct Params {
111+ int m, n, k;
112+ T *A, *B;
113+ float *absmax, *out;
114+ const float *datatype;
115+ int lda, ldb, ldc;
116+ int blocksize;
117+ };
118+
92119 struct Arguments {
93120 GemmUniversalMode mode{};
94121 ProblemShape problem_shape{};
@@ -130,7 +157,7 @@ class kgemv_4bit_inference_cutlass {
130157 // return Status::kSuccess;
131158 // }
132159
133- #if 1
160+ #if 0
134161 kgemv_4bit_inference_cutlass(int M_, int N_, int K_, T *A_, T *B_,
135162 float *absmax_, const float *datatype_, float *out_,
136163 int lda_, int ldb_, int ldc_, int blocksize_)
@@ -151,19 +178,35 @@ class kgemv_4bit_inference_cutlass {
151178 int ldb;
152179 int ldc;
153180 int blocksize;
154- sycl::local_accessor<T> quant_map;
155181 int SharedStorageSize = 0;
156182
157183public:
158184 CUTLASS_DEVICE
159185 void operator()(sycl::nd_item<1> item) const {
160186
161- #else
187+ #elif 0
162188 CUTLASS_DEVICE
163189 void operator()(int M, int N, int K, T *A, T *B,
164190 float *absmax, const float *datatype, float *out,
165191 int lda, int ldb, int ldc, int blocksize) {//(sycl::nd_item<1> item) const {
192+ #else
193+ CUTLASS_DEVICE
194+ void operator ()(Params const & params, char * smem_buf) const {
195+ // SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
196+ auto M = params.m ;
197+ auto N = params.n ;
198+ auto K = params.k ;
199+ auto A = params.A ;
200+ auto B = params.B ;
201+ auto out = params.out ;
202+ auto absmax = params.absmax ;
203+ auto datatype = params.datatype ;
204+ auto lda = params.lda ;
205+ auto ldb = params.ldb ;
206+ auto ldc = params.ldc ;
207+ auto blocksize = params.blocksize ;
166208#endif
209+ #if 0
167210 int L = 1;
168211 StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
169212 StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
@@ -568,7 +611,7 @@ class kgemv_4bit_inference_cutlass {
568611 }
569612
570613 cst_callbacks.end();
571-
614+ # endif
572615 }
573616};
574617
@@ -588,7 +631,7 @@ void gemv_4bit_inference_cutlass(int m, int n, int k, T *A, T *B,
588631
589632 auto problem_shape = ProblemShape{m, n, k, 1 };
590633
591- #if 1
634+ #if 0
592635 dim3 const block = get_block_shape();
593636 //dim3 const grid = get_grid_shape(params);
594637 dim3 grid = get_tiled_cta_shape_mnl(problem_shape); //, TileShape{}); //, ClusterShape{});
@@ -605,7 +648,7 @@ void gemv_4bit_inference_cutlass(int m, int n, int k, T *A, T *B,
605648 queue, kfn);
606649 queue.wait();
607650#else
608- using GemmKernel = kgemv_4bit_inference_cutlass<T, GROUP_SIZE, NUM_PER_THREAD, SUBG_SIZE, BITS>;
651+ using GemmKernel = kgemv_4bit_inference_cutlass<T, GROUP_SIZE , NUM_PER_THREAD , SUBG_SIZE , BITS >;// (m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
609652 using GemmKernel_t = GetUnderlyingKernel_t<GemmKernel>;
610653
611654 dim3 const block = get_block_shape ();
@@ -617,21 +660,62 @@ void gemv_4bit_inference_cutlass(int m, int n, int k, T *A, T *B,
617660 const syclcompat::dim3 sycl_grid (grid.x , grid.y , grid.z );
618661
619662 // configure smem size and carveout
620- const int smem_size = 0; //GemmKernel::SharedStorageSize;
663+ // const int smem_size = 0; //GemmKernel::SharedStorageSize;
664+ static constexpr int smem_size= 1 ;
621665
622666 // Status launch_result{ Status::kSuccess };
623667 // launch_result = Status::kSuccess;
624- cutlass::arch::synclog_setup();
668+ // cutlass::arch::synclog_setup();
625669
626670 sycl::queue q = *stream; // stream ? *stream : syclcompat::get_default_queue();
671+
672+ using Params = GemmKernel_t::Params;
673+ #if 0
674+ cutlass::kernel_launch<GemmKernel, Params>(
675+ grid, block, smem_size, stream, Params{m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize}, false);
676+ #else
627677 using namespace syclcompat ::experimental;
628678
629- auto event = syclcompat::experimental::launch<device_kernel<GemmKernel>>(launch_policy{
630- sycl_grid, sycl_block, local_mem_size{static_cast<std::size_t>(smem_size)}
679+ // Params params{
680+ // .M = m, .N = n, .K = k,
681+ // .A = A, .B = B,
682+ // .out = out,
683+ // .lda = lda, .ldb = ldb, .ldc = ldc
684+ // };
685+ Params params;
686+ params.m = m;
687+ params.n = n;
688+ params.k = k;
689+ params.A = A;
690+ params.B = B;
691+ params.out = out;
692+ params.lda = lda;
693+ params.ldb = ldb;
694+ params.ldc = ldc;
695+ params.absmax = absmax;
696+ params.datatype = datatype;
697+ params.blocksize = blocksize;
698+ auto event = launch<device_kernel<GemmKernel_t>>(launch_policy{
699+ sycl_grid, sycl_block// , local_mem_size{static_cast<std::size_t>(smem_size)}
631700 , kernel_properties{sycl_exp::sub_group_size<DispatchPolicy::SubgroupSize>}
632- }, q);//, params);
701+ }, q, params);
702+ // 计算执行范围
703+ // size_t local_size = 256;
704+ // size_t global_size = (m + local_size - 1) / local_size * local_size;
705+ //
706+ // // 启动内核
707+ // auto event = syclcompat::experimental::launch<
708+ // GemmKernel>(
709+ // launch_policy{
710+ // sycl_grid, sycl_block, local_mem_size{static_cast<std::size_t>(smem_size)},//sycl::nd_range<1>(global_size, local_size),
711+ // kernel_properties{sycl_exp::sub_group_size<DispatchPolicy::SubgroupSize>}
712+ // },
713+ // q,
714+ // params
715+ // );
633716 EventManager::getInstance ().addEvent (event);
634717#endif
718+ #endif
635719}
636720
637721// template class kgemv_4bit_inference_cutlass<sycl::ext::oneapi::bfloat16, 128, 4, 32, 16>;
0 commit comments