Skip to content

Commit bdfe1ec

Browse files
committed
fix scale issue
1 parent 3706871 commit bdfe1ec

9 files changed

Lines changed: 153 additions & 87 deletions

File tree

bitsandbytes/autograd/_functions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,8 @@ def matmul_4bit(
436436
bias: Optional[torch.Tensor] = None,
437437
):
438438
assert quant_state is not None
439-
439+
#import pdb
440+
#pdb.set_trace()
440441
if A.device.type == "cpu" and A.requires_grad == False:
441442
if getattr(quant_state, "ipex", False):
442443
# IPEX CPU will change weight to 4D so don't need transpose
@@ -447,7 +448,8 @@ def matmul_4bit(
447448
return out
448449
else:
449450
return MatMul4Bit.apply(A, B, out, bias, quant_state)
450-
if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu":
451+
#if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu":
452+
if A.requires_grad == False and A.device.type != "hpu":
451453
if A.shape[-1] % quant_state.blocksize != 0:
452454
warn(
453455
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",

bitsandbytes/backends/xpu/ops.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -74,20 +74,16 @@ def _gemv_4bit_impl(
7474
blocksize: int,
7575
out: torch.Tensor,
7676
) -> None:
77-
import pdb
78-
#pdb.set_trace()
79-
m = ct.c_int32(*A.shape[:-1]) #A.shape[1])
77+
m = ct.c_int32(A.shape[-2])#ct.c_int32(1)
8078
n = ct.c_int32(shapeB[0])
8179
k = ct.c_int32(shapeB[1])
82-
80+
#import pdb
8381
lda = m
8482
ldb = ct.c_int32((A.shape[-1] + 1) // 2)
8583
ldc = m
86-
87-
#absmax = absmax * 10
8884
#pdb.set_trace()
89-
#print("A before kernel: ", A)
90-
#print("B before kernel: ", B)
85+
absmax = absmax.view(shapeB[0],int(shapeB[1]/blocksize)).transpose(0,1).contiguous()
86+
#pdb.set_trace()
9187
stream = _get_tensor_stream(A)
9288
if A.dtype == torch.float16:
9389
lib.cgemv_4bit_inference_fp16(
@@ -112,7 +108,7 @@ def _gemv_4bit_impl(
112108
k,
113109
get_ptr(A),
114110
get_ptr(B),
115-
get_ptr(absmax.bfloat16()),
111+
get_ptr(absmax),
116112
get_ptr(code),
117113
get_ptr(out),
118114
lda,
@@ -186,11 +182,9 @@ def _(
186182
blocksize: int,
187183
) -> torch.Tensor:
188184
shape = (*A.shape[:-1], shapeB[0])
189-
#import pdb
190-
#pdb.set_trace()
191-
out = torch.zeros(shape, device=A.device, dtype=torch.float32)
192-
_gemv_4bit_impl(A, B, shapeB, absmax.bfloat16(), code, blocksize, out=out)
193-
return out
185+
out = torch.empty(shape, device=A.device, dtype=A.dtype).float()
186+
_gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out)
187+
return out.bfloat16()
194188

195189
@register_kernel("bitsandbytes::gemv_4bit.out", "xpu")
196190
def _(

bitsandbytes/functional.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,12 @@ def quantize_4bit(
937937
quant_storage,
938938
)
939939

940+
#import pdb
941+
#pdb.set_trace()
942+
#print("_absmax = ", _absmax)
943+
#_absmax = _absmax.view(input_shape[0],int(input_shape[1]/blocksize)).transpose(0,1).contiguous()
944+
#pdb.set_trace()
945+
940946
code = get_4bit_type(quant_type, device=A.device)
941947

942948
if compress_statistics:
@@ -969,7 +975,6 @@ def quantize_4bit(
969975
# TODO(matthewdouglas): Deprecate absmax kwarg
970976
if absmax is not None:
971977
state.absmax = absmax.copy_(state.absmax)
972-
973978
return out, state
974979

975980

csrc/pythonInterface.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ void gemv_4bit_inference_fp16(
381381

382382
#if 1
383383
void gemm_4bit_inference_bf16(
384-
int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, unsigned char* B, sycl::ext::oneapi::bfloat16 *absmax, float *datatype, float * out,
384+
int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, unsigned char* B, float *absmax, float *datatype, float * out,
385385
int lda, int ldb, int ldc, int blocksize, sycl::queue* stream
386386
) {
387387
gemm_4bit_inference_cutlass_dequant<sycl::ext::oneapi::bfloat16, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
@@ -826,7 +826,7 @@ void cgemv_4bit_inference_fp16(
826826

827827
#if 1
828828
void cgemv_4bit_inference_bf16(
829-
int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, unsigned char* B, sycl::ext::oneapi::bfloat16 *absmax, float *datatype,
829+
int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, unsigned char* B, float *absmax, float *datatype,
830830
float * out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream
831831
) {
832832
gemm_4bit_inference_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);

csrc/xpu_cutlass.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ void gemv_4bit_inference_cutlass_cute(int m, int n, int k, T *A, T *B,
109109

110110
template <typename T, int BITS>
111111
void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned char *B,
112-
T *absmax, float *datatype, float *out, int lda,
112+
float *absmax, float *datatype, float *out, int lda,
113113
int ldb, int ldc, int blocksize, sycl::queue *stream);
114114

115115
template <typename T, int BITS>

0 commit comments

Comments
 (0)