Skip to content

Commit bec4032

Browse files
committed
save code
1 parent b337a05 commit bec4032

1 file changed

Lines changed: 104 additions & 22 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 104 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ using ElementOutput = float;
5353

5454
using ProblemShape = Shape<int, int, int, int>;
5555

56+
//constexpr int kQuantMapSize = 16;
57+
static constexpr float quant_map[16] = {
58+
-1.0f, -0.6961928f, -0.52507305f, -0.39491749f,
59+
-0.28444138f, -0.18477343f, -0.09105004f, 0.0f,
60+
0.0795803f, 0.1609302f, 0.2461123f, 0.33791524f,
61+
0.44070983f, 0.562617f, 0.72295684f, 1.0f
62+
};
63+
5664
//#ifndef METHOD
5765
//#define METHOD 1
5866
//#endif
@@ -69,10 +77,10 @@ using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
6977
constexpr int PipelineStages = 2;
7078

7179
#else
72-
using TileShape = Shape<_32, _64, _64>;
80+
using TileShape = Shape<_128, _128, _32>;
7381
using TiledMma =
74-
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
75-
Layout<Shape<_1, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
82+
typename TiledMMAHelper<MMA_Atom<XE_4x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
83+
Layout<Shape<_4, _8, _1>, Stride<_8, _1, _0>>>::TiledMMA;
7684
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
7785
constexpr int PipelineStages = 4;
7886
#endif
@@ -133,9 +141,9 @@ using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
133141
ElementOutput,
134142
cutlass::gemm::TagToStrideC_t<cutlass::layout::RowMajor>, // Convert CUTLASS 2.x to CUTLASS 3.x representation
135143
FusionCallBacks,
136-
XE_2D_U32x8x16_LD_N, // The copy atom used to load matrix C
144+
XE_2D_U32x4x16_LD_N, // The copy atom used to load matrix C
137145
void, void,
138-
XE_2D_U32x8x16_ST_N, // The copy atom used to store matrix D
146+
XE_2D_U32x4x16_ST_N, // The copy atom used to store matrix D
139147
void, void>;
140148
using EpilogueParams = typename CollectiveEpilogue::Params;
141149

