@@ -67,7 +67,7 @@ using TiledMma =
6767 Layout<Shape<_1, _8, _1>, Stride<_8, _1, _0>>>::TiledMMA;
6868using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
6969using GmemTiledCopyB = XE_2D_U4x32x16_LD_T;
70- constexpr int PipelineStages = 4 ;
70+ constexpr int PipelineStages = 2 ;
7171
7272using MmaAtomShape = typename TiledMma::AtomShape_MNK;
7373using WorkgroupTileShape = TileShape;
@@ -239,8 +239,8 @@ class gemm_4bit_cutlass_kernel {
239239 Tensor tAgA = thr_copy_A.retile_S (tCgA);
240240 Tensor tBgB = thr_copy_B.retile_S (tCgB);
241241
242- auto tiled_prefetch_a = cute::prefetch_selector<Shape<Int<BLK_M >,Int<BLK_K >>, Num_SGs>(params.tiled_copy_a );;
243- auto tiled_prefetch_b = cute::prefetch_selector<Shape<Int<BLK_N >,Int<BLK_K >>, Num_SGs>(params.tiled_copy_b );;
242+ auto tiled_prefetch_a = cute::prefetch_selector<Shape<Int<BLK_M >,Int<BLK_K >>, Num_SGs>(params.tiled_copy_a );
243+ auto tiled_prefetch_b = cute::prefetch_selector<Shape<Int<BLK_N >,Int<BLK_K >>, Num_SGs>(params.tiled_copy_b );
244244 auto thr_prefetch_A = tiled_prefetch_a.get_slice (thread_idx);
245245 auto thr_prefetch_B = tiled_prefetch_b.get_slice (thread_idx);
246246
@@ -273,33 +273,20 @@ class gemm_4bit_cutlass_kernel {
273273 using VecDstElemType = cute::array<ElementMMA, compress_size>;
274274 using VecDstType = cute::array<VecDstElemType, vec_size>;
275275
276- auto s_tensor = cute::make_tensor ((VecSrcType*)(cute::raw_pointer_cast (dequant_frag.data ())), cute::make_shape (cute::Int<K / (compress_size * vec_size)>{}, cute::Int<N>{}));
277- auto d_tensor = cute::make_tensor ((VecDstType*)(cute::raw_pointer_cast (mma_B.data ())), cute::make_shape (cute::Int<K / (compress_size * vec_size)>{}, cute::Int<N>{}));
278-
279- // auto src_ = *(cute::array<VecSrcType, K / (compress_size * vec_size, N)>*)(s_tensor.data());
280- // auto dst_ = *(cute::array<VecDstType, K / (compress_size * vec_size, N)>*)(d_tensor.data());
276+ float scale_value = fragment_scale (0 );
277+ auto src = *(VecSrcType*)(cute::raw_pointer_cast (dequant_frag.data ()));
278+ auto & dst = *(VecDstType*)(cute::raw_pointer_cast (mma_B.data ()));
279+ VecDstType dst_val;
281280 #pragma unroll
282- for (int n = 0 ; n < N; n++) {
283- float scale_value = fragment_scale (n);
284- auto src = *(cute::array<VecSrcType, K / (compress_size * vec_size)>*)(s_tensor (_, n).data ());
285- auto & dst = *(cute::array<VecDstType, K / (compress_size * vec_size)>*)(d_tensor (_, n).data ());
286- // auto& src = *(cute::array<VecSrcType, K / (compress_size * vec_size)>*)(src_[n]);
287- // auto& dst = *(cute::array<VecDstType, K / (compress_size * vec_size)>*)(dst_[n]);
288- #pragma unroll
289- for (int k = 0 ; k < K / (compress_size * vec_size); k++) {
290- VecDstType dst_val;
281+ for (int i = 0 ; i < vec_size; i++) {
282+ VecDstElemType dst_elem;
291283 #pragma unroll
292- for (int i = 0 ; i < vec_size; i++) {
293- VecDstElemType dst_elem;
294- #pragma unroll
295- for (int j = 0 ; j < compress_size; j++) {
296- dst_elem[j] = static_cast <ElementMMA>(quant_map[(src[k][i] >> (4 * ((j+1 )%2 + (j/2 )*2 ))) & 0xf ] * scale_value);
297- }
298- dst_val[i] = dst_elem;
284+ for (int j = 0 ; j < compress_size; j++) {
285+ dst_elem[j] = static_cast <ElementMMA>(quant_map[(src[i] >> (4 * ((j+1 )%2 + (j/2 )*2 ))) & 0xf ] * scale_value);
299286 }
300- dst[k] = dst_val;
301- }
287+ dst_val[i] = dst_elem;
302288 }
289+ dst = dst_val;
303290 };
304291
305292 CUTLASS_PRAGMA_UNROLL
@@ -308,7 +295,7 @@ class gemm_4bit_cutlass_kernel {
308295 prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k));
309296 }
310297
311- for (int k_tile = k_start_idx, k_s = 0 ; k_tile < k_tile_count; k_tile++, k_s++) {
298+ for (int k_tile = k_start_idx, k_s = 0 ; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++ ) {
312299 copy (params.tiled_copy_b , tBgB (_,_,_,k_tile), frag_copy_B);
313300 copy (params.tiled_copy_scale , tSgS (_, _, _, (k_start_idx + k_s) / k_reload_factor), frag_copy_Scale);
314301 // barrier_wait(3);
@@ -318,7 +305,6 @@ class gemm_4bit_cutlass_kernel {
318305 if (prefetch_k < k_tile_count) {
319306 prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
320307 prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k));
321- prefetch_k++;
322308 }
323309
324310 cute::gemm (tiled_mma, mma_A, mma_B, accumulators);
0 commit comments