Skip to content

Commit a9936ab

Browse files
committed
new method
1 parent dbf838e commit a9936ab

1 file changed

Lines changed: 10 additions & 7 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ class gemm_4bit_cutlass_kernel {
287287
static constexpr auto K = decltype(size(out))::value / N; // 128 / 2 = 64
288288

289289
using compress_type = uint32_t;
290-
using vec_type = uint32_t;
290+
using vec_type = intel::int4; //uint32_t;
291291

292292
static constexpr auto compress_size = sizeof_bits_v<compress_type> / sizeof_bits_v<SrcType>;
293293
static_assert((compress_size % N) == 0);
@@ -301,18 +301,21 @@ class gemm_4bit_cutlass_kernel {
301301
#pragma unroll
302302
for (int n = 0; n < N; n++) {
303303
float ts = tCrS_input(n);
304-
auto& src = *(cute::array<compress_type, K / compress_size>*)(s_tensor(_, n).data());
304+
auto& src = *(cute::array<vec_type, K / vec_size>*)(s_tensor(_, n).data());
305305
auto& dst = *(cute::array<DstType, K>*)(d_tensor(_, n).data());
306306

307307
#if 1
308308
#pragma unroll
309-
for (int s = 0; s < K / compress_size; s++) {
309+
for (int s = 0; s < K / vec_size; s++) {
310310

311311
#pragma unroll
312-
for(int i = 0; i < compress_size / 2; i++) {
313-
int dst_offset = s * compress_size + i * 2;
314-
dst[dst_offset] = static_cast<DstType>(quant_map[(src[s] >> (4 * (i * 2 + 1))) & 0xf] * ts);
315-
dst[dst_offset + 1] = static_cast<DstType>(quant_map[(src[s] >> (4 * (i * 2))) & 0xf] * ts);
312+
for(int i = 0; i < vec_num; i++) {
313+
#pragma unroll
314+
for(int j = 0; j < compress_size / 2; j++) {
315+
int dst_offset = s * vec_size + i * compress_size + j * 2;
316+
dst[dst_offset] = static_cast<DstType>(quant_map[(src[s][i] >> (4 * (j * 2 + 1))) & 0xf] * ts);
317+
dst[dst_offset + 1] = static_cast<DstType>(quant_map[(src[s][i] >> (4 * (j * 2))) & 0xf] * ts);
318+
}
316319
}
317320
}
318321
#else

0 commit comments

Comments
 (0)