Skip to content

Commit 12a484a

Browse files
committed
make it work!
1 parent 11c4760 commit 12a484a

3 files changed

Lines changed: 123 additions & 51 deletions

File tree

bitsandbytes/backends/xpu/ops.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ def _gemv_4bit_impl(
7474
blocksize: int,
7575
out: torch.Tensor,
7676
) -> None:
77-
import pdb
78-
pdb.set_trace()
77+
#import pdb
78+
#pdb.set_trace()
7979
m = ct.c_int32(*A.shape[:-1])
8080
n = ct.c_int32(shapeB[0])
8181
k = ct.c_int32(shapeB[1])
@@ -85,8 +85,9 @@ def _gemv_4bit_impl(
8585
ldc = m
8686

8787
#absmax = absmax * 10
88-
pdb.set_trace()
89-
88+
#pdb.set_trace()
89+
print("A before kernel: ", A)
90+
print("B before kernel: ", B)
9091
stream = _get_tensor_stream(A)
9192
if A.dtype == torch.float16:
9293
lib.cgemv_4bit_inference_fp16(
@@ -185,8 +186,8 @@ def _(
185186
blocksize: int,
186187
) -> torch.Tensor:
187188
shape = (*A.shape[:-1], shapeB[0])
188-
import pdb
189-
pdb.set_trace()
189+
#import pdb
190+
#pdb.set_trace()
190191
out = torch.zeros(shape, device=A.device, dtype=torch.float32)
191192
_gemv_4bit_impl(A, B, shapeB, absmax.bfloat16(), code, blocksize, out=out)
192193
return out

csrc/xpu_cutlass_fusion.cpp

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -234,23 +234,29 @@ class kgemm_4bit_inference_cutlass_dequant {
234234
using DstType = typename EngineOut::value_type;
235235
using ScaleType = typename EngineScales::value_type;
236236
#if 0
237-
int numbers = decltype(size(in))::value;
238-
for(int i=0; i<numbers; i++){
239-
//auto in_ptr_8 = (uint8_t*)(raw_pointer_cast(in.data()));
240-
//out[i] = static_cast<DstType>(quant_map[in_ptr_8[i].data()]);
241-
uint8_t value = in[i].get();
242-
out[i] = static_cast<DstType>(quant_map[value]);
243-
int thread_idx = int(ThreadIdxX());
244-
if(cute::thread0()){
245-
//if(syclcompat::global_id::x() == 2 && syclcompat::global_id::y() ==0 && syclcompat::global_id::z() ==0 )
246-
//printf("syclcompat::global_id::x() = %d, syclcompat::global_id::y() = %d, syclcompat::global_id::z() = %d, thread_idx = %d, i = %d, in[i].ptr_ = %x, in[i].idx_=%x, value_bit = %x, value = %d, quant_map[value] = %f, out[i] = %f\n",syclcompat::global_id::x(), syclcompat::global_id::y(), syclcompat::global_id::z(), thread_idx, i, in[i].ptr_, in[i].idx_, value, static_cast<int>(value), quant_map[value], static_cast<float>(out[i]));
247-
}
248-
}
249-
int scale_number = decltype(size(tCrS_input))::value;
250-
for(int i=0; i<scale_number; i++){
237+
static constexpr auto N = decltype(size<1>(in))::value;
238+
static constexpr auto loop_cnt = decltype(size(out))::value / N;
239+
for (int n = 0; n < N; n++) {
251240
auto s_value = tCrS_input(i);
252-
if(cute::thread0()) printf("scale_number = %d, tCrS_input[%d] = %f\n",scale_number, i, static_cast<float>(s_value));
253-
}
241+
for (int l = 0; s < loop_cnt; l++) {
242+
243+
// int numbers = decltype(size(in))::value;
244+
// for(int i=0; i<numbers / N; i++){
245+
// //auto in_ptr_8 = (uint8_t*)(raw_pointer_cast(in.data()));
246+
// //out[i] = static_cast<DstType>(quant_map[in_ptr_8[i].data()]);
247+
// uint8_t value = in[i].get();
248+
// out[i] = static_cast<DstType>(quant_map[value]);
249+
// int thread_idx = int(ThreadIdxX());
250+
// if(cute::thread0()){
251+
// //if(syclcompat::global_id::x() == 2 && syclcompat::global_id::y() ==0 && syclcompat::global_id::z() ==0 )
252+
// //printf("syclcompat::global_id::x() = %d, syclcompat::global_id::y() = %d, syclcompat::global_id::z() = %d, thread_idx = %d, i = %d, in[i].ptr_ = %x, in[i].idx_=%x, value_bit = %x, value = %d, quant_map[value] = %f, out[i] = %f\n",syclcompat::global_id::x(), syclcompat::global_id::y(), syclcompat::global_id::z(), thread_idx, i, in[i].ptr_, in[i].idx_, value, static_cast<int>(value), quant_map[value], static_cast<float>(out[i]));
253+
// }
254+
// }
255+
// int scale_number = decltype(size(tCrS_input))::value;
256+
// for(int i=0; i<scale_number; i++){
257+
// auto s_value = tCrS_input(i);
258+
// if(cute::thread0()) printf("scale_number = %d, tCrS_input[%d] = %f\n",scale_number, i, static_cast<float>(s_value));
259+
// }
254260
#else
255261
static constexpr auto N = decltype(size<1>(in))::value;
256262

@@ -269,7 +275,11 @@ class kgemm_4bit_inference_cutlass_dequant {
269275
auto s_tensor = make_tensor((format_type*)(raw_pointer_cast(in.data())), Shape<Int<loop_cnt / scalar>, Int<N>>{});
270276
auto d_tensor = make_tensor(out.data(), Shape<Int<vec_size>, Int<splits>, Int<N>>{});
271277

272-
//if(cute::thread0())
278+
int scale_number = decltype(size(tCrS_input))::value;
279+
for(int i=0; i<scale_number; i++){
280+
auto s_value = tCrS_input(i);
281+
if(cute::thread0()) printf("scale_number = %d, tCrS_input[%d] = %f\n",scale_number, i, static_cast<float>(s_value));
282+
}
273283
// printf("thread_idx = %d, decltype(size(in))::value = %d, K = %d, N = %d, L = %d, src_bits = %d, sizeof_bits_v<format_type> = %d, scalar = %d, decltype(size(out))::value = %d, loop_cnt = %d, splits = %d\n",int(ThreadIdxX()), decltype(size(in))::value, decltype(size<0>(in))::value, N, decltype(size<2>(in))::value, src_bits, sizeof_bits_v<format_type>, scalar, decltype(size(out))::value, loop_cnt, splits);
274284

275285
for (int n = 0; n < N; n++) {
@@ -285,8 +295,13 @@ class kgemm_4bit_inference_cutlass_dequant {
285295

286296
for (int i = 0; i < vec_size; i++) {
287297
uint8_t value = (format_data >> (src_bits * i)) & 0xf;
288-
dst[i] = static_cast<DstType>(quant_map[value] * static_cast<float>(ts));
289-
//if(cute::thread0()) printf("n = %d, s = %d, i = %d, src = %d, quant_map[value] = %f, ts = %f, dst = %f\n", n, s, i, static_cast<int>(value), quant_map[value], static_cast<float>(ts), static_cast<float>(dst[i]));
298+
if(i % 2 != 0) { //1,3, high_4bit
299+
dst[i-1] = static_cast<DstType>(quant_map[value] * static_cast<float>(ts));
300+
} else {
301+
dst[i+1] = static_cast<DstType>(quant_map[value] * static_cast<float>(ts));
302+
}
303+
if(cute::thread0())
304+
printf("tid = %d, n = %d, s = %d, i = %d, format_data = %d, value = %d, quant_map[value] = %f, ts = %f, dst = %f\n",ThreadIdxX(), n, s, i, static_cast<int>(format_data), static_cast<int>(value), quant_map[value], static_cast<float>(ts), static_cast<float>(dst[i]));
290305
}
291306
}
292307
}
@@ -500,29 +515,38 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
500515
}
501516
#undef PRINT
502517
#endif
503-
const int k_start_idx = crd2idx((*k_tile_iter), make_shape(K));
518+
const int k_start_idx = crd2idx((*k_tile_iter), make_shape(K));
504519
int prefetch_k = k_start_idx;
505520

521+
#if 1
522+
const int k_reload_factor = ceil_div(params.group_size, BLK_K);
523+
if(cute::thread0()) printf("params.group_size = %d, BLK_K = %d, k_reload_factor = %f\n",params.group_size, BLK_K, k_reload_factor);
524+
#endif
506525
CUTLASS_PRAGMA_UNROLL
507526
for (int i = 0; i < DispatchPolicy::Stages; i++, prefetch_k++) {
508527
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
509528
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
510529
}
511530

512-
for (int k_tile = k_start_idx; k_tile < k_tile_count + k_start_idx; k_tile++, prefetch_k++) {
531+
for (int k_tile = k_start_idx, k_s = 0; k_tile < k_tile_count + k_start_idx; k_tile++, prefetch_k++, k_s++) {
513532
barrier_arrive(2);
514533

515534
// Copy gmem to rmem for the first k_tile
516535
copy(tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
517536
copy(tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B);
518-
537+
#if 1
538+
const int s_step = k_start_idx + (k_s / k_reload_factor); //1 + k_tile / k_reload_factor;
539+
if(cute::thread0()) printf("k_start_idx = %d, k_s = %d, k_reload_factor = %f, s_step = %d\n",k_start_idx, k_s, k_reload_factor, s_step);
540+
copy(tiled_copy_scale, copy_iter_s(_, _, _, s_step), frag_copy_Scale);
541+
#else
519542
const int k_reload_factor = ceil_div(params.group_size, BLK_K);
520543
//const int k_reload_factor = params.group_size / BLK_K;
521544
522-
if(cute::thread0()) printf("params.group_size = %d, BLK_K = %d, k_reload_factor = %d\n",params.group_size, BLK_K, k_reload_factor);
545+
//if(cute::thread0())
546+
printf("params.group_size = %d, BLK_K = %d, k_reload_factor = %d\n",params.group_size, BLK_K, k_reload_factor);
523547
524548
copy(tiled_copy_scale, copy_iter_s(_, _, _, k_tile / k_reload_factor), frag_copy_Scale);
525-
549+
#endif
526550
if(prefetch_k < k_tile_count) {
527551
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
528552
}
@@ -563,12 +587,10 @@ if (cute::thread0()) {
563587
// 打印输出
564588
debug_print("Accumulators (After GEMM)", accumulators);
565589

566-
barrier_wait(2);
567590
}
568591
#endif
569592
#if 0
570593
cute::gemm(tiled_mma, mma_A, mma_B, accumulators);
571-
barrier_wait(2);
572594

573595
for (int i = 0; i < accumulators.size(); ++i) {
574596
printf("Thread (%d, %d): accumulators[%d] =%f\n", syclcompat::global_id::x() , syclcompat::global_id::y(), i, static_cast<float>(accumulators[i]));

tests/test_xpu.py

Lines changed: 68 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class TestXPU:
4040
@pytest.mark.parametrize("device", ["xpu"])#get_available_devices())
4141
@pytest.mark.parametrize("double_quant", [True], ids=lambda double_quant: f"DQ_{double_quant}")
4242
@pytest.mark.parametrize("storage_type", ["nf4"])
43-
@pytest.mark.parametrize("kind", ["fc1"])#, "attn_packed"])
43+
@pytest.mark.parametrize("kind", ["fc0"])#, "attn_packed"])
4444
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=describe_dtype)
4545
@pytest.mark.parametrize(
4646
"quant_storage",
@@ -65,10 +65,30 @@ def test_gemm_4bit(self, device, dim, dtype, storage_type, quant_storage, double
6565

6666
#for i in range(iters):
6767
#pdb.set_trace()
68-
if kind == "fc1":
68+
if kind == "fc0":
69+
dim = 16
70+
#A = torch.arange(32, 0, -2).reshape(1, dim).bfloat16().xpu() * torch.randn(1, dim, dtype=dtype, device=device) * 10
71+
#shuffled_indices = torch.randperm(dim)
72+
#A = A[:, shuffled_indices] # 直接索引列
73+
74+
#B = torch.arange(0, 32, 1).reshape(2, dim).bfloat16().xpu() * torch.randn(2, dim, dtype=dtype, device=device) / 10
75+
#shuffled_indices = torch.randperm(dim)
76+
#B = B[:, shuffled_indices].contiguous() # 直接索引列
77+
78+
#A = torch.ones(1, dim, dtype=dtype, device=device)
79+
#B = torch.ones(2, dim, dtype=dtype, device=device) # / math.sqrt(dim)
80+
81+
A = torch.randn(1, dim, dtype=dtype, device=device) * 10
82+
B = torch.randn(2, dim, dtype=dtype, device=device) / math.sqrt(dim)
83+
double_quant=False
84+
block_size = 16
85+
elif kind == "fc1":
86+
dim=256
6987
A = torch.randn(32, dim, dtype=dtype, device=device) * 10
7088
#A = torch.arange(1, 32 * 256 + 1).reshape(32, 256).bfloat16().xpu()
71-
B = torch.randn(dim, dim, dtype=dtype, device=device) / math.sqrt(dim)
89+
B = torch.randn(dim, dim, dtype=dtype, device=device) # / math.sqrt(dim)
90+
double_quant=False
91+
block_size = 32
7292
elif kind == "fc2":
7393
A = torch.randn(1, 4 * dim, dtype=dtype, device=device)
7494
B = torch.randn(dim, 4 * dim, dtype=dtype, device=device) / math.sqrt(dim)
@@ -84,24 +104,53 @@ def test_gemm_4bit(self, device, dim, dtype, storage_type, quant_storage, double
84104
quant_type=storage_type,
85105
compress_statistics=double_quant,
86106
quant_storage=quant_storage,
87-
blocksize=64,
107+
blocksize=block_size,
88108
)
89109

90-
##pdb.set_trace()
91-
C3 = torch.matmul(A, B.t())
92-
#pdb.set_trace()
93-
C2 = F.gemv_4bit(A, qB.t(), state=state)
94-
#pdb.set_trace()
95-
print("C3.sum() = ", C3.sum())
96-
print("C2.sum() = ", C2.sum())
97-
diff = abs(C2-C3)
98-
print("diff = ", diff.sum())
99-
print(C3[0])
100-
print(C2[0])
101-
#print(C3)
102-
#print(C2)
103-
#A.requires_grad = True
104-
#C1 = bnb.matmul_4bit(A, qB.t(), state)
110+
if kind == "fc0":
111+
pdb.set_trace()
112+
print("")
113+
print("absmax = ", state.absmax)
114+
print("A = ",A)
115+
print("B = ",B)
116+
print("qB = ",qB)
117+
print("B.t() = ",B.t())
118+
print("qB.t() = ",qB.t())
119+
C3 = torch.matmul(A, B.t())
120+
#pdb.set_trace()
121+
C2 = F.gemv_4bit(A, qB.t(), state=state)
122+
#pdb.set_trace()
123+
print("C3.sum() = ", C3.sum())
124+
print("C2.sum() = ", C2.sum())
125+
diff = abs(C2-C3)
126+
print("diff = ", diff.sum())
127+
print(C3)
128+
print(C2)
129+
#exit()
130+
#print(C3)
131+
#print(C2)
132+
#A.requires_grad = True
133+
#C1 = bnb.matmul_4bit(A, qB.t(), state)
134+
else:
135+
pdb.set_trace()
136+
print("")
137+
print("absmax = ", state.absmax)
138+
print("A[0] = ",A[0])
139+
print("B[0] = ",B[0])
140+
C3 = torch.matmul(A, B.t())
141+
#pdb.set_trace()
142+
C2 = F.gemv_4bit(A, qB.t(), state=state)
143+
#pdb.set_trace()
144+
print("C3.sum() = ", C3.sum())
145+
print("C2.sum() = ", C2.sum())
146+
diff = abs(C2-C3)
147+
print("diff = ", diff.sum())
148+
print(C3[0])
149+
print(C2[0])
150+
#print(C3)
151+
#print(C2)
152+
#A.requires_grad = True
153+
#C1 = bnb.matmul_4bit(A, qB.t(), state)
105154

106155
@pytest.mark.parametrize("device", ["xpu"]) #get_available_devices())
107156
@pytest.mark.parametrize("embedding_dim", [64, 65])

0 commit comments

Comments
 (0)