1+ /*
2+ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
3+
4+ @article{lin2023awq,
5+ title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
6+ author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
7+ journal={arXiv},
8+ year={2023}
9+ }
10+ */
11+ #include < cuda_fp16.h>
12+ #pragma once
13+
14+ __inline__ __device__ void dequantize_s4_to_fp16x2 (half2 const &source, uint4 *result)
15+ {
16+ // uint4 result;
17+
18+ uint32_t *h = reinterpret_cast <uint32_t *>(result);
19+ uint32_t const i4s = reinterpret_cast <uint32_t const &>(source);
20+
21+ // First, we extract the i4s and construct an intermediate fp16 number.
22+ static constexpr uint32_t immLut = (0xf0 & 0xcc ) | 0xaa ;
23+ static constexpr uint32_t BOTTOM_MASK = 0x000f000f ;
24+ static constexpr uint32_t TOP_MASK = 0x00f000f0 ;
25+ static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400 ;
26+
27+ // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
28+ // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
29+ // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
30+ // elt_67 to fp16 without having to shift them to the bottom bits before hand.
31+
32+ // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
33+ // immediately before required.
34+ const uint32_t top_i4s = i4s >> 8 ;
35+ // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
36+ asm volatile (" lop3.b32 %0, %1, %2, %3, %4;\n "
37+ : " =r" (h[0 ])
38+ : " r" (i4s), " n" (BOTTOM_MASK), " n" (I4s_TO_F16s_MAGIC_NUM), " n" (immLut));
39+ // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
40+ asm volatile (" lop3.b32 %0, %1, %2, %3, %4;\n "
41+ : " =r" (h[1 ])
42+ : " r" (i4s), " n" (TOP_MASK), " n" (I4s_TO_F16s_MAGIC_NUM), " n" (immLut));
43+ // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
44+ asm volatile (" lop3.b32 %0, %1, %2, %3, %4;\n "
45+ : " =r" (h[2 ])
46+ : " r" (top_i4s), " n" (BOTTOM_MASK), " n" (I4s_TO_F16s_MAGIC_NUM), " n" (immLut));
47+ // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
48+ asm volatile (" lop3.b32 %0, %1, %2, %3, %4;\n "
49+ : " =r" (h[3 ])
50+ : " r" (top_i4s), " n" (TOP_MASK), " n" (I4s_TO_F16s_MAGIC_NUM), " n" (immLut));
51+
52+ // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
53+ // half2 ctor. In this case, I chose performance reliability over code readability.
54+
55+ // This is the half2 {1032, 1032} represented as an integer.
56+ // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
57+ // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
58+ static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400 ;
59+ // This is the half2 {1 / 16, 1 / 16} represented as an integer.
60+ static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00 ;
61+ // This is the half2 {-72, -72} represented as an integer.
62+ // static constexpr uint32_t NEG_72 = 0xd480d480;
63+ // Haotian: Let's use {-64, -64}.
64+ static constexpr uint32_t NEG_64 = 0xd400d400 ;
65+
66+ // Finally, we construct the output numbers.
67+ // Convert elt_01
68+ asm volatile (" sub.f16x2 %0, %1, %2;\n " : " =r" (h[0 ]) : " r" (h[0 ]), " r" (FP16_TOP_MAGIC_NUM));
69+ // Convert elt_23
70+ asm volatile (" fma.rn.f16x2 %0, %1, %2, %3;\n " : " =r" (h[1 ]) : " r" (h[1 ]), " r" (ONE_SIXTEENTH), " r" (NEG_64));
71+ // Convert elt_45
72+ asm volatile (" sub.f16x2 %0, %1, %2;\n " : " =r" (h[2 ]) : " r" (h[2 ]), " r" (FP16_TOP_MAGIC_NUM));
73+ // Convert elt_67
74+ asm volatile (" fma.rn.f16x2 %0, %1, %2, %3;\n " : " =r" (h[3 ]) : " r" (h[3 ]), " r" (ONE_SIXTEENTH), " r" (NEG_64));
75+
76+ // return result;
77+ }
0 commit comments