@@ -162,97 +162,7 @@ void gemv_4bit_inference(int m, int n, int k, T *A, unsigned char *B,
162162// std::cout<<"this is gemv_4bit_inference cutlass...\n";
163163 cutlass::KernelHardwareInfo hw_info;
164164 hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count (hw_info.device_id );
165- #if 0
166- // The code section below describes datatype for input, output matrices and computation between
167- // elements in input matrices.
168- using ElementAccumulator = float; // <- data type of accumulator
169- using ElementComputeEpilogue = float; // <- data type of epilogue operations
170- using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A
171- using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B
172- using ElementOutput = float; // <- data type of elements in output matrix D
173-
174- using LayoutA = cutlass::layout::RowMajor;
175- using LayoutB = cutlass::layout::RowMajor;
176- using LayoutC = cutlass::layout::RowMajor;
177- using LayoutD = cutlass::layout::RowMajor;
178-
179- // The 2D block copy operations used for the A and B matrices
180- using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
181- using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;
182-
183- // Workgroup-level tile
184- using TileShape = Shape<_256, _256, _32>;
185-
186-
187- // A TiledMMA struct defines a tiling of an MMA atom over M, N and K, combining both additional
188- // hardware (sub-groups for Intel PVC) and iterations by each sub-group.
189- //
190- // The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom
191- // (XE_8x16x16_F32BF16BF16F32_TT), TileShape (<256, 256, 32>) and sub-group layout (8x4x1). The
192- // TiledMMA constructed using TiledMMAHelper has the property that each sub-group operates on a
193- // single contiguous chunk of the work-group TileShape. For this configuration, this implies that
194- // each sub-group operates on a contiguous 32x64x32 chunk (4x4x2 iterations). See
195- // 0t_mma_atom.md#TiledMMAs for more info. Sub-groups are arranged row-major (stride 4,1,0) for
196- // performance reasons.
197- using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32
198- typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
199- Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
200-
201- // For Intel PVC, PipelineStages defines how many k-blocks ahead to prefetch from A and B.
202- constexpr int PipelineStages = 2;
203- using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVC<PipelineStages>;
204- using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue;
205-
206- // This is the 'default' epilogue operation (Linear Combination) which performs everything in:
207- // (D = alpha * (A*B) + beta * C)
208- // aside from the (A*B), which is handled by the GEMM. See 05_pvc_gemm_with_epilogues for more
209- // complex epilogue examples.
210- using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementComputeEpilogue,
211- ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
212-
213- // FusionCallbacks ties the EpilogueOp to an implementation (based on the dispatch
214- // policy/architecture) and defines the epilogue arguments.
215- using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
216- decltype(tile_shape(TiledMma()))>;
217- // GEMM Epilogue - loads & stores C/D matrices, performs epilogue operations & load/stores any
218- // auxiliary data required
219- using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
220- EpilogueDispatchPolicy,
221- TileShape,
222- ElementAccumulator,
223- cutlass::gemm::TagToStrideC_t<LayoutC>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
224- ElementOutput,
225- cutlass::gemm::TagToStrideC_t<LayoutD>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
226- FusionCallBacks,
227- XE_2D_U32x8x16_LD_N, // The copy atom used to load matrix C
228- void, void,
229- XE_2D_U32x8x16_ST_N, // The copy atom used to store matrix D
230- void, void>;
231-
232- // GEMM Mainloop - iteration over blocks in K dimension
233- using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
234- GEMMDispatchPolicy,
235- TileShape,
236- ElementInputA,
237- cutlass::gemm::TagToStrideA_t<LayoutA>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
238- ElementInputB,
239- cutlass::gemm::TagToStrideB_t<LayoutB>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
240- TiledMma,
241- GmemTiledCopyA, void, void, cute::identity, // A
242- GmemTiledCopyB, void, void, cute::identity // B
243- >;
244165
245- // Define the whole kernel (mainloop and epilogue)
246- using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
247- Shape<int, int, int, int>, // Defer global problem shape definition to runtime
248- CollectiveMainloop,
249- CollectiveEpilogue
250- >;
251-
252- // The GemmUniversalAdapter wraps the defined GEMM kernel and handles the launch, and e.g.
253- // persistent scratch memory if required.
254- using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
255- #endif
256166 ProblemShapeType problem_size = ProblemShapeType{m, n, k, ldb};
257167
258168 initialize (problem_size);
0 commit comments