@@ -61,10 +61,10 @@ static constexpr float quant_map_static[16] = {
6161};
6262#endif
6363
64- using TileShape = Shape<_64, _64 , _64>;
64+ using TileShape = Shape<_64, _128 , _64>;
6565using TiledMma =
6666 typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
67- Layout<Shape<_2, _4 , _1>, Stride<_4 , _1, _0>>>::TiledMMA;
67+ Layout<Shape<_2, _8 , _1>, Stride<_8 , _1, _0>>>::TiledMMA;
6868using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
6969using GmemTiledCopyB = XE_2D_U4x32x16_LD_T;
7070constexpr int PipelineStages = 2 ;
@@ -246,10 +246,10 @@ inline float dDequantizeNF4(unsigned char val) {
246246 quant_map_[thread_idx + 16 ] = value;
247247 quant_map_[thread_idx + 32 ] = value;
248248 quant_map_[thread_idx + 48 ] = value;
249- quant_map_[thread_idx + 64 ] = value;
250- quant_map_[thread_idx + 80 ] = value;
251- quant_map_[thread_idx + 96 ] = value;
252- quant_map_[thread_idx + 112 ] = value;
249+ // quant_map_[thread_idx + 64] = value;
250+ // quant_map_[thread_idx + 80] = value;
251+ // quant_map_[thread_idx + 96] = value;
252+ // quant_map_[thread_idx + 112] = value;
253253 }
254254 barrier_arrive (3 );
255255 // }
@@ -396,6 +396,7 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
396396 };
397397 #endif
398398#else // register
399+ #if 0
399400 auto dequant = [&] (float* quant_map){
400401 constexpr int N = decltype(cute::size<1>(mma_B))::value;
401402 constexpr int K = decltype(cute::size(mma_B))::value / N;
@@ -437,14 +438,14 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
437438 src_1 = reinterpret_cast<src_compress_type*>(cute::raw_pointer_cast(dequant_frag.data()))[v];
438439 int c = 0;
439440 uint8_t bit_value = (src_2 >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
440- float converted_value_1 = quant_map[bit_value + (dst_base_idx + c) % 2 * 16 ];
441+ float converted_value_1 = quant_map[bit_value + (dst_base_idx + c) % 4 * 16];
441442 float converted_value_2 = 0.f;
442443 #pragma unroll
443444 for (; c < src_compress_size-1;) {
444445 converted_value_2 = converted_value_1;
445446 c++;
446447 bit_value = (src_2 >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
447- converted_value_1 = quant_map[bit_value + (dst_base_idx + c - 1 ) % 2 * 16 ];
448+ converted_value_1 = quant_map[bit_value + (dst_base_idx + c - 1) % 4 * 16];
448449 dst[dst_base_idx + c-1] = static_cast<ElementMMA>(converted_value_2 * scale_value);
449450 }
450451 dst[dst_base_idx + c] = static_cast<ElementMMA>(converted_value_1 * scale_value);
@@ -459,14 +460,14 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
459460 //int map_offset = dst_base_idx % 2 * 16;
460461 int c = 0;
461462 uint8_t bit_value = (src_2 >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
462- float converted_value_1 = quant_map[bit_value + (dst_base_idx + c) % 2 * 16 ];
463+ float converted_value_1 = quant_map[bit_value + (dst_base_idx + c) % 4 * 16];
463464 float converted_value_2 = 0.f;
464465 #pragma unroll
465466 for (; c < src_compress_size-1;) {
466467 converted_value_2 = converted_value_1;
467468 c++;
468469 bit_value = (src_2 >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
469- converted_value_1 = quant_map[bit_value + (dst_base_idx + c - 1 ) % 2 * 16 ];
470+ converted_value_1 = quant_map[bit_value + (dst_base_idx + c - 1) % 4 * 16];
470471 dst[dst_base_idx + c-1] = static_cast<ElementMMA>(converted_value_2 * scale_value);
471472 }
472473 dst[dst_base_idx + c] = static_cast<ElementMMA>(converted_value_1 * scale_value);
@@ -499,6 +500,60 @@ printf("src_compress_size = %d, dst_compress_size = %d, src_vec_size = %d, dst_v
499500// reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[1] = reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[1];
500501
501502 };
503+ #else
504+ auto dequant = [&] (float * quant_map){
505+ constexpr int N = decltype (cute::size<1 >(mma_B))::value;
506+ constexpr int K = decltype (cute::size (mma_B))::value / N;
507+ // if(cute::thread0) printf("scale num = %d\n", decltype(cute::size(fragment_scale))::value);
508+
509+ using src_compress_type = uint64_t ;
510+ using dst_compress_type = uint64_t ;
511+ constexpr int src_compress_size = cute::sizeof_bits_v<src_compress_type> / cute::sizeof_bits_v<ElementB>; // 16
512+ constexpr int dst_compress_size = cute::sizeof_bits_v<dst_compress_type> / cute::sizeof_bits_v<ElementMMA>; // 4
513+ constexpr int src_vec_size = (K / src_compress_size) >= 16 ? 16 : K / src_compress_size; // 4, 16 -> max vec_size of sycl::vec
514+ constexpr int dst_vec_size = (K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; // 16, 16 -> max vec_size of sycl::vec
515+ constexpr int src_loop_num = K / src_vec_size / src_compress_size;
516+ constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
517+ src_compress_type src[src_vec_size];
518+ ElementMMA dst[dst_loop_num * dst_compress_size * dst_vec_size];
519+
520+
521+ #pragma unroll
522+ for (int n = 0 ; n < N; n++) {
523+ // float scale_value = fragment_scale(0);
524+ #pragma unroll
525+ for (int l = 0 ; l < src_loop_num; l++) {
526+ reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(src)[0 ] = reinterpret_cast <sycl::vec<src_compress_type, src_vec_size>*>(cute::raw_pointer_cast (dequant_frag.data ()))[n*src_loop_num + l];
527+
528+ #pragma unroll
529+ for (int v = 0 ; v < src_vec_size; v++) {
530+ src_compress_type src_value = src[v];
531+ int dst_base_idx = l * src_vec_size * src_compress_size + v * src_compress_size;
532+ #pragma unroll
533+ for (int c = 0 ; c < src_compress_size; c++) {
534+ uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
535+ float scale_value = fragment_scale ((n * BLK_K + dst_base_idx + c) >> (31 - std::countl_zero<unsigned int >(GROUP_SIZE )));
536+ // dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map[bit_value + (dst_base_idx + c) % 4 * 16] * scale_value);
537+ dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map[bit_value] * scale_value);
538+
539+ // uint8_t high = (src_value >> (4 * (c * 2 + 1))) & 0xf;
540+ // uint8_t low = (src_value >> (4 * (c * 2))) & 0xf;
541+ // float ts_high = fragment_scale((n * BLK_K + dst_base_idx + 2 * c) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));;
542+ // float ts_low = fragment_scale((n * BLK_K + dst_base_idx + 2 * c + 1) >> (31 - std::countl_zero<unsigned int>(GROUP_SIZE)));;
543+ // dst[dst_base_idx + 2 * c] = static_cast<ElementMMA>(quant_map[high] * ts_high);
544+ // dst[dst_base_idx + 2 * c + 1] = static_cast<ElementMMA>(quant_map[low] * ts_low);
545+ }
546+ }
547+ }
548+
549+ #pragma unroll
550+ for (int l = 0 ; l < dst_loop_num; l++) {
551+ reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast (mma_B.data ()))[n * dst_loop_num + l] = reinterpret_cast <sycl::vec<dst_compress_type, dst_vec_size>*>(dst)[l];
552+
553+ }
554+ }
555+ };
556+ #endif
502557#endif
503558
504559 CUTLASS_PRAGMA_UNROLL
@@ -578,7 +633,7 @@ void gemm_4bit_cutlass(int m, int n, int k, int l, T *A, unsigned char *B,
578633 // std::cout<<"group_size = "<<blocksize<<std::endl;
579634
580635#if 1
581- static constexpr int smem_size= (32 ) * sizeof (float ) * 2 * 2 ;
636+ static constexpr int smem_size= (32 ) * sizeof (float ) * 2 ; // * 2;
582637#else
583638 static constexpr int smem_size = BLK_N * BLK_K * sizeof(ElementMMA) * 2 * 2; //aligned with 128B and will be reused for dequant src and dst.
584639 #endif
0 commit comments