Skip to content

Commit 87b9650

Browse files
committed
save code
1 parent d19f6b7 commit 87b9650

1 file changed

Lines changed: 66 additions & 11 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ static constexpr float quant_map_static[16] = {
6161
};
6262
#endif
6363

64-
using TileShape = Shape<_64, _64, _64>;
64+
using TileShape = Shape<_64, _128, _64>;
6565
using TiledMma =
6666
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
67-
Layout<Shape<_2, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
67+
Layout<Shape<_2, _8, _1>, Stride<_8, _1, _0>>>::TiledMMA;
6868
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
6969
using GmemTiledCopyB = XE_2D_U4x32x16_LD_T;
7070
constexpr int PipelineStages = 2;
@@ -246,10 +246,10 @@ inline float dDequantizeNF4(unsigned char val) {
246246
quant_map_[thread_idx + 16] = value;
247247
quant_map_[thread_idx + 32] = value;
248248
quant_map_[thread_idx + 48] = value;
249-
quant_map_[thread_idx + 64] = value;
250-
quant_map_[thread_idx + 80] = value;
251-
quant_map_[thread_idx + 96] = value;
252-
quant_map_[thread_idx + 112] = value;
249+
//quant_map_[thread_idx + 64] = value;
250+
//quant_map_[thread_idx + 80] = value;
251+
//quant_map_[thread_idx + 96] = value;
252+
//quant_map_[thread_idx + 112] = value;
253253
}
254254
barrier_arrive(3);
255255
//}
@@ -396,6 +396,7 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
396396
};
397397
#endif
398398
#else //register
399+
#if 0
399400
auto dequant = [&] (float* quant_map){
400401
constexpr int N = decltype(cute::size<1>(mma_B))::value;
401402
constexpr int K = decltype(cute::size(mma_B))::value / N;
@@ -437,14 +438,14 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
437438
src_1 = reinterpret_cast<src_compress_type*>(cute::raw_pointer_cast(dequant_frag.data()))[v];
438439
int c = 0;
439440
uint8_t bit_value = (src_2 >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
440-
float converted_value_1 = quant_map[bit_value + (dst_base_idx + c) % 2 * 16];
441+
float converted_value_1 = quant_map[bit_value + (dst_base_idx + c) % 4 * 16];
441442
float converted_value_2 = 0.f;
442443
#pragma unroll
443444
for (; c < src_compress_size-1;) {
444445
converted_value_2 = converted_value_1;
445446
c++;
446447
bit_value = (src_2 >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
447-
converted_value_1 = quant_map[bit_value + (dst_base_idx + c - 1) % 2 * 16];
448+
converted_value_1 = quant_map[bit_value + (dst_base_idx + c - 1) % 4 * 16];
448449
dst[dst_base_idx + c-1] = static_cast<ElementMMA>(converted_value_2 * scale_value);
449450
}
450451
dst[dst_base_idx + c] = static_cast<ElementMMA>(converted_value_1 * scale_value);
@@ -459,14 +460,14 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
459460
//int map_offset = dst_base_idx % 2 * 16;
460461
int c = 0;
461462
uint8_t bit_value = (src_2 >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
462-
float converted_value_1 = quant_map[bit_value + (dst_base_idx + c) % 2 * 16];
463+
float converted_value_1 = quant_map[bit_value + (dst_base_idx + c) % 4 * 16];
463464
float converted_value_2 = 0.f;
464465
#pragma unroll
465466
for (; c < src_compress_size-1;) {
466467
converted_value_2 = converted_value_1;
467468
c++;
468469
bit_value = (src_2 >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
469-
converted_value_1 = quant_map[bit_value + (dst_base_idx + c - 1) % 2 * 16];
470+
converted_value_1 = quant_map[bit_value + (dst_base_idx + c - 1) % 4 * 16];
470471
dst[dst_base_idx + c-1] = static_cast<ElementMMA>(converted_value_2 * scale_value);
471472
}
472473
dst[dst_base_idx + c] = static_cast<ElementMMA>(converted_value_1 * scale_value);
@@ -499,6 +500,60 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
499500
// reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[1] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[1];
500501

501502
};
503+
#else
504+
auto dequant = [&] (float* quant_map){
505+
constexpr int N = decltype(cute::size<1>(mma_B))::value;
506+
constexpr int K = decltype(cute::size(mma_B))::value / N;
507+
//if(cute::thread0) printf("scale num = %d\n", decltype(cute::size(fragment_scale))::value);
508+
509+
using src_compress_type = uint64_t;
510+
using dst_compress_type = uint64_t;
511+
constexpr int src_compress_size = cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; //16
512+
constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; //4
513+
constexpr int src_vec_size = (K / src_compress_size) >= 16 ? 16 : K / src_compress_size; //4, 16 -> max vec_size of sycl::vec
514+
constexpr int dst_vec_size = (K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; //16, 16 -> max vec_size of sycl::vec
515+
constexpr int src_loop_num = K / src_vec_size / src_compress_size;
516+
constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
517+
src_compress_type src[src_vec_size];
518+
ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
519+
520+
521+
#pragma unroll
522+
for (int n = 0; n < N; n++) {
523+
//float scale_value = fragment_scale(0);
524+
#pragma unroll
525+
for (int l = 0; l < src_loop_num; l++) {
526+
reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[0] = reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[n*src_loop_num + l];
527+
528+
#pragma unroll
529+
for (int v = 0; v < src_vec_size; v++) {
530+
src_compress_type src_value = src[v];
531+
int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
532+
#pragma unroll
533+
for (int c = 0; c < src_compress_size; c++) {
534+
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
535+
float scale_value = fragment_scale((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));
536+
//dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map[bit_value + (dst_base_idx + c) % 4 * 16] * scale_value);
537+
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
538+
539+
// uint8_t high = (src_value >> (4 * (c * 2 + 1))) & 0xf;
540+
// uint8_t low = (src_value >> (4 * (c * 2))) & 0xf;
541+
// float ts_high = fragment_scale((n * BLK_K + dst_base_idx + 2 * c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));;
542+
// float ts_low = fragment_scale((n * BLK_K + dst_base_idx + 2 * c + 1) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));;
543+
// dst[dst_base_idx + 2 * c] = static_cast<ElementMMA>(quant_map[high] * ts_high);
544+
// dst[dst_base_idx + 2 * c + 1] = static_cast<ElementMMA>(quant_map[low] * ts_low);
545+
}
546+
}
547+
}
548+
549+
#pragma unroll
550+
for (int l = 0; l < dst_loop_num; l++) {
551+
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n * dst_loop_num + l] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
552+
553+
}
554+
}
555+
};
556+
#endif
502557
#endif
503558

504559
CUTLASS_PRAGMA_UNROLL
@@ -578,7 +633,7 @@ void gemm_4bit_cutlass(int m, int n, int k, int l, T *A, unsigned char *B,
578633
//std::cout<<"group_size = "<<blocksize<<std::endl;
579634

580635
#if 1
581-
static constexpr int smem_size= (32) * sizeof(float) * 2 * 2;
636+
static constexpr int smem_size= (32) * sizeof(float) * 2;// * 2;
582637
#else
583638
static constexpr int smem_size = BLK_N * BLK_K * sizeof(ElementMMA) * 2 * 2; //aligned with 128B and will be reused for dequant src and dst.
584639
#endif

0 commit comments

Comments
 (0)