Skip to content

Commit 9d4978c

Browse files
committed
enable multi-batch
1 parent bdfe1ec commit 9d4978c

4 files changed

Lines changed: 17 additions & 11 deletions

File tree

bitsandbytes/backends/xpu/ops.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,14 @@ def _gemv_4bit_impl(
7474
blocksize: int,
7575
out: torch.Tensor,
7676
) -> None:
77+
#import pdb
7778
m = ct.c_int32(A.shape[-2])#ct.c_int32(1)
7879
n = ct.c_int32(shapeB[0])
7980
k = ct.c_int32(shapeB[1])
80-
#import pdb
81+
l = 1
82+
#pdb.set_trace()
83+
if A.dim() == 3:
84+
l = A.shape[0]
8185
lda = m
8286
ldb = ct.c_int32((A.shape[-1] + 1) // 2)
8387
ldc = m
@@ -106,6 +110,7 @@ def _gemv_4bit_impl(
106110
m,
107111
n,
108112
k,
113+
l,
109114
get_ptr(A),
110115
get_ptr(B),
111116
get_ptr(absmax),

csrc/pythonInterface.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -381,10 +381,10 @@ void gemv_4bit_inference_fp16(
381381

382382
#if 1
383383
void gemm_4bit_inference_bf16(
384-
int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, unsigned char* B, float *absmax, float *datatype, float * out,
384+
int m, int n, int k, int l, sycl::ext::oneapi::bfloat16 * A, unsigned char* B, float *absmax, float *datatype, float * out,
385385
int lda, int ldb, int ldc, int blocksize, sycl::queue* stream
386386
) {
387-
gemm_4bit_inference_cutlass_dequant<sycl::ext::oneapi::bfloat16, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
387+
gemm_4bit_inference_cutlass_dequant<sycl::ext::oneapi::bfloat16, 16>(m, n, k, l, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
388388
}
389389
#endif
390390

@@ -826,10 +826,10 @@ void cgemv_4bit_inference_fp16(
826826

827827
#if 1
828828
void cgemv_4bit_inference_bf16(
829-
int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, unsigned char* B, float *absmax, float *datatype,
829+
int m, int n, int k, int l, sycl::ext::oneapi::bfloat16 * A, unsigned char* B, float *absmax, float *datatype,
830830
float * out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream
831831
) {
832-
gemm_4bit_inference_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
832+
gemm_4bit_inference_bf16(m, n, k, l, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
833833
}
834834
#else
835835
void cgemv_4bit_inference_bf16(

csrc/xpu_cutlass.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ void gemv_4bit_inference_cutlass_cute(int m, int n, int k, T *A, T *B,
108108
int ldb, int ldc, int blocksize, sycl::queue *stream);
109109

110110
template <typename T, int BITS>
111-
void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned char *B,
111+
void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, int l, T *A, unsigned char *B,
112112
float *absmax, float *datatype, float *out, int lda,
113113
int ldb, int ldc, int blocksize, sycl::queue *stream);
114114

csrc/xpu_cutlass_fusion.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ class kgemm_4bit_inference_cutlass_dequant {
163163
};
164164

165165
struct Params {
166-
int m, n, k;
166+
int m, n, k, l;
167167
T* A;
168168
uint8_t* B;
169169
float* out;
@@ -278,7 +278,7 @@ class kgemm_4bit_inference_cutlass_dequant {
278278
int M = params.m;
279279
int N = params.n;
280280
int K = params.k;
281-
int L = 1;
281+
int L = params.l;
282282

283283
//Total Threads number
284284
static constexpr auto Num_SGs = ATOM_N * ATOM_M * ATOM_K; //32 //2
@@ -578,7 +578,7 @@ printf("\n");
578578
};
579579

580580
template <typename T, int BITS>
581-
void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned char *B,
581+
void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, int l, T *A, unsigned char *B,
582582
float *absmax_, float *datatype, float *out, int lda,
583583
int ldb, int ldc, int blocksize, sycl::queue *stream) {
584584
////std::cout<<"this is gemm_4bit_inference_cutlass_dequant ......................!!!!!!\n";
@@ -599,7 +599,7 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
599599

600600
//static constexpr int smem_size= 512; // (16 * 32) for quant_map
601601
static constexpr int smem_size= 256; // (16 * 16) for quant_map
602-
int l = 1;
602+
//int l = 1;
603603

604604
auto problem_size = ProblemShape{m, n, k, l};
605605

@@ -610,6 +610,7 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
610610
params.m = m;
611611
params.n = n;
612612
params.k = k;
613+
params.l = l;
613614
params.A = A;
614615
params.B = B;
615616
params.out = out;
@@ -701,7 +702,7 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
701702
}
702703

703704
template void gemm_4bit_inference_cutlass_dequant<sycl::ext::oneapi::bfloat16, 16>(
704-
int m, int n, int k, sycl::ext::oneapi::bfloat16 *A, unsigned char *B,
705+
int m, int n, int k, int l, sycl::ext::oneapi::bfloat16 *A, unsigned char *B,
705706
float *absmax, float *datatype, float *out, int lda,
706707
int ldb, int ldc, int blocksize, sycl::queue *stream);
707708

0 commit comments

Comments
 (0)