Skip to content

Commit 9a3b980

Browse files
committed
Enable draft GEMM/GEMV kernel
1 parent f81099b commit 9a3b980

3 files changed

Lines changed: 282 additions & 74 deletions

File tree

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,8 @@ if(BUILD_MPS)
312312
target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph")
313313
endif()
314314
if(BUILD_XPU)
315-
set(SYCL_LINK_FLAGS "-fsycl;--offload-compress;-fsycl-targets=spir64_gen,spir64;-Xs;-device pvc,xe-lpg,ats-m150 -options ' -cl-intel-enable-auto-large-GRF-mode -cl-poison-unsupported-fp64-kernels -cl-intel-greater-than-4GB-buffer-required'")
316-
set(SYCL_COMPILE_FLAGS "-fsycl;-fhonor-nans;-fhonor-infinities;-fno-associative-math;-fno-approx-func;-fno-sycl-instrument-device-code;--offload-compress;-fsycl-targets=spir64_gen,spir64;")
315+
set(SYCL_LINK_FLAGS "-fsycl;--offload-compress;-fsycl-targets=intel_gpu_pvc;-Xspirv-translator;-spirv-ext=+SPV_INTEL_split_barrier;-Xs; -options ' -cl-intel-enable-auto-large-GRF-mode -cl-poison-unsupported-fp64-kernels -cl-intel-greater-than-4GB-buffer-required'")
316+
set(SYCL_COMPILE_FLAGS "-fsycl;-fhonor-nans;-fhonor-infinities;-fno-associative-math;-fno-approx-func;-fno-sycl-instrument-device-code;--offload-compress;-fsycl-targets=intel_gpu_pvc;-Xspirv-translator;-spirv-ext=+SPV_INTEL_split_barrier;")
317317

318318
set_property(TARGET bitsandbytes PROPERTY CXX_STANDARD 20)
319319
target_compile_options(bitsandbytes PRIVATE ${SYCL_COMPILE_FLAGS})

csrc/xpu_cutlass.cpp

Lines changed: 252 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,37 +5,263 @@
55

66
#include <sycl/sycl.hpp>
77