@@ -174,7 +182,6 @@ using Copy_Scale = decltype(make_tiled_copy(atom_load_scale{}, Layout<CopyThread
174182
using StrideC = cutlass::gemm::TagToStrideC_t<cutlass::layout::RowMajor>;
175183
using StrideD = cutlass::gemm::TagToStrideC_t<cutlass::layout::RowMajor>;
176184

177-
178185
template <typename T, int BITS>
179186
class gemm_4bit_cutlass_kernel {
180187
public:
@@ -247,41 +254,93 @@ class gemm_4bit_cutlass_kernel {
247254
static constexpr auto compress_size = sizeof_bits_v<compress_type> / sizeof_bits_v<SrcType>;
248255
static_assert((compress_size % N) == 0);
249256

250-
static constexpr auto vec_size = 8;
257+
static constexpr auto vec_size = 4;
251258
using VecSrcType = cute::array<compress_type, vec_size>;
252259
using VecDstElemType = cute::array<DstType, compress_size>;
253260
using VecDstType = cute::array<VecDstElemType, vec_size>;
254261

255262
auto s_tensor = make_tensor((VecSrcType*)(raw_pointer_cast(in.data())), Shape<Int<K / (compress_size * vec_size)>, Int<N>>{});
256263
auto d_tensor = make_tensor((VecDstType*)(raw_pointer_cast(out.data())), Shape<Int<K / (compress_size * vec_size)>, Int<N>>{});
257-
258-
//if(cute::thread0()) printf("decltype(size(out))::value = %d, N = %d, K = %d, compress_size = %d, vec_size = %d\n", decltype(size(out))::value, N, K, compress_size, vec_size);
259-
#pragma unroll
264+
265+
// constexpr float quant_map[16] = {
266+
// -1.0,
267+
// -0.6961928009986877,
268+
// -0.5250730514526367,
269+
// -0.39491748809814453,
270+
// -0.28444138169288635,
271+
// -0.18477343022823334,
272+
// -0.09105003625154495,
273+
// 0.0,
274+
// 0.07958029955625534,
275+
// 0.16093020141124725,
276+
// 0.24611230194568634,
277+
// 0.33791524171829224,
278+
// 0.44070982933044434,
279+
// 0.5626170039176941,
280+
// 0.7229568362236023,
281+
// 1.0,
282+
// };
283+
if(cute::thread0()) printf("decltype(size(out))::value = %d, N = %d, K = %d, compress_size = %d, vec_size = %d\n", decltype(size(out))::value, N, K, compress_size, vec_size);
284+
#if 1
285+
//#pragma unroll
260286
for (int n = 0; n < N; n++) {
261287
float ts = tCrS_input(n);
262288
auto& src = *(cute::array<VecSrcType, K / (compress_size * vec_size)>*)(s_tensor(_, n).data());
263289
auto& dst = *(cute::array<VecDstType, K / (compress_size * vec_size)>*)(d_tensor(_, n).data());
264290

265-
#pragma unroll
291+
//#pragma unroll
266292
for (int k = 0; k < K / (compress_size * vec_size); k++) {
267-
VecSrcType src_val = src[k];
293+
//VecSrcType src_val = src[k];
268294
VecDstType dst_val;
269295

270-
#pragma unroll
296+
//#pragma unroll
271297
for (int i = 0; i < vec_size; i++) {
272-
compress_type compressed_val = src_val[i];
298+
//compress_type compressed_val = src_val[i];
273299
VecDstElemType dst_elem;
274-
275-
#pragma unroll
276-
for (int j = 0; j < compress_size / 2; j++) {
277-
dst_elem[2*j] = static_cast<DstType>(quant_map[(compressed_val >> (4 * (j * 2 + 1))) & 0xf] * ts);
278-
dst_elem[2*j+1] = static_cast<DstType>(quant_map[(compressed_val >> (4 * (j * 2))) & 0xf] * ts);
300+
301+
// float4 vals = reinterpret_cast<const float4*>(quant_map)[idx/4];
302+
//#pragma unroll
303+
for (int j = 0; j < compress_size; j++) {
304+
//uint8_t high = (src[k][i]>> (4 * (j * 2 + 1))) & 0xf;
305+
//uint8_t low = (compressed_val >> (4 * (j * 2))) & 0xf;
306+
//dst[k][i][j] = static_cast<DstType>(quant_map[(src[k][i]>> (4 * (j * 2 + 1))) & 0xf] * ts);
307+
//dst_elem[j] = static_cast<DstType>(quant_map[(src[k][i]>> (4 * (j * 2 + 1))) & 0xf] * ts);
308+
dst_elem[j] = static_cast<DstType>(1.5f * ts);
309+
//dst_elem[2*j+1] = static_cast<DstType>(quant_map[low] * ts);
279310
}
280311
dst_val[i] = dst_elem;
281312
}
282313
dst[k] = dst_val;
283314
}
284315
}
316+
#else
317+
constexpr int shifts[8] = {4,0,12,8,20,16,28,24};
318+
319+
#pragma unroll
320+
for (int n = 0; n < N; n++) {
321+
DstType ts = static_cast<DstType>(tCrS_input(n));
322+
auto& src = *(cute::array<VecSrcType, K / (compress_size * vec_size)>*)(s_tensor(_, n).data());
323+
auto& dst = *(cute::array<VecDstType, K / (compress_size * vec_size)>*)(d_tensor(_, n).data());
324+
325+
const auto src_val = src[0];
326+
VecDstType dst_val;
327+
#pragma unroll
328+
for (int i = 0; i < vec_size; ++i) {
329+
const compress_type val = src_val[i];
330+
VecDstElemType dst_elem;
331+
dst_elem[0] = quant_map[(val>>shifts[0])&0xF] * ts;
332+
dst_elem[1] = quant_map[(val>>shifts[1])&0xF] * ts;
333+
dst_elem[2] = quant_map[(val>>shifts[2])&0xF] * ts;
334+
dst_elem[3] = quant_map[(val>>shifts[3])&0xF] * ts;
335+
dst_elem[4] = quant_map[(val>>shifts[4])&0xF] * ts;
336+
dst_elem[5] = quant_map[(val>>shifts[5])&0xF] * ts;
337+
dst_elem[6] = quant_map[(val>>shifts[6])&0xF] * ts;
338+
dst_elem[7] = quant_map[(val>>shifts[7])&0xF] * ts;
339+
dst_val[i] = dst_elem;
340+
}
341+
dst[0] = dst_val;
342+
}
343+
#endif
285344
}
286345

287346
CUTLASS_DEVICE
@@ -311,13 +370,31 @@ class gemm_4bit_cutlass_kernel {
311370
for(int i=0; i<16; i++){
312371
quant_map[i] = datatype[i];
313372
}
314-
#else
373+
#else
315374
float* quant_map = reinterpret_cast<float*>(smem_buf);
316375
if (thread_idx < 16) {
317376
quant_map[thread_idx] = datatype[thread_idx];
318377
}
319378
barrier_arrive(3);
320379
#endif
380+
// constexpr float quant_map[16] = {
381+
// -1.0,
382+
// -0.696,//,1928009986877,
383+
// -0.525,//,0730514526367,
384+
// -0.394,//,91748809814453,
385+
// -0.284,//,44138169288635,
386+
// -0.184,//,77343022823334,
387+
// -0.091,//,05003625154495,
388+
// 0.0,
389+
// 0.079,//58029955625534,
390+
// 0.160,//93020141124725,
391+
// 0.246,//11230194568634,
392+
// 0.337,//91524171829224,
393+
// 0.440,//70982933044434,
394+
// 0.562,//6170039176941,
395+
// 0.722,//9568362236023,
396+
// 1.0,
397+
// };
321398
auto blk_shape = TileShape{};
322399
int m_coord, n_coord, l_coord;
323400
if (params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN) {
@@ -459,19 +536,24 @@ class gemm_4bit_cutlass_kernel {
459536
}
460537

461538
for (int k_tile = k_start_idx, k_s = 0; k_tile < k_tile_count + k_start_idx; k_tile++, prefetch_k++, k_s++) {
462-
copy(tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
539+
//copy(tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
463540
copy(tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B);
464541

465542
const int s_idx = (k_start_idx + k_s) / k_reload_factor;
466543
copy(tiled_copy_scale, tSgS(_, _, _, s_idx), frag_copy_Scale);
467544

468-
dequant(dequant_frag, mma_B, fragment_scale, quant_map);
545+
//dequant(dequant_frag, mma_B, fragment_scale, quant_map);
546+
547+
copy(tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
469548

470549
if(prefetch_k < k_tile_count) {
471550
prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k));
472551
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
473552
}
474553

554+
//dequant(dequant_frag, mma_B, fragment_scale);//, quant_map);
555+
//copy(tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
556+
475557
cute::gemm(tiled_mma, mma_A, mma_B, accumulators);
476558
barrier_wait(3);
477559
}

0 commit comments

Comments
 (0)