Skip to content

Commit 4fd70a9

Browse files
committed
new method
1 parent 9eef88a commit 4fd70a9

1 file changed

Lines changed: 12 additions & 2 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ class gemm_4bit_cutlass_kernel {
287287
static constexpr auto K = decltype(size(out))::value / N; // 128 / 2 = 64
288288

289289
using compress_type = uint8_t;
290-
using vec_type = uint8_t;
290+
using vec_type = uint32_t;
291291

292292
static constexpr auto compress_size = sizeof_bits_v<compress_type> / sizeof_bits_v<SrcType>;
293293
static_assert((compress_size % N) == 0);
@@ -303,16 +303,26 @@ class gemm_4bit_cutlass_kernel {
303303
float ts = tCrS_input(n);
304304
auto& src = *(cute::array<compress_type, K / compress_size>*)(s_tensor(_, n).data());
305305
auto& dst = *(cute::array<DstType, K>*)(d_tensor(_, n).data());
306+
int iter_num = 4;
306307

307308
#pragma unroll
308-
for (int s = 0; s < K / compress_size; s++) {
309+
for (int s = 0; s < K / compress_size / iter_num; s++) {
309310

311+
#if 0
310312
#pragma unroll
311313
for(int i = 0; i < compress_size / 2; i++) {
312314
int dst_offset = s * compress_size + i * 2;
313315
dst[dst_offset] = static_cast<DstType>(quant_map[(src[s] >> (4 * (i * 2 + 1))) & 0xf] * ts);
314316
dst[dst_offset + 1] = static_cast<DstType>(quant_map[(src[s] >> (4 * (i * 2))) & 0xf] * ts);
315317
}
318+
#else
319+
#pragma unroll
320+
for(int i = 0; i < iter_num * compress_size / 2; i++) {
321+
int dst_offset = s * iter_num * compress_size + i * 2;
322+
dst[dst_offset] = static_cast<DstType>(quant_map[src[s * iter_num + i] >> 4] * ts);
323+
dst[dst_offset + 1] = static_cast<DstType>(quant_map[src[s * iter_num + i] & 0xf] * ts);
324+
}
325+
#endif
316326
}
317327
}
318328
#else

0 commit comments

Comments
 (0)