Skip to content

Commit d4b8fc1

Browse files
committed
refactor: move implementation into datatype.cc
1 parent b092c8b commit d4b8fc1

File tree

2 files changed

+265
-206
lines changed

2 files changed

+265
-206
lines changed

infini_train/include/datatype.h

Lines changed: 25 additions & 206 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#pragma once
22

3-
#include <bit>
4-
#include <cmath>
3+
#include <cstddef>
54
#include <cstdint>
65
#include <string>
76
#include <unordered_map>
@@ -29,121 +28,15 @@ namespace detail {
2928
// ---------------------------
3029
// BF16 helpers
3130
// ---------------------------
32-
inline constexpr uint16_t FloatToBf16Bits(float value) {
33-
const uint32_t bits = std::bit_cast<uint32_t>(value);
34-
const uint32_t lsb = (bits >> 16) & 1u;
35-
const uint32_t rounding_bias = 0x7fffu + lsb;
36-
return static_cast<uint16_t>((bits + rounding_bias) >> 16);
37-
}
38-
39-
inline constexpr float Bf16BitsToFloat(uint16_t bits) {
40-
const uint32_t u32 = static_cast<uint32_t>(bits) << 16;
41-
return std::bit_cast<float>(u32);
42-
}
31+
uint16_t FloatToBf16Bits(float value);
32+
float Bf16BitsToFloat(uint16_t bits);
4333

4434
// ---------------------------
4535
// FP16 helpers
4636
// Pure software IEEE-754 half <-> float conversion for framework fallback use.
4737
// ---------------------------
48-
inline constexpr uint16_t FloatToFp16Bits(float value) {
49-
const uint32_t bits = std::bit_cast<uint32_t>(value);
50-
51-
const uint32_t sign = (bits >> 16) & 0x8000u;
52-
uint32_t mantissa = bits & 0x007fffffu;
53-
int32_t exp = static_cast<int32_t>((bits >> 23) & 0xffu);
54-
55-
// NaN / Inf
56-
if (exp == 0xff) {
57-
if (mantissa == 0) {
58-
return static_cast<uint16_t>(sign | 0x7c00u); // inf
59-
}
60-
return static_cast<uint16_t>(sign | 0x7e00u); // quiet NaN
61-
}
62-
63-
// Zero / subnormal in float32
64-
if (exp == 0) {
65-
return static_cast<uint16_t>(sign);
66-
}
67-
68-
// Convert exponent bias: fp32 bias 127 -> fp16 bias 15
69-
exp = exp - 127 + 15;
70-
71-
// Overflow -> inf
72-
if (exp >= 0x1f) {
73-
return static_cast<uint16_t>(sign | 0x7c00u);
74-
}
75-
76-
// Underflow -> subnormal / zero
77-
if (exp <= 0) {
78-
if (exp < -10) {
79-
return static_cast<uint16_t>(sign);
80-
}
81-
82-
mantissa |= 0x00800000u;
83-
84-
const int shift = 14 - exp;
85-
uint32_t half_mant = mantissa >> shift;
86-
87-
const uint32_t remainder = mantissa & ((1u << shift) - 1u);
88-
const uint32_t halfway = 1u << (shift - 1);
89-
if (remainder > halfway || (remainder == halfway && (half_mant & 1u))) {
90-
++half_mant;
91-
}
92-
93-
return static_cast<uint16_t>(sign | half_mant);
94-
}
95-
96-
// Normal fp16
97-
uint32_t half_exp = static_cast<uint32_t>(exp) << 10;
98-
uint32_t half_mant = mantissa >> 13;
99-
100-
const uint32_t round_bits = mantissa & 0x1fffu;
101-
if (round_bits > 0x1000u || (round_bits == 0x1000u && (half_mant & 1u))) {
102-
++half_mant;
103-
if (half_mant == 0x400u) {
104-
half_mant = 0;
105-
half_exp += 0x0400u;
106-
if (half_exp >= 0x7c00u) {
107-
return static_cast<uint16_t>(sign | 0x7c00u);
108-
}
109-
}
110-
}
111-
112-
return static_cast<uint16_t>(sign | half_exp | half_mant);
113-
}
114-
115-
inline constexpr float Fp16BitsToFloat(uint16_t bits) {
116-
const uint32_t sign = (static_cast<uint32_t>(bits & 0x8000u)) << 16;
117-
const uint32_t exp = (bits >> 10) & 0x1fu;
118-
const uint32_t mant = bits & 0x03ffu;
119-
120-
uint32_t out = 0;
121-
122-
if (exp == 0) {
123-
if (mant == 0) {
124-
out = sign;
125-
} else {
126-
uint32_t mantissa = mant;
127-
int32_t e = -14;
128-
while ((mantissa & 0x0400u) == 0) {
129-
mantissa <<= 1;
130-
--e;
131-
}
132-
mantissa &= 0x03ffu;
133-
const uint32_t exp32 = static_cast<uint32_t>(e + 127) << 23;
134-
const uint32_t mant32 = mantissa << 13;
135-
out = sign | exp32 | mant32;
136-
}
137-
} else if (exp == 0x1f) {
138-
out = sign | 0x7f800000u | (mant << 13);
139-
} else {
140-
const uint32_t exp32 = static_cast<uint32_t>(static_cast<int32_t>(exp) - 15 + 127) << 23;
141-
const uint32_t mant32 = mant << 13;
142-
out = sign | exp32 | mant32;
143-
}
144-
145-
return std::bit_cast<float>(out);
146-
}
38+
uint16_t FloatToFp16Bits(float value);
39+
float Fp16BitsToFloat(uint16_t bits);
14740

14841
} // namespace detail
14942