8-
#if 0
9-
template <typename T, size_t GROUP_SIZE, size_t NUM_PER_THREAD,
10-
size_t SUBG_SIZE, int BITS>
11-
void kgemv_4bit_inference_cutlass<T, GROUP_SIZE, NUM_PER_THREAD, SUBG_SIZE,
12-
BITS>::operator()(sycl::nd_item<1> item) const {
13-
std::cout<<"this is kgemv_4bit_inference_cutlass ...\n";
14-
#if 0
15-
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
16-
auto [M, N, K, L] = problem_shape_MNKL;
17-
18-
// Complete the stride by combining static layout info (StrideA) with runtime size info (M,K,L)
19-
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
20-
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
21-
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
22-
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
23-
24-
block_A.reset(static_cast<std::size_t>(M) * K * L);
25-
block_B.reset(static_cast<std::size_t>(K) * N * L);
26-
block_C.reset(static_cast<std::size_t>(M) * N * L);
27-
block_D.reset(static_cast<std::size_t>(M) * N * L);
28-
block_ref_D.reset(static_cast<std::size_t>(M) * N * L);
29-
30-
initialize_block(block_A, seed + 2023);
31-
initialize_block(block_B, seed + 2022);
32-
initialize_block(block_C, seed + 2021);
8+
// The code section below describes datatype for input, output matrices and computation between
9+
// elements in input matrices.
10+
using ElementAccumulator = float; // <- data type of accumulator
11+
using ElementComputeEpilogue = float; // <- data type of epilogue operations
12+
using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A
13+
using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B
14+
using ElementOutput = float; // <- data type of elements in output matrix D
15+
16+
using LayoutA = cutlass::layout::RowMajor;
17+
using LayoutB = cutlass::layout::RowMajor;
18+
using LayoutC = cutlass::layout::RowMajor;
19+
using LayoutD = cutlass::layout::RowMajor;
20+
21+
// The 2D block copy operations used for the A and B matrices
22+
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
23+
using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;
24+
25+
// Workgroup-level tile
26+
using TileShape = Shape<_256, _256, _32>;
27+
28+
29+
// A TiledMMA struct defines a tiling of an MMA atom over M, N and K, combining both additional
30+
// hardware (sub-groups for Intel PVC) and iterations by each sub-group.
31+
//
32+
// The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom
33+
// (XE_8x16x16_F32BF16BF16F32_TT), TileShape (<256, 256, 32>) and sub-group layout (8x4x1). The
34+
// TiledMMA constructed using TiledMMAHelper has the property that each sub-group operates on a
35+
// single contiguous chunk of the work-group TileShape. For this configuration, this implies that
36+
// each sub-group operates on a contiguous 32x64x32 chunk (4x4x2 iterations). See
37+
// 0t_mma_atom.md#TiledMMAs for more info. Sub-groups are arranged row-major (stride 4,1,0) for
38+
// performance reasons.
39+
using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32
40+
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
41+
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
42+
43+
// For Intel PVC, PipelineStages defines how many k-blocks ahead to prefetch from A and B.
44+
constexpr int PipelineStages = 2;
45+
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVC<PipelineStages>;
46+
using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue;
47+
48+
// This is the 'default' epilogue operation (Linear Combination) which performs everything in:
49+
// (D = alpha * (A*B) + beta * C)
50+
// aside from the (A*B), which is handled by the GEMM. See 05_pvc_gemm_with_epilogues for more
51+
// complex epilogue examples.
52+
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementComputeEpilogue,
53+
ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
54+
55+
// FusionCallbacks ties the EpilogueOp to an implementation (based on the dispatch
56+
// policy/architecture) and defines the epilogue arguments.
57+
using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
58+
decltype(tile_shape(TiledMma()))>;
59+
// GEMM Epilogue - loads & stores C/D matrices, performs epilogue operations & load/stores any
60+
// auxiliary data required
61+
using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
62+
EpilogueDispatchPolicy,
63+
TileShape,
64+
ElementAccumulator,
65+
cutlass::gemm::TagToStrideC_t<LayoutC>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
66+
ElementOutput,
67+
cutlass::gemm::TagToStrideC_t<LayoutD>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
68+
FusionCallBacks,
69+
XE_2D_U32x8x16_LD_N, // The copy atom used to load matrix C
70+
void, void,
71+
XE_2D_U32x8x16_ST_N, // The copy atom used to store matrix D
72+
void, void>;
73+
74+
// GEMM Mainloop - iteration over blocks in K dimension
75+
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
76+
GEMMDispatchPolicy,
77+
TileShape,
78+
ElementInputA,
79+
cutlass::gemm::TagToStrideA_t<LayoutA>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
80+
ElementInputB,
81+
cutlass::gemm::TagToStrideB_t<LayoutB>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
82+
TiledMma,
83+
GmemTiledCopyA, void, void, cute::identity, // A
84+
GmemTiledCopyB, void, void, cute::identity // B
85+
>;
86+
87+
// Define the whole kernel (mainloop and epilogue)
88+
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
89+
Shape<int, int, int, int>, // Defer global problem shape definition to runtime
90+
CollectiveMainloop,
91+
CollectiveEpilogue
92+
>;
93+
94+
// The GemmUniversalAdapter wraps the defined GEMM kernel and handles the launch, and e.g.
95+
// persistent scratch memory if required.
96+
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
97+
98+
using StrideA = typename Gemm::GemmKernel::StrideA;
99+
using StrideB = typename Gemm::GemmKernel::StrideB;
100+
using StrideC = typename Gemm::GemmKernel::StrideC;
101+
using StrideD = typename Gemm::GemmKernel::StrideD;
102+
103+
using LayoutA = typename Gemm::LayoutA;
104+
using LayoutB = typename Gemm::LayoutB;
105+
using LayoutC = typename Gemm::LayoutC;
106+
using LayoutD = typename Gemm::LayoutD;
107+
108+
using ElementA = typename Gemm::ElementA;
109+
using ElementB = typename Gemm::ElementB;
110+
using ElementAcc = typename Gemm::ElementAccumulator;
111+
112+
using CollectiveEpilogue = typename Gemm::CollectiveEpilogue;
113+
using ElementC = typename Gemm::ElementC;
114+
using ElementOutput = typename CollectiveEpilogue::ElementOutput;
115+
using ElementCompute = typename CollectiveEpilogue::ElementCompute;
116+
using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator;
117+
118+
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
119+
120+
//
121+
// Data members
122+
//
123+
124+
/// Initialization
125+
StrideA stride_A;
126+
StrideB stride_B;
127+
StrideC stride_C;
128+
StrideD stride_D;
129+
uint64_t seed = 0;
130+
131+
cutlass::DeviceAllocation<ElementA> block_A;
132+
cutlass::DeviceAllocation<ElementB> block_B;
133+
cutlass::DeviceAllocation<ElementC> block_C;
134+
cutlass::DeviceAllocation<ElementOutput> block_D;
135+
cutlass::DeviceAllocation<ElementOutput> block_ref_D; // Reference GEMM result for verification
136+
137+
void initialize(const ProblemShapeType& problem_size) {
138+
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
139+
auto [M, N, K, L] = problem_shape_MNKL;
140+
141+
// Complete the stride by combining static layout info (StrideA) with runtime size info (M,K,L)
142+
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
143+
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
144+
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
145+
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
146+
147+
block_A.reset(M * K * L);
148+
block_B.reset(K * N * L);
149+
block_C.reset(M * N * L);
150+
block_D.reset(M * N * L);
151+
block_ref_D.reset(M * N * L);
152+
153+
initialize_block(block_A, seed + 2023);
154+
initialize_block(block_B, seed + 2022);
155+
initialize_block(block_C, seed + 2021);
156+
}
157+
158+
template <typename T, int BITS>
159+
void gemv_4bit_inference(int m, int n, int k, T *A, unsigned char *B,
160+
float *absmax, float *datatype, T *out, int lda,
161+
int ldb, int ldc, int blocksize, sycl::queue *stream) {
162+
//std::cout<<"this is gemv_4bit_inference cutlass...\n";
163+
cutlass::KernelHardwareInfo hw_info;
164+
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
165+
#if 0
166+
// The code section below describes datatype for input, output matrices and computation between
167+
// elements in input matrices.
168+
using ElementAccumulator = float; // <- data type of accumulator
169+
using ElementComputeEpilogue = float; // <- data type of epilogue operations
170+
using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A
171+
using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B
172+
using ElementOutput = float; // <- data type of elements in output matrix D
173+
174+
using LayoutA = cutlass::layout::RowMajor;
175+
using LayoutB = cutlass::layout::RowMajor;
176+
using LayoutC = cutlass::layout::RowMajor;
177+
using LayoutD = cutlass::layout::RowMajor;
178+
179+
// The 2D block copy operations used for the A and B matrices
180+
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
181+
using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;
182+
183+
// Workgroup-level tile
184+
using TileShape = Shape<_256, _256, _32>;
185+
186+
187+
// A TiledMMA struct defines a tiling of an MMA atom over M, N and K, combining both additional
188+
// hardware (sub-groups for Intel PVC) and iterations by each sub-group.
189+
//
190+
// The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom
191+
// (XE_8x16x16_F32BF16BF16F32_TT), TileShape (<256, 256, 32>) and sub-group layout (8x4x1). The
192+
// TiledMMA constructed using TiledMMAHelper has the property that each sub-group operates on a
193+
// single contiguous chunk of the work-group TileShape. For this configuration, this implies that
194+
// each sub-group operates on a contiguous 32x64x32 chunk (4x4x2 iterations). See
195+
// 0t_mma_atom.md#TiledMMAs for more info. Sub-groups are arranged row-major (stride 4,1,0) for
196+
// performance reasons.
197+
using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32
198+
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
199+
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
200+
201+
// For Intel PVC, PipelineStages defines how many k-blocks ahead to prefetch from A and B.
202+
constexpr int PipelineStages = 2;
203+
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVC<PipelineStages>;
204+
using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue;
205+
206+
// This is the 'default' epilogue operation (Linear Combination) which performs everything in:
207+
// (D = alpha * (A*B) + beta * C)
208+
// aside from the (A*B), which is handled by the GEMM. See 05_pvc_gemm_with_epilogues for more
209+
// complex epilogue examples.
210+
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementComputeEpilogue,
211+
ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
212+
213+
// FusionCallbacks ties the EpilogueOp to an implementation (based on the dispatch
214+
// policy/architecture) and defines the epilogue arguments.
215+
using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
216+
decltype(tile_shape(TiledMma()))>;
217+
// GEMM Epilogue - loads & stores C/D matrices, performs epilogue operations & load/stores any
218+
// auxiliary data required
219+
using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
220+
EpilogueDispatchPolicy,
221+
TileShape,
222+
ElementAccumulator,
223+
cutlass::gemm::TagToStrideC_t<LayoutC>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
224+
ElementOutput,
225+
cutlass::gemm::TagToStrideC_t<LayoutD>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
226+
FusionCallBacks,
227+
XE_2D_U32x8x16_LD_N, // The copy atom used to load matrix C
228+
void, void,
229+
XE_2D_U32x8x16_ST_N, // The copy atom used to store matrix D
230+
void, void>;
231+
232+
// GEMM Mainloop - iteration over blocks in K dimension
233+
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
234+
GEMMDispatchPolicy,
235+
TileShape,
236+
ElementInputA,
237+
cutlass::gemm::TagToStrideA_t<LayoutA>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
238+
ElementInputB,
239+
cutlass::gemm::TagToStrideB_t<LayoutB>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
240+
TiledMma,
241+
GmemTiledCopyA, void, void, cute::identity, // A
242+
GmemTiledCopyB, void, void, cute::identity // B
243+
>;
244+
245+
// Define the whole kernel (mainloop and epilogue)
246+
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
247+
Shape<int, int, int, int>, // Defer global problem shape definition to runtime
248+
CollectiveMainloop,
249+
CollectiveEpilogue
250+
>;
251+
252+
// The GemmUniversalAdapter wraps the defined GEMM kernel and handles the launch, and e.g.
253+
// persistent scratch memory if required.
254+
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
255+
#endif
256+
ProblemShapeType problem_size = ProblemShapeType{m, n, k, ldb};
257+
258+
initialize(problem_size);
33259

34260
typename Gemm::GemmKernel::Arguments arguments{
35261
cutlass::gemm::GemmUniversalMode::kGemm,
36262
problem_size,
37263
{block_A.get(), stride_A, block_B.get(), stride_B},
38-
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D},
264+
{{1.f, 0.f}, block_C.get(), stride_C, block_D.get(), stride_D},
39265
hw_info
40266
};
41267

@@ -44,30 +270,12 @@ void kgemv_4bit_inference_cutlass<T, GROUP_SIZE, NUM_PER_THREAD, SUBG_SIZE,
44270
size_t workspace_size = Gemm::get_workspace_size(arguments);
45271
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
46272

47-
if (gemm_op.can_implement(arguments) != cutlass::Status::kSuccess){
48-
std::cout << "Invalid Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
49-
std::exit(1);
50-
}
51-
52273
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get()));
53274

