Skip to content

Commit 4c933c6

Browse files
committed
refine code
1 parent 4b8f1b0 commit 4c933c6

2 files changed

Lines changed: 51 additions & 51 deletions

File tree

bitsandbytes/backends/xpu/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _gemv_4bit_impl(
7676
) -> None:
7777
import pdb
7878
pdb.set_trace()
79-
m = 1 #ct.c_int32(*A.shape[:-1])
79+
m = ct.c_int32(*A.shape[:-1])
8080
n = ct.c_int32(shapeB[0])
8181
k = ct.c_int32(shapeB[1])
8282

csrc/xpu_cutlass_fusion.cpp

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ using TiledMma =
6969
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
7070

7171
// Define Mainloop dispatch policy
72-
constexpr int PipelineStages = 3;
72+
constexpr int PipelineStages = 0;
7373
using DispatchPolicy = cutlass::gemm::MainloopIntelPVCMixedPrecision<PipelineStages>;
7474
static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; // sub_group size
7575

@@ -140,7 +140,7 @@ using GmemTiledCopyC = CopyOpG2R;
140140
using GmemTiledCopyD = cute::conditional_t<not cute::is_void_v<ElementD> && not cute::is_void_v<CopyOpR2G>,
141141
CopyOpR2G, XE_2D_U32x8x16_ST_N>;
142142

