@@ -287,7 +287,7 @@ class gemm_4bit_cutlass_kernel {
287287 static constexpr auto K = decltype (size (out))::value / N; // 128 / 2 = 64
288288
289289 using compress_type = uint8_t ;
290- using vec_type = uint8_t ;
290+ using vec_type = uint32_t ;
291291
292292 static constexpr auto compress_size = sizeof_bits_v<compress_type> / sizeof_bits_v<SrcType>;
293293 static_assert ((compress_size % N) == 0 );
@@ -303,16 +303,26 @@ class gemm_4bit_cutlass_kernel {
303303 float ts = tCrS_input (n);
304304 auto & src = *(cute::array<compress_type, K / compress_size>*)(s_tensor (_, n).data ());
305305 auto & dst = *(cute::array<DstType, K>*)(d_tensor (_, n).data ());
306+ int iter_num = 4 ;
306307
307308 #pragma unroll
308- for (int s = 0 ; s < K / compress_size; s++) {
309+ for (int s = 0 ; s < K / compress_size / iter_num ; s++) {
309310
311+ #if 0
310312 #pragma unroll
311313 for(int i = 0; i < compress_size / 2; i++) {
312314 int dst_offset = s * compress_size + i * 2;
313315 dst[dst_offset] = static_cast<DstType>(quant_map[(src[s] >> (4 * (i * 2 + 1))) & 0xf] * ts);
314316 dst[dst_offset + 1] = static_cast<DstType>(quant_map[(src[s] >> (4 * (i * 2))) & 0xf] * ts);
315317 }
318+ #else
319+ #pragma unroll
320+ for (int i = 0 ; i < iter_num * compress_size / 2 ; i++) {
321+ int dst_offset = s * iter_num * compress_size + i * 2 ;
322+ dst[dst_offset] = static_cast <DstType>(quant_map[src[s * iter_num + i] >> 4 ] * ts);
323+ dst[dst_offset + 1 ] = static_cast <DstType>(quant_map[src[s * iter_num + i] & 0xf ] * ts);
324+ }
325+ #endif
316326 }
317327 }
318328#else
0 commit comments