@@ -52,10 +52,18 @@ using ElementOutput = float;
5252
5353using ProblemShape = Shape<int , int , int , int >;
5454
55+ #if 1
5556using TileShape = Shape<_256, _256, _32>;
57+ // using TileShape = Shape<_128, _256, _32>;
5658using TiledMma =
5759 typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
5860 Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
61+ #else
62+ using TileShape = Shape<_16, _64, _64>;
63+ using TiledMma =
64+ typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
65+ Layout<Shape<_1, _2, _1>, Stride<_2, _1, _0>>>::TiledMMA;
66+ #endif
5967
6068using WorkgroupTileShape = TileShape;
6169static constexpr auto BLK_M = get<0 >(WorkgroupTileShape{}); // 256 //16
@@ -204,8 +212,7 @@ class gemm_4bit_cutlass_kernel {
204212 Tensor<EngineIn, LayoutIn> const & in,
205213 Tensor<EngineOut, LayoutOut>& out,
206214 Tensor<EngineScales, LayoutScales>& tCrS_input,
207- T* quant_map,
208- int n_coord, int thread_idx, int k_start_idx, int k_s, int k_reload_factor, int s_idx
215+ T* quant_map
209216 ) {
210217 static_assert (is_rmem<EngineIn>::value, " Input tensor for A conversion must come from registers" );
211218 static_assert (size_v<LayoutIn> == cosize_v<LayoutIn>);
@@ -243,13 +250,9 @@ class gemm_4bit_cutlass_kernel {
243250 auto & dst = *(cute::array<DstType, vec_size>*)(d_tensor (_, s, n).data ());
244251
245252 CUTLASS_PRAGMA_UNROLL
246- for (int i = 0 ; i < vec_size; i++) {
247- uint8_t value = (format_data >> (src_bits * i)) & 0xf ;
248- if (i % 2 != 0 ) { // 1,3, high_4bit
249- dst[i-1 ] = static_cast <DstType>(quant_map[value] * ts);
250- } else {
251- dst[i+1 ] = static_cast <DstType>(quant_map[value] * ts);
252- }
253+ for (int i = 0 ; i < vec_size/2 ; i++) {
254+ dst[i * 2 ] = static_cast <DstType>(quant_map[(format_data >> (src_bits * (i * 2 + 1 ))) & 0xf ] * ts);
255+ dst[i * 2 + 1 ] = static_cast <DstType>(quant_map[(format_data >> (src_bits * (i * 2 ))) & 0xf ] * ts);
253256 }
254257 }
255258 }
@@ -280,7 +283,7 @@ class gemm_4bit_cutlass_kernel {
280283 if (thread_idx < 16 ) {
281284 quant_map[thread_idx] = T (datatype[thread_idx]);
282285 }
283- barrier_wait ( 1 );
286+ barrier_arrive ( 3 );
284287
285288 auto blk_shape = TileShape{};
286289 int m_coord, n_coord, l_coord;
@@ -362,7 +365,7 @@ class gemm_4bit_cutlass_kernel {
362365 }();
363366
364367#if 0
365- if (thread_idx==16 && n_coord == 0 && l_coord==1 ) {
368+ if (thread_idx==0 && n_coord == 0 && l_coord==0 ) {
366369 print("\n\n======================= A: \n");
367370 print(" gA : "); print(gA); print("\n");
368371 print(" tCgA : "); print(tCgA); print("\n");
@@ -423,7 +426,7 @@ class gemm_4bit_cutlass_kernel {
423426 }
424427
425428 for (int k_tile = k_start_idx, k_s = 0 ; k_tile < k_tile_count + k_start_idx; k_tile++, prefetch_k++, k_s++) {
426- barrier_arrive (2 );
429+ // barrier_arrive(3 );
427430
428431 copy (tiled_copy_a, tAgA (_,_,_,k_tile), frag_copy_A);
429432 copy (tiled_copy_b, tBgB (_,_,_,k_tile), frag_copy_B);
@@ -438,11 +441,11 @@ class gemm_4bit_cutlass_kernel {
438441 prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k));
439442 }
440443
441- dequant (dequant_frag, mma_B, fragment_scale, quant_map, n_coord, thread_idx, k_start_idx, k_s, k_reload_factor, s_idx );
444+ dequant (dequant_frag, mma_B, fragment_scale, quant_map);
442445
443446
444447 cute::gemm (tiled_mma, mma_A, mma_B, accumulators);
445- barrier_wait (2 );
448+ barrier_wait (3 );
446449 }
447450
448451 SharedStorage& shared_storage = *reinterpret_cast <SharedStorage*>((char *)nullptr );
@@ -467,7 +470,7 @@ void gemm_4bit_cutlass(int m, int n, int k, int l, T *A, unsigned char *B,
467470
468471 using GemmKernel = gemm_4bit_cutlass_kernel<T, BITS >;
469472
470- static constexpr int smem_size= 256 ;
473+ static constexpr int smem_size= 16 * 16 / 2 ;
471474
472475 auto problem_size = ProblemShape{m, n, k, l};
473476
0 commit comments