Skip to content

Commit 684d6a4

Browse files
committed
clean code
1 parent 12a484a commit 684d6a4

3 files changed

Lines changed: 34 additions & 130 deletions

File tree

bitsandbytes/backends/xpu/ops.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ def _gemv_4bit_impl(
7474
blocksize: int,
7575
out: torch.Tensor,
7676
) -> None:
77-
#import pdb
78-
#pdb.set_trace()
79-
m = ct.c_int32(*A.shape[:-1])
77+
import pdb
78+
pdb.set_trace()
79+
m = ct.c_int32(*A.shape[:-1]) #A.shape[1])
8080
n = ct.c_int32(shapeB[0])
8181
k = ct.c_int32(shapeB[1])
8282

@@ -86,8 +86,8 @@ def _gemv_4bit_impl(
8686

8787
#absmax = absmax * 10
8888
#pdb.set_trace()
89-
print("A before kernel: ", A)
90-
print("B before kernel: ", B)
89+
#print("A before kernel: ", A)
90+
#print("B before kernel: ", B)
9191
stream = _get_tensor_stream(A)
9292
if A.dtype == torch.float16:
9393
lib.cgemv_4bit_inference_fp16(

csrc/xpu_cutlass_fusion.cpp

Lines changed: 16 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -233,31 +233,7 @@ class kgemm_4bit_inference_cutlass_dequant {
233233
using SrcType = typename EngineIn::value_type;
234234
using DstType = typename EngineOut::value_type;
235235
using ScaleType = typename EngineScales::value_type;
236-
#if 0
237-
static constexpr auto N = decltype(size<1>(in))::value;
238-
static constexpr auto loop_cnt = decltype(size(out))::value / N;
239-
for (int n = 0; n < N; n++) {
240-
auto s_value = tCrS_input(i);
241-
for (int l = 0; s < loop_cnt; l++) {
242-
243-
// int numbers = decltype(size(in))::value;
244-
// for(int i=0; i<numbers / N; i++){
245-
// //auto in_ptr_8 = (uint8_t*)(raw_pointer_cast(in.data()));
246-
// //out[i] = static_cast<DstType>(quant_map[in_ptr_8[i].data()]);
247-
// uint8_t value = in[i].get();
248-
// out[i] = static_cast<DstType>(quant_map[value]);
249-
// int thread_idx = int(ThreadIdxX());
250-
// if(cute::thread0()){
251-
// //if(syclcompat::global_id::x() == 2 && syclcompat::global_id::y() ==0 && syclcompat::global_id::z() ==0 )
252-
// //printf("syclcompat::global_id::x() = %d, syclcompat::global_id::y() = %d, syclcompat::global_id::z() = %d, thread_idx = %d, i = %d, in[i].ptr_ = %x, in[i].idx_=%x, value_bit = %x, value = %d, quant_map[value] = %f, out[i] = %f\n",syclcompat::global_id::x(), syclcompat::global_id::y(), syclcompat::global_id::z(), thread_idx, i, in[i].ptr_, in[i].idx_, value, static_cast<int>(value), quant_map[value], static_cast<float>(out[i]));
253-
// }
254-
// }
255-
// int scale_number = decltype(size(tCrS_input))::value;
256-
// for(int i=0; i<scale_number; i++){
257-
// auto s_value = tCrS_input(i);
258-
// if(cute::thread0()) printf("scale_number = %d, tCrS_input[%d] = %f\n",scale_number, i, static_cast<float>(s_value));
259-
// }
260-
#else
236+
261237
static constexpr auto N = decltype(size<1>(in))::value;
262238

263239
using format_type = ushort; //16
@@ -275,13 +251,6 @@ class kgemm_4bit_inference_cutlass_dequant {
275251
auto s_tensor = make_tensor((format_type*)(raw_pointer_cast(in.data())), Shape<Int<loop_cnt / scalar>, Int<N>>{});
276252
auto d_tensor = make_tensor(out.data(), Shape<Int<vec_size>, Int<splits>, Int<N>>{});
277253

278-
int scale_number = decltype(size(tCrS_input))::value;
279-
for(int i=0; i<scale_number; i++){
280-
auto s_value = tCrS_input(i);
281-
if(cute::thread0()) printf("scale_number = %d, tCrS_input[%d] = %f\n",scale_number, i, static_cast<float>(s_value));
282-
}
283-
// printf("thread_idx = %d, decltype(size(in))::value = %d, K = %d, N = %d, L = %d, src_bits = %d, sizeof_bits_v<format_type> = %d, scalar = %d, decltype(size(out))::value = %d, loop_cnt = %d, splits = %d\n",int(ThreadIdxX()), decltype(size(in))::value, decltype(size<0>(in))::value, N, decltype(size<2>(in))::value, src_bits, sizeof_bits_v<format_type>, scalar, decltype(size(out))::value, loop_cnt, splits);
284-
285254
for (int n = 0; n < N; n++) {
286255
const auto ts = tCrS_input(n);
287256

@@ -300,17 +269,14 @@ for(int i=0; i<scale_number; i++){
300269
} else {
301270
dst[i+1] = static_cast<DstType>(quant_map[value] * static_cast<float>(ts));
302271
}
303-
if(cute::thread0())
304-
printf("tid = %d, n = %d, s = %d, i = %d, format_data = %d, value = %d, quant_map[value] = %f, ts = %f, dst = %f\n",ThreadIdxX(), n, s, i, static_cast<int>(format_data), static_cast<int>(value), quant_map[value], static_cast<float>(ts), static_cast<float>(dst[i]));
305272
}
306273
}
307274
}
308-
#endif
309275
}
310276

311277
CUTLASS_DEVICE
312278
void operator()(Params const& params, char* smem_buf) {
313-
if(cute::thread0()) printf("this is fusion kernel...........\n");
279+
//if(cute::thread0()) printf("this is fusion kernel...........\n");
314280

315281
int M = params.m;
316282
int N = params.n;
@@ -363,21 +329,17 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
363329
auto blk_shape = TileShape{}; //16,64,64
364330
int m_coord, n_coord, l_coord; //block index
365331
if (params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN) {
366-
if(cute::thread0()) printf("AlongN !!\n");
332+
//if(cute::thread0()) printf("AlongN !!\n");
367333
m_coord = BlockIdxY();
368334
n_coord = BlockIdxX();
369335
l_coord = BlockIdxZ();
370336
} else {
371-
if(cute::thread0()) printf("not AlongN !!\n");
337+
//if(cute::thread0()) printf("not AlongN !!\n");
372338
m_coord = BlockIdxX();
373339
n_coord = BlockIdxY();
374340
l_coord = BlockIdxZ();
375341
}
376342
auto blk_coord_mnkl = make_coord(m_coord, n_coord, _, l_coord);
377-
if(cute::thread0()) {
378-
printf("M = %d, N=%d, K=%d, L=%d\n", M, N, K, L);
379-
printf("thread_idx = %d, m_coord = %d, n_coord = %d, l_coord = %d, BlockIdxX() = %d, BlockIdxY() = %d, BlockIdxZ() = %d\n",thread_idx, m_coord, n_coord, l_coord, BlockIdxX(), BlockIdxY(), BlockIdxZ());
380-
}
381343
constexpr auto workgroup_shape = WorkgroupTileShape{}; //256, 256, 32
382344
constexpr auto subgroup_tile_shape = SubgroupTileShape{}; //32, 64, 32 (number of atom level workgroup: 256/8=32, 256/4=64, 32/2=32)
383345

@@ -395,7 +357,6 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
395357
//// Create K slicing tiling iterator and count
396358
auto k_tile_iter = cute::make_coord_iterator(idx2crd(0, make_shape(K)), make_shape(K));
397359
int k_tile_count = ceil_div(K, get<2>(workgroup_shape)); //inner_loop number
398-
if(cute::thread0()) printf("k_tile_count = %d\n", k_tile_count);
399360

400361

401362
////// MainLoop //////
@@ -417,13 +378,10 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
417378

418379
Tensor dequant_frag = make_tensor<ElementB>(mma_B.layout());
419380

420-
//const int SubgroupSize = 16;
421381
static constexpr auto scale_traits_size = decltype(size(typename GmemTiledCopyScale::BlockShape{}))::value / DispatchPolicy::SubgroupSize; //SubgroupSize;
422382
static constexpr auto scale_traits_num = SG_QNT_WIDTH / decltype(size<1>(typename GmemTiledCopyScale::BlockShape{}))::value;
423383
using FragScaleLayout = Layout<Shape<Int<scale_traits_size>, Int<scale_traits_num>, _1>>;
424-
//using FragScaleLayout = Layout<Shape<Int<scale_traits_size>, Int<scale_traits_num>, _1>, Stride<_1,_1,_0>>;
425384
Tensor fragment_scale = make_tensor<ElementScale>(FragScaleLayout{});
426-
if(cute::thread0()) printf("scale_traits_size = %d, scale_traits_num = %d, SG_QNT_WIDTH = %d, BlockShape = %d, BlockShape_1= %d\n", scale_traits_size, scale_traits_num, SG_QNT_WIDTH, decltype(size(typename GmemTiledCopyScale::BlockShape{}))::value, decltype(size<1>(typename GmemTiledCopyScale::BlockShape{}))::value);
427385

428386
static_assert(std::is_same_v<typename decltype(dequant_frag)::value_type, ElementQuant>);
429387
static_assert(std::is_same_v<typename decltype(mma_A)::value_type, ElementMMA>);
@@ -433,15 +391,6 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
433391
Tensor frag_copy_A = thr_copy_A.retile_D(mma_A);
434392
Tensor frag_copy_B = thr_copy_B.retile_D(dequant_frag);
435393
Tensor frag_copy_Scale = thr_copy_scale.retile_D(fragment_scale);
436-
//auto frag_layout = make_layout(
437-
// make_shape(_2{}, _1{}, _1{}), // 形状 (_2, _1, _1)
438-
// make_stride(_1{}, _1{}, _0{}) // 步长 (_1, _1, _0)
439-
//);
440-
//Tensor frag_copy_Scale = thr_copy_scale.retile_D(make_tensor(fragment_scale.data(), frag_layout));
441-
442-
//using FragLayout = Layout<Shape<_2,_1,_1>, Stride<_1,_1,_0>>;
443-
//Tensor fragment_scale = make_tensor<ElementScale>(FragLayout{});
444-
//Tensor frag_copy_Scale = thr_copy_scale.retile_D(fragment_scale);
445394

446395
//// Retile global counting tensors for copies:
447396
Tensor tAgA = thr_copy_A.retile_S(tCgA);
@@ -458,26 +407,14 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
458407
auto pBgB = thr_prefetch_B.partition_S(gB);
459408

460409
// Run mainloop
461-
//auto [m_idx, n_idx, k_idx, l_idx] = blk_coord_mnkl;
462-
//const int n_coord_s = n_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N;
463-
//const int l_coord_s = l_idx;
464-
465-
//if(cute::thread0()) printf("get_sub_group_id() = %d, m_idx = %d, n_idx = %d, k_idx = %d, l_idx = %d, n_coord_s = %d, l_coord_s = %d\n",get_sub_group_id(), m_idx, n_idx, k_idx, l_idx, n_coord_s, l_coord_s);
466-
467410
auto copy_iter_s = [&](){
468411
return make_tensor(make_inttuple_iter(make_coord(n_coord, 0, l_coord)),
469412
make_layout(make_shape(Int<scale_traits_size>{}, Int<scale_traits_num>{}, _1{}, k_tile_count),
470413
make_stride(E<0>{} * _16{}, E<0>{} * decltype(size<1>(typename GmemTiledCopyScale::BlockShape{}))::value, _0{}, E<1>{} * _1{})));
471414

472415
}();
473416

474-
//auto copy_iter_s = [&](){
475-
// return make_tensor(make_inttuple_iter(make_coord(n_coord, 0, l_coord)),
476-
// make_layout(make_shape(Int<decltype(size<0>(typename GmemTiledCopyScale::BlockShape{}))::value>{}, Int<decltype(size<1>(typename GmemTiledCopyScale::BlockShape{}))::value>{}, _1{}, k_tile_count),
477-
// make_stride(_16{}, _32{}, _0{}, _1{})));
478-
//}();
479-
480-
#if 1
417+
#if 0
481418
#define PRINT(x) print(#x ": "); print(x); print("\n");
482419
if (cutlass::thread(LOG_THREAD, LOG_GROUP)) {
483420
print("\n\n======================= A: \n");
@@ -518,10 +455,9 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
518455
const int k_start_idx = crd2idx((*k_tile_iter), make_shape(K));
519456
int prefetch_k = k_start_idx;
520457

521-
#if 1
522458
const int k_reload_factor = ceil_div(params.group_size, BLK_K);
523-
if(cute::thread0()) printf("params.group_size = %d, BLK_K = %d, k_reload_factor = %f\n",params.group_size, BLK_K, k_reload_factor);
524-
#endif
459+
//if(cute::thread0()) printf("params.group_size = %d, BLK_K = %d, k_reload_factor = %f\n",params.group_size, BLK_K, k_reload_factor);
460+
525461
CUTLASS_PRAGMA_UNROLL
526462
for (int i = 0; i < DispatchPolicy::Stages; i++, prefetch_k++) {
527463
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
@@ -534,19 +470,11 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
534470
// Copy gmem to rmem for the first k_tile
535471
copy(tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
536472
copy(tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B);
537-
#if 1
538-
const int s_step = k_start_idx + (k_s / k_reload_factor); //1 + k_tile / k_reload_factor;
539-
if(cute::thread0()) printf("k_start_idx = %d, k_s = %d, k_reload_factor = %f, s_step = %d\n",k_start_idx, k_s, k_reload_factor, s_step);
540-
copy(tiled_copy_scale, copy_iter_s(_, _, _, s_step), frag_copy_Scale);
541-
#else
542-
const int k_reload_factor = ceil_div(params.group_size, BLK_K);
543-
//const int k_reload_factor = params.group_size / BLK_K;
544473

545-
//if(cute::thread0())
546-
printf("params.group_size = %d, BLK_K = %d, k_reload_factor = %d\n",params.group_size, BLK_K, k_reload_factor);
474+
const int s_step = k_start_idx + (k_s / k_reload_factor);
475+
//if(cute::thread0()) printf("k_start_idx = %d, k_s = %d, k_reload_factor = %f, s_step = %d\n",k_start_idx, k_s, k_reload_factor, s_step);
476+
copy(tiled_copy_scale, copy_iter_s(_, _, _, s_step), frag_copy_Scale);
547477

548-
copy(tiled_copy_scale, copy_iter_s(_, _, _, k_tile / k_reload_factor), frag_copy_Scale);
549-
#endif
550478
if(prefetch_k < k_tile_count) {
551479
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
552480
}
@@ -617,41 +545,16 @@ template <typename T, int BITS>
617545
void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned char *B,
618546
T *absmax_, float *datatype, float *out, int lda,
619547
int ldb, int ldc, int blocksize, sycl::queue *stream) {
620-
std::cout<<"this is gemm_4bit_inference_cutlass_dequant ......................!!!!!!\n";
548+
//std::cout<<"this is gemm_4bit_inference_cutlass_dequant ......................!!!!!!\n";
621549

622550
sycl::queue q = *stream;
623551
using GemmKernel = kgemm_4bit_inference_cutlass_dequant<T, BITS>;
624552

625553
static constexpr int smem_size= 512; // (16 * 32) for quant_map
626554
int l = 1;
627555

628-
//TODO(Xiaoli): FIX ME?? auto problem_size = ProblemShape{m, n, k};
629556
auto problem_size = ProblemShape{m, n, k, l};
630-
//TODO(Xiaoli): FIX ME
631-
// T* absmax = (T*)absmax_;
632-
// T* absmax = (T*)absmax_;
633557

634-
//std::vector<T> host_data(n * k / blocksize);
635-
#if 0
636-
int element_size_A = m * k;
637-
auto scale_host_A = sycl::aligned_alloc_host<T>(512, element_size_A, q);
638-
q.memcpy(scale_host_A, A, element_size_A * sizeof(T)).wait();
639-
for (int i = 0; i < element_size_A; ++i) {
640-
//std::cout << scale_host[i] << " ";
641-
printf("%f ",static_cast<float>(scale_host_A[i]));
642-
}
643-
std::cout << std::endl;
644-
645-
int element_size = n * k / blocksize;
646-
auto scale_host = sycl::aligned_alloc_host<T>(512, element_size, q);
647-
q.memcpy(scale_host, absmax_, element_size * sizeof(T)).wait();
648-
for (int i = 0; i < element_size; ++i) {
649-
//std::cout << scale_host[i] << " ";
650-
printf("%f ",static_cast<float>(scale_host[i]));
651-
}
652-
std::cout << std::endl;
653-
#endif
654-
#if 1
655558
// Init Params
656559
using Params = GemmKernel::Params;
657560
Params params;
@@ -678,7 +581,7 @@ std::cout << std::endl;
678581

679582
const int scale_k = cute::ceil_div(k, blocksize);
680583
StrideScale stride_S = cutlass::make_cute_packed_stride(StrideScale{}, cute::make_shape(n, scale_k, l));
681-
std::cout<<"n = "<<n<<" k = "<<k<<" blocksize = "<<blocksize<<" scale_k = "<<scale_k<<std::endl;
584+
//std::cout<<"n = "<<n<<" k = "<<k<<" blocksize = "<<blocksize<<" scale_k = "<<scale_k<<std::endl;
682585
auto mScale = make_tensor(
683586
make_gmem_ptr(absmax_),
684587
make_layout(make_shape(n, scale_k, l), stride_S));
@@ -694,6 +597,7 @@ std::cout << std::endl;
694597
StrideC stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, l));
695598
StrideD stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, l));
696599

600+
#if 0
697601
#define PRINT(x) print(#x ": "); print(x); print("\n");
698602
if (cutlass::thread(LOG_THREAD, LOG_GROUP)) {
699603
print("===================== stride :\n");
@@ -705,6 +609,7 @@ std::cout << std::endl;
705609
print("===================== stride :\n");
706610
}
707611
#undef PRINT
612+
#endif
708613

709614
params.hw_info = hw_info;
710615
params.epilogue = CollectiveEpilogue::to_underlying_arguments(problem_size, {{alpha, beta}, nullptr, stride_C, out, stride_D}, nullptr);
@@ -721,8 +626,8 @@ std::cout << std::endl;
721626

722627
const syclcompat::dim3 sycl_block(block.x, block.y, block.z); //workgroup_size: 1*2*1*16, 1, 1
723628
const syclcompat::dim3 sycl_grid(grid.x, grid.y, grid.z); //workgroup_number (problem_size / tile_size): N/64, M/16, 1
724-
printf("Host Grid: (%d, %d, %d)\n", grid.x, grid.y, grid.z);
725-
printf("Host Block: (%d, %d, %d)\n", block.x, block.y, block.z);
629+
//printf("Host Grid: (%d, %d, %d)\n", grid.x, grid.y, grid.z);
630+
//printf("Host Block: (%d, %d, %d)\n", block.x, block.y, block.z);
726631

727632
auto kernel_props = [] {
728633
return syclcompat::experimental::kernel_properties{
@@ -739,7 +644,6 @@ std::cout << std::endl;
739644
auto event = syclcompat::experimental::launch<device_kernel<GemmKernel>>(policy, q, params);
740645
EventManager::getInstance().addEvent(event);
741646
//syclcompat::wait();
742-
#endif
743647
}
744648

745649
template void gemm_4bit_inference_cutlass_dequant<sycl::ext::oneapi::bfloat16, 16>(

tests/test_xpu.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class TestXPU:
4040
@pytest.mark.parametrize("device", ["xpu"])#get_available_devices())
4141
@pytest.mark.parametrize("double_quant", [True], ids=lambda double_quant: f"DQ_{double_quant}")
4242
@pytest.mark.parametrize("storage_type", ["nf4"])
43-
@pytest.mark.parametrize("kind", ["fc0"])#, "attn_packed"])
43+
@pytest.mark.parametrize("kind", ["fc1"])#, "attn_packed"])
4444
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=describe_dtype)
4545
@pytest.mark.parametrize(
4646
"quant_storage",
@@ -83,10 +83,10 @@ def test_gemm_4bit(self, device, dim, dtype, storage_type, quant_storage, double
8383
double_quant=False
8484
block_size = 16
8585
elif kind == "fc1":
86-
dim=256
87-
A = torch.randn(32, dim, dtype=dtype, device=device) * 10
86+
dim=4096
87+
A = torch.randn(64, dim, dtype=dtype, device=device)
8888
#A = torch.arange(1, 32 * 256 + 1).reshape(32, 256).bfloat16().xpu()
89-
B = torch.randn(dim, dim, dtype=dtype, device=device) # / math.sqrt(dim)
89+
B = torch.randn(dim , dim, dtype=dtype, device=device) / math.sqrt(dim)
9090
double_quant=False
9191
block_size = 32
9292
elif kind == "fc2":
@@ -133,18 +133,18 @@ def test_gemm_4bit(self, device, dim, dtype, storage_type, quant_storage, double
133133
#C1 = bnb.matmul_4bit(A, qB.t(), state)
134134
else:
135135
pdb.set_trace()
136-
print("")
137-
print("absmax = ", state.absmax)
138-
print("A[0] = ",A[0])
139-
print("B[0] = ",B[0])
136+
#print("")
137+
#print("absmax = ", state.absmax)
138+
#print("A[0] = ",A[0])
139+
#print("B[0] = ",B[0])
140140
C3 = torch.matmul(A, B.t())
141141
#pdb.set_trace()
142142
C2 = F.gemv_4bit(A, qB.t(), state=state)
143-
#pdb.set_trace()
144-
print("C3.sum() = ", C3.sum())
145-
print("C2.sum() = ", C2.sum())
146-
diff = abs(C2-C3)
147-
print("diff = ", diff.sum())
143+
pdb.set_trace()
144+
#print("C3.sum() = ", C3.sum())
145+
#print("C2.sum() = ", C2.sum())
146+
diff = abs(C2-C3.bfloat16())
147+
print("diff = ", diff[0])
148148
print(C3[0])
149149
print(C2[0])
150150
#print(C3)

0 commit comments

Comments
 (0)