@@ -394,7 +394,7 @@ class kgemm_4bit_inference_cutlass_dequant {
394394 if (cute::thread0 ()) printf (" M = %d, N=%d, K=%d, L=%d, m_coord = %d, n_coord = %d, l_coord = %d, BlockIdxX() = %d, BlockIdxY() = %d, BlockIdxZ() = %d\n " ,M, N, K, L, m_coord, n_coord, l_coord, BlockIdxX (), BlockIdxY (), BlockIdxZ ());
395395
396396 constexpr auto workgroup_shape = WorkgroupTileShape{}; // 256, 256, 32
397- constexpr auto subgroup_tile_shape = SubgroupTileShape{}; // number of atom level workgroup: 256/8=32, 256/16=16 , 32/16=2
397+ constexpr auto subgroup_tile_shape = SubgroupTileShape{}; // 32, 64, 32 ( number of atom level workgroup: 256/8=32, 256/4=64 , 32/2=32)
398398
399399 Tensor mA_mkl = cute::get_pvc_tensor (make_shape (M,K,L)); // coordinate tensor: 0,1,2....
400400 Tensor mB_nkl = cute::get_pvc_tensor (make_shape (N,K,L)); // coordinate tensor: 0,1,2....
@@ -505,7 +505,7 @@ class kgemm_4bit_inference_cutlass_dequant {
505505 // 生成一个逻辑视图 tAgA,其形状和步长与 tCgA 相同,但数据仍存储在原始位置(共享内存)
506506 // 共享内存 → retile_S → 逻辑视图 (next step later → 寄存器 (实际复制))
507507 Tensor tAgA = thr_copy_A.retile_S (tCgA);
508- Tensor tBgB = thr_copy_B_4bit.retile_S (tCgB );
508+ Tensor tBgB = thr_copy_B_4bit.retile_S (tCgB_4bit );
509509
510510// // Prepare for prefetch
511511 // BLK_M, BLK_N, BLK_K, Num_SGs: Gemm Tile Atom information.
@@ -533,7 +533,7 @@ class kgemm_4bit_inference_cutlass_dequant {
533533// //
534534 // 在矩阵乘法(GEMM)中动态计算每个线程块(CTA)需要处理的数据分块位置
535535 auto [m_idx, n_idx, k_idx, l_idx] = blk_coord_mnkl;
536- m_coord = m_idx * BLK_M + (get_sub_group_id () / ATOM_N ) * SG_M ; // m_idx * BLK_M:分块在 M 维度的起始全局坐标; get_sub_group_id() / ATOM_N) * SG_M:子组在 M 维度的偏移(用于细粒度并行)
536+ m_coord = m_idx * BLK_M + (get_sub_group_id () / ATOM_N ) * SG_M ; // m_idx * BLK_M:分块在 M 维度的起始全局坐标; ( get_sub_group_id() / ATOM_N) * SG_M:子组在 M 维度的偏移(用于细粒度并行)
537537 n_coord = n_idx * BLK_N + (get_sub_group_id () % ATOM_N ) * SG_N ; // n_idx * BLK_N:分块在 N 维度的起始全局坐标; (get_sub_group_id() % ATOM_N) * SG_N:子组在 N 维度的偏移
538538 l_coord = l_idx;
539539
@@ -594,7 +594,7 @@ class kgemm_4bit_inference_cutlass_dequant {
594594 prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k));
595595 }
596596
597- const int k_reload_factor = params.group_size / BLK_K / 2 ;
597+ const int k_reload_factor = params.group_size / BLK_K ;
598598 if (cute::thread0 ()) printf (" k_reload_factor = %d\n " , k_reload_factor);
599599
600600 // CUTLASS_PRAGMA_UNROLL
@@ -733,8 +733,8 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
733733 dim3 const block = GemmKernel::get_block_shape ();
734734 dim3 const grid = GemmKernel::get_grid_shape (params);
735735
736- const syclcompat::dim3 sycl_block (block.x , block.y , block.z ); // workgroup_size: 8*4*1, 1, 1
737- const syclcompat::dim3 sycl_grid (grid.x , grid.y , grid.z ); // workgroup_number (problem_size / tile_size): N/256, M/256, K/32
736+ const syclcompat::dim3 sycl_block (block.x , block.y , block.z ); // workgroup_size: 8*4*1*16 , 1, 1
737+ const syclcompat::dim3 sycl_grid (grid.x , grid.y , grid.z ); // workgroup_number (problem_size / tile_size): N/256, M/256, 1
738738 printf (" Host Grid: (%d, %d, %d)\n " , grid.x , grid.y , grid.z );
739739 printf (" Host Block: (%d, %d, %d)\n " , block.x , block.y , block.z );
740740
0 commit comments