@@ -34,7 +34,6 @@ using namespace cute;
3434using namespace cutlass ;
3535using namespace cutlass ::gemm;
3636
37- #if 1
3837using ElementA = bfloat16_t ;
3938using ElementB = bfloat16_t ;
4039using ElementC = float ;
@@ -65,14 +64,9 @@ using TiledMma =
6564 Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
6665static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{});
6766using DispatchPolicy = MainloopIntelPVC<Stages, KernelPVC /* Schedule*/ >;
68- #endif
69- // using TiledMma =
70- // typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
71- // Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
72- // static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{});
73- // using TileShape = Shape<_256, _256, _32>;
74- // using ProblemShape = Shape<int, int, int, int>;
75- #if 1
67+ using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<float /* data_type of GEMM output*/ , ElementComputeEpilogue, ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
68+ using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::IntelPVCEpilogue, EpilogueOp, TileShape, decltype (tile_shape(TiledMma()))>;
69+
7670static dim3
7771 get_block_shape () {
7872 return dim3 (MaxThreadsPerBlock, 1 , 1 );
@@ -91,122 +85,10 @@ get_tiled_cta_shape_mnl(ProblemShape problem_shape) {
9185 };
9286}
9387
94- // struct Arguments {
95- // GemmUniversalMode mode{};
96- // ProblemShape problem_shape{};
97- // //MainloopArguments mainloop{};
98- // ElementA const* ptr_A;
99- // StrideA dA;
100- // ElementB const* ptr_B;
101- // StrideB dB;
102- //
103- // //EpilogueArguments epilogue{};
104- // //typename FusionCallbacks::Arguments thread{};
105- // ElementC const* ptr_C;
106- // StrideC dC;
107- // ElementD* ptr_D;
108- // StrideD dD;
109- //
110- // cutlass::KernelHardwareInfo hw_info{};
111- // //TileSchedulerArguments scheduler{};
112- // };
113-
114- // static size_t get_workspace_size(Arguments const& args) {
115- // size_t workspace_bytes = 0;
116- // if (args.mode == GemmUniversalMode::kGemmSplitKParallel) {
117- // workspace_bytes += sizeof(int) * size_t(cute::size<0>(TileShape{})) * size_t(cute::size<1>(TileShape{}));
118- // }
119- //
120- // //TODO: Check it!!
121- // workspace_bytes += 0; //GemmKernel::get_workspace_size(args);
122- //
123- // CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
124- //
125- // return workspace_bytes;
126- // }
127- #endif
128-
12988template <typename T, size_t GROUP_SIZE , size_t NUM_PER_THREAD ,
13089 size_t SUBG_SIZE , int BITS >
13190class kgemv_4bit_inference_cutlass {
13291public:
133- // kgemv_4bit_inference_cutlass(int M_, int N_, int K_, T *A_, T *B_,
134- // float *absmax_, const float *datatype_, float *out_,
135- // int lda_, int ldb_, int ldc_, int blocksize_)
136- // : M(M_), N(N_), K(K_), A(A_), B(B_),
137- // absmax(absmax_), out(out_), datatype(datatype_),
138- // lda(lda_), ldb(ldb_), ldc(ldc_), blocksize(blocksize_) {}
139-
140- // SYCL_EXTERNAL
141- // void operator()(sycl::nd_item<1> item) const {
142-
143- // Specific setting
144- #if 1
145- using ElementA = bfloat16_t ;
146- using ElementB = bfloat16_t ;
147- using ElementC = float ;
148- using ElementD = float ;
149- using ElementAccumulator = float ; // data_type of accumulator
150- using ElementComputeEpilogue = float ; // data_type of epilogue operations
151- using ElementOutput = float ;
152- static constexpr int Stages = 2 ;
153-
154- using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
155- using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;
156- using CopyOpG2R = XE_2D_U32x8x16_LD_N;
157- using CopyOpR2G = XE_2D_U32x8x16_ST_N;
158- using GmemTiledCopyC = CopyOpG2R;
159- using GmemTiledCopyD = cute::conditional_t <not cute::is_void_v<ElementD> && not cute::is_void_v<CopyOpR2G>,
160- CopyOpR2G, XE_2D_U32x8x16_ST_N>;
161-
162- using StrideA = cutlass::gemm::TagToStrideA_t<cutlass::layout::RowMajor>;
163- using StrideB = cutlass::gemm::TagToStrideB_t<cutlass::layout::RowMajor>;
164- using StrideC = cutlass::gemm::TagToStrideC_t<cutlass::layout::RowMajor>;
165- using StrideD = cutlass::gemm::TagToStrideC_t<cutlass::layout::RowMajor>;
166- using ProblemShape = Shape<int , int , int , int >;
167-
168- // int L = 1;
169- // StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
170- // StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
171- // StrideC stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
172- // cutlass::KernelHardwareInfo hw_info;
173- // hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
174- using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<float /* data_type of GEMM output*/ , ElementComputeEpilogue, ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
175- using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::IntelPVCEpilogue, EpilogueOp, TileShape, decltype (tile_shape(TiledMma()))>;
176-
177- using TileShape = Shape<_256, _256, _32>;
178- using WorkgroupTileShape = TileShape;
179- using TiledMma =
180- typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
181- Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
182- static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{});
183- using DispatchPolicy = MainloopIntelPVC<Stages, KernelPVC /* Schedule*/ >;
184- #endif
185- #if 1
186- // using ClusterShape = cutlass::gemm::GemmShape<
187- // cute::size<0>(typename DispatchPolicy::ClusterShape{}),
188- // cute::size<1>(typename DispatchPolicy::ClusterShape{}),
189- // cute::size<2>(typename DispatchPolicy::ClusterShape{})>;
190-
191- // new Functions
192- // static dim3
193- // get_block_shape() {
194- // return dim3(MaxThreadsPerBlock, 1, 1);
195- // }
196- //
197- // static dim3
198- // get_tiled_cta_shape_mnl(ProblemShape problem_shape) {
199- // using cta_shape = TileShape;
200- // auto cta_m = (get<0>(problem_shape) + get<0>(cta_shape{}) - 1) / get<0>(cta_shape{});
201- // auto cta_n = (get<1>(problem_shape) + get<1>(cta_shape{}) - 1) / get<1>(cta_shape{});
202- //
203- // return {
204- // static_cast<uint32_t>(cta_m),
205- // static_cast<uint32_t>(cta_n),
206- // static_cast<uint32_t>(get<3>(problem_shape))
207- // };
208- // }
209- // SYCL_EXTERNAL
21092struct Arguments {
21193 GemmUniversalMode mode{};
21294 ProblemShape problem_shape{};
@@ -247,11 +129,6 @@ static size_t get_workspace_size(Arguments const& args) {
247129// CudaHostAdapter* cuda_adapter = nullptr) {
248130// return Status::kSuccess;
249131// }
250- #endif
251- // template <typename T, size_t GROUP_SIZE, size_t NUM_PER_THREAD,
252- // size_t SUBG_SIZE, int BITS>
253- // class kgemv_4bit_inference_cutlass {
254- // public:
255132#if 1
256133 kgemv_4bit_inference_cutlass (int M_ , int N_ , int K_ , T *A_ , T *B_ ,
257134 float *absmax_, const float *datatype_, float *out_,
@@ -276,11 +153,9 @@ static size_t get_workspace_size(Arguments const& args) {
276153 sycl::local_accessor<T> quant_map;
277154 int SharedStorageSize = 0 ;
278155public:
279- // SYCL_EXTERNAL
280156CUTLASS_DEVICE
281157 void operator ()(sycl::nd_item<1 > item) const {
282158
283- // std::cout<<"this is kgemv_4bit_inference cutlass fusion path !!!\n";
284159#else
285160 CUTLASS_DEVICE
286161 void operator ()(int M, int N, int K, T *A, T *B,
@@ -693,91 +568,7 @@ CUTLASS_DEVICE
693568
694569 cst_callbacks.end ();
695570}
696-
697- // private:
698- // int M;
699- // int N;
700- // int K;
701- // T *A;
702- // T *B;
703- // float *absmax;
704- // const float *datatype;
705- // float *out;
706- // int lda;
707- // int ldb;
708- // int ldc;
709- // int blocksize;
710- // sycl::local_accessor<T> quant_map;
711- // int SharedStorageSize = 0;
712571};
713- #if 0
714- //TODO: replace with private kernel submit ???
715- //Launch Kernel
716- dim3 const block = GemmKernel::get_block_shape();
717- dim3 const grid = get_grid_shape(params);
718-
719- const syclcompat::dim3 sycl_block(block.x, block.y, block.z);
720- const syclcompat::dim3 sycl_grid(grid.x, grid.y, grid.z);
721-
722- // configure smem size and carveout
723- int smem_size = GemmKernel::SharedStorageSize;
724-
725- Status launch_result{ Status::kSuccess };
726- launch_result = Status::kSuccess;
727- cutlass::arch::synclog_setup();
728-
729- sycl::queue q = *stream; //stream ? *stream : syclcompat::get_default_queue();
730- using namespace syclcompat::experimental;
731- if constexpr (cute::is_same_v<DispatchPolicy, MainloopDeviceAgnostic>) {
732- auto event = launch<device_kernel<GemmKernel>>(launch_policy{
733- sycl_grid, sycl_block, local_mem_size{static_cast<std::size_t>(smem_size)}
734- }, q, params);
735- EventManager::getInstance().addEvent(event);
736- } else {
737- auto event = launch<device_kernel<GemmKernel>>(launch_policy{
738- sycl_grid, sycl_block, local_mem_size{static_cast<std::size_t>(smem_size)}
739- , kernel_properties{sycl_exp::sub_group_size<DispatchPolicy::SubgroupSize>}
740- }, q, params);
741- EventManager::getInstance().addEvent(event);
742- }
743- #endif
744-
745- #if 0
746- template <typename T, size_t GROUP_SIZE, size_t NUM_PER_THREAD,
747- size_t SUBG_SIZE, int BITS>
748- class kgemv_4bit_inference_cutlass {
749- public:
750- SYCL_EXTERNAL void operator()(sycl::nd_item<1> item) const;
751-
752- kgemv_4bit_inference_cutlass(int M_, int N_, int K_, T *A_, unsigned char *B_,
753- float *absmax_, const float *datatype_, T *out_,
754- int lda_, int ldb_, int ldc_, int blocksize_)
755- : M(M_), N(N_), K(K_), A(A_), B(B_), absmax(absmax_), datatype(datatype_),
756- out(out_), lda(lda_), ldb(ldb_), ldc(ldc_), blocksize(blocksize_),
757- quant_map(), SharedStorageSize() {}
758-
759- void sycl_ker_local_memory_creation(sycl::handler &cgh) {
760- quant_map = sycl::local_accessor<T>(16, cgh);
761- }
762-
763- private:
764- int M;
765- int N;
766- int K;
767- T *A;
768- unsigned char *B;
769- float *absmax;
770- const float *datatype;
771- T *out;
772- int lda;
773- int ldb;
774- int ldc;
775- int blocksize;
776- sycl::local_accessor<T> quant_map;
777- int SharedStorageSize = 0;
778- };
779-
780- #endif
781572
782573// template class kgemv_4bit_inference_cutlass<sycl::ext::oneapi::bfloat16, 128, 4, 32, 16>;
783574
@@ -794,6 +585,8 @@ void gemv_4bit_inference_cutlass(int m, int n, int k, T *A, T *B,
794585 size_t workgroup_num = (n + NUM_PER_THREAD - 1 ) / NUM_PER_THREAD ;
795586
796587 auto problem_shape = ProblemShape{m, n, k, 1 };
588+
589+ #if 1
797590 dim3 const block = get_block_shape ();
798591 // dim3 const grid = get_grid_shape(params);
799592 dim3 grid = get_tiled_cta_shape_mnl (problem_shape); // , TileShape{}); //, ClusterShape{});
@@ -802,7 +595,6 @@ void gemv_4bit_inference_cutlass(int m, int n, int k, T *A, T *B,
802595 const syclcompat::dim3 sycl_block (block.x , block.y , block.z );
803596 const syclcompat::dim3 sycl_grid (grid.x , grid.y , grid.z );
804597
805- #if 1
806598 auto &queue = *stream;
807599 kgemv_4bit_inference_cutlass<T, GROUP_SIZE , NUM_PER_THREAD , SUBG_SIZE , BITS > kfn (m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
808600 sycl_kernel_submit<decltype (kfn), 1 , 32 >(
@@ -848,19 +640,3 @@ template void gemv_4bit_inference_cutlass<sycl::ext::oneapi::bfloat16, 16>(
848640 float *absmax, float *datatype, float *out, int lda,
849641 int ldb, int ldc, int blocksize, sycl::queue *stream);
850642
851- // template class kgemv_4bit_inference_cutlass<sycl::ext::oneapi::bfloat16, 128, 4, 32, 16>;
852- // template void gemv_4bit_fusion<sycl::half, 16>(
853- // int m, int n, int k, sycl::half *A, unsigned char *B, float *absmax,
854- // float *datatype, sycl::half *out, int lda, int ldb, int ldc, int blocksize,
855- // sycl::queue *stream);
856- // template void gemv_4bit_inference<sycl::ext::oneapi::bfloat16, 16>(
857- // int m, int n, int k, sycl::ext::oneapi::bfloat16 *A, sycl::ext::oneapi::bfloat16 *B,
858- // float *absmax, float *datatype, float *out, int lda,
859- // int ldb, int ldc, int blocksize, sycl::queue *stream);
860- // template void gemv_4bit_inference<float, 32>(int m, int n, int k, float *A,
861- // unsigned char *B, float *absmax,
862- // float *datatype, float *out,
863- // int lda, int ldb, int ldc,
864- // int blocksize,
865- // sycl::queue *stream);
866- //
0 commit comments