@@ -58,10 +58,10 @@ using ElementOutput = float;
5858
5959using ProblemShape = Shape<int , int , int , int >;
6060
61- using TileShape = Shape<_16, _64, _64 >;
61+ using TileShape = Shape<_256, _256, _32 >;
6262using TiledMma =
6363 typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
64- Layout<Shape<_1, _2 , _1>, Stride<_2 , _1, _0>>>::TiledMMA;
64+ Layout<Shape<_8, _4 , _1>, Stride<_4 , _1, _0>>>::TiledMMA;
6565
6666using WorkgroupTileShape = TileShape;
6767static constexpr auto BLK_M = get<0 >(WorkgroupTileShape{}); // 16
@@ -128,7 +128,7 @@ using ClusterShape = typename DispatchPolicy::ClusterShape;
128128using CopyThreadShape = Shape<_1, Int<SubgroupSize>>;
129129using CopyThreadShapeRev = decltype (cute::reverse(CopyThreadShape{}));
130130
131- using GmemTiledCopyA = XE_2D_U16x16x32_LD_N;
131+ using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; // XE_2D_U16x16x32_LD_N;
132132using StrideA = cutlass::gemm::TagToStrideA_t<cutlass::layout::RowMajor>;
133133// using Copy_A = typename Copy_Traits<GmemTiledCopyA, StrideA>::template DefaultTiledCopy<ElementA>;
134134using traits_load_A = Copy_Traits<GmemTiledCopyA, StrideA>;
@@ -287,7 +287,7 @@ class kgemm_4bit_inference_cutlass_dequant {
287287 for (int i = 0 ; i < vec_size; i++) {
288288 uint8_t value = (format_data >> (src_bits * i)) & 0xf ;
289289 dst[i] = static_cast <DstType>(quant_map[value] * static_cast <float >(ts));
290- // if(cute::thread0()) printf("n = %d, s = %d, i = %d, src = %d, quant_map[value] = %f, ts = %f, dst = %f\n", n, s, i, static_cast<int>(value), quant_map[value], static_cast<float>(ts), static_cast<float>(dst[i]));
290+ if (cute::thread0 ()) printf (" n = %d, s = %d, i = %d, src = %d, quant_map[value] = %f, ts = %f, dst = %f\n " , n, s, i, static_cast <int >(value), quant_map[value], static_cast <float >(ts), static_cast <float >(dst[i]));
291291 }
292292 }
293293 }
@@ -303,12 +303,12 @@ class kgemm_4bit_inference_cutlass_dequant {
303303 int K = params.k ;
304304 int L = 1 ;
305305
306- const int BLK_M = 16 ;
307- const int BLK_N = 64 ;
308- const int BLK_K = 64 ;
306+ const int BLK_M = 256 ;
307+ const int BLK_N = 256 ;
308+ const int BLK_K = 32 ;
309309
310- const int ATOM_M = 1 ;
311- const int ATOM_N = 2 ;
310+ const int ATOM_M = 8 ;
311+ const int ATOM_N = 4 ;
312312const int ATOM_K = 1 ;
313313
314314const int SG_M = ceil_div (BLK_M , ATOM_M );
0 commit comments