1414#include " cutlass/detail/layout.hpp"
1515#include " cutlass/detail/mma.hpp"
1616#include " cutlass/cuda_host_adapter.hpp"
17+ #include < cutlass/numeric_types.h>
18+ #include < cutlass/bfloat16.h>
1719
1820#include " cutlass/kernel_launch.h"
1921#if !defined(__CUDACC_RTC__)
@@ -46,7 +48,7 @@ using ElementB = QuantType; //cutlass::gemm::collective::detail::deduce_mixed_wi
4648
4749using ElementMMA = ElementA;
4850using ElementQuant = QuantType;
49- using ElementScale = MmaType;
51+ using ElementScale = sycl::ext::oneapi::bfloat16; // MmaType;
5052
5153using ElementC = float ;
5254using ElementD = float ;
@@ -63,11 +65,11 @@ using ProblemShape = Shape<int, int, int, int>;
6365// inner_loop_number (Atom numbers per thread): (256/8) * (256/4) * (32/1)
6466// XE_8x16x16_F32BF16BF16F32_TT -> hardware 指令
6567// Stride<_4, _1, _0> could be optional?
66- using TileShape = Shape<_256, _256, _32>;
68+ using TileShape = Shape<_256, _256, _32>; /* TODO: maybe need to adjust me for the small tile shapes */
6769// using TileShape = Shape<_32, _32, _32>;
6870using TiledMma =
6971 typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
70- Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
72+ Layout<Shape<_8, _4, _1> /* TODO: maybe need to adjust me for the small tile shapes */ , Stride<_4, _1, _0>>>::TiledMMA;
7173
7274// Define Mainloop dispatch policy
7375constexpr int PipelineStages = 1 ;
@@ -228,7 +230,12 @@ class kgemm_4bit_inference_cutlass_dequant {
228230 return Int<cute::gcd (Cosize, 32 / cute::min (sizeof_bits_v<SrcType>, sizeof_bits_v<DstType>))>{};
229231 }
230232 };
231-
233+
234+ float bfloat16_to_float (uint16_t bf16_bits) {
235+ uint32_t float_bits = (bf16_bits << 16 ); // 将 bfloat16 左移16位转为 float
236+ return reinterpret_cast <float &>(float_bits);
237+ }
238+
232239 // / Utilities to transform A.
233240 template <class EngineIn ,
234241 class EngineOut ,
@@ -251,20 +258,39 @@ class kgemm_4bit_inference_cutlass_dequant {
251258 static_assert (size_v<LayoutIn> == cosize_v<LayoutIn>);
252259 static_assert (size_v<LayoutOut> == cosize_v<LayoutOut>);
253260
254- for (int i = 0 ; i < size (tCrB_src); ++i) {
255- // uint8_t src_value = tCrB_src(i);
256- // tCrB_dst(2*i) = static_cast<ElementMMA>(quant_map[(src_value >> 4) & 0x0F]);// * tCrS(i/4) ;
257- // tCrB_dst(2*i+1) = static_cast<ElementMMA>(quant_map[src_value & 0x0F]);// * tCrS(i/4);
258- uint8_t packed = tCrB_src (i);
259- uint8_t high = (packed >> 4 ) & 0x0F ;
260- uint8_t low = packed & 0x0F ;
261-
262- // 应用缩放因子
263- float val_high = quant_map[high];// * tCrS(i/4); // 假设每32个元素共享一个scale
264- float val_low = quant_map[low];// * tCrS(i/4);
265-
266- tCrB_dst (2 *i) = static_cast <ElementMMA>(val_high);
267- tCrB_dst (2 *i+1 ) = static_cast <ElementMMA>(val_low);
261+ int scale_number = decltype (size (tCrS))::value;
262+ int src_number = decltype (size (tCrB_src))::value;
263+ int src_sub_number = src_number / scale_number;
264+
265+ // float scale_value = 1.0;
266+
267+ if (cute::thread0 ()) printf (" scale_number = %d, src_number= %d, src_sub_number = %d\n " , scale_number, src_number, src_sub_number);
268+ for (int i=0 ; i < scale_number; ++i) {
269+ auto scale_value = tCrS (i);
270+ float scale_value_float = static_cast <float >(scale_value);
271+ // uint16_t scale_bits = reinterpret_cast<uint16_t&>(scale_value);
272+ // uint16_t scale_bits = reinterpret_cast<uint16_t&>(scale_value);
273+ // printf("scale_value = %f, tCrS(%d) raw bits: %d\n",scale_value, i, static_cast<int>(scale_value));
274+ // float scale_value_float = bfloat16_to_float(scale_bits);
275+ printf (" scale_value_float = %f\n " , scale_value_float);
276+ for (int j = 0 ; j < src_sub_number; ++j) {
277+ int offset = i * src_sub_number;
278+ uint8_t packed = tCrB_src (offset + j);
279+ uint8_t high = (packed >> 4 ) & 0x0F ;
280+ uint8_t low = packed & 0x0F ;
281+
282+ float val_high = quant_map[high];
283+ float val_low = quant_map[low];
284+
285+ float val_high_scaled = val_high * scale_value_float;
286+ float val_low_scaled = val_high * scale_value_float;
287+
288+ // printf("scale value = %f, val_high_scaled = %f, val_low_scaled = %f\n", scale_value, val_high_scaled, val_low_scaled);
289+ tCrB_dst (offset + 2 * j) = static_cast <ElementMMA>(val_high_scaled);// * scale_value;
290+ tCrB_dst (offset + 2 * j + 1 ) = static_cast <ElementMMA>(val_low_scaled);// * scale_value;
291+
292+ // printf("scale value = %f, val_high = %f, val_low = %f\n", scale_value, val_high, val_low);
293+ }
268294 }
269295 }
270296
@@ -387,7 +413,7 @@ class kgemm_4bit_inference_cutlass_dequant {
387413
388414 Tensor tSgS = [&](){
389415 return make_tensor (make_inttuple_iter (make_coord (n_coord, 0 , l_coord)), // 初始坐标:(n_coord, 0, l_coord),表示从 N 维的 n_coord 开始,K 维从 0 开始
390- make_layout (make_shape (_2{}, _2{}, _1{}, k_tile_count), // 迭代器的逻辑形状:[2, 2, 1, k_tile_count],表示每次迭代生成 2x2x1 的坐标块,共 k_tile_count 次
416+ make_layout (make_shape (_2{}, _2{}, _1{}, k_tile_count / 2 ), // 迭代器的逻辑形状:[2, 2, 1, k_tile_count],表示每次迭代生成 2x2x1 的坐标块,共 k_tile_count 次
391417 make_stride (E<0 >{} * _16{}, E<0 >{} * _32{}, _0{}, E<1 >{} * _1{}))); // 步长 [16, 32, 0, 1]:
392418 }();
393419
@@ -524,12 +550,20 @@ for(int i=0; i<num_Acc; i++) {
524550 }
525551};
526552
553+ void convert_float_to_bfloat16_host (float * src, cutlass::bfloat16_t * dst, int size) {
554+ for (int i = 0 ; i < size; ++i) {
555+ dst[i] = static_cast <cutlass::bfloat16_t >(src[i]); // 直接转换
556+ }
557+ }
558+
527559template <typename T, int BITS >
528560void gemm_4bit_inference_cutlass_dequant (int m, int n, int k, T *A, unsigned char *B,
529- float *absmax_, float *datatype, float *out, int lda,
561+ T *absmax_, float *datatype, float *out, int lda,
530562 int ldb, int ldc, int blocksize, sycl::queue *stream) {
531563 std::cout<<" this is gemm_4bit_inference_cutlass_dequant ......................!!!!!!\n " ;
532564
565+
566+
533567 sycl::queue q = *stream;
534568 using GemmKernel = kgemm_4bit_inference_cutlass_dequant<T, BITS >;
535569
@@ -539,8 +573,26 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
539573 // TODO(Xiaoli): FIX ME?? auto problem_size = ProblemShape{m, n, k};
540574 auto problem_size = ProblemShape{m, n, k, l};
541575 // TODO(Xiaoli): FIX ME
542- T* absmax = (T*)absmax_;
543-
576+ // T* absmax = (T*)absmax_;
577+ // T* absmax = static_cast<T*>(absmax_);
578+ // auto absmax = reinterpret_cast<ElementScale*>((T*)absmax_);
579+ // int scale_size = 2048;
580+ // ElementScale* absmax = new ElementScale[scale_size];
581+ // for (int i = 0; i < scale_size; ++i) {
582+ // absmax[i] = static_cast<ElementScale>(absmax_[i]); // 逐元素转换
583+ // std::cout<<"absmax_[i] = "<<absmax_[i]<<" absmax[i] = "<< absmax[i]<<std::endl;
584+ // }
585+
586+ // int N = 2048;
587+ // printf("A[0] = %f\n", A[0]);
588+ // printf("absmax_[0] = %f\n", absmax_[0]);
589+ // cutlass::bfloat16_t* absmax_test = new cutlass::bfloat16_t[N];
590+ // std::cout<<"absmax_[0] ==== "<<absmax_[0]<<std::endl;
591+ // convert_float_to_bfloat16_host(absmax_, absmax, N);
592+ // absmax_test[0] = static_cast<cutlass::bfloat16_t>(absmax[0]);
593+ // printf("absmax_test[0] = %f\n", absmax_test[0]);
594+ // std::cout<<"absmax_test[0] ==== "<<absmax_test[0]<<std::endl;
595+ #if 1
544596 // Init Params
545597 using Params = GemmKernel::Params;
546598 Params params;
@@ -570,12 +622,16 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
570622 params.tiled_copy_a = tiled_copy_a;
571623 params.tiled_copy_b = tiled_copy_b;
572624 params.tiled_copy_b_4bit = tiled_copy_b_4bit;
573-
574- const int scale_k = cute::ceil_div (k, blocksize);
625+
626+ // float f = 1.0f;
627+ // cutlass::bfloat16_t bf16 = static_cast<cutlass::bfloat16_t>(f);
628+ // std::cout<<"k = "<<k<<" blocksize = "<<blocksize<<std::endl;
629+ const int scale_k = cute::ceil_div (k / 2 , blocksize);
630+ std::cout<<" n = " <<n<<" k = " <<k<<" blocksize = " <<blocksize<<" scale_k = " <<scale_k<<std::endl;
575631 const int dq_mn_size = n; // A(m) or B(n)
576632 StrideScale stride_S = cutlass::make_cute_packed_stride (StrideScale{}, cute::make_shape (dq_mn_size, scale_k, l));
577633 auto mScale = make_tensor (
578- make_gmem_ptr (absmax ), // static_cast<NonVoidElementScale *>(absmax)),
634+ make_gmem_ptr (absmax_ ), // static_cast<ElementScale *>(absmax)),
579635 make_layout (make_shape (dq_mn_size, scale_k, l), stride_S));
580636 Copy_Scale tiled_copy_scale{Copy_Scale{}.with (mScale )};
581637
@@ -627,10 +683,11 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
627683 auto event = syclcompat::experimental::launch<device_kernel<GemmKernel>>(policy, q, params);
628684 EventManager::getInstance ().addEvent (event);
629685 // syclcompat::wait();
686+ #endif
630687}
631688
632689template void gemm_4bit_inference_cutlass_dequant<sycl::ext::oneapi::bfloat16, 16 >(
633690 int m, int n, int k, sycl::ext::oneapi::bfloat16 *A, unsigned char *B,
634- float *absmax, float *datatype, float *out, int lda,
691+ sycl::ext::oneapi::bfloat16 *absmax, float *datatype, float *out, int lda,
635692 int ldb, int ldc, int blocksize, sycl::queue *stream);
636693
0 commit comments