Skip to content

Commit 89da289

Browse files
committed
refine code
1 parent 4475756 commit 89da289

2 files changed

Lines changed: 22 additions & 23 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,6 @@ using val_layout_load_A = decltype(make_layout(shape_div(typename traits_load_A:
126126
// val_layout_load_A:寄存器片段布局
127127
using Copy_A = decltype(make_tiled_copy(atom_load_A{}, Layout<CopyThreadShape>{}, val_layout_load_A{}));
128128

129-
using GmemTiledCopyB_4bit = XE_2D_U8x32x32_LD_V; // U8 (1-byte) block copy for 8bit-B (narrower type)
130-
using StrideB_4bit = cutlass::gemm::TagToStrideB_t<cutlass::layout::RowMajor>;
131-
using traits_load_B_4bit = Copy_Traits<GmemTiledCopyB_4bit, StrideB_4bit>;
132-
using atom_load_B_4bit = Copy_Atom<traits_load_B_4bit, ElementB>;
133-
using val_layout_load_B_4bit = decltype(make_layout(shape_div(typename traits_load_B_4bit::BlockShape{}, CopyThreadShape{})));
134-
using Copy_B_4bit = decltype(make_tiled_copy(atom_load_B_4bit{}, Layout<CopyThreadShape>{}, val_layout_load_B_4bit{}));
135-
136129
using GmemTiledCopyB = XE_2D_U8x32x32_LD_V; // U8 (1-byte) block copy for 8bit-B (narrower type)
137130
using StrideB = cutlass::gemm::TagToStrideB_t<cutlass::layout::RowMajor>;
138131
using traits_load_B = Copy_Traits<GmemTiledCopyB, StrideB>;
@@ -200,7 +193,7 @@ class kgemm_4bit_inference_cutlass_dequant {
200193

201194
Copy_A tiled_copy_a;
202195
Copy_B tiled_copy_b;
203-
Copy_B_4bit tiled_copy_b_4bit;
196+
Copy_B tiled_copy_b_4bit;
204197
Copy_Scale tiled_copy_scale;
205198
int group_size;
206199

@@ -293,13 +286,13 @@ class kgemm_4bit_inference_cutlass_dequant {
293286
// 2 x 16 of these are same K
294287
// 4 different scale/zero values per thread, no exchange needed
295288
//CUTLASS_PRAGMA_UNROLL
296-
for (int i = 0; i < 4; ++i) {
297-
//CUTLASS_PRAGMA_UNROLL
298-
for (int j = 0; j < 32; ++j) {
299-
tCrB_dst(_, i, _)[j] *= tCrS(i);
300-
//printf("thread_idx = %d, i = %d, j = %d, scale_value = %f\n", thread_idx, i, j, tCrS(i));
301-
}
302-
}
289+
// for (int i = 0; i < 4; ++i) {
290+
// //CUTLASS_PRAGMA_UNROLL
291+
// for (int j = 0; j < 32; ++j) {
292+
// tCrB_dst(_, i, _)[j] *= tCrS(i);
293+
// //printf("thread_idx = %d, i = %d, j = %d, scale_value = %f\n", thread_idx, i, j, tCrS(i));
294+
// }
295+
// }
303296

304297
#if 0
305298
for(int i=0; i<num_elements_dst; i++){
@@ -386,7 +379,7 @@ class kgemm_4bit_inference_cutlass_dequant {
386379

387380
Tensor gA = local_tile(mA_mkl, select<0,2>(blk_shape), make_coord(m_coord,_,l_coord));
388381
Tensor gB = local_tile(mB_nkl, select<1,2>(blk_shape), make_coord(n_coord,_,l_coord));
389-
Tensor gB_4bit = local_tile(mB_nkl_4bit, select<1,2>(blk_shape /*blk_shape_4bit*/), make_coord(n_coord,_,l_coord));
382+
Tensor gB_4bit = local_tile(mB_nkl_4bit, select<1,2>(blk_shape_4bit), make_coord(n_coord,_,l_coord));
390383

391384
//// Allocate the tiled_mma and the accumulators for the (M,N) subgroup_tile_shape
392385
TiledMma tiled_mma;
@@ -496,8 +489,14 @@ class kgemm_4bit_inference_cutlass_dequant {
496489

497490
CUTLASS_PRAGMA_UNROLL
498491
for (int i = 0; i < DispatchPolicy::Stages; i++, prefetch_k++) {
499-
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
500-
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
492+
if(prefetch_k < k_tile_count) {
493+
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
494+
}
495+
if(prefetch_k < k_tile_count/2) {
496+
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
497+
}
498+
//prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
499+
//prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
501500
}
502501

503502
const int k_reload_factor = params.group_size / BLK_K;
@@ -560,7 +559,7 @@ for(int i=0; i<num_Acc; i++) {
560559

561560
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>((char*)nullptr);
562561
CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue};
563-
auto problem_shape_MNKL = problem_size; //append<4>(problem_size, 1);
562+
auto problem_shape_MNKL = append<4>(problem_size, 1);
564563
epilogue(
565564
problem_shape_MNKL,
566565
subgroup_tile_shape,
@@ -613,7 +612,7 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
613612

614613
StrideB stride_B_4bit = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k/2, l));
615614
auto mB_nkl_4bit = make_tensor(make_gmem_ptr(B), make_layout(make_shape(n, k/2, l), stride_B_4bit));
616-
Copy_B_4bit tiled_copy_b_4bit{Copy_B_4bit{}.with(mB_nkl_4bit)};
615+
Copy_B tiled_copy_b_4bit{Copy_B{}.with(mB_nkl_4bit)};
617616

618617
params.tiled_copy_a = tiled_copy_a;
619618
params.tiled_copy_b = tiled_copy_b;
@@ -631,7 +630,7 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
631630

632631
cutlass::KernelHardwareInfo hw_info;
633632
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
634-
auto problem_shape_MNKL = problem_size; //append<4>(problem_size, 1);
633+
auto problem_shape_MNKL = append<4>(problem_size, 1);
635634
float alpha=1.0;
636635
float beta=0.f;
637636
StrideC stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, l));

tests/test_xpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,9 @@ def test_gemm_4bit(self, device, dim, dtype, storage_type, quant_storage, double
8585
quant_storage=quant_storage,
8686
)
8787
##pdb.set_trace()
88-
C3 = torch.matmul(A, B)
88+
C3 = torch.matmul(A.t(), B)
8989
pdb.set_trace()
90-
C2 = F.gemv_4bit(A, qB, state=state)
90+
C2 = F.gemv_4bit(A, qB.t(), state=state)
9191
#pdb.set_trace()
9292
print(C3[0])
9393
print(C2[0])

0 commit comments

Comments
 (0)