@@ -286,7 +286,7 @@ class gemm_4bit_cutlass_kernel {
286286#if 1
287287 static constexpr auto K = decltype (size (out))::value / N; // 128 / 2 = 64
288288
289- using compress_type = uint8_t ;
289+ using compress_type = uint32_t ;
290290 using vec_type = uint32_t ;
291291
292292 static constexpr auto compress_size = sizeof_bits_v<compress_type> / sizeof_bits_v<SrcType>;
@@ -303,27 +303,31 @@ class gemm_4bit_cutlass_kernel {
303303 float ts = tCrS_input (n);
304304 auto & src = *(cute::array<compress_type, K / compress_size>*)(s_tensor (_, n).data ());
305305 auto & dst = *(cute::array<DstType, K>*)(d_tensor (_, n).data ());
306- int iter_num = 4 ;
307306
307+ #if 1
308308 #pragma unroll
309- for (int s = 0 ; s < K / compress_size / iter_num ; s++) {
309+ for (int s = 0 ; s < K / compress_size; s++) {
310310
311- #if 0
312311 #pragma unroll
313312 for (int i = 0 ; i < compress_size / 2 ; i++) {
314313 int dst_offset = s * compress_size + i * 2 ;
315314 dst[dst_offset] = static_cast <DstType>(quant_map[(src[s] >> (4 * (i * 2 + 1 ))) & 0xf ] * ts);
316315 dst[dst_offset + 1 ] = static_cast <DstType>(quant_map[(src[s] >> (4 * (i * 2 ))) & 0xf ] * ts);
317316 }
317+ }
318318#else
319+ int iter_num = 4;
320+ #pragma unroll
321+ for (int s = 0; s < K / compress_size / iter_num; s++) {
322+
319323 #pragma unroll
320324 for(int i = 0; i < iter_num * compress_size / 2; i++) {
321325 int dst_offset = s * iter_num * compress_size + i * 2;
322326 dst[dst_offset] = static_cast<DstType>(quant_map[src[s * iter_num + i] >> 4] * ts);
323327 dst[dst_offset + 1] = static_cast<DstType>(quant_map[src[s * iter_num + i] & 0xf] * ts);
324328 }
325- #endif
326329 }
330+ #endif
327331 }
328332#else
329333 using compress_type = uint8_t;
0 commit comments