@@ -287,7 +287,7 @@ class gemm_4bit_cutlass_kernel {
287287 static constexpr auto K = decltype (size (out))::value / N; // 128 / 2 = 64
288288
289289 using compress_type = uint32_t ;
290- using vec_type = uint32_t ;
290+ using vec_type = intel::int4; // uint32_t;
291291
292292 static constexpr auto compress_size = sizeof_bits_v<compress_type> / sizeof_bits_v<SrcType>;
293293 static_assert ((compress_size % N) == 0 );
@@ -301,18 +301,21 @@ class gemm_4bit_cutlass_kernel {
301301 #pragma unroll
302302 for (int n = 0 ; n < N; n++) {
303303 float ts = tCrS_input (n);
304- auto & src = *(cute::array<compress_type , K / compress_size >*)(s_tensor (_, n).data ());
304+ auto & src = *(cute::array<vec_type , K / vec_size >*)(s_tensor (_, n).data ());
305305 auto & dst = *(cute::array<DstType, K>*)(d_tensor (_, n).data ());
306306
307307#if 1
308308 #pragma unroll
309- for (int s = 0 ; s < K / compress_size ; s++) {
309+ for (int s = 0 ; s < K / vec_size ; s++) {
310310
311311 #pragma unroll
312- for (int i = 0 ; i < compress_size / 2 ; i++) {
313- int dst_offset = s * compress_size + i * 2 ;
314- dst[dst_offset] = static_cast <DstType>(quant_map[(src[s] >> (4 * (i * 2 + 1 ))) & 0xf ] * ts);
315- dst[dst_offset + 1 ] = static_cast <DstType>(quant_map[(src[s] >> (4 * (i * 2 ))) & 0xf ] * ts);
312+ for (int i = 0 ; i < vec_num; i++) {
313+ #pragma unroll
314+ for (int j = 0 ; j < compress_size / 2 ; j++) {
315+ int dst_offset = s * vec_size + i * compress_size + j * 2 ;
316+ dst[dst_offset] = static_cast <DstType>(quant_map[(src[s][i] >> (4 * (j * 2 + 1 ))) & 0xf ] * ts);
317+ dst[dst_offset + 1 ] = static_cast <DstType>(quant_map[(src[s][i] >> (4 * (j * 2 ))) & 0xf ] * ts);
318+ }
316319 }
317320 }
318321#else
0 commit comments