Skip to content

Commit 7f2682d

Browse files
committed
clean code
1 parent 12a484a commit 7f2682d

3 files changed

Lines changed: 34 additions & 136 deletions

File tree

bitsandbytes/backends/xpu/ops.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ def _gemv_4bit_impl(
7474
blocksize: int,
7575
out: torch.Tensor,
7676
) -> None:
77-
#import pdb
78-
#pdb.set_trace()
79-
m = ct.c_int32(*A.shape[:-1])
77+
import pdb
78+
pdb.set_trace()
79+
m = ct.c_int32(*A.shape[:-1]) #A.shape[1])
8080
n = ct.c_int32(shapeB[0])
8181
k = ct.c_int32(shapeB[1])
8282

@@ -86,8 +86,8 @@ def _gemv_4bit_impl(
8686

8787
#absmax = absmax * 10
8888
#pdb.set_trace()
89-
print("A before kernel: ", A)
90-
print("B before kernel: ", B)
89+
#print("A before kernel: ", A)
90+
#print("B before kernel: ", B)
9191
stream = _get_tensor_stream(A)
9292
if A.dtype == torch.float16:
9393
lib.cgemv_4bit_inference_fp16(

csrc/xpu_cutlass_fusion.cpp

Lines changed: 16 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,6 @@ class kgemm_4bit_inference_cutlass_dequant {
176176
T* A;
177177
uint8_t* B;
178178
float* out;
179-
//T *absmax;
180179
float *datatype; //LUT
181180
int group_size;
182181

@@ -206,11 +205,6 @@ class kgemm_4bit_inference_cutlass_dequant {
206205
}
207206
}
208207

209-
/*float bfloat16_to_float(uint16_t bf16_bits) {
210-
uint32_t float_bits = (bf16_bits << 16); // 将 bfloat16 左移16位转为 float
211-
return reinterpret_cast<float&>(float_bits);
212-
}*/
213-
214208
/// Utilities to transform A.
215209
template <class EngineIn,
216210
class EngineOut,
@@ -233,31 +227,7 @@ class kgemm_4bit_inference_cutlass_dequant {
233227
using SrcType = typename EngineIn::value_type;
234228
using DstType = typename EngineOut::value_type;
235229
using ScaleType = typename EngineScales::value_type;
236-
#if 0
237-
static constexpr auto N = decltype(size<1>(in))::value;
238-
static constexpr auto loop_cnt = decltype(size(out))::value / N;
239-
for (int n = 0; n < N; n++) {
240-
auto s_value = tCrS_input(i);
241-
for (int l = 0; s < loop_cnt; l++) {
242-
243-
// int numbers = decltype(size(in))::value;
244-
// for(int i=0; i<numbers / N; i++){
245-
// //auto in_ptr_8 = (uint8_t*)(raw_pointer_cast(in.data()));
246-
// //out[i] = static_cast<DstType>(quant_map[in_ptr_8[i].data()]);
247-
// uint8_t value = in[i].get();
248-
// out[i] = static_cast<DstType>(quant_map[value]);
249-
// int thread_idx = int(ThreadIdxX());
250-
// if(cute::thread0()){
251-
// //if(syclcompat::global_id::x() == 2 && syclcompat::global_id::y() ==0 && syclcompat::global_id::z() ==0 )
252-
// //printf("syclcompat::global_id::x() = %d, syclcompat::global_id::y() = %d, syclcompat::global_id::z() = %d, thread_idx = %d, i = %d, in[i].ptr_ = %x, in[i].idx_=%x, value_bit = %x, value = %d, quant_map[value] = %f, out[i] = %f\n",syclcompat::global_id::x(), syclcompat::global_id::y(), syclcompat::global_id::z(), thread_idx, i, in[i].ptr_, in[i].idx_, value, static_cast<int>(value), quant_map[value], static_cast<float>(out[i]));
253-
// }
254-
// }
255-
// int scale_number = decltype(size(tCrS_input))::value;
256-
// for(int i=0; i<scale_number; i++){
257-
// auto s_value = tCrS_input(i);
258-
// if(cute::thread0()) printf("scale_number = %d, tCrS_input[%d] = %f\n",scale_number, i, static_cast<float>(s_value));
259-
// }
260-
#else
230+
261231
static constexpr auto N = decltype(size<1>(in))::value;
262232

263233
using format_type = ushort; //16
@@ -275,13 +245,6 @@ class kgemm_4bit_inference_cutlass_dequant {
275245
auto s_tensor = make_tensor((format_type*)(raw_pointer_cast(in.data())), Shape<Int<loop_cnt / scalar>, Int<N>>{});
276246
auto d_tensor = make_tensor(out.data(), Shape<Int<vec_size>, Int<splits>, Int<N>>{});
277247

278-
int scale_number = decltype(size(tCrS_input))::value;
279-
for(int i=0; i<scale_number; i++){
280-
auto s_value = tCrS_input(i);
281-
if(cute::thread0()) printf("scale_number = %d, tCrS_input[%d] = %f\n",scale_number, i, static_cast<float>(s_value));
282-
}
283-
// printf("thread_idx = %d, decltype(size(in))::value = %d, K = %d, N = %d, L = %d, src_bits = %d, sizeof_bits_v<format_type> = %d, scalar = %d, decltype(size(out))::value = %d, loop_cnt = %d, splits = %d\n",int(ThreadIdxX()), decltype(size(in))::value, decltype(size<0>(in))::value, N, decltype(size<2>(in))::value, src_bits, sizeof_bits_v<format_type>, scalar, decltype(size(out))::value, loop_cnt, splits);
284-
285248
for (int n = 0; n < N; n++) {
286249
const auto ts = tCrS_input(n);
287250

@@ -300,17 +263,14 @@ for(int i=0; i<scale_number; i++){
300263
} else {
301264
dst[i+1] = static_cast<DstType>(quant_map[value] * static_cast<float>(ts));
302265
}
303-
if(cute::thread0())
304-
printf("tid = %d, n = %d, s = %d, i = %d, format_data = %d, value = %d, quant_map[value] = %f, ts = %f, dst = %f\n",ThreadIdxX(), n, s, i, static_cast<int>(format_data), static_cast<int>(value), quant_map[value], static_cast<float>(ts), static_cast<float>(dst[i]));
305266
}
306267
}
307268
}
308-
#endif
309269
}
310270

311271
CUTLASS_DEVICE
312272
void operator()(Params const& params, char* smem_buf) {
313-
if(cute::thread0()) printf("this is fusion kernel...........\n");
273+
//if(cute::thread0()) printf("this is fusion kernel...........\n");
314274

315275
int M = params.m;
316276
int N = params.n;
@@ -363,21 +323,17 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
363323
auto blk_shape = TileShape{}; //16,64,64
364324
int m_coord, n_coord, l_coord; //block index
365325
if (params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN) {
366-
if(cute::thread0()) printf("AlongN !!\n");
326+
//if(cute::thread0()) printf("AlongN !!\n");
367327
m_coord = BlockIdxY();
368328
n_coord = BlockIdxX();
369329
l_coord = BlockIdxZ();
370330
} else {
371-
if(cute::thread0()) printf("not AlongN !!\n");
331+
//if(cute::thread0()) printf("not AlongN !!\n");
372332
m_coord = BlockIdxX();
373333
n_coord = BlockIdxY();
374334
l_coord = BlockIdxZ();
375335
}
376336
auto blk_coord_mnkl = make_coord(m_coord, n_coord, _, l_coord);
377-
if(cute::thread0()) {
378-
printf("M = %d, N=%d, K=%d, L=%d\n", M, N, K, L);
379-
printf("thread_idx = %d, m_coord = %d, n_coord = %d, l_coord = %d, BlockIdxX() = %d, BlockIdxY() = %d, BlockIdxZ() = %d\n",thread_idx, m_coord, n_coord, l_coord, BlockIdxX(), BlockIdxY(), BlockIdxZ());
380-
}
381337
constexpr auto workgroup_shape = WorkgroupTileShape{}; //256, 256, 32
382338
constexpr auto subgroup_tile_shape = SubgroupTileShape{}; //32, 64, 32 (number of atom level workgroup: 256/8=32, 256/4=64, 32/2=32)
383339

@@ -395,7 +351,6 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
395351
//// Create K slicing tiling iterator and count
396352
auto k_tile_iter = cute::make_coord_iterator(idx2crd(0, make_shape(K)), make_shape(K));
397353
int k_tile_count = ceil_div(K, get<2>(workgroup_shape)); //inner_loop number
398-
if(cute::thread0()) printf("k_tile_count = %d\n", k_tile_count);
399354

400355

401356
////// MainLoop //////
@@ -417,13 +372,10 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
417372

418373
Tensor dequant_frag = make_tensor<ElementB>(mma_B.layout());
419374

420-
//const int SubgroupSize = 16;
421375
static constexpr auto scale_traits_size = decltype(size(typename GmemTiledCopyScale::BlockShape{}))::value / DispatchPolicy::SubgroupSize; //SubgroupSize;
422376
static constexpr auto scale_traits_num = SG_QNT_WIDTH / decltype(size<1>(typename GmemTiledCopyScale::BlockShape{}))::value;
423377
using FragScaleLayout = Layout<Shape<Int<scale_traits_size>, Int<scale_traits_num>, _1>>;
424-
//using FragScaleLayout = Layout<Shape<Int<scale_traits_size>, Int<scale_traits_num>, _1>, Stride<_1,_1,_0>>;
425378
Tensor fragment_scale = make_tensor<ElementScale>(FragScaleLayout{});
426-
if(cute::thread0()) printf("scale_traits_size = %d, scale_traits_num = %d, SG_QNT_WIDTH = %d, BlockShape = %d, BlockShape_1= %d\n", scale_traits_size, scale_traits_num, SG_QNT_WIDTH, decltype(size(typename GmemTiledCopyScale::BlockShape{}))::value, decltype(size<1>(typename GmemTiledCopyScale::BlockShape{}))::value);
427379

428380
static_assert(std::is_same_v<typename decltype(dequant_frag)::value_type, ElementQuant>);
429381
static_assert(std::is_same_v<typename decltype(mma_A)::value_type, ElementMMA>);
@@ -433,15 +385,6 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
433385
Tensor frag_copy_A = thr_copy_A.retile_D(mma_A);
434386
Tensor frag_copy_B = thr_copy_B.retile_D(dequant_frag);
435387
Tensor frag_copy_Scale = thr_copy_scale.retile_D(fragment_scale);
436-
//auto frag_layout = make_layout(
437-
// make_shape(_2{}, _1{}, _1{}), // 形状 (_2, _1, _1)
438-
// make_stride(_1{}, _1{}, _0{}) // 步长 (_1, _1, _0)
439-
//);
440-
//Tensor frag_copy_Scale = thr_copy_scale.retile_D(make_tensor(fragment_scale.data(), frag_layout));
441-
442-
//using FragLayout = Layout<Shape<_2,_1,_1>, Stride<_1,_1,_0>>;
443-
//Tensor fragment_scale = make_tensor<ElementScale>(FragLayout{});
444-
//Tensor frag_copy_Scale = thr_copy_scale.retile_D(fragment_scale);
445388

446389
//// Retile global counting tensors for copies:
447390
Tensor tAgA = thr_copy_A.retile_S(tCgA);
@@ -458,26 +401,14 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
458401
auto pBgB = thr_prefetch_B.partition_S(gB);
459402

460403
// Run mainloop
461-
//auto [m_idx, n_idx, k_idx, l_idx] = blk_coord_mnkl;
462-
//const int n_coord_s = n_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N;
463-
//const int l_coord_s = l_idx;
464-
465-
//if(cute::thread0()) printf("get_sub_group_id() = %d, m_idx = %d, n_idx = %d, k_idx = %d, l_idx = %d, n_coord_s = %d, l_coord_s = %d\n",get_sub_group_id(), m_idx, n_idx, k_idx, l_idx, n_coord_s, l_coord_s);
466-
467404
auto copy_iter_s = [&](){
468405
return make_tensor(make_inttuple_iter(make_coord(n_coord, 0, l_coord)),
469406
make_layout(make_shape(Int<scale_traits_size>{}, Int<scale_traits_num>{}, _1{}, k_tile_count),
470407
make_stride(E<0>{} * _16{}, E<0>{} * decltype(size<1>(typename GmemTiledCopyScale::BlockShape{}))::value, _0{}, E<1>{} * _1{})));
471408

472409
}();
473410

474-
//auto copy_iter_s = [&](){
475-
// return make_tensor(make_inttuple_iter(make_coord(n_coord, 0, l_coord)),
476-
// make_layout(make_shape(Int<decltype(size<0>(typename GmemTiledCopyScale::BlockShape{}))::value>{}, Int<decltype(size<1>(typename GmemTiledCopyScale::BlockShape{}))::value>{}, _1{}, k_tile_count),
477-
// make_stride(_16{}, _32{}, _0{}, _1{})));
478-
//}();
479-
480-
#if 1
411+
#if 0
481412
#define PRINT(x) print(#x ": "); print(x); print("\n");
482413
if (cutlass::thread(LOG_THREAD, LOG_GROUP)) {
483414
print("\n\n======================= A: \n");
@@ -518,10 +449,9 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
518449
const int k_start_idx = crd2idx((*k_tile_iter), make_shape(K));
519450
int prefetch_k = k_start_idx;
520451

521-
#if 1
522452
const int k_reload_factor = ceil_div(params.group_size, BLK_K);
523-
if(cute::thread0()) printf("params.group_size = %d, BLK_K = %d, k_reload_factor = %f\n",params.group_size, BLK_K, k_reload_factor);
524-
#endif
453+
//if(cute::thread0()) printf("params.group_size = %d, BLK_K = %d, k_reload_factor = %f\n",params.group_size, BLK_K, k_reload_factor);
454+
525455
CUTLASS_PRAGMA_UNROLL
526456
for (int i = 0; i < DispatchPolicy::Stages; i++, prefetch_k++) {
527457
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
@@ -534,19 +464,11 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
534464
// Copy gmem to rmem for the first k_tile
535465
copy(tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
536466
copy(tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B);
537-
#if 1
538-
const int s_step = k_start_idx + (k_s / k_reload_factor); //1 + k_tile / k_reload_factor;
539-
if(cute::thread0()) printf("k_start_idx = %d, k_s = %d, k_reload_factor = %f, s_step = %d\n",k_start_idx, k_s, k_reload_factor, s_step);
540-
copy(tiled_copy_scale, copy_iter_s(_, _, _, s_step), frag_copy_Scale);
541-
#else
542-
const int k_reload_factor = ceil_div(params.group_size, BLK_K);
543-
//const int k_reload_factor = params.group_size / BLK_K;
544467

545-
//if(cute::thread0())
546-
printf("params.group_size = %d, BLK_K = %d, k_reload_factor = %d\n",params.group_size, BLK_K, k_reload_factor);
468+
const int s_step = k_start_idx + (k_s / k_reload_factor);
469+
//if(cute::thread0()) printf("k_start_idx = %d, k_s = %d, k_reload_factor = %f, s_step = %d\n",k_start_idx, k_s, k_reload_factor, s_step);
470+
copy(tiled_copy_scale, copy_iter_s(_, _, _, s_step), frag_copy_Scale);
547471

548-
copy(tiled_copy_scale, copy_iter_s(_, _, _, k_tile / k_reload_factor), frag_copy_Scale);
549-
#endif
550472
if(prefetch_k < k_tile_count) {
551473
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
552474
}
@@ -617,41 +539,16 @@ template <typename T, int BITS>
617539
void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned char *B,
618540
T *absmax_, float *datatype, float *out, int lda,
619541
int ldb, int ldc, int blocksize, sycl::queue *stream) {
620-
std::cout<<"this is gemm_4bit_inference_cutlass_dequant ......................!!!!!!\n";
542+
//std::cout<<"this is gemm_4bit_inference_cutlass_dequant ......................!!!!!!\n";
621543

622544
sycl::queue q = *stream;
623545
using GemmKernel = kgemm_4bit_inference_cutlass_dequant<T, BITS>;
624546

625547
static constexpr int smem_size= 512; // (16 * 32) for quant_map
626548
int l = 1;
627549

628-
//TODO(Xiaoli): FIX ME?? auto problem_size = ProblemShape{m, n, k};
629550
auto problem_size = ProblemShape{m, n, k, l};
630-
//TODO(Xiaoli): FIX ME
631-
// T* absmax = (T*)absmax_;
632-
// T* absmax = (T*)absmax_;
633551

634-
//std::vector<T> host_data(n * k / blocksize);
635-
#if 0
636-
int element_size_A = m * k;
637-
auto scale_host_A = sycl::aligned_alloc_host<T>(512, element_size_A, q);
638-
q.memcpy(scale_host_A, A, element_size_A * sizeof(T)).wait();
639-
for (int i = 0; i < element_size_A; ++i) {
640-
//std::cout << scale_host[i] << " ";
641-
printf("%f ",static_cast<float>(scale_host_A[i]));
642-
}
643-
std::cout << std::endl;
644-
645-
int element_size = n * k / blocksize;
646-
auto scale_host = sycl::aligned_alloc_host<T>(512, element_size, q);
647-
q.memcpy(scale_host, absmax_, element_size * sizeof(T)).wait();
648-
for (int i = 0; i < element_size; ++i) {
649-
//std::cout << scale_host[i] << " ";
650-
printf("%f ",static_cast<float>(scale_host[i]));
651-
}
652-
std::cout << std::endl;
653-
#endif
654-
#if 1
655552
// Init Params
656553
using Params = GemmKernel::Params;
657554
Params params;
@@ -678,7 +575,7 @@ std::cout << std::endl;
678575

679576
const int scale_k = cute::ceil_div(k, blocksize);
680577
StrideScale stride_S = cutlass::make_cute_packed_stride(StrideScale{}, cute::make_shape(n, scale_k, l));
681-
std::cout<<"n = "<<n<<" k = "<<k<<" blocksize = "<<blocksize<<" scale_k = "<<scale_k<<std::endl;
578+
//std::cout<<"n = "<<n<<" k = "<<k<<" blocksize = "<<blocksize<<" scale_k = "<<scale_k<<std::endl;
682579
auto mScale = make_tensor(
683580
make_gmem_ptr(absmax_),
684581
make_layout(make_shape(n, scale_k, l), stride_S));
@@ -694,6 +591,7 @@ std::cout << std::endl;
694591
StrideC stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, l));
695592
StrideD stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, l));
696593

594+
#if 0
697595
#define PRINT(x) print(#x ": "); print(x); print("\n");
698596
if (cutlass::thread(LOG_THREAD, LOG_GROUP)) {
699597
print("===================== stride :\n");
@@ -705,6 +603,7 @@ std::cout << std::endl;
705603
print("===================== stride :\n");
706604
}
707605
#undef PRINT
606+
#endif
708607

709608
params.hw_info = hw_info;
710609
params.epilogue = CollectiveEpilogue::to_underlying_arguments(problem_size, {{alpha, beta}, nullptr, stride_C, out, stride_D}, nullptr);
@@ -721,8 +620,8 @@ std::cout << std::endl;
721620

722621
const syclcompat::dim3 sycl_block(block.x, block.y, block.z); //workgroup_size: 1*2*1*16, 1, 1
723622
const syclcompat::dim3 sycl_grid(grid.x, grid.y, grid.z); //workgroup_number (problem_size / tile_size): N/64, M/16, 1
724-
printf("Host Grid: (%d, %d, %d)\n", grid.x, grid.y, grid.z);
725-
printf("Host Block: (%d, %d, %d)\n", block.x, block.y, block.z);
623+
//printf("Host Grid: (%d, %d, %d)\n", grid.x, grid.y, grid.z);
624+
//printf("Host Block: (%d, %d, %d)\n", block.x, block.y, block.z);
726625

727626
auto kernel_props = [] {
728627
return syclcompat::experimental::kernel_properties{
@@ -739,7 +638,6 @@ std::cout << std::endl;
739638
auto event = syclcompat::experimental::launch<device_kernel<GemmKernel>>(policy, q, params);
740639
EventManager::getInstance().addEvent(event);
741640
//syclcompat::wait();
742-
#endif
743641
}
744642

745643
template void gemm_4bit_inference_cutlass_dequant<sycl::ext::oneapi::bfloat16, 16>(

tests/test_xpu.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class TestXPU:
4040
@pytest.mark.parametrize("device", ["xpu"])#get_available_devices())
4141
@pytest.mark.parametrize("double_quant", [True], ids=lambda double_quant: f"DQ_{double_quant}")
4242
@pytest.mark.parametrize("storage_type", ["nf4"])
43-
@pytest.mark.parametrize("kind", ["fc0"])#, "attn_packed"])
43+
@pytest.mark.parametrize("kind", ["fc1"])#, "attn_packed"])
4444
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=describe_dtype)
4545
@pytest.mark.parametrize(
4646
"quant_storage",
@@ -83,10 +83,10 @@ def test_gemm_4bit(self, device, dim, dtype, storage_type, quant_storage, double
8383
double_quant=False
8484
block_size = 16
8585
elif kind == "fc1":
86-
dim=256
87-
A = torch.randn(32, dim, dtype=dtype, device=device) * 10
86+
dim=4096
87+
A = torch.randn(64, dim, dtype=dtype, device=device)
8888
#A = torch.arange(1, 32 * 256 + 1).reshape(32, 256).bfloat16().xpu()
89-
B = torch.randn(dim, dim, dtype=dtype, device=device) # / math.sqrt(dim)
89+
B = torch.randn(dim , dim, dtype=dtype, device=device) / math.sqrt(dim)
9090
double_quant=False
9191
block_size = 32
9292
elif kind == "fc2":
@@ -133,18 +133,18 @@ def test_gemm_4bit(self, device, dim, dtype, storage_type, quant_storage, double
133133
#C1 = bnb.matmul_4bit(A, qB.t(), state)
134134
else:
135135
pdb.set_trace()
136-
print("")
137-
print("absmax = ", state.absmax)
138-
print("A[0] = ",A[0])
139-
print("B[0] = ",B[0])
136+
#print("")
137+
#print("absmax = ", state.absmax)
138+
#print("A[0] = ",A[0])
139+
#print("B[0] = ",B[0])
140140
C3 = torch.matmul(A, B.t())
141141
#pdb.set_trace()
142142
C2 = F.gemv_4bit(A, qB.t(), state=state)
143-
#pdb.set_trace()
144-
print("C3.sum() = ", C3.sum())
145-
print("C2.sum() = ", C2.sum())
146-
diff = abs(C2-C3)
147-
print("diff = ", diff.sum())
143+
pdb.set_trace()
144+
#print("C3.sum() = ", C3.sum())
145+
#print("C2.sum() = ", C2.sum())
146+
diff = abs(C2-C3.bfloat16())
147+
print("diff = ", diff[0])
148148
print(C3[0])
149149
print(C2[0])
150150
#print(C3)

0 commit comments

Comments
 (0)