Skip to content

Commit 30de594

Browse files
committed
refine code
1 parent 67f79ea commit 30de594

2 files changed

Lines changed: 7 additions & 7 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ class kgemm_4bit_inference_cutlass_dequant {
394394
if(cute::thread0()) printf("M = %d, N=%d, K=%d, L=%d, m_coord = %d, n_coord = %d, l_coord = %d, BlockIdxX() = %d, BlockIdxY() = %d, BlockIdxZ() = %d\n",M, N, K, L, m_coord, n_coord, l_coord, BlockIdxX(), BlockIdxY(), BlockIdxZ());
395395

396396
constexpr auto workgroup_shape = WorkgroupTileShape{}; //256, 256, 32
397-
constexpr auto subgroup_tile_shape = SubgroupTileShape{}; // number of atom level workgroup: 256/8=32, 256/16=16, 32/16=2
397+
constexpr auto subgroup_tile_shape = SubgroupTileShape{}; //32, 64, 32 (number of atom level workgroup: 256/8=32, 256/4=64, 32/2=32)
398398

399399
Tensor mA_mkl = cute::get_pvc_tensor(make_shape(M,K,L)); //coordinate tensor: 0,1,2....
400400
Tensor mB_nkl = cute::get_pvc_tensor(make_shape(N,K,L)); //coordinate tensor: 0,1,2....
@@ -505,7 +505,7 @@ class kgemm_4bit_inference_cutlass_dequant {
505505
// 生成一个逻辑视图 tAgA,其形状和步长与 tCgA 相同,但数据仍存储在原始位置(共享内存)
506506
// 共享内存 → retile_S → 逻辑视图 (next step later → 寄存器 (实际复制))
507507
Tensor tAgA = thr_copy_A.retile_S(tCgA);
508-
Tensor tBgB = thr_copy_B_4bit.retile_S(tCgB);
508+
Tensor tBgB = thr_copy_B_4bit.retile_S(tCgB_4bit);
509509

510510
//// Prepare for prefetch
511511
// BLK_M, BLK_N, BLK_K, Num_SGs: Gemm Tile Atom information.
@@ -533,7 +533,7 @@ class kgemm_4bit_inference_cutlass_dequant {
533533
////
534534
// 在矩阵乘法(GEMM)中动态计算每个线程块(CTA)需要处理的数据分块位置
535535
auto [m_idx, n_idx, k_idx, l_idx] = blk_coord_mnkl;
536-
m_coord = m_idx * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M; // m_idx * BLK_M:分块在 M 维度的起始全局坐标; get_sub_group_id() / ATOM_N) * SG_M:子组在 M 维度的偏移(用于细粒度并行)
536+
m_coord = m_idx * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M; // m_idx * BLK_M:分块在 M 维度的起始全局坐标; (get_sub_group_id() / ATOM_N) * SG_M:子组在 M 维度的偏移(用于细粒度并行)
537537
n_coord = n_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N; // n_idx * BLK_N:分块在 N 维度的起始全局坐标; (get_sub_group_id() % ATOM_N) * SG_N:子组在 N 维度的偏移
538538
l_coord = l_idx;
539539

@@ -594,7 +594,7 @@ class kgemm_4bit_inference_cutlass_dequant {
594594
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
595595
}
596596

597-
const int k_reload_factor = params.group_size / BLK_K / 2;
597+
const int k_reload_factor = params.group_size / BLK_K;
598598
if(cute::thread0()) printf("k_reload_factor = %d\n", k_reload_factor);
599599

600600
//CUTLASS_PRAGMA_UNROLL
@@ -733,8 +733,8 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
733733
dim3 const block = GemmKernel::get_block_shape();
734734
dim3 const grid = GemmKernel::get_grid_shape(params);
735735

736-
const syclcompat::dim3 sycl_block(block.x, block.y, block.z); //workgroup_size: 8*4*1, 1, 1
737-
const syclcompat::dim3 sycl_grid(grid.x, grid.y, grid.z); //workgroup_number (problem_size / tile_size): N/256, M/256, K/32
736+
const syclcompat::dim3 sycl_block(block.x, block.y, block.z); //workgroup_size: 8*4*1*16, 1, 1
737+
const syclcompat::dim3 sycl_grid(grid.x, grid.y, grid.z); //workgroup_number (problem_size / tile_size): N/256, M/256, 1
738738
printf("Host Grid: (%d, %d, %d)\n", grid.x, grid.y, grid.z);
739739
printf("Host Block: (%d, %d, %d)\n", block.x, block.y, block.z);
740740

tests/test_xpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class TestXPU:
4747
[torch.uint8],
4848
ids=describe_dtype,
4949
)
50-
@pytest.mark.parametrize("dim", [512], ids=id_formatter("dim"))
50+
@pytest.mark.parametrize("dim", [256], ids=id_formatter("dim"))
5151
def test_gemm_4bit(self, device, dim, dtype, storage_type, quant_storage, double_quant, kind):
5252
errs1 = []
5353
errs2 = []

0 commit comments

Comments
 (0)