Skip to content

Commit 06d8edd

Browse files
committed
Add dequant gemm fusion kernel with cutlass
1 parent 33445c0 commit 06d8edd

5 files changed

Lines changed: 551 additions & 2 deletions

File tree

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ set(CUDA_FILES csrc/ops.cu csrc/kernels.cu)
2828
set(HIP_FILES csrc/ops.hip csrc/kernels.hip)
2929
set(MPS_FILES csrc/mps_ops.mm)
3030
set(METAL_FILES csrc/mps_kernels.metal)
31-
set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp csrc/xpu_cutlass.cpp csrc/xpu_cutlass-cute.cpp)
31+
set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp csrc/xpu_cutlass.cpp csrc/xpu_cutlass-cute.cpp csrc/xpu_cutlass_fusion.cpp)
3232
# C++ sources are always included
3333
list(APPEND SRC_FILES ${CPP_FILES})
3434

build_xpu.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ cmake -DCMAKE_CXX_STANDARD=17 -DSYCL_INTEL_TARGET=ON -DENABLE_INTEL_XMX=ON -DCUT
33
#cmake -DCOMPUTE_BACKEND=xpu -DCMAKE_BUILD_TYPE=Debug -DCMAKE_CXX_FLAGS_DEBUG="-O0 -g" -S .
44
#cmake --build . --config Release -j
55
make
6-
pip install -e .
6+
#pip install -e .

csrc/pythonInterface.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,13 @@ void gemv_4bit_inference_fp16(
379379
//gemv_4bit_fusion<sycl::half, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
380380
}
381381

382+
void gemm_4bit_inference_bf16(
383+
int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, unsigned char* B, float *absmax, float *datatype, float * out,
384+
int lda, int ldb, int ldc, int blocksize, sycl::queue* stream
385+
) {
386+
gemm_4bit_inference_cutlass_dequant<sycl::ext::oneapi::bfloat16, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
387+
}
388+
382389
void gemv_4bit_inference_bf16(
383390
int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, sycl::ext::oneapi::bfloat16* B, float *absmax, float *datatype,
384391
float * out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream

csrc/xpu_cutlass.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ void gemv_4bit_inference_cutlass_cute(int m, int n, int k, T *A, T *B,
107107
float *absmax, float *datatype, float *out, int lda,
108108
int ldb, int ldc, int blocksize, sycl::queue *stream);
109109

110+
template <typename T, int BITS>
111+
void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned char *B,
112+
float *absmax, float *datatype, float *out, int lda,
113+
int ldb, int ldc, int blocksize, sycl::queue *stream);
114+
110115
template <typename T, int BITS>
111116
void gemv_4bit_inference_cutlass(int m, int n, int k, T *A, T *B,
112117
float *absmax, float *datatype, float *out, int lda,

0 commit comments

Comments
 (0)