55
66#include < sycl/sycl.hpp>
77
8- #if 0
9- template <typename T, size_t GROUP_SIZE, size_t NUM_PER_THREAD,
10- size_t SUBG_SIZE, int BITS>
11- void kgemv_4bit_inference_cutlass<T, GROUP_SIZE, NUM_PER_THREAD, SUBG_SIZE,
12- BITS>::operator()(sycl::nd_item<1> item) const {
13- std::cout<<"this is kgemv_4bit_inference_cutlass ...\n";
14- #if 0
15- auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
16- auto [M, N, K, L] = problem_shape_MNKL;
17-
18- // Complete the stride by combining static layout info (StrideA) with runtime size info (M,K,L)
19- stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
20- stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
21- stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
22- stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
23-
24- block_A.reset(static_cast<std::size_t>(M) * K * L);
25- block_B.reset(static_cast<std::size_t>(K) * N * L);
26- block_C.reset(static_cast<std::size_t>(M) * N * L);
27- block_D.reset(static_cast<std::size_t>(M) * N * L);
28- block_ref_D.reset(static_cast<std::size_t>(M) * N * L);
29-
30- initialize_block(block_A, seed + 2023);
31- initialize_block(block_B, seed + 2022);
32- initialize_block(block_C, seed + 2021);
8+ // The code section below describes datatype for input, output matrices and computation between
9+ // elements in input matrices.
10+ using ElementAccumulator = float ; // <- data type of accumulator
11+ using ElementComputeEpilogue = float ; // <- data type of epilogue operations
12+ using ElementInputA = bfloat16_t ; // <- data type of elements in input matrix A
13+ using ElementInputB = bfloat16_t ; // <- data type of elements in input matrix B
14+ using ElementOutput = float ; // <- data type of elements in output matrix D
15+
16+ using LayoutA = cutlass::layout::RowMajor;
17+ using LayoutB = cutlass::layout::RowMajor;
18+ using LayoutC = cutlass::layout::RowMajor;
19+ using LayoutD = cutlass::layout::RowMajor;
20+
21+ // The 2D block copy operations used for the A and B matrices
22+ using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
23+ using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;
24+
25+ // Workgroup-level tile
26+ using TileShape = Shape<_256, _256, _32>;
27+
28+
29+ // A TiledMMA struct defines a tiling of an MMA atom over M, N and K, combining both additional
30+ // hardware (sub-groups for Intel PVC) and iterations by each sub-group.
31+ //
32+ // The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom
33+ // (XE_8x16x16_F32BF16BF16F32_TT), TileShape (<256, 256, 32>) and sub-group layout (8x4x1). The
34+ // TiledMMA constructed using TiledMMAHelper has the property that each sub-group operates on a
35+ // single contiguous chunk of the work-group TileShape. For this configuration, this implies that
36+ // each sub-group operates on a contiguous 32x64x32 chunk (4x4x2 iterations). See
37+ // 0t_mma_atom.md#TiledMMAs for more info. Sub-groups are arranged row-major (stride 4,1,0) for
38+ // performance reasons.
39+ using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32
40+ typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
41+ Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
42+
43+ // For Intel PVC, PipelineStages defines how many k-blocks ahead to prefetch from A and B.
44+ constexpr int PipelineStages = 2 ;
45+ using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVC<PipelineStages>;
46+ using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue;
47+
48+ // This is the 'default' epilogue operation (Linear Combination) which performs everything in:
49+ // (D = alpha * (A*B) + beta * C)
50+ // aside from the (A*B), which is handled by the GEMM. See 05_pvc_gemm_with_epilogues for more
51+ // complex epilogue examples.
52+ using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementComputeEpilogue,
53+ ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
54+
55+ // FusionCallbacks ties the EpilogueOp to an implementation (based on the dispatch
56+ // policy/architecture) and defines the epilogue arguments.
57+ using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
58+ decltype (tile_shape(TiledMma()))>;
59+ // GEMM Epilogue - loads & stores C/D matrices, performs epilogue operations & load/stores any
60+ // auxiliary data required
61+ using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
62+ EpilogueDispatchPolicy,
63+ TileShape,
64+ ElementAccumulator,
65+ cutlass::gemm::TagToStrideC_t<LayoutC>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
66+ ElementOutput,
67+ cutlass::gemm::TagToStrideC_t<LayoutD>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
68+ FusionCallBacks,
69+ XE_2D_U32x8x16_LD_N, // The copy atom used to load matrix C
70+ void , void ,
71+ XE_2D_U32x8x16_ST_N, // The copy atom used to store matrix D
72+ void , void >;
73+
74+ // GEMM Mainloop - iteration over blocks in K dimension
75+ using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
76+ GEMMDispatchPolicy,
77+ TileShape,
78+ ElementInputA,
79+ cutlass::gemm::TagToStrideA_t<LayoutA>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
80+ ElementInputB,
81+ cutlass::gemm::TagToStrideB_t<LayoutB>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
82+ TiledMma,
83+ GmemTiledCopyA, void , void , cute::identity, // A
84+ GmemTiledCopyB, void , void , cute::identity // B
85+ >;
86+
87+ // Define the whole kernel (mainloop and epilogue)
88+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
89+ Shape<int , int , int , int >, // Defer global problem shape definition to runtime
90+ CollectiveMainloop,
91+ CollectiveEpilogue
92+ >;
93+
94+ // The GemmUniversalAdapter wraps the defined GEMM kernel and handles the launch, and e.g.
95+ // persistent scratch memory if required.
96+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
97+
98+ using StrideA = typename Gemm::GemmKernel::StrideA;
99+ using StrideB = typename Gemm::GemmKernel::StrideB;
100+ using StrideC = typename Gemm::GemmKernel::StrideC;
101+ using StrideD = typename Gemm::GemmKernel::StrideD;
102+
103+ using LayoutA = typename Gemm::LayoutA;
104+ using LayoutB = typename Gemm::LayoutB;
105+ using LayoutC = typename Gemm::LayoutC;
106+ using LayoutD = typename Gemm::LayoutD;
107+
108+ using ElementA = typename Gemm::ElementA;
109+ using ElementB = typename Gemm::ElementB;
110+ using ElementAcc = typename Gemm::ElementAccumulator;
111+
112+ using CollectiveEpilogue = typename Gemm::CollectiveEpilogue;
113+ using ElementC = typename Gemm::ElementC;
114+ using ElementOutput = typename CollectiveEpilogue::ElementOutput;
115+ using ElementCompute = typename CollectiveEpilogue::ElementCompute;
116+ using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator;
117+
118+ using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
119+
120+ //
121+ // Data members
122+ //
123+
124+ // / Initialization
125+ StrideA stride_A;
126+ StrideB stride_B;
127+ StrideC stride_C;
128+ StrideD stride_D;
129+ uint64_t seed = 0 ;
130+
131+ cutlass::DeviceAllocation<ElementA> block_A;
132+ cutlass::DeviceAllocation<ElementB> block_B;
133+ cutlass::DeviceAllocation<ElementC> block_C;
134+ cutlass::DeviceAllocation<ElementOutput> block_D;
135+ cutlass::DeviceAllocation<ElementOutput> block_ref_D; // Reference GEMM result for verification
136+
137+ void initialize (const ProblemShapeType& problem_size) {
138+ auto problem_shape_MNKL = cute::append<4 >(problem_size, 1 );
139+ auto [M, N, K, L] = problem_shape_MNKL;
140+
141+ // Complete the stride by combining static layout info (StrideA) with runtime size info (M,K,L)
142+ stride_A = cutlass::make_cute_packed_stride (StrideA{}, cute::make_shape (M, K, L));
143+ stride_B = cutlass::make_cute_packed_stride (StrideB{}, cute::make_shape (N, K, L));
144+ stride_C = cutlass::make_cute_packed_stride (StrideC{}, cute::make_shape (M, N, L));
145+ stride_D = cutlass::make_cute_packed_stride (StrideD{}, cute::make_shape (M, N, L));
146+
147+ block_A.reset (M * K * L);
148+ block_B.reset (K * N * L);
149+ block_C.reset (M * N * L);
150+ block_D.reset (M * N * L);
151+ block_ref_D.reset (M * N * L);
152+
153+ initialize_block (block_A, seed + 2023 );
154+ initialize_block (block_B, seed + 2022 );
155+ initialize_block (block_C, seed + 2021 );
156+ }
157+
158+ template <typename T, int BITS >
159+ void gemv_4bit_inference (int m, int n, int k, T *A, unsigned char *B,
160+ float *absmax, float *datatype, T *out, int lda,
161+ int ldb, int ldc, int blocksize, sycl::queue *stream) {
162+ // std::cout<<"this is gemv_4bit_inference cutlass...\n";
163+ cutlass::KernelHardwareInfo hw_info;
164+ hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count (hw_info.device_id );
165+ #if 0
166+ // The code section below describes datatype for input, output matrices and computation between
167+ // elements in input matrices.
168+ using ElementAccumulator = float; // <- data type of accumulator
169+ using ElementComputeEpilogue = float; // <- data type of epilogue operations
170+ using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A
171+ using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B
172+ using ElementOutput = float; // <- data type of elements in output matrix D
173+
174+ using LayoutA = cutlass::layout::RowMajor;
175+ using LayoutB = cutlass::layout::RowMajor;
176+ using LayoutC = cutlass::layout::RowMajor;
177+ using LayoutD = cutlass::layout::RowMajor;
178+
179+ // The 2D block copy operations used for the A and B matrices
180+ using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
181+ using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;
182+
183+ // Workgroup-level tile
184+ using TileShape = Shape<_256, _256, _32>;
185+
186+
187+ // A TiledMMA struct defines a tiling of an MMA atom over M, N and K, combining both additional
188+ // hardware (sub-groups for Intel PVC) and iterations by each sub-group.
189+ //
190+ // The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom
191+ // (XE_8x16x16_F32BF16BF16F32_TT), TileShape (<256, 256, 32>) and sub-group layout (8x4x1). The
192+ // TiledMMA constructed using TiledMMAHelper has the property that each sub-group operates on a
193+ // single contiguous chunk of the work-group TileShape. For this configuration, this implies that
194+ // each sub-group operates on a contiguous 32x64x32 chunk (4x4x2 iterations). See
195+ // 0t_mma_atom.md#TiledMMAs for more info. Sub-groups are arranged row-major (stride 4,1,0) for
196+ // performance reasons.
197+ using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32
198+ typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
199+ Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
200+
201+ // For Intel PVC, PipelineStages defines how many k-blocks ahead to prefetch from A and B.
202+ constexpr int PipelineStages = 2;
203+ using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVC<PipelineStages>;
204+ using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue;
205+
206+ // This is the 'default' epilogue operation (Linear Combination) which performs everything in:
207+ // (D = alpha * (A*B) + beta * C)
208+ // aside from the (A*B), which is handled by the GEMM. See 05_pvc_gemm_with_epilogues for more
209+ // complex epilogue examples.
210+ using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementComputeEpilogue,
211+ ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
212+
213+ // FusionCallbacks ties the EpilogueOp to an implementation (based on the dispatch
214+ // policy/architecture) and defines the epilogue arguments.
215+ using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
216+ decltype(tile_shape(TiledMma()))>;
217+ // GEMM Epilogue - loads & stores C/D matrices, performs epilogue operations & load/stores any
218+ // auxiliary data required
219+ using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
220+ EpilogueDispatchPolicy,
221+ TileShape,
222+ ElementAccumulator,
223+ cutlass::gemm::TagToStrideC_t<LayoutC>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
224+ ElementOutput,
225+ cutlass::gemm::TagToStrideC_t<LayoutD>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
226+ FusionCallBacks,
227+ XE_2D_U32x8x16_LD_N, // The copy atom used to load matrix C
228+ void, void,
229+ XE_2D_U32x8x16_ST_N, // The copy atom used to store matrix D
230+ void, void>;
231+
232+ // GEMM Mainloop - iteration over blocks in K dimension
233+ using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
234+ GEMMDispatchPolicy,
235+ TileShape,
236+ ElementInputA,
237+ cutlass::gemm::TagToStrideA_t<LayoutA>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
238+ ElementInputB,
239+ cutlass::gemm::TagToStrideB_t<LayoutB>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
240+ TiledMma,
241+ GmemTiledCopyA, void, void, cute::identity, // A
242+ GmemTiledCopyB, void, void, cute::identity // B
243+ >;
244+
245+ // Define the whole kernel (mainloop and epilogue)
246+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
247+ Shape<int, int, int, int>, // Defer global problem shape definition to runtime
248+ CollectiveMainloop,
249+ CollectiveEpilogue
250+ >;
251+
252+ // The GemmUniversalAdapter wraps the defined GEMM kernel and handles the launch, and e.g.
253+ // persistent scratch memory if required.
254+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
255+ #endif
256+ ProblemShapeType problem_size = ProblemShapeType{m, n, k, ldb};
257+
258+ initialize (problem_size);
33259
34260 typename Gemm::GemmKernel::Arguments arguments{
35261 cutlass::gemm::GemmUniversalMode::kGemm ,
36262 problem_size,
37263 {block_A.get (), stride_A, block_B.get (), stride_B},
38- {{options.alpha, options.beta }, block_C.get(), stride_C, block_D.get(), stride_D},
264+ {{1 . f , 0 . f }, block_C.get (), stride_C, block_D.get (), stride_D},
39265 hw_info
40266 };
41267
@@ -44,30 +270,12 @@ void kgemv_4bit_inference_cutlass<T, GROUP_SIZE, NUM_PER_THREAD, SUBG_SIZE,
44270 size_t workspace_size = Gemm::get_workspace_size (arguments);
45271 cutlass::device_memory::allocation<uint8_t > workspace (workspace_size);
46272
47- if (gemm_op.can_implement(arguments) != cutlass::Status::kSuccess){
48- std::cout << "Invalid Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
49- std::exit(1);
50- }
51-
52273 CUTLASS_CHECK (gemm_op.initialize (arguments, workspace.get ()));
53274
54275 // Run the GEMM
55276 CUTLASS_CHECK (gemm_op.run ());
56277
57278 syclcompat::wait ();
58- #endif
59- }
60-
61- template class kgemv_4bit_inference_cutlass<sycl::half, 128, 4, 32, 16>;
62- template class kgemv_4bit_inference_cutlass<sycl::ext::oneapi::bfloat16, 128, 4, 32, 16>;
63- template class kgemv_4bit_inference_cutlass<float, 128, 4, 32, 32>;
64-
65- #endif
66- template <typename T, int BITS >
67- void gemv_4bit_inference (int m, int n, int k, T *A, unsigned char *B,
68- float *absmax, float *datatype, T *out, int lda,
69- int ldb, int ldc, int blocksize, sycl::queue *stream) {
70- std::cout<<" this is gemv_4bit_inference cutlass...\n " ;
71279}
72280
73281template void gemv_4bit_inference<sycl::half, 16 >(
0 commit comments