Skip to content

Commit 91c0321

Browse files
committed
new method
1 parent 9eef88a commit 91c0321

1 file changed

Lines changed: 7 additions & 1 deletion

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ class gemm_4bit_cutlass_kernel {
287287
static constexpr auto K = decltype(size(out))::value / N; // 128 / 2 = 64
288288

289289
using compress_type = uint8_t;
290-
using vec_type = uint8_t;
290+
using vec_type = uint32_t;
291291

292292
static constexpr auto compress_size = sizeof_bits_v<compress_type> / sizeof_bits_v<SrcType>;
293293
static_assert((compress_size % N) == 0);
@@ -307,12 +307,18 @@ class gemm_4bit_cutlass_kernel {
307307
#pragma unroll
308308
for (int s = 0; s < K / compress_size; s++) {
309309

310+
#if 0
310311
#pragma unroll
311312
for(int i = 0; i < compress_size / 2; i++) {
312313
int dst_offset = s * compress_size + i * 2;
313314
dst[dst_offset] = static_cast<DstType>(quant_map[(src[s] >> (4 * (i * 2 + 1))) & 0xf] * ts);
314315
dst[dst_offset + 1] = static_cast<DstType>(quant_map[(src[s] >> (4 * (i * 2))) & 0xf] * ts);
315316
}
317+
#else
318+
int dst_offset = s * compress_size;
319+
dst[dst_offset] = static_cast<DstType>(quant_map[src[s] >> 4] * ts);
320+
dst[dst_offset + 1] = static_cast<DstType>(quant_map[src[s] & 0xf] * ts);
321+
#endif
316322
}
317323
}
318324
#else

0 commit comments

Comments
 (0)