Skip to content

Commit d238a6a

Browse files
committed
refine code
1 parent 572ef63 commit d238a6a

3 files changed

Lines changed: 54 additions & 19 deletions

File tree

bitsandbytes/backends/xpu/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _gemv_4bit_impl(
8484
ldb = ct.c_int32((A.shape[-1] + 1) // 2)
8585
ldc = m
8686

87-
absmax = absmax * 10
87+
#absmax = absmax * 10
8888
pdb.set_trace()
8989

9090
stream = _get_tensor_stream(A)

csrc/xpu_cutlass_fusion.cpp

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,6 @@ using val_layout_load_B = decltype(make_layout(shape_div(typename traits_load_B:
146146
using Copy_B = decltype(make_tiled_copy(atom_load_B{}, Layout<CopyThreadShape>{}, val_layout_load_B{}));
147147

148148
using GmemTiledCopyScale = XE_2D_U16x1x32_LD_N; //XE_2D_U16x1x16_LD_N;
149-
//using GmemTiledCopyScale = XE_2D_U16x1x16_LD_N;
150149
static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
151150
using StrideScale = cute::Stride<_1, int64_t, int64_t>; //dynamic stride
152151
using traits_load_scale = Copy_Traits<GmemTiledCopyScale, StrideScale>;
@@ -245,7 +244,7 @@ class kgemm_4bit_inference_cutlass_dequant {
245244
int scale_number = decltype(size(tCrS_input))::value;
246245
for(int i=0; i<scale_number; i++){
247246
auto s_value = tCrS_input(i);
248-
if(cute::thread0()) printf("scale_number = %d, tCrS_input[%d] = %f\n",scale_number, i, s_value);
247+
if(cute::thread0()) printf("scale_number = %d, tCrS_input[%d] = %f\n",scale_number, i, static_cast<float>(s_value));
249248
}
250249
#else
251250
static constexpr auto N = decltype(size<1>(in))::value;
@@ -298,6 +297,21 @@ class kgemm_4bit_inference_cutlass_dequant {
298297
int K = params.k;
299298
int L = 1;
300299

300+
const int BLK_M = 16;
301+
const int BLK_N = 64;
302+
const int BLK_K = 64;
303+
304+
const int ATOM_M = 1;
305+
const int ATOM_N = 2;
306+
const int ATOM_K = 1;
307+
308+
const int SG_M = ceil_div(BLK_M, ATOM_M);
309+
const int SG_N = ceil_div(BLK_N, ATOM_N);
310+
const int SG_K = ceil_div(BLK_K, ATOM_K);
311+
312+
const int Num_SGs = ATOM_N * ATOM_M * ATOM_K;
313+
static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
314+
301315
T* A = params.A;
302316
uint8_t* B = params.B;
303317
float* out = params.out;
@@ -383,9 +397,8 @@ class kgemm_4bit_inference_cutlass_dequant {
383397

384398
Tensor dequant_frag = make_tensor<ElementB>(mma_B.layout());
385399

386-
const int SubgroupSize = 16;
387-
const int SG_QNT_WIDTH = 32;
388-
static constexpr auto scale_traits_size = decltype(size(typename GmemTiledCopyScale::BlockShape{}))::value / SubgroupSize;
400+
//const int SubgroupSize = 16;
401+
static constexpr auto scale_traits_size = decltype(size(typename GmemTiledCopyScale::BlockShape{}))::value / DispatchPolicy::SubgroupSize; //SubgroupSize;
389402
static constexpr auto scale_traits_num = SG_QNT_WIDTH / decltype(size<1>(typename GmemTiledCopyScale::BlockShape{}))::value;
390403
using FragScaleLayout = Layout<Shape<Int<scale_traits_size>, Int<scale_traits_num>, _1>>;
391404
Tensor fragment_scale = make_tensor<ElementScale>(FragScaleLayout{});
@@ -405,7 +418,6 @@ class kgemm_4bit_inference_cutlass_dequant {
405418
Tensor tBgB = thr_copy_B.retile_S(tCgB);
406419

407420
//// Prepare for prefetch
408-
const int BLK_K = 64;
409421
auto tiled_prefetch_a = cute::prefetch_selector<Shape<Int<BLK_M>,Int<BLK_K>>, Num_SGs>(tiled_copy_a);;
410422
auto tiled_prefetch_b = cute::prefetch_selector<Shape<Int<BLK_N>,Int<BLK_K>>, Num_SGs>(tiled_copy_b);;
411423
auto thr_prefetch_A = tiled_prefetch_a.get_slice(thread_idx);
@@ -416,21 +428,22 @@ class kgemm_4bit_inference_cutlass_dequant {
416428
auto pBgB = thr_prefetch_B.partition_S(gB);
417429

418430
// Run mainloop
419-
const int BLK_N = 64;
420-
const int ATOM_N = 2;
421-
const int SG_N = 32;
422-
auto [m_idx, n_idx, k_idx, l_idx] = blk_coord_mnkl;
423-
const int n_coord_s = n_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N;
424-
const int l_coord_s = l_idx;
431+
//auto [m_idx, n_idx, k_idx, l_idx] = blk_coord_mnkl;
432+
//const int n_coord_s = n_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N;
433+
//const int l_coord_s = l_idx;
425434

426-
if(cute::thread0()) printf("get_sub_group_id() = %d, m_idx = %d, n_idx = %d, k_idx = %d, l_idx = %d, n_coord_s = %d, l_coord_s = %d\n",get_sub_group_id(), m_idx, n_idx, k_idx, l_idx, n_coord_s, l_coord_s);
435+
//if(cute::thread0()) printf("get_sub_group_id() = %d, m_idx = %d, n_idx = %d, k_idx = %d, l_idx = %d, n_coord_s = %d, l_coord_s = %d\n",get_sub_group_id(), m_idx, n_idx, k_idx, l_idx, n_coord_s, l_coord_s);
427436

428437
auto copy_iter_s = [&](){
429-
return make_tensor(make_inttuple_iter(make_coord(n_coord_s, 0, l_coord_s)),
438+
return make_tensor(make_inttuple_iter(make_coord(n_coord, 0, l_coord)),
430439
make_layout(make_shape(Int<scale_traits_size>{}, Int<scale_traits_num>{}, _1{}, k_tile_count),
431440
make_stride(E<0>{} * _16{}, E<0>{} * decltype(size<1>(typename GmemTiledCopyScale::BlockShape{}))::value, _0{}, E<1>{} * _1{})));
432441

433442
}();
443+
444+
//using ExpectedLayout = typename decltype(tiled_copy_scale)::TiledLayout::dst_layout; //decltype(tiled_copy_scale.dst_layout()); //decltype(tiled_copy_scale.atom_layout_dst());
445+
//static_assert(is_same<decltype(frag_copy_Scale.layout()), ExpectedLayout>::value, "布局不匹配");
446+
434447
#if 1
435448
#define PRINT(x) print(#x ": "); print(x); print("\n");
436449
if (cutlass::thread(LOG_THREAD, LOG_GROUP)) {
@@ -450,7 +463,8 @@ class kgemm_4bit_inference_cutlass_dequant {
450463
print(" dequant_frag : "); print(dequant_frag); print("\n");
451464

452465
print("===================== D :\n");
453-
print(" frag_copy_ScaleB : "); print(frag_copy_Scale); print("\n");
466+
print(" tiled_copy_scale : "); print(tiled_copy_scale); print("\n");
467+
print(" frag_copy_Scale : "); print(frag_copy_Scale); print("\n");
454468
print(" copy_iter_s: "); print(copy_iter_s); print("\n");
455469

456470
print("===================== D :\n");
@@ -484,6 +498,7 @@ class kgemm_4bit_inference_cutlass_dequant {
484498
copy(tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B);
485499

486500
const int k_reload_factor = ceil_div(params.group_size, BLK_K);
501+
//const int k_reload_factor = params.group_size / BLK_K;
487502

488503
if(cute::thread0()) printf("params.group_size = %d, BLK_K = %d, k_reload_factor = %d\n",params.group_size, BLK_K, k_reload_factor);
489504

@@ -575,6 +590,26 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
575590
// T* absmax = (T*)absmax_;
576591
// T* absmax = (T*)absmax_;
577592

593+
//std::vector<T> host_data(n * k / blocksize);
594+
#if 0
595+
int element_size_A = m * k;
596+
auto scale_host_A = sycl::aligned_alloc_host<T>(512, element_size_A, q);
597+
q.memcpy(scale_host_A, A, element_size_A * sizeof(T)).wait();
598+
for (int i = 0; i < element_size_A; ++i) {
599+
//std::cout << scale_host[i] << " ";
600+
printf("%f ",static_cast<float>(scale_host_A[i]));
601+
}
602+
std::cout << std::endl;
603+
604+
int element_size = n * k / blocksize;
605+
auto scale_host = sycl::aligned_alloc_host<T>(512, element_size, q);
606+
q.memcpy(scale_host, absmax_, element_size * sizeof(T)).wait();
607+
for (int i = 0; i < element_size; ++i) {
608+
//std::cout << scale_host[i] << " ";
609+
printf("%f ",static_cast<float>(scale_host[i]));
610+
}
611+
std::cout << std::endl;
612+
#endif
578613
#if 1
579614
// Init Params
580615
using Params = GemmKernel::Params;

tests/test_xpu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def test_gemm_4bit(self, device, dim, dtype, storage_type, quant_storage, double
6767
#pdb.set_trace()
6868
if kind == "fc1":
6969
A = torch.ones(32, dim, dtype=dtype, device=device)
70-
B = torch.ones(dim, dim, dtype=dtype, device=device) # / math.sqrt(dim)
70+
B = torch.ones(dim, dim, dtype=dtype, device=device) / math.sqrt(dim)
7171
elif kind == "fc2":
7272
A = torch.randn(1, 4 * dim, dtype=dtype, device=device)
7373
B = torch.randn(dim, 4 * dim, dtype=dtype, device=device) / math.sqrt(dim)
@@ -83,13 +83,13 @@ def test_gemm_4bit(self, device, dim, dtype, storage_type, quant_storage, double
8383
quant_type=storage_type,
8484
compress_statistics=double_quant,
8585
quant_storage=quant_storage,
86-
blocksize=32,
86+
blocksize=64,
8787
)
8888

8989
##pdb.set_trace()
9090
C3 = torch.matmul(A, B.t())
9191
#pdb.set_trace()
92-
C2 = F.gemv_4bit(A, qB, state=state)
92+
C2 = F.gemv_4bit(A, qB.t(), state=state)
9393
#pdb.set_trace()
9494
print("C3.sum() = ", C3.sum())
9595
print("C2.sum() = ", C2.sum())

0 commit comments

Comments
 (0)