@@ -205,7 +205,7 @@ class gemm_4bit_cutlass_kernel {
205205 }
206206 }
207207
208- // / Utilities to transform A.
208+ # if 0
209209 template <class EngineIn,
210210 class EngineOut,
211211 class EngineScales,
@@ -229,134 +229,110 @@ class gemm_4bit_cutlass_kernel {
229229 using ScaleType = typename EngineScales::value_type;
230230
231231 static constexpr auto N = decltype(size<1>(in))::value;
232-
233- #if 0
234- #if 1
235- //using format_type = int; //32
236- static constexpr auto vec_size = decltype(size(out))::value / N / 2; // 128 / 2 / 2 = 64
237- using format_type = std::array<unsigned char, vec_size>; //<unsigned char, 64>
238- static constexpr auto src_bits = sizeof_bits_v<SrcType>; //8
239-
240- auto s_tensor = make_tensor((format_type*)(raw_pointer_cast(in.data())), Shape<Int<N>>{});
241- auto d_tensor = make_tensor(out.data(), Shape<Int<vec_size * 2>, Int<N>>{});
242-
243- //if(cute::thread0()) printf("decltype(size(out))::value = %d, N = %d, src_bits = %d, vec_size = %d\n", decltype(size(out))::value, N, src_bits, vec_size);
244-
245- CUTLASS_PRAGMA_UNROLL
246- for (int n = 0; n < N; n++) {
247- float ts = tCrS_input(n);
248- auto& src = *(cute::array<unsigned char, vec_size>*)(s_tensor(n).data());
249- auto& dst = *(cute::array<DstType, vec_size * 2>*)(d_tensor(_, n).data());
250-
251- CUTLASS_PRAGMA_UNROLL
252- for (int i = 0; i < vec_size; i++) {
253- dst[i * 2] = static_cast<DstType>(1.0f * ts);
254- dst[i * 2 + 1] = static_cast<DstType>(1.0f * ts);
255- //dst[i * 2] = static_cast<DstType>(quant_map[(src[i] >> src_bits)] * ts);
256- //dst[i * 2 + 1] = static_cast<DstType>(quant_map[(src[i] & 0xf)] * ts);
257- }
258- }
259- #else
260- //using format_type = int; //32
261- static constexpr auto vec_size = decltype(size(out))::value / N / 2; // 128 / 2 / 2 = 64
262- using format_type = std::array<unsigned char, vec_size>; //<unsigned char, 64>
263- static constexpr auto src_bits = sizeof_bits_v<SrcType>; //8
264-
265- auto s_tensor = make_tensor((format_type*)(raw_pointer_cast(in.data())), Shape<Int<N>>{});
266- auto d_tensor = make_tensor(out.data(), Shape<Int<vec_size * 2>, Int<N>>{});
267-
268- //if(cute::thread0()) printf("decltype(size(out))::value = %d, N = %d, src_bits = %d, vec_size = %d\n", decltype(size(out))::value, N, src_bits, vec_size);
269-
270- CUTLASS_PRAGMA_UNROLL
271- for (int n = 0; n < N; n++) {
272- float ts = tCrS_input(n);
273- auto& src = *(cute::array<unsigned char, vec_size>*)(s_tensor(n).data());
274- auto& dst = *(cute::array<DstType, vec_size * 2>*)(d_tensor(_, n).data());
275-
276- CUTLASS_PRAGMA_UNROLL
277- for (int i = 0; i < vec_size; i++) {
278- //dst[i * 2] = static_cast<DstType>(1.0f * ts);
279- //dst[i * 2 + 1] = static_cast<DstType>(1.0f * ts);
280- dst[i * 2] = static_cast<DstType>(quant_map[(src[i] >> src_bits)] * ts);
281- dst[i * 2 + 1] = static_cast<DstType>(quant_map[(src[i] & 0xf)] * ts);
282- }
283- }
284- #endif
285- #else
286- #if 1
287232 static constexpr auto K = decltype(size(out))::value / N; // 128 / 2 = 64
288233
289234 using compress_type = uint32_t;
290- using vec_type = intel::int4; // uint32_t;
291235
292236 static constexpr auto compress_size = sizeof_bits_v<compress_type> / sizeof_bits_v<SrcType>;
293237 static_assert((compress_size % N) == 0);
294238
295- static constexpr auto vec_num = sizeof_bits_v<vec_type> / sizeof_bits_v<compress_type>;
296- static constexpr auto vec_size = compress_size * vec_num;
297-
298- auto s_tensor = make_tensor ((vec_type*)(raw_pointer_cast (in.data ())), Shape<Int<K / vec_size>, Int<N>>{});
239+ auto s_tensor = make_tensor((compress_type*)(raw_pointer_cast(in.data())), Shape<Int<K / compress_size>, Int<N>>{});
299240 auto d_tensor = make_tensor(out.data(), Shape<Int<K>, Int<N>>{});
300241
301242 #pragma unroll
302243 for (int n = 0; n < N; n++) {
303244 float ts = tCrS_input(n);
304- auto & src = *(cute::array<vec_type , K / vec_size >*)(s_tensor (_, n).data ());
245+ auto& src = *(cute::array<compress_type , K / compress_size >*)(s_tensor(_, n).data());
305246 auto& dst = *(cute::array<DstType, K>*)(d_tensor(_, n).data());
306247
307- #if 1
308248 #pragma unroll
309- for (int s = 0 ; s < K / vec_size ; s++) {
310-
249+ for (int s = 0; s < K / compress_size ; s++) {
250+ compress_type src_val = src[s];
311251 #pragma unroll
312- for (int i = 0 ; i < vec_num; i++) {
313- #pragma unroll
314- for (int j = 0 ; j < compress_size / 2 ; j++) {
315- int dst_offset = s * vec_size + i * compress_size + j * 2 ;
316- dst[dst_offset] = static_cast <DstType>(quant_map[(src[s][i] >> (4 * (j * 2 + 1 ))) & 0xf ] * ts);
317- dst[dst_offset + 1 ] = static_cast <DstType>(quant_map[(src[s][i] >> (4 * (j * 2 ))) & 0xf ] * ts);
318- }
252+ for(int i = 0; i < compress_size / 2; i++) {
253+ int dst_offset = s * compress_size + i * 2;
254+ uint8_t high = (src_val >> (4 * (i * 2 + 1))) & 0xf;
255+ uint8_t low = (src_val >> (4 * (i * 2))) & 0xf;
256+ dst[dst_offset] = static_cast<DstType>(quant_map[high] * ts);
257+ dst[dst_offset + 1] = static_cast<DstType>(quant_map[low] * ts);
319258 }
320259 }
260+ }
261+ }
321262#else
322- int iter_num = 4;
323- #pragma unroll
324- for (int s = 0; s < K / compress_size / iter_num; s++) {
263+ template <class EngineIn ,
264+ class EngineOut ,
265+ class EngineScales ,
266+ class LayoutIn ,
267+ class LayoutOut ,
268+ class LayoutScales ,
269+ class ... Ts>
270+ CUTLASS_DEVICE void dequant (
271+ Tensor<EngineIn, LayoutIn> const & in,
272+ Tensor<EngineOut, LayoutOut>& out,
273+ Tensor<EngineScales, LayoutScales>& tCrS_input,
274+ const float * quant_map
275+ ) {
276+ static_assert (is_rmem<EngineIn>::value, " Input tensor must be in registers" );
277+ static_assert (size_v<LayoutIn> == cosize_v<LayoutIn>);
278+ static_assert (size_v<LayoutOut> == cosize_v<LayoutOut>);
325279
326- #pragma unroll
327- for(int i = 0; i < iter_num * compress_size / 2; i++) {
328- int dst_offset = s * iter_num * compress_size + i * 2;
329- dst[dst_offset] = static_cast<DstType>(quant_map[src[s * iter_num + i] >> 4] * ts);
330- dst[dst_offset + 1] = static_cast<DstType>(quant_map[src[s * iter_num + i] & 0xf] * ts);
331- }
332- }
333- #endif
334- }
335- #else
336- using compress_type = uint8_t;
337- static constexpr auto compress_ratio = sizeof_bits_v<compress_type> / sizeof_bits_v<SrcType>;
280+ using SrcType = typename EngineIn::value_type;
281+ using DstType = typename EngineOut::value_type;
282+ using ScaleType = typename EngineScales::value_type;
283+
284+ static constexpr auto N = decltype (size<1 >(in))::value;
338285 static constexpr auto K = decltype (size (out))::value / N;
339- auto s_tensor = make_tensor((compress_type*)(raw_pointer_cast(in.data())), Shape<Int<K/compress_ratio>, Int<N>>{});
340- auto d_tensor = make_tensor(out.data(), Shape<Int<K>, Int<N>>{});
286+
287+ using compress_type = uint32_t ;
288+ static constexpr auto compress_size = sizeof_bits_v<compress_type> / sizeof_bits_v<SrcType>;
289+ static_assert ((compress_size % N) == 0 );
290+
291+ static constexpr auto vec_size = 2 ;
292+ // using VecSrcElemType = cute::array<SrcType, compress_size>;
293+ using VecSrcType = cute::array<compress_type, vec_size>; // sycl::vec<uint32_t, 4>;
294+ using VecDstElemType = cute::array<DstType, compress_size>;
295+ using VecDstType = cute::array<VecDstElemType, vec_size>;
296+
297+ // 预定义掩码和位移
298+ // constexpr uint32_t MASK_HIGH[4] = {0xF0, 0xF000, 0xF00000, 0xF0000000};
299+ // constexpr uint32_t MASK_LOW[4] = {0xF, 0xF00, 0xF0000, 0xF000000};
300+ // constexpr int SHIFT_HIGH[4] = {4, 12, 20, 28};
301+ // constexpr int SHIFT_LOW[4] = {0, 8, 16, 24};
302+
303+ auto s_tensor = make_tensor ((VecSrcType*)(raw_pointer_cast (in.data ())), Shape<Int<K / (compress_size * vec_size)>, Int<N>>{});
304+ auto d_tensor = make_tensor ((VecDstType*)(raw_pointer_cast (out.data ())), Shape<Int<K / (compress_size * vec_size)>, Int<N>>{});
341305
342306 #pragma unroll
343307 for (int n = 0 ; n < N; n++) {
344- float ts = tCrS_input(n);
345- auto& src = *(cute::array<compress_type, K/compress_ratio>*)(s_tensor(_, n).data());
346- auto& dst = *(cute::array<DstType, K>*)(d_tensor(_, n).data());
347- //auto& src = s_tensor(_, n).data();
348- //auto& dst = d_tensor(_, n).data();
308+ float ts = tCrS_input (n);
309+ auto & src = *(cute::array<VecSrcType, K / (compress_size * vec_size)>*)(s_tensor (_, n).data ());
310+ auto & dst = *(cute::array<VecDstType, K / (compress_size * vec_size)>*)(d_tensor (_, n).data ());
349311
350- #pragma unroll
351- for (int k = 0; k < K/compress_ratio/2; k++) {
352- dst[k * 2] = static_cast<DstType>(quant_map[src[k] >> 4] * ts);
353- dst[k * 2 + 1] = static_cast<DstType>(quant_map[src[k] & 0xf] * ts);
354- }
312+ #pragma unroll
313+ for (int k = 0 ; k < K / (compress_size * vec_size); k++) {
314+ VecSrcType src_val = src[k];
315+ VecDstType dst_val;// = dst[k];
316+
317+ #pragma unroll
318+ for (int i = 0 ; i < vec_size; i++) {
319+ compress_type compressed_val = src_val[i];
320+ VecDstElemType compressed_dst_val;// = dst_val[i];
321+
322+ #pragma unroll
323+ for (int j = 0 ; j < compress_size / 2 ; j++) {
324+ // uint8_t high = (compressed_val & MASK_HIGH[j]) >> SHIFT_HIGH[j];
325+ // uint8_t low = (compressed_val & MASK_LOW[j]) >> SHIFT_LOW[j];
326+ compressed_dst_val[2 *j] = static_cast <DstType>(quant_map[(compressed_val >> (4 * (j * 2 + 1 ))) & 0xf ] * ts);
327+ compressed_dst_val[2 *j+1 ] = static_cast <DstType>(quant_map[(compressed_val >> (4 * (j * 2 ))) & 0xf ] * ts);
328+ }
329+ dst_val[i] = compressed_dst_val;
330+ }
331+ dst[k] = dst_val;
332+ }
355333 }
356- #endif
357- #endif
358- }
359-
334+ }
335+ #endif
360336 CUTLASS_DEVICE
361337 void operator ()(Params const & params, char * smem_buf) {
362338 int M = params.m ;
0 commit comments