@@ -292,11 +292,11 @@ CUTLASS_DEVICE void dequant(
292292 static constexpr auto N = decltype (size<1 >(in))::value;
293293 static constexpr auto K = decltype (size (out))::value / N;
294294
295- using compress_type = ushort; // uint32_t;
295+ using compress_type = uint32_t ;
296296 static constexpr auto compress_size = sizeof_bits_v<compress_type> / sizeof_bits_v<SrcType>;
297297 static_assert ((compress_size % N) == 0 );
298298
299- static constexpr auto vec_size = 4 ;
299+ static constexpr auto vec_size = 8 ;
300300 // using VecSrcElemType = cute::array<SrcType, compress_size>;
301301 using VecSrcType = cute::array<compress_type, vec_size>; // sycl::vec<uint32_t, 4>;
302302 using VecDstElemType = cute::array<DstType, compress_size>;
@@ -307,16 +307,35 @@ CUTLASS_DEVICE void dequant(
307307 constexpr uint32_t MASK_LOW [4 ] = {0xF , 0xF00 , 0xF0000 , 0xF000000 };
308308 constexpr int SHIFT_HIGH [4 ] = {4 , 12 , 20 , 28 };
309309 constexpr int SHIFT_LOW [4 ] = {0 , 8 , 16 , 24 };
310+ constexpr int shifts[8 ] = {4 ,0 ,12 ,8 ,20 ,16 ,28 ,24 };
310311
311- auto s_tensor = make_tensor ((VecSrcType*)(raw_pointer_cast (in.data ())), Shape<Int<K / (compress_size * vec_size) >, Int<N>>{});
312- auto d_tensor = make_tensor ((VecDstType*)(raw_pointer_cast (out.data ())), Shape<Int<K / (compress_size * vec_size) >, Int<N>>{});
312+ auto s_tensor = make_tensor ((VecSrcType*)(raw_pointer_cast (in.data ())), Shape<Int<1 >, Int<N>>{});
313+ auto d_tensor = make_tensor ((VecDstType*)(raw_pointer_cast (out.data ())), Shape<Int<1 >, Int<N>>{});
313314
314315 #pragma unroll
315316 for (int n = 0 ; n < N; n++) {
316317 float ts = tCrS_input (n);
317- auto & src = *(cute::array<VecSrcType, K / (compress_size * vec_size)>*)(s_tensor (_, n).data ());
318- auto & dst = *(cute::array<VecDstType, K / (compress_size * vec_size)>*)(d_tensor (_, n).data ());
319-
318+ auto & src = *(cute::array<VecSrcType, 1 >*)(s_tensor (_, n).data ());
319+ auto & dst = *(cute::array<VecDstType, 1 >*)(d_tensor (_, n).data ());
320+ #if 0
321+ const auto src_val = src[0];
322+ VecDstType dst_val;
323+ #pragma unroll
324+ for (int i = 0; i < vec_size; ++i) {
325+ const compress_type val = src_val[i];
326+ VecDstElemType dst_elem;
327+ dst_elem[0] = static_cast<DstType>(quant_map[(val>>shifts[0])&0xF] * ts);
328+ dst_elem[1] = static_cast<DstType>(quant_map[(val>>shifts[1])&0xF] * ts);
329+ dst_elem[2] = static_cast<DstType>(quant_map[(val>>shifts[2])&0xF] * ts);
330+ dst_elem[3] = static_cast<DstType>(quant_map[(val>>shifts[3])&0xF] * ts);
331+ dst_elem[4] = static_cast<DstType>(quant_map[(val>>shifts[4])&0xF] * ts);
332+ dst_elem[5] = static_cast<DstType>(quant_map[(val>>shifts[5])&0xF] * ts);
333+ dst_elem[6] = static_cast<DstType>(quant_map[(val>>shifts[6])&0xF] * ts);
334+ dst_elem[7] = static_cast<DstType>(quant_map[(val>>shifts[7])&0xF] * ts);
335+ dst_val[i] = dst_elem;
336+ }
337+ dst[0] = dst_val;
338+ #else
320339 #pragma unroll
321340 for (int k = 0 ; k < K / (compress_size * vec_size); k++) {
322341 VecSrcType src_val = src[k];
@@ -332,17 +351,18 @@ CUTLASS_DEVICE void dequant(
332351 #pragma unroll
333352 for (int j = 0 ; j < compress_size / 2 ; j++) {
334353 // for (int j = 0; j < 4; j++) {
335- uint8_t high = (compressed_val & MASK_HIGH [j]) >> SHIFT_HIGH [j];
336- uint8_t low = (compressed_val & MASK_LOW [j]) >> SHIFT_LOW [j];
337- dst_elem[2 * j] = static_cast <DstType>(quant_map[high] * ts);
338- dst_elem[2 * j + 1 ] = static_cast <DstType>(quant_map[low] * ts);
339- // dst_elem[2*j] = static_cast<DstType>(quant_map[(compressed_val >> (4 * (j * 2 + 1))) & 0xf] * ts);
340- // dst_elem[2*j+1] = static_cast<DstType>(quant_map[(compressed_val >> (4 * (j * 2))) & 0xf] * ts);
354+ // uint8_t high = (compressed_val & MASK_HIGH[j]) >> SHIFT_HIGH[j];
355+ // uint8_t low = (compressed_val & MASK_LOW[j]) >> SHIFT_LOW[j];
356+ // dst_elem[2 * j] = static_cast<DstType>(quant_map[high] * ts);
357+ // dst_elem[2 * j + 1] = static_cast<DstType>(quant_map[low] * ts);
358+ dst_elem[2 *j] = static_cast <DstType>(quant_map[(compressed_val >> (4 * (j * 2 + 1 ))) & 0xf ] * ts);
359+ dst_elem[2 *j+1 ] = static_cast <DstType>(quant_map[(compressed_val >> (4 * (j * 2 ))) & 0xf ] * ts);
341360 }
342361 dst_val[i] = dst_elem;
343362 }
344363 dst[k] = dst_val;
345364 }
365+ #endif
346366 }
347367}
348368#endif
0 commit comments