Skip to content

Commit ec8aecf

Browse files
committed
add new method deqaunt
1 parent 92b1d24 commit ec8aecf

1 file changed

Lines changed: 34 additions & 28 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -283,47 +283,53 @@ class gemm_4bit_cutlass_kernel {
283283
}
284284
#endif
285285
#else
286+
#if 1
286287
using format_type = uint32_t; //32
287-
static constexpr auto src_bits = sizeof_bits_v<SrcType>; //4
288-
static constexpr auto scalar = sizeof_bits_v<format_type> / src_bits; // 8
289-
static constexpr auto loop_cnt = decltype(size(out))::value / N; // 128 / 2 = 64
290-
static_assert((scalar % N) == 0);
288+
static constexpr auto compress_ratio = sizeof_bits_v<format_type> / sizeof_bits_v<SrcType>; // 8
289+
static constexpr auto K = decltype(size(out))::value / N; // 128 / 2 = 64
290+
static_assert((compress_ratio % N) == 0);
291291

292-
static constexpr auto vec_size = scalar; //8
293-
static constexpr auto splits = loop_cnt / vec_size; // 64 / 8 = 8
294-
static_assert(vec_size <= scalar);
295-
296-
auto s_tensor = make_tensor((format_type*)(raw_pointer_cast(in.data())), Shape<Int<loop_cnt / scalar>, Int<N>>{});
297-
auto d_tensor = make_tensor(out.data(), Shape<Int<vec_size>, Int<splits>, Int<N>>{});
298-
299-
//if(cute::thread0()) printf("decltype(size(out))::value = %d, N = %d, src_bits = %d, scalar = %d, loop_cnt = %d, vec_size = %d, splits = %d\n", decltype(size(out))::value, N, src_bits, scalar, loop_cnt, vec_size, splits);
292+
auto s_tensor = make_tensor((format_type*)(raw_pointer_cast(in.data())), Shape<Int<K / compress_ratio>, Int<N>>{});
293+
auto d_tensor = make_tensor(out.data(), Shape<Int<K>, Int<N>>{});
300294

301295
CUTLASS_PRAGMA_UNROLL
302296
for (int n = 0; n < N; n++) {
303297
float ts = tCrS_input(n);
304-
auto& src = *(cute::array<format_type, loop_cnt / scalar>*)(s_tensor(_, n).data());
298+
auto& src = *(cute::array<format_type, K / compress_ratio>*)(s_tensor(_, n).data());
299+
auto& dst = *(cute::array<DstType, K>*)(d_tensor(_, n).data());
305300

306301
CUTLASS_PRAGMA_UNROLL
307-
for (int s = 0; s < splits; s++) {
308-
auto idx = vec_size * s / scalar;
309-
auto format_data = src[idx];
310-
311-
auto& dst = *(cute::array<DstType, vec_size>*)(d_tensor(_, s, n).data());
302+
for (int s = 0; s < K / compress_ratio; s++) {
312303

313304
CUTLASS_PRAGMA_UNROLL
314-
for (int i = 0; i < vec_size/2; i++) {
315-
#if 0
316-
dst[i * 2] = static_cast<DstType>(1.0f * ts);
317-
dst[i * 2 + 1] = static_cast<DstType>(1.0f * ts);
318-
#else
319-
dst[i * 2] = static_cast<DstType>(quant_map[(format_data >> (src_bits * (i * 2 + 1))) & 0xf] * ts);
320-
dst[i * 2 + 1] = static_cast<DstType>(quant_map[(format_data >> (src_bits * (i * 2))) & 0xf] * ts);
321-
//dst[i * 2] = quant_map[(format_data >> (src_bits * (i * 2 + 1))) & 0xf] * ts;
322-
//dst[i * 2 + 1] = quant_map[(format_data >> (src_bits * (i * 2))) & 0xf] * ts;
323-
#endif
305+
for (int i = 0; i < compress_ratio/2; i++) {
306+
dst[s * compress_ratio + i * 2] = static_cast<DstType>(quant_map[(src[s] >> (4 * (i * 2 + 1))) & 0xf] * ts);
307+
dst[s * compress_ratio + i * 2 + 1] = static_cast<DstType>(quant_map[(src[s] >> (4 * (i * 2))) & 0xf] * ts);
324308
}
325309
}
326310
}
311+
#else
312+
using compress_type = uint8_t;
313+
static constexpr auto compress_ratio = sizeof_bits_v<compress_type> / sizeof_bits_v<SrcType>;
314+
static constexpr auto K = decltype(size(out))::value / N;
315+
auto s_tensor = make_tensor((compress_type*)(raw_pointer_cast(in.data())), Shape<Int<K/compress_ratio>, Int<N>>{});
316+
auto d_tensor = make_tensor(out.data(), Shape<Int<K>, Int<N>>{});
317+
318+
#pragma unroll
319+
for (int n = 0; n < N; n++) {
320+
float ts = tCrS_input(n);
321+
auto& src = *(cute::array<compress_type, K/compress_ratio>*)(s_tensor(_, n).data());
322+
auto& dst = *(cute::array<DstType, K>*)(d_tensor(_, n).data());
323+
//auto& src = s_tensor(_, n).data();
324+
//auto& dst = d_tensor(_, n).data();
325+
326+
#pragma unroll
327+
for (int k = 0; k < K/compress_ratio/2; k++) {
328+
dst[k * 2] = static_cast<DstType>(quant_map[src[k] >> 4] * ts);
329+
dst[k * 2 + 1] = static_cast<DstType>(quant_map[src[k] & 0xf] * ts);
330+
}
331+
}
332+
#endif
327333
#endif
328334
}
329335

0 commit comments

Comments
 (0)