Skip to content

Commit 3c680f4

Browse files
committed
refine perf, and tuning method
1 parent e7bfa16 commit 3c680f4

1 file changed

Lines changed: 28 additions & 15 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,11 @@ using ElementOutput = float;
5353

5454
using ProblemShape = Shape<int, int, int, int>;
5555

56-
#if 0
56+
#ifndef METHOD
57+
#define METHOD 2
58+
#endif
59+
60+
#if METHOD == 1
5761
using TileShape = Shape<_256, _256, _32>;
5862
//using TileShape = Shape<_128, _128, _32>;
5963
//using TileShape = Shape<_128, _256, _32>;
@@ -92,7 +96,11 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
9296
static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); //8*4*1*16=512 //1*2*1*16=32
9397

9498
// Define Mainloop dispatch policy
99+
#if METHOD == 1
100+
constexpr int PipelineStages = 2;
101+
#else
95102
constexpr int PipelineStages = 4;
103+
#endif
96104
using DispatchPolicy = cutlass::gemm::MainloopIntelPVCMixedPrecision<PipelineStages>;
97105
static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; // 16
98106

@@ -131,7 +139,7 @@ using ClusterShape = typename DispatchPolicy::ClusterShape;
131139
using CopyThreadShape = Shape<_1, Int<SubgroupSize>>;
132140
using CopyThreadShapeRev = decltype(cute::reverse(CopyThreadShape{}));
133141

134-
#if 0
142+
#if METHOD == 1
135143
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
136144
#else
137145
using GmemTiledCopyA = XE_2D_U16x16x32_LD_N;
@@ -284,21 +292,21 @@ CUTLASS_DEVICE void dequant(
284292
static constexpr auto N = decltype(size<1>(in))::value;
285293
static constexpr auto K = decltype(size(out))::value / N;
286294

287-
using compress_type = uint32_t;
295+
using compress_type = ushort; //uint32_t;
288296
static constexpr auto compress_size = sizeof_bits_v<compress_type> / sizeof_bits_v<SrcType>;
289297
static_assert((compress_size % N) == 0);
290298

291-
static constexpr auto vec_size = 2;
299+
static constexpr auto vec_size = 4;
292300
//using VecSrcElemType = cute::array<SrcType, compress_size>;
293301
using VecSrcType = cute::array<compress_type, vec_size>; //sycl::vec<uint32_t, 4>;
294302
using VecDstElemType = cute::array<DstType, compress_size>;
295303
using VecDstType = cute::array<VecDstElemType, vec_size>;
296304

297305
// 预定义掩码和位移
298-
//constexpr uint32_t MASK_HIGH[4] = {0xF0, 0xF000, 0xF00000, 0xF0000000};
299-
//constexpr uint32_t MASK_LOW[4] = {0xF, 0xF00, 0xF0000, 0xF000000};
300-
//constexpr int SHIFT_HIGH[4] = {4, 12, 20, 28};
301-
//constexpr int SHIFT_LOW[4] = {0, 8, 16, 24};
306+
constexpr uint32_t MASK_HIGH[4] = {0xF0, 0xF000, 0xF00000, 0xF0000000};
307+
constexpr uint32_t MASK_LOW[4] = {0xF, 0xF00, 0xF0000, 0xF000000};
308+
constexpr int SHIFT_HIGH[4] = {4, 12, 20, 28};
309+
constexpr int SHIFT_LOW[4] = {0, 8, 16, 24};
302310

303311
auto s_tensor = make_tensor((VecSrcType*)(raw_pointer_cast(in.data())), Shape<Int<K / (compress_size * vec_size)>, Int<N>>{});
304312
auto d_tensor = make_tensor((VecDstType*)(raw_pointer_cast(out.data())), Shape<Int<K / (compress_size * vec_size)>, Int<N>>{});
@@ -312,21 +320,26 @@ CUTLASS_DEVICE void dequant(
312320
#pragma unroll
313321
for (int k = 0; k < K / (compress_size * vec_size); k++) {
314322
VecSrcType src_val = src[k];
315-
VecDstType dst_val;// = dst[k];
323+
//VecDstType& dst_val = dst[k];
324+
VecDstType dst_val;
316325

317326
#pragma unroll
318327
for (int i = 0; i < vec_size; i++) {
319328
compress_type compressed_val = src_val[i];
320-
VecDstElemType compressed_dst_val;// = dst_val[i];
329+
//VecDstElemType& dst_elem = dst_val[i];
330+
VecDstElemType dst_elem;
321331

322332
#pragma unroll
323333
for (int j = 0; j < compress_size / 2; j++) {
324-
//uint8_t high = (compressed_val & MASK_HIGH[j]) >> SHIFT_HIGH[j];
325-
//uint8_t low = (compressed_val & MASK_LOW[j]) >> SHIFT_LOW[j];
326-
compressed_dst_val[2*j] = static_cast<DstType>(quant_map[(compressed_val >> (4 * (j * 2 + 1))) & 0xf] * ts);
327-
compressed_dst_val[2*j+1] = static_cast<DstType>(quant_map[(compressed_val >> (4 * (j * 2))) & 0xf] * ts);
334+
//for (int j = 0; j < 4; j++) {
335+
uint8_t high = (compressed_val & MASK_HIGH[j]) >> SHIFT_HIGH[j];
336+
uint8_t low = (compressed_val & MASK_LOW[j]) >> SHIFT_LOW[j];
337+
dst_elem[2 * j] = static_cast<DstType>(quant_map[high] * ts);
338+
dst_elem[2 * j + 1] = static_cast<DstType>(quant_map[low] * ts);
339+
//dst_elem[2*j] = static_cast<DstType>(quant_map[(compressed_val >> (4 * (j * 2 + 1))) & 0xf] * ts);
340+
//dst_elem[2*j+1] = static_cast<DstType>(quant_map[(compressed_val >> (4 * (j * 2))) & 0xf] * ts);
328341
}
329-
dst_val[i] = compressed_dst_val;
342+
dst_val[i] = dst_elem;
330343
}
331344
dst[k] = dst_val;
332345
}

0 commit comments

Comments
 (0)