143-
//TODO(Xiaoli): Maybe legacy, refine me.
143+
// Calculate subgroup_tile_shape (reminder: not the same thing with "subgroup_size" in sycl!!)
144144
static constexpr auto BLK_M = get<0>(WorkgroupTileShape{});
145145
static constexpr auto BLK_N = get<1>(WorkgroupTileShape{});
146146
static constexpr auto BLK_K = get<2>(WorkgroupTileShape{});
@@ -174,16 +174,14 @@ class kgemm_4bit_inference_cutlass_dequant {
174174
int m, n, k;
175175
T* A;
176176
uint8_t* B;
177-
float *absmax; //TODO(Xiaoli): FIX ME
178177
float* out;
179-
float *datatype;
178+
float *datatype; //LUT
180179

181-
//GemmUniversalMode mode{};
182180
ProblemShape problem_shape{};
183-
184-
//inloopParams mainloop{};
181+
185182
Copy_A tiled_copy_a;
186183
Copy_B tiled_copy_b;
184+
Copy_B tiled_copy_b_4bit;
187185
Copy_Scale tiled_copy_scale;
188186
int group_size;
189187

@@ -309,45 +307,41 @@ class kgemm_4bit_inference_cutlass_dequant {
309307

310308
CUTLASS_DEVICE
311309
void operator()(Params const& params, char* smem_buf) {
310+
if(cute::thread0()) printf("this is fusion kernel...........\n");
311+
312312
int M = params.m;
313313
int N = params.n;
314314
int K = params.k;
315315
T* A = params.A;
316316
uint8_t* B = params.B;
317317
float* out = params.out;
318318
float* datatype = params.datatype;
319-
//int blocksize = params.blocksize;
320319
auto tiled_copy_a = params.tiled_copy_a;
321320
auto tiled_copy_b = params.tiled_copy_b;
322-
auto tiled_copy_scale = params.tiled_copy_scale;
323-
if(cute::thread0())
324-
printf("this is fusion kernel...........\n");
321+
auto tiled_copy_b_4bit = params.tiled_copy_b_4bit;
322+
auto tiled_copy_scale = params.tiled_copy_scale;
323+
325324
int L = 1;
326325
auto problem_size = ProblemShape{M, N, K, L};
327-
328-
//TODO(Xiaoli): FIX ME
329-
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
330326

331-
float* quant_map = reinterpret_cast<float*>(smem_buf);
332327
// Preconditions
333328
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
334329
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
335330
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
336331
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
337332

338-
// Get the appropriate blocks for this sub_group -- potential for sub_group locality
339333
int thread_idx = int(ThreadIdxX());
340-
//#if 0
341-
//Load Dequat table
334+
335+
// Load Dequatize LUT and save to SLM, 16 for 4bits
336+
float* quant_map = reinterpret_cast<float*>(smem_buf);
342337
if (thread_idx < 16) {
343-
quant_map[thread_idx] = datatype[thread_idx]; //T(datatype[thread_idx]);
338+
quant_map[thread_idx] = datatype[thread_idx];
344339
printf("quant_map[thread_idx] = %f\n", quant_map[thread_idx]);
345340
}
346341
barrier_wait(1);
347342

348-
#if 1
349-
auto blk_shape = TileShape{};
350-
int m_coord, n_coord, l_coord;
343+
auto blk_shape = TileShape{}; //256,256,32
344+
int m_coord, n_coord, l_coord; //block index
351345
if (params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN) {
352346
if(cute::thread0()) printf("AlongN !!\n");
353347
m_coord = BlockIdxY();
@@ -359,25 +353,23 @@ class kgemm_4bit_inference_cutlass_dequant {
359353
n_coord = BlockIdxY();
360354
l_coord = BlockIdxZ();
361355
}
356+
auto blk_coord_mnkl = make_coord(m_coord, n_coord, _, l_coord);
362357
if(cute::thread0()) printf("M = %d, N=%d, K=%d, L=%d, m_coord = %d, n_coord = %d, l_coord = %d, BlockIdxX() = %d, BlockIdxY() = %d, BlockIdxZ() = %d\n",M, N, K, L, m_coord, n_coord, l_coord, BlockIdxX(), BlockIdxY(), BlockIdxZ());
363358

364-
auto blk_coord_mnkl = make_coord(m_coord, n_coord, _, l_coord);
365-
constexpr auto workgroup_shape = WorkgroupTileShape{};
366-
constexpr auto subgroup_shape = SubgroupTileShape{};
367-
if(cute::thread0())
368-
printf("BLK_M = %d, BLK_N = %d, BLK_K = %d, ATOM_M = %d, ATOM_N = %d, ATOM_K = %d, SG_M = %d, SG_N = %d, SG_K = %d\n", BLK_M, BLK_N, BLK_K, ATOM_M, ATOM_N, ATOM_K, SG_M, SG_N, SG_K);
359+
constexpr auto workgroup_shape = WorkgroupTileShape{}; //256, 256, 32
360+
constexpr auto subgroup_tile_shape = SubgroupTileShape{}; // 256/8=32, 256/16=16, 32/16=2
369361

370-
Tensor mA_mkl = cute::get_pvc_tensor(make_shape(M,K,L)); //(m,k,l)
371-
Tensor mB_nkl = cute::get_pvc_tensor(make_shape(N,K,L)); //(n,k,l)
362+
Tensor mA_mkl = cute::get_pvc_tensor(make_shape(M,K,L)); //coordinate tensor: 0,1,2....
363+
Tensor mB_nkl = cute::get_pvc_tensor(make_shape(N,K,L)); //coordinate tensor: 0,1,2....
372364

373365
Tensor gA = local_tile(mA_mkl, select<0,2>(blk_shape), make_coord(m_coord,_,l_coord));
374366
Tensor gB = local_tile(mB_nkl, select<1,2>(blk_shape), make_coord(n_coord,_,l_coord));
375367

376-
// Allocate the tiled_mma and the accumulators for the (M,N) subgroup_shape
368+
// Allocate the tiled_mma and the accumulators for the (M,N) subgroup_tile_shape
377369
TiledMma tiled_mma;
378370

379-
auto expanded_shape = replace<1>(blk_shape, cute::C<2>{} * get<1>(blk_shape));
380-
Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(expanded_shape));
371+
//auto expanded_shape = replace<1>(blk_shape, cute::C<2>{} * get<1>(blk_shape));
372+
Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape));
381373
clear(accumulators);
382374

383375
auto k_tile_iter = cute::make_coord_iterator(idx2crd(0, make_shape(K)), make_shape(K));
@@ -387,6 +379,7 @@ class kgemm_4bit_inference_cutlass_dequant {
387379
//Run MainLoop
388380
auto thr_copy_A = tiled_copy_a.get_slice(thread_idx);
389381
auto thr_copy_B = tiled_copy_b.get_slice(thread_idx);
382+
auto thr_copy_B_4bit = tiled_copy_b_4bit.get_slice(thread_idx);
390383
auto thr_copy_scale = tiled_copy_scale.get_slice(thread_idx);
391384

392385
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
@@ -397,39 +390,39 @@ class kgemm_4bit_inference_cutlass_dequant {
397390
Tensor tCgA = thr_mma.partition_A(gA);
398391
Tensor tCgB = thr_mma.partition_B(gB);
399392

400-
// Create fragments
393+
// Create fragments
401394
Tensor mma_A = make_tensor<ElementMMA>(make_fragment_layout(tiled_copy_a, tCgA(_,_,_,0).shape()));
402395
Tensor mma_B = make_tensor<ElementMMA>(make_fragment_layout(tiled_copy_b, tCgB(_,_,_,0).shape()));
403396

404397
using FragScaleLayout = Layout<Shape<_2, _2, _1>>;
405398
Tensor fragment_scale_input = make_tensor<NonVoidElementScale>(FragScaleLayout{});
406399

407400
// narrow input fragment
408-
Tensor quant_frag = make_tensor<ElementQuant>(decltype(mma_B.layout()){});
401+
Tensor mma_B_4bit = make_tensor<ElementMMA>(make_fragment_layout(tiled_copy_b_4bit, tCgB(_,_,_,0).shape()));
402+
Tensor quant_frag = make_tensor<ElementQuant>(decltype(mma_B_4bit.layout()){});
409403

410-
auto original_shape = tCgB(_,_,_,0).shape();
411-
auto expanded_shape_2 = make_shape(cute::get<0>(original_shape), cute::C<2>{} * cute::get<1>(original_shape),cute::get<2>(original_shape));
412-
auto expanded_layout = make_fragment_layout(tiled_copy_b, expanded_shape_2);
413-
Tensor mma_B_expanded = make_tensor<ElementMMA>(expanded_layout);
404+
//auto original_shape = tCgB(_,_,_,0).shape();
405+
//auto expanded_shape_2 = make_shape(cute::get<0>(original_shape), cute::C<2>{} * cute::get<1>(original_shape),cute::get<2>(original_shape));
406+
//auto expanded_layout = make_fragment_layout(tiled_copy_b, expanded_shape_2);
407+
//Tensor mma_B_expanded = make_tensor<ElementMMA>(expanded_layout);
414408

415409
static_assert(std::is_same_v<typename decltype(quant_frag)::value_type, ElementQuant>);
416410
static_assert(std::is_same_v<typename decltype(mma_A)::value_type, ElementMMA>);
417411
static_assert(std::is_same_v<typename decltype(mma_B)::value_type, ElementMMA>);
418412

419413
// Retile for copy
420414
auto [frag_copy_A, frag_copy_B] = [&](){
421-
return std::make_pair(thr_copy_A.retile_D(mma_A), thr_copy_B.retile_D(quant_frag));
415+
return std::make_pair(thr_copy_A.retile_D(mma_A), thr_copy_B_4bit.retile_D(quant_frag));
422416
}();
423417

424418
Tensor copy_tCrS = thr_copy_scale.retile_D(fragment_scale_input);
425-
//Tensor copy_tCrZ = thr_copy_zero.retile_D(fragment_zero_input);
426419

427420
// Retile global counting tensors for copies
428421
Tensor tAgA = thr_copy_A.retile_S(tCgA);
429-
Tensor tBgB = thr_copy_B.retile_S(tCgB);
422+
Tensor tBgB = thr_copy_B_4bit.retile_S(tCgB);
430423

431424
auto tiled_prefetch_a = cute::prefetch_selector<Shape<Int<BLK_M>,Int<BLK_K>>, Num_SGs>(tiled_copy_a);
432-
auto tiled_prefetch_b = cute::prefetch_selector<Shape<Int<BLK_N>,Int<BLK_K>>, Num_SGs>(tiled_copy_b);
425+
auto tiled_prefetch_b = cute::prefetch_selector<Shape<Int<BLK_N>,Int<BLK_K>>, Num_SGs>(tiled_copy_b_4bit);
433426
auto thr_prefetch_A = tiled_prefetch_a.get_slice(thread_idx);
434427
auto thr_prefetch_B = tiled_prefetch_b.get_slice(thread_idx);
435428

@@ -460,37 +453,39 @@ class kgemm_4bit_inference_cutlass_dequant {
460453
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
461454
}
462455

463-
const int k_reload_factor = params.group_size / BLK_K;
456+
const int k_reload_factor = params.group_size / BLK_K / 2;
464457
if(cute::thread0()) printf("k_reload_factor = %d\n", k_reload_factor);
465458

466459
CUTLASS_PRAGMA_UNROLL
467460
for (int k_tile = 0, k = k_start_idx; k_tile < k_tile_count; ++k_tile, ++k, ++prefetch_k) {
468461
// Copy gmem to rmem for the first k_tile
469462
copy(tiled_copy_a, tAgA(_,_,_,k), frag_copy_A);
470-
copy(tiled_copy_b, tBgB(_,_,_,k), frag_copy_B);
463+
copy(tiled_copy_b_4bit, tBgB(_,_,_,k), frag_copy_B);
471464

472465
copy(tiled_copy_scale, copy_iter_s(_, _, _, k_start_idx + (k_tile / k_reload_factor)), copy_tCrS);
473-
dequant(quant_frag, mma_B_expanded, fragment_scale_input, quant_map);
466+
//dequant(quant_frag, mma_B_expanded, fragment_scale_input, quant_map);
467+
dequant(quant_frag, mma_B, fragment_scale_input, quant_map);
474468

475469
if(prefetch_k < k_tile_count) {
476470
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
477471
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
478472
}
479473

480-
cute::gemm(tiled_mma, mma_A, mma_B_expanded, accumulators);
474+
cute::gemm(tiled_mma, mma_A, mma_B, accumulators);
481475
}
476+
477+
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>((char*)nullptr);
482478
CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue};
483-
auto expanded_problem_size = ProblemShape{M, 2 * N, K, 1};
484-
auto problem_shape_MNKL = append<4>(expanded_problem_size, 1);
479+
//auto expanded_problem_size = ProblemShape{M, 2 * N, K, 1};
480+
auto problem_shape_MNKL = append<4>(problem_size, 1);
485481
epilogue(
486482
problem_shape_MNKL,
487-
subgroup_shape, // TODO(codeplay): Inconsistency here w/ blk_coord_mnkl
483+
subgroup_tile_shape,
488484
blk_coord_mnkl,
489485
accumulators,
490486
tiled_mma,
491487
thread_idx
492488
);
493-
#endif
494489
}
495490
};
496491

@@ -532,9 +527,14 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
532527
StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, l));
533528
auto mB_nkl = make_tensor(make_gmem_ptr(B), make_layout(make_shape(n, k, l), stride_B));
534529
Copy_B tiled_copy_b{Copy_B{}.with(mB_nkl)};
530+
531+
StrideB stride_B_4bit = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k/2, l));
532+
auto mB_nkl_4bit = make_tensor(make_gmem_ptr(B), make_layout(make_shape(n, k/2, l), stride_B));
533+
Copy_B tiled_copy_b_4bit{Copy_B{}.with(mB_nkl_4bit)};
535534

536535
params.tiled_copy_a = tiled_copy_a;
537536
params.tiled_copy_b = tiled_copy_b;
537+
params.tiled_copy_b_4bit = tiled_copy_b_4bit;
538538

539539
const int scale_k = cute::ceil_div(k, blocksize);
540540
const int dq_mn_size = n;

0 commit comments

Comments
 (0)