Skip to content

Commit efda0eb

Browse files
committed
refine code
1 parent d361323 commit efda0eb

2 files changed

Lines changed: 26 additions & 11 deletions

File tree

bitsandbytes/backends/xpu/ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ def _gemv_4bit_impl(
8484
ldb = ct.c_int32((A.shape[-1] + 1) // 2)
8585
ldc = m
8686

87+
absmax = absmax * 10
88+
pdb.set_trace()
89+
8790
stream = _get_tensor_stream(A)
8891
if A.dtype == torch.float16:
8992
lib.cgemv_4bit_inference_fp16(
@@ -108,7 +111,7 @@ def _gemv_4bit_impl(
108111
k,
109112
get_ptr(A),
110113
get_ptr(B),
111-
get_ptr(absmax),
114+
get_ptr(absmax.bfloat16()),
112115
get_ptr(code),
113116
get_ptr(out),
114117
lda,

csrc/xpu_cutlass_fusion.cpp

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ using atom_load_B = Copy_Atom<traits_load_B, ElementB>;
145145
using val_layout_load_B = decltype(make_layout(shape_div(typename traits_load_B::BlockShape{}, CopyThreadShape{})));
146146
using Copy_B = decltype(make_tiled_copy(atom_load_B{}, Layout<CopyThreadShape>{}, val_layout_load_B{}));
147147

148+
//using GmemTiledCopyScale = XE_2D_U16x1x32_LD_N; //XE_2D_U16x1x16_LD_N;
148149
using GmemTiledCopyScale = XE_2D_U16x1x16_LD_N;
149150
static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
150151
using StrideScale = cute::Stride<_1, int64_t, int64_t>; //dynamic stride
@@ -171,7 +172,7 @@ class kgemm_4bit_inference_cutlass_dequant {
171172
T* A;
172173
uint8_t* B;
173174
float* out;
174-
T *absmax;
175+
//T *absmax;
175176
float *datatype; //LUT
176177
int group_size;
177178

@@ -228,7 +229,7 @@ class kgemm_4bit_inference_cutlass_dequant {
228229
using SrcType = typename EngineIn::value_type;
229230
using DstType = typename EngineOut::value_type;
230231
using ScaleType = typename EngineScales::value_type;
231-
#if 0
232+
#if 1
232233
int numbers = decltype(size(in))::value;
233234
for(int i=0; i<numbers; i++){
234235
//auto in_ptr_8 = (uint8_t*)(raw_pointer_cast(in.data()));
@@ -240,6 +241,11 @@ class kgemm_4bit_inference_cutlass_dequant {
240241
//if(syclcompat::global_id::x() == 2 && syclcompat::global_id::y() ==0 && syclcompat::global_id::z() ==0 )
241242
//printf("syclcompat::global_id::x() = %d, syclcompat::global_id::y() = %d, syclcompat::global_id::z() = %d, thread_idx = %d, i = %d, in[i].ptr_ = %x, in[i].idx_=%x, value_bit = %x, value = %d, quant_map[value] = %f, out[i] = %f\n",syclcompat::global_id::x(), syclcompat::global_id::y(), syclcompat::global_id::z(), thread_idx, i, in[i].ptr_, in[i].idx_, value, static_cast<int>(value), quant_map[value], static_cast<float>(out[i]));
242243
}
244+
int scale_number = decltype(size(tCrS_input))::value;
245+
for(int i=0; i<scale_number; i++){
246+
auto s_value = tCrS_input[i];
247+
if(cute::thread0()) printf("scale_number = %d, tCrS_input[%d] = %f\n",scale_number, i, static_cast<float>(s_value));
248+
}
243249
}
244250
#else
245251
static constexpr auto N = decltype(size<1>(in))::value;
@@ -275,8 +281,8 @@ class kgemm_4bit_inference_cutlass_dequant {
275281
276282
for (int i = 0; i < vec_size; i++) {
277283
uint8_t value = (format_data >> (src_bits * i)) & 0xf;
278-
dst[i] = (static_cast<DstType>(quant_map[value]));// * ts;
279-
//if(cute::thread0()) printf("n = %d, s = %d, i = %d, src = %d, dst = %f\n", n, s, i, static_cast<int>(value), static_cast<float>(dst[i]));
284+
dst[i] = static_cast<DstType>(quant_map[value] * static_cast<float>(ts));
285+
if(cute::thread0()) printf("n = %d, s = %d, i = %d, src = %d, quant_map[value] = %f, ts = %f, dst = %f\n", n, s, i, static_cast<int>(value), quant_map[value], static_cast<float>(ts), static_cast<float>(dst[i]));
280286
}
281287
}
282288
}
@@ -334,11 +340,9 @@ class kgemm_4bit_inference_cutlass_dequant {
334340
l_coord = BlockIdxZ();
335341
}
336342
auto blk_coord_mnkl = make_coord(m_coord, n_coord, _, l_coord);
337-
if(0){//cute::thread0()) {
343+
if(cute::thread0()) {
338344
printf("M = %d, N=%d, K=%d, L=%d\n", M, N, K, L);
339-
//}
340345
printf("thread_idx = %d, m_coord = %d, n_coord = %d, l_coord = %d, BlockIdxX() = %d, BlockIdxY() = %d, BlockIdxZ() = %d\n",thread_idx, m_coord, n_coord, l_coord, BlockIdxX(), BlockIdxY(), BlockIdxZ());
341-
342346
}
343347
constexpr auto workgroup_shape = WorkgroupTileShape{}; //256, 256, 32
344348
constexpr auto subgroup_tile_shape = SubgroupTileShape{}; //32, 64, 32 (number of atom level workgroup: 256/8=32, 256/4=64, 32/2=32)
@@ -383,7 +387,8 @@ class kgemm_4bit_inference_cutlass_dequant {
383387
static constexpr auto scale_traits_num = SG_QNT_WIDTH / size<1>(typename GmemTiledCopyScale::BlockShape{});
384388
using FragScaleLayout = Layout<Shape<Int<scale_traits_size>, Int<scale_traits_num>, _1>>;
385389
Tensor fragment_scale = make_tensor<ElementScale>(FragScaleLayout{});
386-
390+
if(cute::thread0()) printf("scale_traits_size = %d, scale_traits_num = %d, SG_QNT_WIDTH = %d\n", scale_traits_size, scale_traits_num, SG_QNT_WIDTH);
391+
387392
static_assert(std::is_same_v<typename decltype(dequant_frag)::value_type, ElementQuant>);
388393
static_assert(std::is_same_v<typename decltype(mma_A)::value_type, ElementMMA>);
389394
static_assert(std::is_same_v<typename decltype(mma_B)::value_type, ElementMMA>);
@@ -412,6 +417,8 @@ class kgemm_4bit_inference_cutlass_dequant {
412417
const int n_coord_s = n_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N;
413418
const int l_coord_s = l_idx;
414419

420+
if(cute::thread0()) printf("m_idx = %d, n_idx = %d, k_idx = %d, l_idx = %d, n_coord_s = %d, l_coord_s = %d\n",m_idx, n_idx, k_idx, l_idx, n_coord_s, l_coord_s);
421+
415422
auto copy_iter_s = [&](){
416423
return make_tensor(make_inttuple_iter(make_coord(n_coord_s, 0, l_coord_s)),
417424
make_layout(make_shape(Int<scale_traits_size>{}, Int<scale_traits_num>{}, _1{}, k_tile_count),
@@ -436,6 +443,10 @@ class kgemm_4bit_inference_cutlass_dequant {
436443
print(" frag_copy_B : "); print(frag_copy_B); print("\n");
437444
print(" dequant_frag : "); print(dequant_frag); print("\n");
438445

446+
print("===================== D :\n");
447+
print(" frag_copy_ScaleB : "); print(frag_copy_Scale); print("\n");
448+
print(" copy_iter_s: "); print(copy_iter_s); print("\n");
449+
439450
print("===================== D :\n");
440451
print(" accumulators : "); print(accumulators); print("\n");
441452

@@ -468,6 +479,8 @@ class kgemm_4bit_inference_cutlass_dequant {
468479

469480
const int k_reload_factor = params.group_size / BLK_K;
470481

482+
if(cute::thread0()) printf("params.group_size = %d, BLK_K = %d, k_reload_factor = %d\n",params.group_size, BLK_K, k_reload_factor);
483+
471484
copy(tiled_copy_scale, copy_iter_s(_, _, _, k_tile / k_reload_factor), frag_copy_Scale);
472485

473486
if(prefetch_k < k_tile_count) {
@@ -574,7 +587,6 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
574587
auto mA_mkl = make_tensor(make_gmem_ptr(A), make_layout(make_shape(m, k, l), stride_A));
575588
Copy_A tiled_copy_a{Copy_A{}.with(mA_mkl)};
576589

577-
//StrideB stride_B = make_stride(int64_t{n}, cute::Int<1>{}, int64_t{0});
578590
StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, l));
579591
auto mB_nkl = make_tensor(cute::subbyte_iterator<ElementB>(B), make_layout(make_shape(n, k, l), stride_B));
580592
Copy_B tiled_copy_b{Copy_B{}.with(mB_nkl)};
@@ -595,7 +607,7 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
595607
std::cout<<"n = "<<n<<" k = "<<k<<" blocksize = "<<blocksize<<" scale_k = "<<scale_k<<std::endl;
596608

597609
auto mScale = make_tensor(
598-
make_gmem_ptr(absmax_), //static_cast<ElementScale *>(absmax)),
610+
make_gmem_ptr(reinterpret_cast<ElementScale *>(absmax_)),
599611
make_layout(make_shape(n, scale_k, l), stride_S));
600612
Copy_Scale tiled_copy_scale = {Copy_Scale{}.with(mScale)};
601613

0 commit comments

Comments
 (0)