Skip to content

Commit e7bfa16

Browse files
committed
clean code, optimize dequant
1 parent a9936ab commit e7bfa16

1 file changed

Lines changed: 79 additions & 103 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 79 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ class gemm_4bit_cutlass_kernel {
205205
}
206206
}
207207

208-
/// Utilities to transform A.
208+
#if 0
209209
template <class EngineIn,
210210
class EngineOut,
211211
class EngineScales,
@@ -229,134 +229,110 @@ class gemm_4bit_cutlass_kernel {
229229
using ScaleType = typename EngineScales::value_type;
230230

231231
static constexpr auto N = decltype(size<1>(in))::value;
232-
233-
#if 0
234-
#if 1
235-
//using format_type = int; //32
236-
static constexpr auto vec_size = decltype(size(out))::value / N / 2; // 128 / 2 / 2 = 64
237-
using format_type = std::array<unsigned char, vec_size>; //<unsigned char, 64>
238-
static constexpr auto src_bits = sizeof_bits_v<SrcType>; //8
239-
240-
auto s_tensor = make_tensor((format_type*)(raw_pointer_cast(in.data())), Shape<Int<N>>{});
241-
auto d_tensor = make_tensor(out.data(), Shape<Int<vec_size * 2>, Int<N>>{});
242-
243-
//if(cute::thread0()) printf("decltype(size(out))::value = %d, N = %d, src_bits = %d, vec_size = %d\n", decltype(size(out))::value, N, src_bits, vec_size);
244-
245-
CUTLASS_PRAGMA_UNROLL
246-
for (int n = 0; n < N; n++) {
247-
float ts = tCrS_input(n);
248-
auto& src = *(cute::array<unsigned char, vec_size>*)(s_tensor(n).data());
249-
auto& dst = *(cute::array<DstType, vec_size * 2>*)(d_tensor(_, n).data());
250-
251-
CUTLASS_PRAGMA_UNROLL
252-
for (int i = 0; i < vec_size; i++) {
253-
dst[i * 2] = static_cast<DstType>(1.0f * ts);
254-
dst[i * 2 + 1] = static_cast<DstType>(1.0f * ts);
255-
//dst[i * 2] = static_cast<DstType>(quant_map[(src[i] >> src_bits)] * ts);
256-
//dst[i * 2 + 1] = static_cast<DstType>(quant_map[(src[i] & 0xf)] * ts);
257-
}
258-
}
259-
#else
260-
//using format_type = int; //32
261-
static constexpr auto vec_size = decltype(size(out))::value / N / 2; // 128 / 2 / 2 = 64
262-
using format_type = std::array<unsigned char, vec_size>; //<unsigned char, 64>
263-
static constexpr auto src_bits = sizeof_bits_v<SrcType>; //8
264-
265-
auto s_tensor = make_tensor((format_type*)(raw_pointer_cast(in.data())), Shape<Int<N>>{});
266-
auto d_tensor = make_tensor(out.data(), Shape<Int<vec_size * 2>, Int<N>>{});
267-
268-
//if(cute::thread0()) printf("decltype(size(out))::value = %d, N = %d, src_bits = %d, vec_size = %d\n", decltype(size(out))::value, N, src_bits, vec_size);
269-
270-
CUTLASS_PRAGMA_UNROLL
271-
for (int n = 0; n < N; n++) {
272-
float ts = tCrS_input(n);
273-
auto& src = *(cute::array<unsigned char, vec_size>*)(s_tensor(n).data());
274-
auto& dst = *(cute::array<DstType, vec_size * 2>*)(d_tensor(_, n).data());
275-
276-
CUTLASS_PRAGMA_UNROLL
277-
for (int i = 0; i < vec_size; i++) {
278-
//dst[i * 2] = static_cast<DstType>(1.0f * ts);
279-
//dst[i * 2 + 1] = static_cast<DstType>(1.0f * ts);
280-
dst[i * 2] = static_cast<DstType>(quant_map[(src[i] >> src_bits)] * ts);
281-
dst[i * 2 + 1] = static_cast<DstType>(quant_map[(src[i] & 0xf)] * ts);
282-
}
283-
}
284-
#endif
285-
#else
286-
#if 1
287232
static constexpr auto K = decltype(size(out))::value / N; // 128 / 2 = 64
288233

289234
using compress_type = uint32_t;
290-
using vec_type = intel::int4; //uint32_t;
291235

292236
static constexpr auto compress_size = sizeof_bits_v<compress_type> / sizeof_bits_v<SrcType>;
293237
static_assert((compress_size % N) == 0);
294238

295-
static constexpr auto vec_num = sizeof_bits_v<vec_type> / sizeof_bits_v<compress_type>;
296-
static constexpr auto vec_size = compress_size * vec_num;
297-
298-
auto s_tensor = make_tensor((vec_type*)(raw_pointer_cast(in.data())), Shape<Int<K / vec_size>, Int<N>>{});
239+
auto s_tensor = make_tensor((compress_type*)(raw_pointer_cast(in.data())), Shape<Int<K / compress_size>, Int<N>>{});
299240
auto d_tensor = make_tensor(out.data(), Shape<Int<K>, Int<N>>{});
300241

301242
#pragma unroll
302243
for (int n = 0; n < N; n++) {
303244
float ts = tCrS_input(n);
304-
auto& src = *(cute::array<vec_type, K / vec_size>*)(s_tensor(_, n).data());
245+
auto& src = *(cute::array<compress_type, K / compress_size>*)(s_tensor(_, n).data());
305246
auto& dst = *(cute::array<DstType, K>*)(d_tensor(_, n).data());
306247

307-
#if 1
308248
#pragma unroll
309-
for (int s = 0; s < K / vec_size; s++) {
310-
249+
for (int s = 0; s < K / compress_size; s++) {
250+
compress_type src_val = src[s];
311251
#pragma unroll
312-
for(int i = 0; i < vec_num; i++) {
313-
#pragma unroll
314-
for(int j = 0; j < compress_size / 2; j++) {
315-
int dst_offset = s * vec_size + i * compress_size + j * 2;
316-
dst[dst_offset] = static_cast<DstType>(quant_map[(src[s][i] >> (4 * (j * 2 + 1))) & 0xf] * ts);
317-
dst[dst_offset + 1] = static_cast<DstType>(quant_map[(src[s][i] >> (4 * (j * 2))) & 0xf] * ts);
318-
}
252+
for(int i = 0; i < compress_size / 2; i++) {
253+
int dst_offset = s * compress_size + i * 2;
254+
uint8_t high = (src_val >> (4 * (i * 2 + 1))) & 0xf;
255+
uint8_t low = (src_val >> (4 * (i * 2))) & 0xf;
256+
dst[dst_offset] = static_cast<DstType>(quant_map[high] * ts);
257+
dst[dst_offset + 1] = static_cast<DstType>(quant_map[low] * ts);
319258
}
320259
}
260+
}
261+
}
321262
#else
322-
int iter_num = 4;
323-
#pragma unroll
324-
for (int s = 0; s < K / compress_size / iter_num; s++) {
263+
template <class EngineIn,
264+
class EngineOut,
265+
class EngineScales,
266+
class LayoutIn,
267+
class LayoutOut,
268+
class LayoutScales,
269+
class... Ts>
270+
CUTLASS_DEVICE void dequant(
271+
Tensor<EngineIn, LayoutIn> const& in,
272+
Tensor<EngineOut, LayoutOut>& out,
273+
Tensor<EngineScales, LayoutScales>& tCrS_input,
274+
const float* quant_map
275+
) {
276+
static_assert(is_rmem<EngineIn>::value, "Input tensor must be in registers");
277+
static_assert(size_v<LayoutIn> == cosize_v<LayoutIn>);
278+
static_assert(size_v<LayoutOut> == cosize_v<LayoutOut>);
325279

326-
#pragma unroll
327-
for(int i = 0; i < iter_num * compress_size / 2; i++) {
328-
int dst_offset = s * iter_num * compress_size + i * 2;
329-
dst[dst_offset] = static_cast<DstType>(quant_map[src[s * iter_num + i] >> 4] * ts);
330-
dst[dst_offset + 1] = static_cast<DstType>(quant_map[src[s * iter_num + i] & 0xf] * ts);
331-
}
332-
}
333-
#endif
334-
}
335-
#else
336-
using compress_type = uint8_t;
337-
static constexpr auto compress_ratio = sizeof_bits_v<compress_type> / sizeof_bits_v<SrcType>;
280+
using SrcType = typename EngineIn::value_type;
281+
using DstType = typename EngineOut::value_type;
282+
using ScaleType = typename EngineScales::value_type;
283+
284+
static constexpr auto N = decltype(size<1>(in))::value;
338285
static constexpr auto K = decltype(size(out))::value / N;
339-
auto s_tensor = make_tensor((compress_type*)(raw_pointer_cast(in.data())), Shape<Int<K/compress_ratio>, Int<N>>{});
340-
auto d_tensor = make_tensor(out.data(), Shape<Int<K>, Int<N>>{});
286+
287+
using compress_type = uint32_t;
288+
static constexpr auto compress_size = sizeof_bits_v<compress_type> / sizeof_bits_v<SrcType>;
289+
static_assert((compress_size % N) == 0);
290+
291+
static constexpr auto vec_size = 2;
292+
//using VecSrcElemType = cute::array<SrcType, compress_size>;
293+
using VecSrcType = cute::array<compress_type, vec_size>; //sycl::vec<uint32_t, 4>;
294+
using VecDstElemType = cute::array<DstType, compress_size>;
295+
using VecDstType = cute::array<VecDstElemType, vec_size>;
296+
297+
// 预定义掩码和位移
298+
//constexpr uint32_t MASK_HIGH[4] = {0xF0, 0xF000, 0xF00000, 0xF0000000};
299+
//constexpr uint32_t MASK_LOW[4] = {0xF, 0xF00, 0xF0000, 0xF000000};
300+
//constexpr int SHIFT_HIGH[4] = {4, 12, 20, 28};
301+
//constexpr int SHIFT_LOW[4] = {0, 8, 16, 24};
302+
303+
auto s_tensor = make_tensor((VecSrcType*)(raw_pointer_cast(in.data())), Shape<Int<K / (compress_size * vec_size)>, Int<N>>{});
304+
auto d_tensor = make_tensor((VecDstType*)(raw_pointer_cast(out.data())), Shape<Int<K / (compress_size * vec_size)>, Int<N>>{});
341305

342306
#pragma unroll
343307
for (int n = 0; n < N; n++) {
344-
float ts = tCrS_input(n);
345-
auto& src = *(cute::array<compress_type, K/compress_ratio>*)(s_tensor(_, n).data());
346-
auto& dst = *(cute::array<DstType, K>*)(d_tensor(_, n).data());
347-
//auto& src = s_tensor(_, n).data();
348-
//auto& dst = d_tensor(_, n).data();
308+
float ts = tCrS_input(n);
309+
auto& src = *(cute::array<VecSrcType, K / (compress_size * vec_size)>*)(s_tensor(_, n).data());
310+
auto& dst = *(cute::array<VecDstType, K / (compress_size * vec_size)>*)(d_tensor(_, n).data());
349311

350-
#pragma unroll
351-
for (int k = 0; k < K/compress_ratio/2; k++) {
352-
dst[k * 2] = static_cast<DstType>(quant_map[src[k] >> 4] * ts);
353-
dst[k * 2 + 1] = static_cast<DstType>(quant_map[src[k] & 0xf] * ts);
354-
}
312+
#pragma unroll
313+
for (int k = 0; k < K / (compress_size * vec_size); k++) {
314+
VecSrcType src_val = src[k];
315+
VecDstType dst_val;// = dst[k];
316+
317+
#pragma unroll
318+
for (int i = 0; i < vec_size; i++) {
319+
compress_type compressed_val = src_val[i];
320+
VecDstElemType compressed_dst_val;// = dst_val[i];
321+
322+
#pragma unroll
323+
for (int j = 0; j < compress_size / 2; j++) {
324+
//uint8_t high = (compressed_val & MASK_HIGH[j]) >> SHIFT_HIGH[j];
325+
//uint8_t low = (compressed_val & MASK_LOW[j]) >> SHIFT_LOW[j];
326+
compressed_dst_val[2*j] = static_cast<DstType>(quant_map[(compressed_val >> (4 * (j * 2 + 1))) & 0xf] * ts);
327+
compressed_dst_val[2*j+1] = static_cast<DstType>(quant_map[(compressed_val >> (4 * (j * 2))) & 0xf] * ts);
328+
}
329+
dst_val[i] = compressed_dst_val;
330+
}
331+
dst[k] = dst_val;
332+
}
355333
}
356-
#endif
357-
#endif
358-
}
359-
334+
}
335+
#endif
360336
CUTLASS_DEVICE
361337
void operator()(Params const& params, char* smem_buf) {
362338
int M = params.m;

0 commit comments

Comments
 (0)