-
Notifications
You must be signed in to change notification settings - Fork 28
Expand file tree
/
Copy pathbfloat16.hpp
More file actions
125 lines (110 loc) · 4.71 KB
/
Copy pathbfloat16.hpp
File metadata and controls
125 lines (110 loc) · 4.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#pragma once
#include <cmath>
#include <cstdint>
class bfloat16;
class bfloat16 {
using StorageType = uint16_t;
StorageType value;
static StorageType from_float(const float &a) {
if (std::isnan(a))
return 0xffc1;
union {
uint32_t intStorage;
float floatValue;
};
floatValue = a;
// Do RNE and truncate
uint32_t roundingBias = ((intStorage >> 16) & 0x1) + 0x00007FFF;
return static_cast<StorageType>((intStorage + roundingBias) >> 16);
}
static float to_float(const StorageType &a) {
union {
uint32_t intStorage;
float floatValue;
};
intStorage = a << 16;
return floatValue;
}
public:
bfloat16() = default;
bfloat16(const bfloat16 &) = default;
~bfloat16() = default;
// Implicit conversion from float to bfloat16
bfloat16(const float &a) { value = from_float(a); }
bfloat16 &operator=(const float &rhs) {
value = from_float(rhs);
return *this;
}
// Implicit conversion from bfloat16 to float
operator float() const { return to_float(value); }
// Logical operators (!,||,&&) are covered if we can cast to bool
explicit operator bool() const { return to_float(value) != 0.0f; }
// Unary minus operator overloading
friend bfloat16 operator-(const bfloat16 &lhs) {
return -to_float(lhs.value);
}
// Increment and decrement operators overloading
#define OP(op) \
friend bfloat16 &operator op(bfloat16 &lhs) { \
float f = to_float(lhs.value); \
lhs.value = from_float(op f); \
return lhs; \
} \
friend bfloat16 operator op(bfloat16 &lhs, int) { \
bfloat16 old = lhs; \
operator op(lhs); \
return old; \
}
OP(++)
OP(--)
#undef OP
// Assignment operators overloading
#define OP(op) \
friend bfloat16 &operator op(bfloat16 &lhs, const bfloat16 &rhs) { \
float f = static_cast<float>(lhs); \
f op static_cast<float>(rhs); \
return lhs = f; \
} \
template <typename T> \
friend bfloat16 &operator op(bfloat16 &lhs, const T &rhs) { \
float f = static_cast<float>(lhs); \
f op static_cast<float>(rhs); \
return lhs = f; \
} \
template <typename T> friend T &operator op(T &lhs, const bfloat16 &rhs) { \
float f = static_cast<float>(lhs); \
f op static_cast<float>(rhs); \
return lhs = f; \
}
OP(+=)
OP(-=)
OP(*=)
OP(/=)
#undef OP
// Binary operators overloading
#define OP(type, op) \
friend type operator op(const bfloat16 &lhs, const bfloat16 &rhs) { \
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
} \
template <typename T> \
friend type operator op(const bfloat16 &lhs, const T &rhs) { \
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
} \
template <typename T> \
friend type operator op(const T &lhs, const bfloat16 &rhs) { \
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
}
OP(bfloat16, +)
OP(bfloat16, -)
OP(bfloat16, *)
OP(bfloat16, /)
OP(bool, ==)
OP(bool, !=)
OP(bool, <)
OP(bool, >)
OP(bool, <=)
OP(bool, >=)
#undef OP
// Bitwise(|,&,~,^), modulo(%) and shift(<<,>>) operations are not supported
// for floating-point types.
};