Skip to content

Commit cb27ec8

Browse files
committed
change policy
1 parent da1df6e commit cb27ec8

1 file changed

Lines changed: 9 additions & 9 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ using ElementOutput = float;
5858

5959
using ProblemShape = Shape<int, int, int, int>;
6060

61-
using TileShape = Shape<_16, _64, _64>;
61+
using TileShape = Shape<_256, _256, _32>;
6262
using TiledMma =
6363
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
64-
Layout<Shape<_1, _2, _1>, Stride<_2, _1, _0>>>::TiledMMA;
64+
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
6565

6666
using WorkgroupTileShape = TileShape;
6767
static constexpr auto BLK_M = get<0>(WorkgroupTileShape{}); //16
@@ -128,7 +128,7 @@ using ClusterShape = typename DispatchPolicy::ClusterShape;
128128
using CopyThreadShape = Shape<_1, Int<SubgroupSize>>;
129129
using CopyThreadShapeRev = decltype(cute::reverse(CopyThreadShape{}));
130130

131-
using GmemTiledCopyA = XE_2D_U16x16x32_LD_N;
131+
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; //XE_2D_U16x16x32_LD_N;
132132
using StrideA = cutlass::gemm::TagToStrideA_t<cutlass::layout::RowMajor>;
133133
//using Copy_A = typename Copy_Traits<GmemTiledCopyA, StrideA>::template DefaultTiledCopy<ElementA>;
134134
using traits_load_A = Copy_Traits<GmemTiledCopyA, StrideA>;
@@ -287,7 +287,7 @@ class kgemm_4bit_inference_cutlass_dequant {
287287
for (int i = 0; i < vec_size; i++) {
288288
uint8_t value = (format_data >> (src_bits * i)) & 0xf;
289289
dst[i] = static_cast<DstType>(quant_map[value] * static_cast<float>(ts));
290-
//if(cute::thread0()) printf("n = %d, s = %d, i = %d, src = %d, quant_map[value] = %f, ts = %f, dst = %f\n", n, s, i, static_cast<int>(value), quant_map[value], static_cast<float>(ts), static_cast<float>(dst[i]));
290+
if(cute::thread0()) printf("n = %d, s = %d, i = %d, src = %d, quant_map[value] = %f, ts = %f, dst = %f\n", n, s, i, static_cast<int>(value), quant_map[value], static_cast<float>(ts), static_cast<float>(dst[i]));
291291
}
292292
}
293293
}
@@ -303,12 +303,12 @@ class kgemm_4bit_inference_cutlass_dequant {
303303
int K = params.k;
304304
int L = 1;
305305

306-
const int BLK_M = 16;
307-
const int BLK_N = 64;
308-
const int BLK_K = 64;
306+
const int BLK_M = 256;
307+
const int BLK_N = 256;
308+
const int BLK_K = 32;
309309

310-
const int ATOM_M = 1;
311-
const int ATOM_N = 2;
310+
const int ATOM_M = 8;
311+
const int ATOM_N = 4;
312312
const int ATOM_K = 1;
313313

314314
const int SG_M = ceil_div(BLK_M, ATOM_M);

0 commit comments

Comments
 (0)