11#pragma once
22
3- #include < bit>
4- #include < cmath>
3+ #include < cstddef>
54#include < cstdint>
65#include < string>
76#include < unordered_map>
@@ -29,121 +28,15 @@ namespace detail {
2928// ---------------------------
3029// BF16 helpers
3130// ---------------------------
32- inline constexpr uint16_t FloatToBf16Bits (float value) {
33- const uint32_t bits = std::bit_cast<uint32_t >(value);
34- const uint32_t lsb = (bits >> 16 ) & 1u ;
35- const uint32_t rounding_bias = 0x7fffu + lsb;
36- return static_cast <uint16_t >((bits + rounding_bias) >> 16 );
37- }
38-
39- inline constexpr float Bf16BitsToFloat (uint16_t bits) {
40- const uint32_t u32 = static_cast <uint32_t >(bits) << 16 ;
41- return std::bit_cast<float >(u32 );
42- }
31+ uint16_t FloatToBf16Bits (float value);
32+ float Bf16BitsToFloat (uint16_t bits);
4333
4434// ---------------------------
4535// FP16 helpers
4636// Pure software IEEE-754 half <-> float conversion for framework fallback use.
4737// ---------------------------
48- inline constexpr uint16_t FloatToFp16Bits (float value) {
49- const uint32_t bits = std::bit_cast<uint32_t >(value);
50-
51- const uint32_t sign = (bits >> 16 ) & 0x8000u ;
52- uint32_t mantissa = bits & 0x007fffffu ;
53- int32_t exp = static_cast <int32_t >((bits >> 23 ) & 0xffu );
54-
55- // NaN / Inf
56- if (exp == 0xff ) {
57- if (mantissa == 0 ) {
58- return static_cast <uint16_t >(sign | 0x7c00u ); // inf
59- }
60- return static_cast <uint16_t >(sign | 0x7e00u ); // quiet NaN
61- }
62-
63- // Zero / subnormal in float32
64- if (exp == 0 ) {
65- return static_cast <uint16_t >(sign);
66- }
67-
68- // Convert exponent bias: fp32 bias 127 -> fp16 bias 15
69- exp = exp - 127 + 15 ;
70-
71- // Overflow -> inf
72- if (exp >= 0x1f ) {
73- return static_cast <uint16_t >(sign | 0x7c00u );
74- }
75-
76- // Underflow -> subnormal / zero
77- if (exp <= 0 ) {
78- if (exp < -10 ) {
79- return static_cast <uint16_t >(sign);
80- }
81-
82- mantissa |= 0x00800000u ;
83-
84- const int shift = 14 - exp;
85- uint32_t half_mant = mantissa >> shift;
86-
87- const uint32_t remainder = mantissa & ((1u << shift) - 1u );
88- const uint32_t halfway = 1u << (shift - 1 );
89- if (remainder > halfway || (remainder == halfway && (half_mant & 1u ))) {
90- ++half_mant;
91- }
92-
93- return static_cast <uint16_t >(sign | half_mant);
94- }
95-
96- // Normal fp16
97- uint32_t half_exp = static_cast <uint32_t >(exp) << 10 ;
98- uint32_t half_mant = mantissa >> 13 ;
99-
100- const uint32_t round_bits = mantissa & 0x1fffu ;
101- if (round_bits > 0x1000u || (round_bits == 0x1000u && (half_mant & 1u ))) {
102- ++half_mant;
103- if (half_mant == 0x400u ) {
104- half_mant = 0 ;
105- half_exp += 0x0400u ;
106- if (half_exp >= 0x7c00u ) {
107- return static_cast <uint16_t >(sign | 0x7c00u );
108- }
109- }
110- }
111-
112- return static_cast <uint16_t >(sign | half_exp | half_mant);
113- }
114-
115- inline constexpr float Fp16BitsToFloat (uint16_t bits) {
116- const uint32_t sign = (static_cast <uint32_t >(bits & 0x8000u )) << 16 ;
117- const uint32_t exp = (bits >> 10 ) & 0x1fu ;
118- const uint32_t mant = bits & 0x03ffu ;
119-
120- uint32_t out = 0 ;
121-
122- if (exp == 0 ) {
123- if (mant == 0 ) {
124- out = sign;
125- } else {
126- uint32_t mantissa = mant;
127- int32_t e = -14 ;
128- while ((mantissa & 0x0400u ) == 0 ) {
129- mantissa <<= 1 ;
130- --e;
131- }
132- mantissa &= 0x03ffu ;
133- const uint32_t exp32 = static_cast <uint32_t >(e + 127 ) << 23 ;
134- const uint32_t mant32 = mantissa << 13 ;
135- out = sign | exp32 | mant32;
136- }
137- } else if (exp == 0x1f ) {
138- out = sign | 0x7f800000u | (mant << 13 );
139- } else {
140- const uint32_t exp32 = static_cast <uint32_t >(static_cast <int32_t >(exp) - 15 + 127 ) << 23 ;
141- const uint32_t mant32 = mant << 13 ;
142- out = sign | exp32 | mant32;
143- }
144-
145- return std::bit_cast<float >(out);
146- }
38+ uint16_t FloatToFp16Bits (float value);
39+ float Fp16BitsToFloat (uint16_t bits);
14740
14841} // namespace detail
14942
@@ -156,18 +49,15 @@ struct alignas(2) FP16 {
15649 constexpr FP16 () = default ;
15750 constexpr FP16 (uint16_t bits, from_bits_t ) : x (bits) {}
15851
159- explicit constexpr FP16 (float value) : x ( detail::FloatToFp16Bits (value)) {}
160- explicit constexpr FP16 (double value) : FP16 ( static_cast < float >(value)) {}
161- explicit constexpr FP16 (int value) : FP16 ( static_cast < float >(value)) {}
162- explicit constexpr FP16 (int64_t value) : FP16 ( static_cast < float >(value)) {}
52+ explicit FP16 (float value);
53+ explicit FP16 (double value);
54+ explicit FP16 (int value);
55+ explicit FP16 (int64_t value);
16356
164- explicit constexpr operator float () const { return detail::Fp16BitsToFloat (x); }
165- explicit constexpr operator double () const { return static_cast < double >( static_cast < float >(* this )); }
57+ explicit operator float () const ;
58+ explicit operator double () const ;
16659
167- FP16 &operator ++() {
168- *this = FP16 (static_cast <float >(*this ) + 1 .0f );
169- return *this ;
170- }
60+ FP16 &operator ++();
17161};
17262
17363struct alignas (2 ) BF16 {
@@ -179,18 +69,15 @@ struct alignas(2) BF16 {
17969 constexpr BF16 () = default ;
18070 constexpr BF16 (uint16_t bits, from_bits_t ) : x (bits) {}
18171
182- explicit constexpr BF16 (float value) : x ( detail::FloatToBf16Bits (value)) {}
183- explicit constexpr BF16 (double value) : BF16 ( static_cast < float >(value)) {}
184- explicit constexpr BF16 (int value) : BF16 ( static_cast < float >(value)) {}
185- explicit constexpr BF16 (int64_t value) : BF16 ( static_cast < float >(value)) {}
72+ explicit BF16 (float value);
73+ explicit BF16 (double value);
74+ explicit BF16 (int value);
75+ explicit BF16 (int64_t value);
18676
187- explicit constexpr operator float () const { return detail::Bf16BitsToFloat (x); }
188- explicit constexpr operator double () const { return static_cast < double >( static_cast < float >(* this )); }
77+ explicit operator float () const ;
78+ explicit operator double () const ;
18979
190- BF16 &operator ++() {
191- *this = BF16 (static_cast <float >(*this ) + 1 .0f );
192- return *this ;
193- }
80+ BF16 &operator ++();
19481};
19582
19683// -----------------------------------------------------------------------------
@@ -211,42 +98,9 @@ enum class DataType : int8_t {
21198 kFLOAT64 ,
21299};
213100
214- constexpr size_t DTypeSize (DataType data_type) {
215- switch (data_type) {
216- case DataType::kUINT8 :
217- return 1 ;
218- case DataType::kINT8 :
219- return 1 ;
220- case DataType::kUINT16 :
221- return 2 ;
222- case DataType::kINT16 :
223- return 2 ;
224- case DataType::kUINT32 :
225- return 4 ;
226- case DataType::kINT32 :
227- return 4 ;
228- case DataType::kUINT64 :
229- return 8 ;
230- case DataType::kINT64 :
231- return 8 ;
232- case DataType::kBFLOAT16 :
233- return 2 ;
234- case DataType::kFLOAT16 :
235- return 2 ;
236- case DataType::kFLOAT32 :
237- return 4 ;
238- case DataType::kFLOAT64 :
239- return 8 ;
240- }
241- return 0 ; // unreachable
242- }
243-
244- inline const std::unordered_map<DataType, std::string> kDataTypeToDesc = {
245- {DataType::kUINT8 , " uint8" }, {DataType::kINT8 , " int8" }, {DataType::kUINT16 , " uint16" },
246- {DataType::kINT16 , " int16" }, {DataType::kUINT32 , " uint32" }, {DataType::kINT32 , " int32" },
247- {DataType::kUINT64 , " uint64" }, {DataType::kINT64 , " int64" }, {DataType::kBFLOAT16 , " bf16" },
248- {DataType::kFLOAT16 , " fp16" }, {DataType::kFLOAT32 , " fp32" }, {DataType::kFLOAT64 , " fp64" },
249- };
101+ size_t DTypeSize (DataType data_type);
102+
103+ extern const std::unordered_map<DataType, std::string> kDataTypeToDesc ;
250104
251105// -----------------------------------------------------------------------------
252106// Compile-time type mapping infrastructure
@@ -312,44 +166,9 @@ template <> struct DataTypeMap<BF16> {
312166// =============================================================================
313167
314168// / Returns true for floating-point DataTypes (FP16, BF16, FP32, FP64).
315- constexpr bool IsFloatingPointDType (DataType dt) {
316- return dt == DataType::kFLOAT16 || dt == DataType::kBFLOAT16 || dt == DataType::kFLOAT32
317- || dt == DataType::kFLOAT64 ;
318- }
319-
320- // / Binary DataType promotion. Safe to call in both host and device code.
321- constexpr DataType PromoteDataTypes (DataType a, DataType b) {
322- if (a == b) {
323- return a;
324- }
325-
326- // Rule 1: FP16 ↔ BF16 — no lossless path, promote to FP32
327- if ((a == DataType::kFLOAT16 && b == DataType::kBFLOAT16 )
328- || (a == DataType::kBFLOAT16 && b == DataType::kFLOAT16 )) {
329- return DataType::kFLOAT32 ;
330- }
331-
332- const bool a_fp = IsFloatingPointDType (a);
333- const bool b_fp = IsFloatingPointDType (b);
334-
335- // Rule 2: float beats integer
336- if (a_fp && !b_fp) {
337- return a;
338- }
339- if (b_fp && !a_fp) {
340- return b;
341- }
342-
343- // Rule 3: same category — wider wins
344- return DTypeSize (a) >= DTypeSize (b) ? a : b;
345- }
346-
347- // / Compile-time binary promotion: DataTypePromotion<A, B>::value
348- template <DataType A, DataType B> struct DataTypePromotion {
349- static constexpr DataType value = PromoteDataTypes(A, B);
350- };
169+ bool IsFloatingPointDType (DataType dt);
351170
352- // / Convenience variable template
353- template < DataType A , DataType B> inline constexpr DataType DataTypePromotion_v = DataTypePromotion<A, B>::value ;
171+ // / Binary DataType promotion.
172+ DataType PromoteDataTypes ( DataType a , DataType b) ;
354173
355174} // namespace infini_train
0 commit comments