@@ -53,6 +53,14 @@ using ElementOutput = float;
5353
5454using ProblemShape = Shape<int , int , int , int >;
5555
56+ // constexpr int kQuantMapSize = 16;
57+ static constexpr float quant_map[16 ] = {
58+ -1 .0f , -0 .6961928f , -0 .52507305f , -0 .39491749f ,
59+ -0 .28444138f , -0 .18477343f , -0 .09105004f , 0 .0f ,
60+ 0 .0795803f , 0 .1609302f , 0 .2461123f , 0 .33791524f ,
61+ 0 .44070983f , 0 .562617f , 0 .72295684f , 1 .0f
62+ };
63+
5664// #ifndef METHOD
5765// #define METHOD 1
5866// #endif
@@ -69,10 +77,10 @@ using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
6977constexpr int PipelineStages = 2;
7078
7179#else
72- using TileShape = Shape<_32, _64, _64 >;
80+ using TileShape = Shape<_128, _128, _32 >;
7381 using TiledMma =
74- typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT >, Layout<TileShape>,
75- Layout<Shape<_1, _4 , _1>, Stride<_4 , _1, _0>>>::TiledMMA;
82+ typename TiledMMAHelper<MMA_Atom<XE_4x16x16_F32BF16BF16F32_TT >, Layout<TileShape>,
83+ Layout<Shape<_4, _8 , _1>, Stride<_8 , _1, _0>>>::TiledMMA;
7684 using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
7785 constexpr int PipelineStages = 4 ;
7886#endif
@@ -133,9 +141,9 @@ using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
133141 ElementOutput,
134142 cutlass::gemm::TagToStrideC_t<cutlass::layout::RowMajor>, // Convert CUTLASS 2.x to CUTLASS 3.x representation
135143 FusionCallBacks,
136- XE_2D_U32x8x16_LD_N , // The copy atom used to load matrix C
144+ XE_2D_U32x4x16_LD_N , // The copy atom used to load matrix C
137145 void , void ,
138- XE_2D_U32x8x16_ST_N , // The copy atom used to store matrix D
146+ XE_2D_U32x4x16_ST_N , // The copy atom used to store matrix D
139147 void , void >;
140148using EpilogueParams = typename CollectiveEpilogue::Params;
141149
@@ -174,7 +182,6 @@ using Copy_Scale = decltype(make_tiled_copy(atom_load_scale{}, Layout<CopyThread
174182using StrideC = cutlass::gemm::TagToStrideC_t<cutlass::layout::RowMajor>;
175183using StrideD = cutlass::gemm::TagToStrideC_t<cutlass::layout::RowMajor>;
176184
177-
178185template <typename T, int BITS >
179186class gemm_4bit_cutlass_kernel {
180187public:
@@ -247,41 +254,93 @@ class gemm_4bit_cutlass_kernel {
247254 static constexpr auto compress_size = sizeof_bits_v<compress_type> / sizeof_bits_v<SrcType>;
248255 static_assert ((compress_size % N) == 0 );
249256
250- static constexpr auto vec_size = 8 ;
257+ static constexpr auto vec_size = 4 ;
251258 using VecSrcType = cute::array<compress_type, vec_size>;
252259 using VecDstElemType = cute::array<DstType, compress_size>;
253260 using VecDstType = cute::array<VecDstElemType, vec_size>;
254261
255262 auto s_tensor = make_tensor ((VecSrcType*)(raw_pointer_cast (in.data ())), Shape<Int<K / (compress_size * vec_size)>, Int<N>>{});
256263 auto d_tensor = make_tensor ((VecDstType*)(raw_pointer_cast (out.data ())), Shape<Int<K / (compress_size * vec_size)>, Int<N>>{});
257-
258- // if(cute::thread0()) printf("decltype(size(out))::value = %d, N = %d, K = %d, compress_size = %d, vec_size = %d\n", decltype(size(out))::value, N, K, compress_size, vec_size);
259- #pragma unroll
264+
265+ // constexpr float quant_map[16] = {
266+ // -1.0,
267+ // -0.6961928009986877,
268+ // -0.5250730514526367,
269+ // -0.39491748809814453,
270+ // -0.28444138169288635,
271+ // -0.18477343022823334,
272+ // -0.09105003625154495,
273+ // 0.0,
274+ // 0.07958029955625534,
275+ // 0.16093020141124725,
276+ // 0.24611230194568634,
277+ // 0.33791524171829224,
278+ // 0.44070982933044434,
279+ // 0.5626170039176941,
280+ // 0.7229568362236023,
281+ // 1.0,
282+ // };
283+ if (cute::thread0 ()) printf (" decltype(size(out))::value = %d, N = %d, K = %d, compress_size = %d, vec_size = %d\n " , decltype (size (out))::value, N, K, compress_size, vec_size);
284+ #if 1
285+ // #pragma unroll
260286 for (int n = 0 ; n < N; n++) {
261287 float ts = tCrS_input (n);
262288 auto & src = *(cute::array<VecSrcType, K / (compress_size * vec_size)>*)(s_tensor (_, n).data ());
263289 auto & dst = *(cute::array<VecDstType, K / (compress_size * vec_size)>*)(d_tensor (_, n).data ());
264290
265- #pragma unroll
291+ // #pragma unroll
266292 for (int k = 0 ; k < K / (compress_size * vec_size); k++) {
267- VecSrcType src_val = src[k];
293+ // VecSrcType src_val = src[k];
268294 VecDstType dst_val;
269295
270- #pragma unroll
296+ // #pragma unroll
271297 for (int i = 0 ; i < vec_size; i++) {
272- compress_type compressed_val = src_val[i];
298+ // compress_type compressed_val = src_val[i];
273299 VecDstElemType dst_elem;
274-
275- #pragma unroll
276- for (int j = 0 ; j < compress_size / 2 ; j++) {
277- dst_elem[2 *j] = static_cast <DstType>(quant_map[(compressed_val >> (4 * (j * 2 + 1 ))) & 0xf ] * ts);
278- dst_elem[2 *j+1 ] = static_cast <DstType>(quant_map[(compressed_val >> (4 * (j * 2 ))) & 0xf ] * ts);
300+
301+ // float4 vals = reinterpret_cast<const float4*>(quant_map)[idx/4];
302+ // #pragma unroll
303+ for (int j = 0 ; j < compress_size; j++) {
304+ // uint8_t high = (src[k][i]>> (4 * (j * 2 + 1))) & 0xf;
305+ // uint8_t low = (compressed_val >> (4 * (j * 2))) & 0xf;
306+ // dst[k][i][j] = static_cast<DstType>(quant_map[(src[k][i]>> (4 * (j * 2 + 1))) & 0xf] * ts);
307+ // dst_elem[j] = static_cast<DstType>(quant_map[(src[k][i]>> (4 * (j * 2 + 1))) & 0xf] * ts);
308+ dst_elem[j] = static_cast <DstType>(1 .5f * ts);
309+ // dst_elem[2*j+1] = static_cast<DstType>(quant_map[low] * ts);
279310 }
280311 dst_val[i] = dst_elem;
281312 }
282313 dst[k] = dst_val;
283314 }
284315 }
316+ #else
317+ constexpr int shifts[8] = {4,0,12,8,20,16,28,24};
318+
319+ #pragma unroll
320+ for (int n = 0; n < N; n++) {
321+ DstType ts = static_cast<DstType>(tCrS_input(n));
322+ auto& src = *(cute::array<VecSrcType, K / (compress_size * vec_size)>*)(s_tensor(_, n).data());
323+ auto& dst = *(cute::array<VecDstType, K / (compress_size * vec_size)>*)(d_tensor(_, n).data());
324+
325+ const auto src_val = src[0];
326+ VecDstType dst_val;
327+ #pragma unroll
328+ for (int i = 0; i < vec_size; ++i) {
329+ const compress_type val = src_val[i];
330+ VecDstElemType dst_elem;
331+ dst_elem[0] = quant_map[(val>>shifts[0])&0xF] * ts;
332+ dst_elem[1] = quant_map[(val>>shifts[1])&0xF] * ts;
333+ dst_elem[2] = quant_map[(val>>shifts[2])&0xF] * ts;
334+ dst_elem[3] = quant_map[(val>>shifts[3])&0xF] * ts;
335+ dst_elem[4] = quant_map[(val>>shifts[4])&0xF] * ts;
336+ dst_elem[5] = quant_map[(val>>shifts[5])&0xF] * ts;
337+ dst_elem[6] = quant_map[(val>>shifts[6])&0xF] * ts;
338+ dst_elem[7] = quant_map[(val>>shifts[7])&0xF] * ts;
339+ dst_val[i] = dst_elem;
340+ }
341+ dst[0] = dst_val;
342+ }
343+ #endif
285344 }
286345
287346 CUTLASS_DEVICE
@@ -311,13 +370,31 @@ class gemm_4bit_cutlass_kernel {
311370 for(int i=0; i<16; i++){
312371 quant_map[i] = datatype[i];
313372 }
314- #else
373+ #else
315374 float * quant_map = reinterpret_cast <float *>(smem_buf);
316375 if (thread_idx < 16 ) {
317376 quant_map[thread_idx] = datatype[thread_idx];
318377 }
319378 barrier_arrive (3 );
320379#endif
380+ // constexpr float quant_map[16] = {
381+ // -1.0,
382+ // -0.696,//,1928009986877,
383+ // -0.525,//,0730514526367,
384+ // -0.394,//,91748809814453,
385+ // -0.284,//,44138169288635,
386+ // -0.184,//,77343022823334,
387+ // -0.091,//,05003625154495,
388+ // 0.0,
389+ // 0.079,//58029955625534,
390+ // 0.160,//93020141124725,
391+ // 0.246,//11230194568634,
392+ // 0.337,//91524171829224,
393+ // 0.440,//70982933044434,
394+ // 0.562,//6170039176941,
395+ // 0.722,//9568362236023,
396+ // 1.0,
397+ // };
321398 auto blk_shape = TileShape{};
322399 int m_coord, n_coord, l_coord;
323400 if (params.scheduler .raster_order_ == TileScheduler::RasterOrder::AlongN) {
@@ -459,19 +536,24 @@ class gemm_4bit_cutlass_kernel {
459536 }
460537
461538 for (int k_tile = k_start_idx, k_s = 0 ; k_tile < k_tile_count + k_start_idx; k_tile++, prefetch_k++, k_s++) {
462- copy (tiled_copy_a, tAgA (_,_,_,k_tile), frag_copy_A);
539+ // copy(tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
463540 copy (tiled_copy_b, tBgB (_,_,_,k_tile), frag_copy_B);
464541
465542 const int s_idx = (k_start_idx + k_s) / k_reload_factor;
466543 copy (tiled_copy_scale, tSgS (_, _, _, s_idx), frag_copy_Scale);
467544
468- dequant (dequant_frag, mma_B, fragment_scale, quant_map);
545+ // dequant(dequant_frag, mma_B, fragment_scale, quant_map);
546+
547+ copy (tiled_copy_a, tAgA (_,_,_,k_tile), frag_copy_A);
469548
470549 if (prefetch_k < k_tile_count) {
471550 prefetch (tiled_prefetch_a, pAgA (_,_,_,prefetch_k));
472551 prefetch (tiled_prefetch_b, pBgB (_,_,_,prefetch_k));
473552 }
474553
554+ // dequant(dequant_frag, mma_B, fragment_scale);//, quant_map);
555+ // copy(tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A);
556+
475557 cute::gemm (tiled_mma, mma_A, mma_B, accumulators);
476558 barrier_wait (3 );
477559 }
0 commit comments