File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -287,7 +287,7 @@ class gemm_4bit_cutlass_kernel {
287287 static constexpr auto K = decltype (size (out))::value / N; // 128 / 2 = 64
288288
289289 using compress_type = uint8_t ;
290- using vec_type = uint8_t ;
290+ using vec_type = uint32_t ;
291291
292292 static constexpr auto compress_size = sizeof_bits_v<compress_type> / sizeof_bits_v<SrcType>;
293293 static_assert ((compress_size % N) == 0 );
@@ -307,12 +307,18 @@ class gemm_4bit_cutlass_kernel {
307307 #pragma unroll
308308 for (int s = 0 ; s < K / compress_size; s++) {
309309
310+ #if 0
310311 #pragma unroll
311312 for(int i = 0; i < compress_size / 2; i++) {
312313 int dst_offset = s * compress_size + i * 2;
313314 dst[dst_offset] = static_cast<DstType>(quant_map[(src[s] >> (4 * (i * 2 + 1))) & 0xf] * ts);
314315 dst[dst_offset + 1] = static_cast<DstType>(quant_map[(src[s] >> (4 * (i * 2))) & 0xf] * ts);
315316 }
317+ #else
318+ int dst_offset = s * compress_size;
319+ dst[dst_offset] = static_cast <DstType>(quant_map[src[s] >> 4 ] * ts);
320+ dst[dst_offset + 1 ] = static_cast <DstType>(quant_map[src[s] & 0xf ] * ts);
321+ #endif
316322 }
317323 }
318324#else
You can’t perform that action at this time.
0 commit comments