@@ -40,7 +40,7 @@ using namespace cutlass::gemm;
4040
4141// Define Basic information
4242// Weight-only-quant (B)
43- using MmaType = cutlass::bfloat16_t ;
43+ using MmaType = sycl::ext::oneapi::bfloat16; // cutlass::bfloat16_t;
4444using QuantType = cutlass::uint4_t ; // NF4,FP4
4545
4646using ElementA = MmaType; // bfloat16_t;
@@ -50,18 +50,23 @@ using ElementMMA = ElementA;
5050using ElementQuant = QuantType;
5151using ElementScale = MmaType; // sycl::ext::oneapi::bfloat16; //MmaType;
5252
53- using ElementC = float ;
54- using ElementD = float ;
55- using ElementAccumulator = float ; // data_type of accumulator
56- using ElementComputeEpilogue = float ; // data_type of epilogue operations
57- using ElementOutput = float ;
53+ using ElementAccumulator = MmaType; // data_type of accumulator
54+ using ElementComputeEpilogue = MmaType; // data_type of epilogue operations
55+ using ElementOutput = MmaType;
5856
5957using ProblemShape = Shape<int , int , int , int >;
6058
59+ #if 1
6160using TileShape = Shape<_256, _256, _32>;
6261using TiledMma =
63- typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT >, Layout<TileShape>,
62+ typename TiledMMAHelper<MMA_Atom<XE_8x16x16_BF16BF16BF16BF16_TT >, Layout<TileShape>,
6463 Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
64+ #else
65+ using TileShape = Shape<_16, _64, _64>;
66+ using TiledMma =
67+ typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32F16F16F32_TT>, Layout<TileShape>,
68+ Layout<Shape<_1, _2, _1>, Stride<_2, _1, _0>>>::TiledMMA;
69+ #endif
6570
6671using WorkgroupTileShape = TileShape;
6772static constexpr auto BLK_M = get<0 >(WorkgroupTileShape{}); // 256 //16
@@ -94,7 +99,8 @@ static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; // 16
9499
95100// Design Epilogue
96101using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue;
97- using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementAccumulator, ElementComputeEpilogue, ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
102+ // constexpr int kAlignment = 128 / sizeof(ElementOutput);
103+ using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementComputeEpilogue, ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
98104using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape, decltype (tile_shape(TiledMma()))>;
99105using SharedStorage = FusionCallBacks::SharedStorage;
100106
@@ -115,9 +121,9 @@ using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
115121 ElementOutput,
116122 cutlass::gemm::TagToStrideC_t<cutlass::layout::RowMajor>, // Convert CUTLASS 2.x to CUTLASS 3.x representation
117123 FusionCallBacks,
118- XE_2D_U32x8x16_LD_N , // The copy atom used to load matrix C
124+ XE_2D_U16x8x16_LD_N , // The copy atom used to load matrix C
119125 void , void ,
120- XE_2D_U32x8x16_ST_N , // The copy atom used to store matrix D
126+ XE_2D_U16x8x16_ST_N , // The copy atom used to store matrix D
121127 void , void >;
122128using EpilogueParams = typename CollectiveEpilogue::Params;
123129
@@ -166,7 +172,7 @@ class kgemm_4bit_inference_cutlass_dequant {
166172 int m, n, k;
167173 T* A;
168174 uint8_t * B;
169- float * out;
175+ T * out;
170176 float *datatype; // LUT
171177 int group_size;
172178
@@ -279,7 +285,7 @@ class kgemm_4bit_inference_cutlass_dequant {
279285
280286 T* A = params.A ;
281287 uint8_t * B = params.B ;
282- float * out = params.out ;
288+ T * out = params.out ;
283289 float * datatype = params.datatype ;
284290
285291 auto tiled_copy_a = params.tiled_copy_a ;
@@ -544,7 +550,7 @@ printf("\n");
544550
545551template <typename T, int BITS >
546552void gemm_4bit_inference_cutlass_dequant (int m, int n, int k, T *A, unsigned char *B,
547- T *absmax_, float *datatype, float *out, int lda,
553+ T *absmax_, float *datatype, T *out, int lda,
548554 int ldb, int ldc, int blocksize, sycl::queue *stream) {
549555 // std::cout<<"this is gemm_4bit_inference_cutlass_dequant ......................!!!!!!\n";
550556
@@ -593,8 +599,8 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
593599 cutlass::KernelHardwareInfo hw_info;
594600 hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count (hw_info.device_id );
595601 auto problem_shape_MNKL = problem_size; // append<4>(problem_size, 1);
596- float alpha=1 .0f ;
597- float beta=0 .f ;
602+ T alpha=1 .0f ;
603+ T beta=0 .f ;
598604 StrideC stride_C = cutlass::make_cute_packed_stride (StrideC{}, cute::make_shape (m, n, l));
599605 StrideD stride_D = cutlass::make_cute_packed_stride (StrideD{}, cute::make_shape (m, n, l));
600606
@@ -649,6 +655,6 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
649655
650656template void gemm_4bit_inference_cutlass_dequant<sycl::ext::oneapi::bfloat16, 16 >(
651657 int m, int n, int k, sycl::ext::oneapi::bfloat16 *A, unsigned char *B,
652- sycl::ext::oneapi::bfloat16 *absmax, float *datatype, float *out, int lda,
658+ sycl::ext::oneapi::bfloat16 *absmax, float *datatype, sycl::ext::oneapi::bfloat16 *out, int lda,
653659 int ldb, int ldc, int blocksize, sycl::queue *stream);
654660
0 commit comments