Skip to content

Commit 0ce659e

Browse files
committed
fix scale issue
1 parent 3706871 commit 0ce659e

8 files changed

Lines changed: 135 additions & 78 deletions

File tree

bitsandbytes/backends/xpu/ops.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -74,20 +74,15 @@ def _gemv_4bit_impl(
7474
blocksize: int,
7575
out: torch.Tensor,
7676
) -> None:
77-
import pdb
78-
#pdb.set_trace()
79-
m = ct.c_int32(*A.shape[:-1]) #A.shape[1])
77+
m = ct.c_int32(1)
8078
n = ct.c_int32(shapeB[0])
8179
k = ct.c_int32(shapeB[1])
82-
80+
import pdb
8381
lda = m
8482
ldb = ct.c_int32((A.shape[-1] + 1) // 2)
8583
ldc = m
86-
87-
#absmax = absmax * 10
8884
#pdb.set_trace()
89-
#print("A before kernel: ", A)
90-
#print("B before kernel: ", B)
85+
absmax = absmax.view(shapeB[0],int(shapeB[1]/blocksize)).transpose(0,1).contiguous()
9186
stream = _get_tensor_stream(A)
9287
if A.dtype == torch.float16:
9388
lib.cgemv_4bit_inference_fp16(
@@ -112,7 +107,7 @@ def _gemv_4bit_impl(
112107
k,
113108
get_ptr(A),
114109
get_ptr(B),
115-
get_ptr(absmax.bfloat16()),
110+
get_ptr(absmax),
116111
get_ptr(code),
117112
get_ptr(out),
118113
lda,
@@ -186,10 +181,8 @@ def _(
186181
blocksize: int,
187182
) -> torch.Tensor:
188183
shape = (*A.shape[:-1], shapeB[0])
189-
#import pdb
190-
#pdb.set_trace()
191-
out = torch.zeros(shape, device=A.device, dtype=torch.float32)
192-
_gemv_4bit_impl(A, B, shapeB, absmax.bfloat16(), code, blocksize, out=out)
184+
out = torch.empty(shape, device=A.device, dtype=A.dtype).float()
185+
_gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out)
193186
return out
194187

195188
@register_kernel("bitsandbytes::gemv_4bit.out", "xpu")

csrc/pythonInterface.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ void gemv_4bit_inference_fp16(
381381

382382
#if 1
383383
void gemm_4bit_inference_bf16(
384-
int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, unsigned char* B, sycl::ext::oneapi::bfloat16 *absmax, float *datatype, float * out,
384+
int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, unsigned char* B, float *absmax, float *datatype, float * out,
385385
int lda, int ldb, int ldc, int blocksize, sycl::queue* stream
386386
) {
387387
gemm_4bit_inference_cutlass_dequant<sycl::ext::oneapi::bfloat16, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
@@ -826,7 +826,7 @@ void cgemv_4bit_inference_fp16(
826826

827827
#if 1
828828
void cgemv_4bit_inference_bf16(
829-
int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, unsigned char* B, sycl::ext::oneapi::bfloat16 *absmax, float *datatype,
829+
int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, unsigned char* B, float *absmax, float *datatype,
830830
float * out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream
831831
) {
832832
gemm_4bit_inference_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);

csrc/xpu_cutlass.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ void gemv_4bit_inference_cutlass_cute(int m, int n, int k, T *A, T *B,
109109

110110
template <typename T, int BITS>
111111
void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned char *B,
112-
T *absmax, float *datatype, float *out, int lda,
112+
float *absmax, float *datatype, float *out, int lda,
113113
int ldb, int ldc, int blocksize, sycl::queue *stream);
114114

115115
template <typename T, int BITS>

csrc/xpu_cutlass_fusion.cpp

Lines changed: 109 additions & 56 deletions
Large diffs are not rendered by default.

include/cute/atom/copy_traits_xe.hpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,9 @@ struct XE_2D_LD_Unpack {
289289

290290
constexpr auto inst_size_bits = detail::size_of_inst_bits<CopyOp, dtype>;
291291

292+
//if(cute::thread0()){
293+
// print("copy base_addr: "); print(base_addr); print("\n");
294+
//}
292295
CopyOp::copy(base_addr + l * traits.stride_l,
293296
(traits.width * sizeof_bits_v<dtype>) / sizeof_bits_v<int8_t>, traits.height,
294297
(traits.pitch * sizeof_bits_v<dtype>) / sizeof_bits_v<int8_t>,
@@ -314,7 +317,9 @@ struct XE_2D_LD_Unpack {
314317
int y = is_need_reversed ? n : m;
315318

316319
constexpr auto inst_size_bits = detail::size_of_inst_bits<CopyOp, dtype>;
317-
320+
//if(cute::thread0()){
321+
// print("prefetch base_addr: "); print(base_addr); print("\n");
322+
//}
318323
CopyOp::PREFETCH::copy(base_addr + l * atom.stride_l,
319324
(atom.width * sizeof_bits_v<dtype>) / sizeof_bits_v<int8_t>, atom.height,
320325
(atom.pitch * sizeof_bits_v<dtype>) / sizeof_bits_v<int8_t>,

include/cute/atom/mma_atom.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,13 @@ struct ThrMMA : TiledMMA
583583
auto thr_tensor = make_tensor(static_cast<BTensor&&>(btensor).data(), this->thrfrg_B(btensor.layout()));
584584

585585
auto thr_vnk = make_coord(get<0>(thr_vmnk_), make_coord(get<2>(thr_vmnk_), get<3>(thr_vmnk_)));
586-
//if(cute::thread0()) printf("partition_B: get<0>(thr_vmnk_) = %d, get<2>(thr_vmnk_) = %d, get<3>(thr_vmnk_) = %d\n", static_cast<int>(get<0>(thr_vmnk_)),static_cast<int>(get<2>(thr_vmnk_)),static_cast<int>(get<3>(thr_vmnk_)));
586+
#if 0
587+
if(int(ThreadIdxX()) == 16 && BlockIdxY()==0){
588+
printf("partition_B: get<0>(thr_vmnk_) = %d, get<2>(thr_vmnk_) = %d, get<3>(thr_vmnk_) = %d\n", static_cast<int>(get<0>(thr_vmnk_)),static_cast<int>(get<2>(thr_vmnk_)),static_cast<int>(get<3>(thr_vmnk_)));
589+
print(" thr_tensor : "); print(thr_tensor); print("\n");
590+
print(" thr_tensor_return : "); print(thr_tensor(thr_vnk, make_coord(_, repeat<rank<1,1>(thr_tensor)>(_)))); print("\n");
591+
}
592+
#endif
587593
return thr_tensor(thr_vnk, make_coord(_, repeat<rank<1,1>(thr_tensor)>(_)));
588594
}
589595

run_case.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,6 @@
3030
#gdb -args python -m pytest -vs tests/test_xpu.py::TestXPU::test_gemm_4bit
3131
#pytest -vs tests/test_xpu.py::TestXPU::test_gemm_4bit
3232
pytest -vs tests/test_xpu.py::TestXPU::test_gemv_4bit
33-
#python tests/test_xpu_db.py
33+
##python tests/test_xpu_db.py
3434
#gdb -args python tests/test_xpu_db.py
3535
#pytest tests/test_functional.py::TestQuantize4BitFunctional::test_gemv_4bit[dim=256-uint8-bf16-fc1-nf4-DQ_True-xpu]

tests/test_xpu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,11 @@ def test_gemm_4bit(self, device, dim, dtype, storage_type, quant_storage, double
118118
print("qB.t() = ",qB.t())
119119
C3 = torch.matmul(A, B.t())
120120
#pdb.set_trace()
121-
C2 = F.gemv_4bit(A, qB.t(), state=state)
121+
C2 = F.gemv_4bit(A, qB.t(), state=state).bfloat16()
122122
#pdb.set_trace()
123123
print("C3.sum() = ", C3.sum())
124124
print("C2.sum() = ", C2.sum())
125-
diff = C2.bfloat16()-C3
125+
diff = C2-C3
126126
print("diff/C2 = ", diff.sum()/C3.sum())
127127
print(C3)
128128
print(C2)
@@ -139,7 +139,7 @@ def test_gemm_4bit(self, device, dim, dtype, storage_type, quant_storage, double
139139
#print("B[0] = ",B[0])
140140
C3 = torch.matmul(A, B.t())
141141
#pdb.set_trace()
142-
C2 = F.gemv_4bit(A, qB.t(), state=state)
142+
C2 = F.gemv_4bit(A, qB.t(), state=state).bfloat16()
143143
pdb.set_trace()
144144
#print("C3.sum() = ", C3.sum())
145145
#print("C2.sum() = ", C2.sum())

0 commit comments

Comments
 (0)