@@ -283,47 +283,53 @@ class gemm_4bit_cutlass_kernel {
283283 }
284284#endif
285285#else
286+ #if 1
286287 using format_type = uint32_t ; // 32
287- static constexpr auto src_bits = sizeof_bits_v<SrcType>; // 4
288- static constexpr auto scalar = sizeof_bits_v<format_type> / src_bits; // 8
289- static constexpr auto loop_cnt = decltype (size (out))::value / N; // 128 / 2 = 64
290- static_assert ((scalar % N) == 0 );
288+ static constexpr auto compress_ratio = sizeof_bits_v<format_type> / sizeof_bits_v<SrcType>; // 8
289+ static constexpr auto K = decltype (size (out))::value / N; // 128 / 2 = 64
290+ static_assert ((compress_ratio % N) == 0 );
291291
292- static constexpr auto vec_size = scalar; // 8
293- static constexpr auto splits = loop_cnt / vec_size; // 64 / 8 = 8
294- static_assert (vec_size <= scalar);
295-
296- auto s_tensor = make_tensor ((format_type*)(raw_pointer_cast (in.data ())), Shape<Int<loop_cnt / scalar>, Int<N>>{});
297- auto d_tensor = make_tensor (out.data (), Shape<Int<vec_size>, Int<splits>, Int<N>>{});
298-
299- // if(cute::thread0()) printf("decltype(size(out))::value = %d, N = %d, src_bits = %d, scalar = %d, loop_cnt = %d, vec_size = %d, splits = %d\n", decltype(size(out))::value, N, src_bits, scalar, loop_cnt, vec_size, splits);
292+ auto s_tensor = make_tensor ((format_type*)(raw_pointer_cast (in.data ())), Shape<Int<K / compress_ratio>, Int<N>>{});
293+ auto d_tensor = make_tensor (out.data (), Shape<Int<K>, Int<N>>{});
300294
301295 CUTLASS_PRAGMA_UNROLL
302296 for (int n = 0 ; n < N; n++) {
303297 float ts = tCrS_input (n);
304- auto & src = *(cute::array<format_type, loop_cnt / scalar>*)(s_tensor (_, n).data ());
298+ auto & src = *(cute::array<format_type, K / compress_ratio>*)(s_tensor (_, n).data ());
299+ auto & dst = *(cute::array<DstType, K>*)(d_tensor (_, n).data ());
305300
306301 CUTLASS_PRAGMA_UNROLL
307- for (int s = 0 ; s < splits; s++) {
308- auto idx = vec_size * s / scalar;
309- auto format_data = src[idx];
310-
311- auto & dst = *(cute::array<DstType, vec_size>*)(d_tensor (_, s, n).data ());
302+ for (int s = 0 ; s < K / compress_ratio; s++) {
312303
313304 CUTLASS_PRAGMA_UNROLL
314- for (int i = 0 ; i < vec_size/2 ; i++) {
315- #if 0
316- dst[i * 2] = static_cast<DstType>(1.0f * ts);
317- dst[i * 2 + 1] = static_cast<DstType>(1.0f * ts);
318- #else
319- dst[i * 2 ] = static_cast <DstType>(quant_map[(format_data >> (src_bits * (i * 2 + 1 ))) & 0xf ] * ts);
320- dst[i * 2 + 1 ] = static_cast <DstType>(quant_map[(format_data >> (src_bits * (i * 2 ))) & 0xf ] * ts);
321- // dst[i * 2] = quant_map[(format_data >> (src_bits * (i * 2 + 1))) & 0xf] * ts;
322- // dst[i * 2 + 1] = quant_map[(format_data >> (src_bits * (i * 2))) & 0xf] * ts;
323- #endif
305+ for (int i = 0 ; i < compress_ratio/2 ; i++) {
306+ dst[s * compress_ratio + i * 2 ] = static_cast <DstType>(quant_map[(src[s] >> (4 * (i * 2 + 1 ))) & 0xf ] * ts);
307+ dst[s * compress_ratio + i * 2 + 1 ] = static_cast <DstType>(quant_map[(src[s] >> (4 * (i * 2 ))) & 0xf ] * ts);
324308 }
325309 }
326310 }
311+ #else
312+ using compress_type = uint8_t;
313+ static constexpr auto compress_ratio = sizeof_bits_v<compress_type> / sizeof_bits_v<SrcType>;
314+ static constexpr auto K = decltype(size(out))::value / N;
315+ auto s_tensor = make_tensor((compress_type*)(raw_pointer_cast(in.data())), Shape<Int<K/compress_ratio>, Int<N>>{});
316+ auto d_tensor = make_tensor(out.data(), Shape<Int<K>, Int<N>>{});
317+
318+ #pragma unroll
319+ for (int n = 0; n < N; n++) {
320+ float ts = tCrS_input(n);
321+ auto& src = *(cute::array<compress_type, K/compress_ratio>*)(s_tensor(_, n).data());
322+ auto& dst = *(cute::array<DstType, K>*)(d_tensor(_, n).data());
323+ //auto& src = s_tensor(_, n).data();
324+ //auto& dst = d_tensor(_, n).data();
325+
326+ #pragma unroll
327+ for (int k = 0; k < K/compress_ratio/2; k++) {
328+ dst[k * 2] = static_cast<DstType>(quant_map[src[k] >> 4] * ts);
329+ dst[k * 2 + 1] = static_cast<DstType>(quant_map[src[k] & 0xf] * ts);
330+ }
331+ }
332+ #endif
327333#endif
328334 }
329335
0 commit comments