Skip to content

Commit dbf838e

Browse files
committed
change method
1 parent 4fd70a9 commit dbf838e

1 file changed

Lines changed: 9 additions & 5 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ class gemm_4bit_cutlass_kernel {
286286
#if 1
287287
static constexpr auto K = decltype(size(out))::value / N; // 128 / 2 = 64
288288

289-
using compress_type = uint8_t;
289+
using compress_type = uint32_t;
290290
using vec_type = uint32_t;
291291

292292
static constexpr auto compress_size = sizeof_bits_v<compress_type> / sizeof_bits_v<SrcType>;
@@ -303,27 +303,31 @@ class gemm_4bit_cutlass_kernel {
303303
float ts = tCrS_input(n);
304304
auto& src = *(cute::array<compress_type, K / compress_size>*)(s_tensor(_, n).data());
305305
auto& dst = *(cute::array<DstType, K>*)(d_tensor(_, n).data());
306-
int iter_num = 4;
307306

307+
#if 1
308308
#pragma unroll
309-
for (int s = 0; s < K / compress_size / iter_num; s++) {
309+
for (int s = 0; s < K / compress_size; s++) {
310310

311-
#if 0
312311
#pragma unroll
313312
for(int i = 0; i < compress_size / 2; i++) {
314313
int dst_offset = s * compress_size + i * 2;
315314
dst[dst_offset] = static_cast<DstType>(quant_map[(src[s] >> (4 * (i * 2 + 1))) & 0xf] * ts);
316315
dst[dst_offset + 1] = static_cast<DstType>(quant_map[(src[s] >> (4 * (i * 2))) & 0xf] * ts);
317316
}
317+
}
318318
#else
319+
int iter_num = 4;
320+
#pragma unroll
321+
for (int s = 0; s < K / compress_size / iter_num; s++) {
322+
319323
#pragma unroll
320324
for(int i = 0; i < iter_num * compress_size / 2; i++) {
321325
int dst_offset = s * iter_num * compress_size + i * 2;
322326
dst[dst_offset] = static_cast<DstType>(quant_map[src[s * iter_num + i] >> 4] * ts);
323327
dst[dst_offset + 1] = static_cast<DstType>(quant_map[src[s * iter_num + i] & 0xf] * ts);
324328
}
325-
#endif
326329
}
330+
#endif
327331
}
328332
#else
329333
using compress_type = uint8_t;

0 commit comments

Comments
 (0)