Skip to content

Commit edee20b

Browse files
committed
save code
1 parent fcf0f8f commit edee20b

1 file changed

Lines changed: 16 additions & 8 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ using ElementOutput = float;
5252

5353
using ProblemShape = Shape<int, int, int, int>;
5454

55-
using TileShape = Shape<_64, _128, _64>;
55+
using TileShape = Shape<_64, _128, _32>;
5656
using TiledMma =
5757
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
5858
Layout<Shape<_2, _8, _1>, Stride<_8, _1, _0>>>::TiledMMA;
@@ -310,10 +310,10 @@ class gemm_4bit_cutlass_kernel {
310310
prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k));
311311
}
312312

313-
int start_lut_id = sg_idx % LUT_NUM;
313+
//int start_lut_id = sg_idx % LUT_NUM;
314314

315315
#if 1
316-
auto dequant = [&](decltype(dequant_frag_a)* dequant_frag_, decltype(fragment_scale_a)* fragment_scale_, decltype(mma_B_a)* mma_B_) {
316+
auto dequant = [](decltype(dequant_frag_a)* dequant_frag_, decltype(fragment_scale_a)* fragment_scale_, decltype(mma_B_a)* mma_B_, float(*quant_map)[16]) {
317317
constexpr int N = decltype(cute::size<1>(*mma_B_))::value;
318318
constexpr int K = decltype(cute::size(*mma_B_))::value / N;
319319

@@ -330,7 +330,7 @@ class gemm_4bit_cutlass_kernel {
330330

331331
ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
332332

333-
int lut_id = start_lut_id;
333+
int lut_id = syclcompat::get_nd_item<1>().get_sub_group().get_group_linear_id() % LUT_NUM; //start_lut_id;
334334
#pragma unroll
335335
for (int n = 0; n < N; n++) {
336336

@@ -339,7 +339,6 @@ class gemm_4bit_cutlass_kernel {
339339

340340
#pragma unroll
341341
for (int v = 0; v < src_vec_size; v++) {
342-
//src_compress_type src_value = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag[buffer_idx]->data()))[n*src_loop_num + l][v];
343342
src_compress_type src_value = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag_->data()))[n*src_loop_num + l][v];
344343
int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
345344

@@ -348,7 +347,7 @@ class gemm_4bit_cutlass_kernel {
348347
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
349348
float scale_value = (*fragment_scale_)((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
350349

351-
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_[lut_id][bit_value] * scale_value);
350+
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map[lut_id][bit_value] * scale_value);
352351
lut_id = (lut_id + 1) % LUT_NUM;
353352
}
354353
}
@@ -371,16 +370,24 @@ class gemm_4bit_cutlass_kernel {
371370
}
372371
prefetch_k++;
373372

373+
int buf_idx = 0;
374+
374375
for (int k_tile = k_start_idx + 1, k_s = 1; k_tile < k_tile_count; k_tile++, k_s++, prefetch_k++) {
375-
const int buf_idx = k_tile % 2;
376+
buf_idx ^= 1; //k_tile % 2;
376377

377378
//dequant(start_lut_id, 1 - buf_idx);
378379
//if(buf_idx == 1) {
379380
// dequant(start_lut_id, 0);
380381
//} else {
381382
// dequant(start_lut_id, 1);
382383
//}
383-
dequant(dequant_frag[1 - buf_idx], fragment_scale[1 - buf_idx], mma_B[1 - buf_idx]);
384+
385+
dequant(dequant_frag[1 - buf_idx], fragment_scale[1 - buf_idx], mma_B[1 - buf_idx], quant_map_);
386+
//if(buf_idx == 1) {
387+
// dequant(dequant_frag[0], fragment_scale[0], mma_B[0]);
388+
//} else {
389+
// dequant(dequant_frag[1], fragment_scale[1], mma_B[1]);
390+
//}
384391

385392
copy(params.tiled_copy_b, tBgB(_,_,_,k_tile), *frag_copy_B[buf_idx]);
386393
copy(params.tiled_copy_scale, tSgS(_,_,_,(k_start_idx+k_s)*BLK_K/params.group_size), *frag_copy_Scale[buf_idx]);
@@ -392,6 +399,7 @@ class gemm_4bit_cutlass_kernel {
392399
}
393400

394401
cute::gemm(tiled_mma, *mma_A[1 - buf_idx], *mma_B[1 - buf_idx], accumulators);
402+
395403
barrier_wait(3);
396404
}
397405
cute::gemm(tiled_mma, *mma_A[1], *mma_B[1], accumulators);

0 commit comments

Comments
 (0)