@@ -53,7 +53,11 @@ using ElementOutput = float;
5353
5454using ProblemShape = Shape<int , int , int , int >;
5555
56- #if 0
56+ #ifndef METHOD
57+ #define METHOD 2
58+ #endif
59+
60+ #if METHOD == 1
5761using TileShape = Shape<_256, _256, _32>;
5862// using TileShape = Shape<_128, _128, _32>;
5963// using TileShape = Shape<_128, _256, _32>;
@@ -92,7 +96,11 @@ static constexpr auto SG_QNT_WIDTH = Int<SG_N>{};
9296static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); // 8*4*1*16=512 //1*2*1*16=32
9397
9498// Define Mainloop dispatch policy
99+ #if METHOD == 1
100+ constexpr int PipelineStages = 2 ;
101+ #else
95102constexpr int PipelineStages = 4 ;
103+ #endif
96104using DispatchPolicy = cutlass::gemm::MainloopIntelPVCMixedPrecision<PipelineStages>;
97105static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; // 16
98106
@@ -131,7 +139,7 @@ using ClusterShape = typename DispatchPolicy::ClusterShape;
131139using CopyThreadShape = Shape<_1, Int<SubgroupSize>>;
132140using CopyThreadShapeRev = decltype (cute::reverse(CopyThreadShape{}));
133141
134- #if 0
142+ #if METHOD == 1
135143using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
136144#else
137145using GmemTiledCopyA = XE_2D_U16x16x32_LD_N;
@@ -284,21 +292,21 @@ CUTLASS_DEVICE void dequant(
284292 static constexpr auto N = decltype (size<1 >(in))::value;
285293 static constexpr auto K = decltype (size (out))::value / N;
286294
287- using compress_type = uint32_t ;
295+ using compress_type = ushort; // uint32_t;
288296 static constexpr auto compress_size = sizeof_bits_v<compress_type> / sizeof_bits_v<SrcType>;
289297 static_assert ((compress_size % N) == 0 );
290298
291- static constexpr auto vec_size = 2 ;
299+ static constexpr auto vec_size = 4 ;
292300 // using VecSrcElemType = cute::array<SrcType, compress_size>;
293301 using VecSrcType = cute::array<compress_type, vec_size>; // sycl::vec<uint32_t, 4>;
294302 using VecDstElemType = cute::array<DstType, compress_size>;
295303 using VecDstType = cute::array<VecDstElemType, vec_size>;
296304
297305 // 预定义掩码和位移
298- // constexpr uint32_t MASK_HIGH[4] = {0xF0, 0xF000, 0xF00000, 0xF0000000};
299- // constexpr uint32_t MASK_LOW[4] = {0xF, 0xF00, 0xF0000, 0xF000000};
300- // constexpr int SHIFT_HIGH[4] = {4, 12, 20, 28};
301- // constexpr int SHIFT_LOW[4] = {0, 8, 16, 24};
306+ constexpr uint32_t MASK_HIGH [4 ] = {0xF0 , 0xF000 , 0xF00000 , 0xF0000000 };
307+ constexpr uint32_t MASK_LOW [4 ] = {0xF , 0xF00 , 0xF0000 , 0xF000000 };
308+ constexpr int SHIFT_HIGH [4 ] = {4 , 12 , 20 , 28 };
309+ constexpr int SHIFT_LOW [4 ] = {0 , 8 , 16 , 24 };
302310
303311 auto s_tensor = make_tensor ((VecSrcType*)(raw_pointer_cast (in.data ())), Shape<Int<K / (compress_size * vec_size)>, Int<N>>{});
304312 auto d_tensor = make_tensor ((VecDstType*)(raw_pointer_cast (out.data ())), Shape<Int<K / (compress_size * vec_size)>, Int<N>>{});
@@ -312,21 +320,26 @@ CUTLASS_DEVICE void dequant(
312320 #pragma unroll
313321 for (int k = 0 ; k < K / (compress_size * vec_size); k++) {
314322 VecSrcType src_val = src[k];
315- VecDstType dst_val;// = dst[k];
323+ // VecDstType& dst_val = dst[k];
324+ VecDstType dst_val;
316325
317326 #pragma unroll
318327 for (int i = 0 ; i < vec_size; i++) {
319328 compress_type compressed_val = src_val[i];
320- VecDstElemType compressed_dst_val;// = dst_val[i];
329+ // VecDstElemType& dst_elem = dst_val[i];
330+ VecDstElemType dst_elem;
321331
322332 #pragma unroll
323333 for (int j = 0 ; j < compress_size / 2 ; j++) {
324- // uint8_t high = (compressed_val & MASK_HIGH[j]) >> SHIFT_HIGH[j];
325- // uint8_t low = (compressed_val & MASK_LOW[j]) >> SHIFT_LOW[j];
326- compressed_dst_val[2 *j] = static_cast <DstType>(quant_map[(compressed_val >> (4 * (j * 2 + 1 ))) & 0xf ] * ts);
327- compressed_dst_val[2 *j+1 ] = static_cast <DstType>(quant_map[(compressed_val >> (4 * (j * 2 ))) & 0xf ] * ts);
334+ // for (int j = 0; j < 4; j++) {
335+ uint8_t high = (compressed_val & MASK_HIGH [j]) >> SHIFT_HIGH [j];
336+ uint8_t low = (compressed_val & MASK_LOW [j]) >> SHIFT_LOW [j];
337+ dst_elem[2 * j] = static_cast <DstType>(quant_map[high] * ts);
338+ dst_elem[2 * j + 1 ] = static_cast <DstType>(quant_map[low] * ts);
339+ // dst_elem[2*j] = static_cast<DstType>(quant_map[(compressed_val >> (4 * (j * 2 + 1))) & 0xf] * ts);
340+ // dst_elem[2*j+1] = static_cast<DstType>(quant_map[(compressed_val >> (4 * (j * 2))) & 0xf] * ts);
328341 }
329- dst_val[i] = compressed_dst_val;
342+ dst_val[i] = dst_elem;
330343 }
331344 dst[k] = dst_val;
332345 }
0 commit comments