Skip to content

Commit 989fbd3

Browse files
committed
new method
1 parent 3c680f4 commit 989fbd3

1 file changed

Lines changed: 33 additions & 13 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -292,11 +292,11 @@ CUTLASS_DEVICE void dequant(
292292
static constexpr auto N = decltype(size<1>(in))::value;
293293
static constexpr auto K = decltype(size(out))::value / N;
294294

295-
using compress_type = ushort; //uint32_t;
295+
using compress_type = uint32_t;
296296
static constexpr auto compress_size = sizeof_bits_v<compress_type> / sizeof_bits_v<SrcType>;
297297
static_assert((compress_size % N) == 0);
298298

299-
static constexpr auto vec_size = 4;
299+
static constexpr auto vec_size = 8;
300300
//using VecSrcElemType = cute::array<SrcType, compress_size>;
301301
using VecSrcType = cute::array<compress_type, vec_size>; //sycl::vec<uint32_t, 4>;
302302
using VecDstElemType = cute::array<DstType, compress_size>;
@@ -307,16 +307,35 @@ CUTLASS_DEVICE void dequant(
307307
constexpr uint32_t MASK_LOW[4] = {0xF, 0xF00, 0xF0000, 0xF000000};
308308
constexpr int SHIFT_HIGH[4] = {4, 12, 20, 28};
309309
constexpr int SHIFT_LOW[4] = {0, 8, 16, 24};
310+
constexpr int shifts[8] = {4,0,12,8,20,16,28,24};
310311

311-
auto s_tensor = make_tensor((VecSrcType*)(raw_pointer_cast(in.data())), Shape<Int<K / (compress_size * vec_size)>, Int<N>>{});
312-
auto d_tensor = make_tensor((VecDstType*)(raw_pointer_cast(out.data())), Shape<Int<K / (compress_size * vec_size)>, Int<N>>{});
312+
auto s_tensor = make_tensor((VecSrcType*)(raw_pointer_cast(in.data())), Shape<Int<1>, Int<N>>{});
313+
auto d_tensor = make_tensor((VecDstType*)(raw_pointer_cast(out.data())), Shape<Int<1>, Int<N>>{});
313314

314315
#pragma unroll
315316
for (int n = 0; n < N; n++) {
316317
float ts = tCrS_input(n);
317-
auto& src = *(cute::array<VecSrcType, K / (compress_size * vec_size)>*)(s_tensor(_, n).data());
318-
auto& dst = *(cute::array<VecDstType, K / (compress_size * vec_size)>*)(d_tensor(_, n).data());
319-
318+
auto& src = *(cute::array<VecSrcType, 1>*)(s_tensor(_, n).data());
319+
auto& dst = *(cute::array<VecDstType, 1>*)(d_tensor(_, n).data());
320+
#if 0
321+
const auto src_val = src[0];
322+
VecDstType dst_val;
323+
#pragma unroll
324+
for (int i = 0; i < vec_size; ++i) {
325+
const compress_type val = src_val[i];
326+
VecDstElemType dst_elem;
327+
dst_elem[0] = static_cast<DstType>(quant_map[(val>>shifts[0])&0xF] * ts);
328+
dst_elem[1] = static_cast<DstType>(quant_map[(val>>shifts[1])&0xF] * ts);
329+
dst_elem[2] = static_cast<DstType>(quant_map[(val>>shifts[2])&0xF] * ts);
330+
dst_elem[3] = static_cast<DstType>(quant_map[(val>>shifts[3])&0xF] * ts);
331+
dst_elem[4] = static_cast<DstType>(quant_map[(val>>shifts[4])&0xF] * ts);
332+
dst_elem[5] = static_cast<DstType>(quant_map[(val>>shifts[5])&0xF] * ts);
333+
dst_elem[6] = static_cast<DstType>(quant_map[(val>>shifts[6])&0xF] * ts);
334+
dst_elem[7] = static_cast<DstType>(quant_map[(val>>shifts[7])&0xF] * ts);
335+
dst_val[i] = dst_elem;
336+
}
337+
dst[0] = dst_val;
338+
#else
320339
#pragma unroll
321340
for (int k = 0; k < K / (compress_size * vec_size); k++) {
322341
VecSrcType src_val = src[k];
@@ -332,17 +351,18 @@ CUTLASS_DEVICE void dequant(
332351
#pragma unroll
333352
for (int j = 0; j < compress_size / 2; j++) {
334353
//for (int j = 0; j < 4; j++) {
335-
uint8_t high = (compressed_val & MASK_HIGH[j]) >> SHIFT_HIGH[j];
336-
uint8_t low = (compressed_val & MASK_LOW[j]) >> SHIFT_LOW[j];
337-
dst_elem[2 * j] = static_cast<DstType>(quant_map[high] * ts);
338-
dst_elem[2 * j + 1] = static_cast<DstType>(quant_map[low] * ts);
339-
//dst_elem[2*j] = static_cast<DstType>(quant_map[(compressed_val >> (4 * (j * 2 + 1))) & 0xf] * ts);
340-
//dst_elem[2*j+1] = static_cast<DstType>(quant_map[(compressed_val >> (4 * (j * 2))) & 0xf] * ts);
354+
//uint8_t high = (compressed_val & MASK_HIGH[j]) >> SHIFT_HIGH[j];
355+
//uint8_t low = (compressed_val & MASK_LOW[j]) >> SHIFT_LOW[j];
356+
//dst_elem[2 * j] = static_cast<DstType>(quant_map[high] * ts);
357+
//dst_elem[2 * j + 1] = static_cast<DstType>(quant_map[low] * ts);
358+
dst_elem[2*j] = static_cast<DstType>(quant_map[(compressed_val >> (4 * (j * 2 + 1))) & 0xf] * ts);
359+
dst_elem[2*j+1] = static_cast<DstType>(quant_map[(compressed_val >> (4 * (j * 2))) & 0xf] * ts);
341360
}
342361
dst_val[i] = dst_elem;
343362
}
344363
dst[k] = dst_val;
345364
}
365+
#endif
346366
}
347367
}
348368
#endif

0 commit comments

Comments
 (0)