@@ -36,7 +36,7 @@ using namespace cutlass::gemm;
3636
3737// Define Basic information
3838// Weight-only-quant (B)
39- using MmaType = cutlass::bfloat16_t ;
39+ using MmaType = sycl::ext::oneapi::bfloat16; // cutlass::bfloat16_t;
4040using QuantType = cutlass::uint4_t ; // NF4,FP4
4141
4242using ElementA = MmaType;
@@ -186,6 +186,7 @@ class gemm_4bit_cutlass_kernel {
186186 ? BlockIdxX () : BlockIdxY ();
187187 const int l_coord = BlockIdxZ ();
188188
189+ #if 1
189190 float * quant_map;
190191 {
191192 // Load Dequatize LUT and save to SLM, 16 for 4bits
@@ -195,7 +196,14 @@ class gemm_4bit_cutlass_kernel {
195196 }
196197 barrier_arrive (3 );
197198 }
198-
199+ #else
200+ constexpr float quant_map[16] = {
201+ -1.0f, -0.6961928f, -0.52507305f, -0.39491749f,
202+ -0.28444138f, -0.18477343f, -0.09105004f, 0.0f,
203+ 0.0795803f, 0.1609302f, 0.2461123f, 0.33791524f,
204+ 0.44070983f, 0.562617f, 0.72295684f, 1.0f
205+ };
206+ #endif
199207 Tensor mA_mkl = cute::get_pvc_tensor (make_shape (params.m , params.k , params.l ));
200208 Tensor mB_nkl = cute::get_pvc_tensor (make_shape (params.n , params.k ,1 ));
201209
@@ -260,33 +268,38 @@ class gemm_4bit_cutlass_kernel {
260268 const int k_start_idx = crd2idx ((*k_tile_iter), make_shape (params.k ));
261269 int prefetch_k = k_start_idx;
262270
271+ #if 0
272+ auto convert = [](uint8_t quant_idx, float scale) {
273+ const float range = 2.0f; // 假设量化范围[-1,1]
274+ return ((quant_idx / 7.5f) - 1.0f) * scale; // 7.5=15/2 (4-bit)
275+ };
276+ #endif
263277 auto dequant = [&] {
264278 constexpr int N = decltype (cute::size<1 >(mma_B))::value;
265279 constexpr int K = decltype (cute::size (mma_B))::value / N;
266- // if(cute::thread0()) printf("K = %d, N = %d\n", K, N);
267280
268281 using compress_type = uint32_t ;
269282 constexpr int compress_size = cute::sizeof_bits_v<compress_type> / cute::sizeof_bits_v<ElementB>;
270- constexpr auto vec_size = K / compress_size;
283+ constexpr int vec_size = K / compress_size;
271284
272- using VecSrcType = cute::array<compress_type, vec_size> ;
273- using VecDstElemType = cute::array<ElementMMA, compress_size> ;
274- using VecDstType = cute::array<VecDstElemType, vec_size> ;
285+ // if(cute::thread0()) printf("N = %d, K = %d, compress_size = %d, vec_size = %d\n", N, K, compress_size, vec_size) ;
286+ compress_type src[vec_size] ;
287+ ElementMMA dst[K] ;
275288
276289 float scale_value = fragment_scale (0 );
277- auto src = *(VecSrcType*)(cute::raw_pointer_cast (dequant_frag.data ()));
278- auto & dst = *(VecDstType*)(cute::raw_pointer_cast (mma_B.data ()));
279- VecDstType dst_val;
280- #pragma unroll
281- for (int i = 0 ; i < vec_size; i++) {
282- VecDstElemType dst_elem;
290+
291+ reinterpret_cast <sycl::vec<compress_type, vec_size>*>(src)[0 ] = reinterpret_cast <sycl::vec<compress_type, vec_size>*>(cute::raw_pointer_cast (dequant_frag.data ()))[0 ];
292+
293+ #pragma unroll
294+ for (int i = 0 ; i < vec_size; i++) {
283295 #pragma unroll
284296 for (int j = 0 ; j < compress_size; j++) {
285- dst_elem[j] = static_cast <ElementMMA>(quant_map[(src[i] >> (4 * ((j+1 )%2 + (j/2 )*2 ))) & 0xf ] * scale_value);
297+ uint8_t bit_value = (src[i] >> (4 * ((j+1 )%2 + (j/2 )*2 ))) & 0xf ;
298+ dst[i*compress_size+j] = static_cast <ElementMMA>(quant_map[bit_value] * scale_value);
299+ // dst[i*compress_size+j] = static_cast<ElementMMA>(convert(bit_value, scale_value));
286300 }
287- dst_val[i] = dst_elem;
288- }
289- dst = dst_val;
301+ }
302+ reinterpret_cast <sycl::vec<int64_t , 16 >*>(cute::raw_pointer_cast (mma_B.data ()))[0 ] = reinterpret_cast <sycl::vec<int64_t , 16 >*>(dst)[0 ];
290303 };
291304
292305 CUTLASS_PRAGMA_UNROLL
@@ -338,7 +351,7 @@ void gemm_4bit_cutlass(int m, int n, int k, int l, T *A, unsigned char *B,
338351
339352 using GemmKernel = gemm_4bit_cutlass_kernel<T, BITS >;
340353
341- static constexpr int smem_size= 16 *32 /8 ;
354+ static constexpr int smem_size= ( 16 + 1 ) *32 /8 ;
342355
343356 auto problem_size = ProblemShape{m, n, k, l};
344357
0 commit comments