@@ -156,18 +49,15 @@ struct alignas(2) FP16 {
15649
constexpr FP16() = default;
15750
constexpr FP16(uint16_t bits, from_bits_t) : x(bits) {}
15851

159-
explicit constexpr FP16(float value) : x(detail::FloatToFp16Bits(value)) {}
160-
explicit constexpr FP16(double value) : FP16(static_cast<float>(value)) {}
161-
explicit constexpr FP16(int value) : FP16(static_cast<float>(value)) {}
162-
explicit constexpr FP16(int64_t value) : FP16(static_cast<float>(value)) {}
52+
explicit FP16(float value);
53+
explicit FP16(double value);
54+
explicit FP16(int value);
55+
explicit FP16(int64_t value);
16356

164-
explicit constexpr operator float() const { return detail::Fp16BitsToFloat(x); }
165-
explicit constexpr operator double() const { return static_cast<double>(static_cast<float>(*this)); }
57+
explicit operator float() const;
58+
explicit operator double() const;
16659

167-
FP16 &operator++() {
168-
*this = FP16(static_cast<float>(*this) + 1.0f);
169-
return *this;
170-
}
60+
FP16 &operator++();
17161
};
17262

17363
struct alignas(2) BF16 {
@@ -179,18 +69,15 @@ struct alignas(2) BF16 {
17969
constexpr BF16() = default;
18070
constexpr BF16(uint16_t bits, from_bits_t) : x(bits) {}
18171

182-
explicit constexpr BF16(float value) : x(detail::FloatToBf16Bits(value)) {}
183-
explicit constexpr BF16(double value) : BF16(static_cast<float>(value)) {}
184-
explicit constexpr BF16(int value) : BF16(static_cast<float>(value)) {}
185-
explicit constexpr BF16(int64_t value) : BF16(static_cast<float>(value)) {}
72+
explicit BF16(float value);
73+
explicit BF16(double value);
74+
explicit BF16(int value);
75+
explicit BF16(int64_t value);
18676

187-
explicit constexpr operator float() const { return detail::Bf16BitsToFloat(x); }
188-
explicit constexpr operator double() const { return static_cast<double>(static_cast<float>(*this)); }
77+
explicit operator float() const;
78+
explicit operator double() const;
18979

190-
BF16 &operator++() {
191-
*this = BF16(static_cast<float>(*this) + 1.0f);
192-
return *this;
193-
}
80+
BF16 &operator++();
19481
};
19582

19683
// -----------------------------------------------------------------------------
@@ -211,42 +98,9 @@ enum class DataType : int8_t {
21198
kFLOAT64,
21299
};
213100

214-
constexpr size_t DTypeSize(DataType data_type) {
215-
switch (data_type) {
216-
case DataType::kUINT8:
217-
return 1;
218-
case DataType::kINT8:
219-
return 1;
220-
case DataType::kUINT16:
221-
return 2;
222-
case DataType::kINT16:
223-
return 2;
224-
case DataType::kUINT32:
225-
return 4;
226-
case DataType::kINT32:
227-
return 4;
228-
case DataType::kUINT64:
229-
return 8;
230-
case DataType::kINT64:
231-
return 8;
232-
case DataType::kBFLOAT16:
233-
return 2;
234-
case DataType::kFLOAT16:
235-
return 2;
236-
case DataType::kFLOAT32:
237-
return 4;
238-
case DataType::kFLOAT64:
239-
return 8;
240-
}
241-
return 0; // unreachable
242-
}
243-
244-
inline const std::unordered_map<DataType, std::string> kDataTypeToDesc = {
245-
{DataType::kUINT8, "uint8"}, {DataType::kINT8, "int8"}, {DataType::kUINT16, "uint16"},
246-
{DataType::kINT16, "int16"}, {DataType::kUINT32, "uint32"}, {DataType::kINT32, "int32"},
247-
{DataType::kUINT64, "uint64"}, {DataType::kINT64, "int64"}, {DataType::kBFLOAT16, "bf16"},
248-
{DataType::kFLOAT16, "fp16"}, {DataType::kFLOAT32, "fp32"}, {DataType::kFLOAT64, "fp64"},
249-
};
101+
size_t DTypeSize(DataType data_type);
102+
103+
extern const std::unordered_map<DataType, std::string> kDataTypeToDesc;
250104

251105
// -----------------------------------------------------------------------------
252106
// Compile-time type mapping infrastructure
@@ -312,44 +166,9 @@ template <> struct DataTypeMap<BF16> {
312166
// =============================================================================
313167

314168
/// Returns true for floating-point DataTypes (FP16, BF16, FP32, FP64).
315-
constexpr bool IsFloatingPointDType(DataType dt) {
316-
return dt == DataType::kFLOAT16 || dt == DataType::kBFLOAT16 || dt == DataType::kFLOAT32
317-
|| dt == DataType::kFLOAT64;
318-
}
319-
320-
/// Binary DataType promotion. Safe to call in both host and device code.
321-
constexpr DataType PromoteDataTypes(DataType a, DataType b) {
322-
if (a == b) {
323-
return a;
324-
}
325-
326-
// Rule 1: FP16 ↔ BF16 — no lossless path, promote to FP32
327-
if ((a == DataType::kFLOAT16 && b == DataType::kBFLOAT16)
328-
|| (a == DataType::kBFLOAT16 && b == DataType::kFLOAT16)) {
329-
return DataType::kFLOAT32;
330-
}
331-
332-
const bool a_fp = IsFloatingPointDType(a);
333-
const bool b_fp = IsFloatingPointDType(b);
334-
335-
// Rule 2: float beats integer
336-
if (a_fp && !b_fp) {
337-
return a;
338-
}
339-
if (b_fp && !a_fp) {
340-
return b;
341-
}
342-
343-
// Rule 3: same category — wider wins
344-
return DTypeSize(a) >= DTypeSize(b) ? a : b;
345-
}
346-
347-
/// Compile-time binary promotion: DataTypePromotion<A, B>::value
348-
template <DataType A, DataType B> struct DataTypePromotion {
349-
static constexpr DataType value = PromoteDataTypes(A, B);
350-
};
169+
bool IsFloatingPointDType(DataType dt);
351170

352-
/// Convenience variable template
353-
template <DataType A, DataType B> inline constexpr DataType DataTypePromotion_v = DataTypePromotion<A, B>::value;
171+
/// Binary DataType promotion.
172+
DataType PromoteDataTypes(DataType a, DataType b);
354173

355174
} // namespace infini_train

0 commit comments

Comments
 (0)