Skip to content

Commit 4df6967

Browse files
committed
fix scale issue
1 parent 3706871 commit 4df6967

9 files changed

Lines changed: 146 additions & 82 deletions

File tree

bitsandbytes/backends/xpu/ops.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -74,20 +74,16 @@ def _gemv_4bit_impl(
7474
blocksize: int,
7575
out: torch.Tensor,
7676
) -> None:
77-
import pdb
78-
#pdb.set_trace()
79-
m = ct.c_int32(*A.shape[:-1]) #A.shape[1])
77+
m = ct.c_int32(1)
8078
n = ct.c_int32(shapeB[0])
8179
k = ct.c_int32(shapeB[1])
82-
80+
import pdb
8381
lda = m
8482
ldb = ct.c_int32((A.shape[-1] + 1) // 2)
8583
ldc = m
86-
87-
#absmax = absmax * 10
8884
#pdb.set_trace()
89-
#print("A before kernel: ", A)
90-
#print("B before kernel: ", B)
85+
#absmax = absmax.view(shapeB[0],int(shapeB[1]/blocksize)).transpose(0,1).contiguous()
86+
#pdb.set_trace()
9187
stream = _get_tensor_stream(A)
9288
if A.dtype == torch.float16:
9389
lib.cgemv_4bit_inference_fp16(
@@ -112,7 +108,7 @@ def _gemv_4bit_impl(
112108
k,
113109
get_ptr(A),
114110
get_ptr(B),
115-
get_ptr(absmax.bfloat16()),
111+
get_ptr(absmax),
116112
get_ptr(code),
117113
get_ptr(out),
118114
lda,
@@ -186,11 +182,9 @@ def _(
186182
blocksize: int,
187183
) -> torch.Tensor:
188184
shape = (*A.shape[:-1], shapeB[0])
189-
#import pdb
190-
#pdb.set_trace()
191-
out = torch.zeros(shape, device=A.device, dtype=torch.float32)
192-
_gemv_4bit_impl(A, B, shapeB, absmax.bfloat16(), code, blocksize, out=out)
193-
return out
185+
out = torch.empty(shape, device=A.device, dtype=A.dtype).float()
186+
_gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out)
187+
return out.bfloat16()
194188

195189
@register_kernel("bitsandbytes::gemv_4bit.out", "xpu")
196190
def _(

bitsandbytes/functional.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,12 @@ def quantize_4bit(
937937
quant_storage,
938938
)
939939

940+
#import pdb
941+
#pdb.set_trace()
942+
#print("_absmax = ", _absmax)
943+
_absmax = _absmax.view(input_shape[0],int(input_shape[1]/blocksize)).transpose(0,1).contiguous()
944+
#pdb.set_trace()
945+
940946
code = get_4bit_type(quant_type, device=A.device)
941947

942948
if compress_statistics:
@@ -969,7 +975,6 @@ def quantize_4bit(
969975
# TODO(matthewdouglas): Deprecate absmax kwarg
970976
if absmax is not None:
971977
state.absmax = absmax.copy_(state.absmax)
972-
973978
return out, state
974979

975980

csrc/pythonInterface.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ void gemv_4bit_inference_fp16(
381381

382382
#if 1
383383
void gemm_4bit_inference_bf16(
384-
int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, unsigned char* B, sycl::ext::oneapi::bfloat16 *absmax, float *datatype, float * out,
384+
int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, unsigned char* B, float *absmax, float *datatype, float * out,
385385
int lda, int ldb, int ldc, int blocksize, sycl::queue* stream
386386
) {
387387
gemm_4bit_inference_cutlass_dequant<sycl::ext::oneapi::bfloat16, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
@@ -826,7 +826,7 @@ void cgemv_4bit_inference_fp16(
826826

827827
#if 1
828828
void cgemv_4bit_inference_bf16(
829-
int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, unsigned char* B, sycl::ext::oneapi::bfloat16 *absmax, float *datatype,
829+
int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, unsigned char* B, float *absmax, float *datatype,
830830
float * out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream
831831
) {
832832
gemm_4bit_inference_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);

csrc/xpu_cutlass.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ void gemv_4bit_inference_cutlass_cute(int m, int n, int k, T *A, T *B,
109109

110110
template <typename T, int BITS>
111111
void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned char *B,
112-
T *absmax, float *datatype, float *out, int lda,
112+
float *absmax, float *datatype, float *out, int lda,
113113
int ldb, int ldc, int blocksize, sycl::queue *stream);
114114

115115
template <typename T, int BITS>

csrc/xpu_cutlass_fusion.cpp

Lines changed: 109 additions & 56 deletions
Large diffs are not rendered by default.

include/cute/atom/copy_traits_xe.hpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,9 @@ struct XE_2D_LD_Unpack {
289289

290290
constexpr auto inst_size_bits = detail::size_of_inst_bits<CopyOp, dtype>;
291291

292+
//if(cute::thread0()){
293+
// print("copy base_addr: "); print(base_addr); print("\n");
294+
//}
292295
CopyOp::copy(base_addr + l * traits.stride_l,
293296
(traits.width * sizeof_bits_v<dtype>) / sizeof_bits_v<int8_t>, traits.height,
294297
(traits.pitch * sizeof_bits_v<dtype>) / sizeof_bits_v<int8_t>,
@@ -314,7 +317,9 @@ struct XE_2D_LD_Unpack {
314317
int y = is_need_reversed ? n : m;
315318

316319
constexpr auto inst_size_bits = detail::size_of_inst_bits<CopyOp, dtype>;
317-
320+
//if(cute::thread0()){
321+
// print("prefetch base_addr: "); print(base_addr); print("\n");
322+
//}
318323
CopyOp::PREFETCH::copy(base_addr + l * atom.stride_l,
319324
(atom.width * sizeof_bits_v<dtype>) / sizeof_bits_v<int8_t>, atom.height,
320325
(atom.pitch * sizeof_bits_v<dtype>) / sizeof_bits_v<int8_t>,

include/cute/atom/mma_atom.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,13 @@ struct ThrMMA : TiledMMA
583583
auto thr_tensor = make_tensor(static_cast<BTensor&&>(btensor).data(), this->thrfrg_B(btensor.layout()));
584584

585585
auto thr_vnk = make_coord(get<0>(thr_vmnk_), make_coord(get<2>(thr_vmnk_), get<3>(thr_vmnk_)));
586-
//if(cute::thread0()) printf("partition_B: get<0>(thr_vmnk_) = %d, get<2>(thr_vmnk_) = %d, get<3>(thr_vmnk_) = %d\n", static_cast<int>(get<0>(thr_vmnk_)),static_cast<int>(get<2>(thr_vmnk_)),static_cast<int>(get<3>(thr_vmnk_)));
586+
#if 0
587+
if(int(ThreadIdxX()) == 16 && BlockIdxY()==0){
588+
printf("partition_B: get<0>(thr_vmnk_) = %d, get<2>(thr_vmnk_) = %d, get<3>(thr_vmnk_) = %d\n", static_cast<int>(get<0>(thr_vmnk_)),static_cast<int>(get<2>(thr_vmnk_)),static_cast<int>(get<3>(thr_vmnk_)));
589+
print(" thr_tensor : "); print(thr_tensor); print("\n");
590+
print(" thr_tensor_return : "); print(thr_tensor(thr_vnk, make_coord(_, repeat<rank<1,1>(thr_tensor)>(_)))); print("\n");
591+
}
592+
#endif
587593
return thr_tensor(thr_vnk, make_coord(_, repeat<rank<1,1>(thr_tensor)>(_)));
588594
}
589595

run_case.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,6 @@
3030
#gdb -args python -m pytest -vs tests/test_xpu.py::TestXPU::test_gemm_4bit
3131
#pytest -vs tests/test_xpu.py::TestXPU::test_gemm_4bit
3232
pytest -vs tests/test_xpu.py::TestXPU::test_gemv_4bit
33-
#python tests/test_xpu_db.py
33+
##python tests/test_xpu_db.py
3434
#gdb -args python tests/test_xpu_db.py
3535
#pytest tests/test_functional.py::TestQuantize4BitFunctional::test_gemv_4bit[dim=256-uint8-bf16-fc1-nf4-DQ_True-xpu]

tests/test_xpu.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,11 @@ def test_gemm_4bit(self, device, dim, dtype, storage_type, quant_storage, double
118118
print("qB.t() = ",qB.t())
119119
C3 = torch.matmul(A, B.t())
120120
#pdb.set_trace()
121-
C2 = F.gemv_4bit(A, qB.t(), state=state)
121+
C2 = F.gemv_4bit(A, qB.t(), state=state).bfloat16()
122122
#pdb.set_trace()
123123
print("C3.sum() = ", C3.sum())
124124
print("C2.sum() = ", C2.sum())
125-
diff = C2.bfloat16()-C3
125+
diff = C2-C3
126126
print("diff/C2 = ", diff.sum()/C3.sum())
127127
print(C3)
128128
print(C2)
@@ -139,7 +139,7 @@ def test_gemm_4bit(self, device, dim, dtype, storage_type, quant_storage, double
139139
#print("B[0] = ",B[0])
140140
C3 = torch.matmul(A, B.t())
141141
#pdb.set_trace()
142-
C2 = F.gemv_4bit(A, qB.t(), state=state)
142+
C2 = F.gemv_4bit(A, qB.t(), state=state).bfloat16()
143143
pdb.set_trace()
144144
#print("C3.sum() = ", C3.sum())
145145
#print("C2.sum() = ", C2.sum())
@@ -294,6 +294,7 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
294294
A = torch.randn(1, dim, dtype=dtype, device=device)
295295
B = torch.randn(dim * 3, dim, dtype=dtype, device=device) / math.sqrt(dim)
296296

297+
#pdb.set_trace()
297298
qB, state = F.quantize_4bit(
298299
B,
299300
quant_type=storage_type,
@@ -303,10 +304,10 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
303304
#pdb.set_trace()
304305
C3 = torch.matmul(A, B.t())
305306
#pdb.set_trace()
306-
C2 = F.gemv_4bit(A, qB.t(), state=state).bfloat16()
307+
C2 = F.gemv_4bit(A, qB.t(), state=state)
307308
#print("C2[0] = ", C2[0])
308309
A.requires_grad = True
309-
C1 = bnb.matmul_4bit(A, qB.t(), state)#.bfloat16()
310+
C1 = F.gemv_4bit(A, qB.t(), state=state) #bnb.matmul_4bit(A, qB.t(), state)
310311
#pdb.set_trace()
311312

312313
err1 = (C1 - C2).abs().float()

0 commit comments

Comments
 (0)