54275
// Run the GEMM
55276
CUTLASS_CHECK(gemm_op.run());
56277

57278
syclcompat::wait();
58-
#endif
59-
}
60-
61-
template class kgemv_4bit_inference_cutlass<sycl::half, 128, 4, 32, 16>;
62-
template class kgemv_4bit_inference_cutlass<sycl::ext::oneapi::bfloat16, 128, 4, 32, 16>;
63-
template class kgemv_4bit_inference_cutlass<float, 128, 4, 32, 32>;
64-
65-
#endif
66-
template <typename T, int BITS>
67-
void gemv_4bit_inference(int m, int n, int k, T *A, unsigned char *B,
68-
float *absmax, float *datatype, T *out, int lda,
69-
int ldb, int ldc, int blocksize, sycl::queue *stream) {
70-
std::cout<<"this is gemv_4bit_inference cutlass...\n";
71279
}
72280

73281
template void gemv_4bit_inference<sycl::half, 16>(

csrc/xpu_cutlass.h

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,36 +2,36 @@
22

33
#include <float.h>
44

5-
//#include "cutlass/epilogue/collective/default_epilogue.hpp"
6-
//#include "cutlass/epilogue/collective/xe_epilogue.hpp"
7-
//#include "cutlass/epilogue/fusion/xe_callbacks.hpp"
8-
//#include "cutlass/gemm/device/gemm_universal.h"
9-
//#include "cutlass/gemm/device/gemm_universal_adapter.h"
10-
//#include "cutlass/gemm/collective/collective_mma.hpp"
11-
//#include "cutlass/util/GPU_Clock.hpp"
12-
//#include "cutlass/epilogue/dispatch_policy.hpp"
13-
//#include <cute/atom/copy_traits_xe.hpp>
14-
//
15-
//#include <cute/tensor.hpp>
16-
//#include <random>
17-
//
18-
//#include "cutlass/util/command_line.h"
19-
//#include "cutlass/util/device_memory.h"
20-
//#include "cutlass/util/packed_stride.hpp"
21-
//#include "cutlass/util/reference/device/gemm_complex.h"
22-
//#include "cutlass/util/reference/device/tensor_compare.h"
23-
//#include "sycl_common.hpp"
24-
//#include "helper.h"
5+
#include "cutlass/epilogue/collective/default_epilogue.hpp"
6+
#include "cutlass/epilogue/collective/xe_epilogue.hpp"
7+
#include "cutlass/epilogue/fusion/xe_callbacks.hpp"
8+
#include "cutlass/gemm/device/gemm_universal.h"
9+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
10+
#include "cutlass/gemm/collective/collective_mma.hpp"
11+
#include "cutlass/util/GPU_Clock.hpp"
12+
#include "cutlass/epilogue/dispatch_policy.hpp"
13+
#include <cute/atom/copy_traits_xe.hpp>
2514

26-
#include "cutlass/cutlass.h"
27-
#include "cutlass/gemm/dispatch_policy.hpp"
28-
#include "cutlass/gemm/gemm.h"
29-
#include "cutlass/kernel_hardware_info.hpp"
15+
#include <cute/tensor.hpp>
16+
#include <random>
3017

31-
#include "cute/algorithm/functional.hpp"
32-
#include "cute/atom/mma_atom.hpp"
33-
#include "cute/algorithm/gemm.hpp"
34-
#include "cute/tensor_predicate.hpp"
18+
#include "cutlass/util/command_line.h"
19+
#include "cutlass/util/device_memory.h"
20+
#include "cutlass/util/packed_stride.hpp"
21+
#include "cutlass/util/reference/device/gemm_complex.h"
22+
#include "cutlass/util/reference/device/tensor_compare.h"
23+
#include "sycl_common.hpp"
24+
#include "helper.h"
25+
26+
//#include "cutlass/cutlass.h"
27+
//#include "cutlass/gemm/dispatch_policy.hpp"
28+
//#include "cutlass/gemm/gemm.h"
29+
//#include "cutlass/kernel_hardware_info.hpp"
30+
//
31+
//#include "cute/algorithm/functional.hpp"
32+
//#include "cute/atom/mma_atom.hpp"
33+
//#include "cute/algorithm/gemm.hpp"
34+
//#include "cute/tensor_predicate.hpp"
3535

3636
using namespace cute;
3737

0 commit comments

Comments
 (0)