Skip to content

Commit 5742605

Browse files
authored
Fix double-rounding bug in double -> (b)float16 casts (#8906)
* Fix double-rounding bug in double -> (b)float16 casts * Share more code between coming from 64 and 32 bits Also add and fix some comments
1 parent e6be49b commit 5742605

9 files changed

Lines changed: 213 additions & 41 deletions

File tree

src/CodeGen_X86.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,8 @@ void CodeGen_X86::visit(const Cast *op) {
524524
if (target.has_feature(Target::F16C) &&
525525
dst.code() == Type::Float &&
526526
src.code() == Type::Float &&
527-
(dst.bits() == 16 || src.bits() == 16)) {
527+
(dst.bits() == 16 || src.bits() == 16) &&
528+
src.bits() <= 32) { // Don't use for narrowing casts from double - it results in a libm call
528529
// Node we use code() == Type::Float instead of is_float(), because we
529530
// don't want to catch bfloat casts.
530531

src/EmulateFloat16Math.cpp

Lines changed: 80 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,27 +9,44 @@ namespace Halide {
99
namespace Internal {
1010

1111
Expr bfloat16_to_float32(Expr e) {
12+
const int lanes = e.type().lanes();
1213
if (e.type().is_bfloat()) {
1314
e = reinterpret(e.type().with_code(Type::UInt), e);
1415
}
15-
e = cast(UInt(32, e.type().lanes()), e);
16+
e = cast(UInt(32, lanes), e);
1617
e = e << 16;
17-
e = reinterpret(Float(32, e.type().lanes()), e);
18+
e = reinterpret(Float(32, lanes), e);
1819
e = strict_float(e);
1920
return e;
2021
}
2122

22-
Expr float32_to_bfloat16(Expr e) {
23-
internal_assert(e.type().bits() == 32);
23+
Expr float_to_bfloat16(Expr e) {
24+
const int lanes = e.type().lanes();
2425
e = strict_float(e);
25-
e = reinterpret(UInt(32, e.type().lanes()), e);
26-
// We want to round ties to even, so before truncating either
27-
// add 0x8000 (0.5) to odd numbers or 0x7fff (0.499999) to
28-
// even numbers.
29-
e += 0x7fff + ((e >> 16) & 1);
26+
27+
Expr err;
28+
// First round to float and record any gain of loss of magnitude
29+
if (e.type().bits() == 64) {
30+
Expr f = cast(Float(32, lanes), e);
31+
err = abs(e) - abs(f);
32+
e = f;
33+
} else {
34+
internal_assert(e.type().bits() == 32);
35+
}
36+
e = reinterpret(UInt(32, lanes), e);
37+
38+
// We want to round ties to even, so if we have no error recorded above,
39+
// before truncating either add 0x8000 (0.5) to odd numbers or 0x7fff
40+
// (0.499999) to even numbers. If we have error, break ties using that
41+
// instead.
42+
Expr tie_breaker = (e >> 16) & 1; // 1 when rounding down would go to odd
43+
if (err.defined()) {
44+
tie_breaker = ((err == 0) & tie_breaker) | (err > 0);
45+
}
46+
e += tie_breaker + 0x7fff;
3047
e = (e >> 16);
31-
e = cast(UInt(16, e.type().lanes()), e);
32-
e = reinterpret(BFloat(16, e.type().lanes()), e);
48+
e = cast(UInt(16, lanes), e);
49+
e = reinterpret(BFloat(16, lanes), e);
3350
return e;
3451
}
3552

@@ -63,51 +80,75 @@ Expr float16_to_float32(Expr value) {
6380
return f32;
6481
}
6582

66-
Expr float32_to_float16(Expr value) {
83+
Expr float_to_float16(Expr value) {
6784
// We're about the sniff the bits of a float, so we should
6885
// guard it with strict float to ensure we don't do things
6986
// like assume it can't be denormal.
7087
value = strict_float(value);
7188

72-
Type f32_t = Float(32, value.type().lanes());
89+
const int src_bits = value.type().bits();
90+
91+
Type float_t = Float(src_bits, value.type().lanes());
7392
Type f16_t = Float(16, value.type().lanes());
74-
Type u32_t = UInt(32, value.type().lanes());
93+
Type bits_t = UInt(src_bits, value.type().lanes());
7594
Type u16_t = UInt(16, value.type().lanes());
7695

77-
Expr bits = reinterpret(u32_t, value);
96+
Expr bits = reinterpret(bits_t, value);
7897

7998
// Extract the sign bit
80-
Expr sign = bits & make_const(u32_t, 0x80000000);
99+
Expr sign = bits & make_const(bits_t, (uint64_t)1 << (src_bits - 1));
81100
bits = bits ^ sign;
82101

83102
// Test the endpoints
84-
Expr is_denorm = (bits < make_const(u32_t, 0x38800000));
85-
Expr is_inf = (bits >= make_const(u32_t, 0x47800000));
86-
Expr is_nan = (bits > make_const(u32_t, 0x7f800000));
103+
104+
// Smallest input representable as normal float16 (2^-14)
105+
Expr two_to_the_minus_14 = src_bits == 32 ?
106+
make_const(bits_t, 0x38800000) :
107+
make_const(bits_t, (uint64_t)0x3f10000000000000ULL);
108+
Expr is_denorm = bits < two_to_the_minus_14;
109+
110+
// Smallest input too big to represent as a float16 (2^16)
111+
Expr two_to_the_16 = src_bits == 32 ?
112+
make_const(bits_t, 0x47800000) :
113+
make_const(bits_t, (uint64_t)0x40f0000000000000ULL);
114+
Expr is_inf = bits >= two_to_the_16;
115+
116+
// Check if the input is a nan, which is anything bigger than an infinity bit pattern
117+
Expr input_inf_bits = src_bits == 32 ?
118+
make_const(bits_t, 0x7f800000) :
119+
make_const(bits_t, (uint64_t)0x7ff0000000000000ULL);
120+
Expr is_nan = bits > input_inf_bits;
87121

88122
// Denorms are linearly spaced, so we can handle them
89123
// by scaling up the input as a float and using the
90124
// existing int-conversion rounding instructions.
91-
Expr denorm_bits = cast(u16_t, strict_float(round(strict_float(reinterpret(f32_t, bits + 0x0c000000)))));
125+
Expr two_to_the_24 = src_bits == 32 ?
126+
make_const(bits_t, 0x0c000000) :
127+
make_const(bits_t, (uint64_t)0x0180000000000000ULL);
128+
Expr denorm_bits = cast(u16_t, strict_float(round(reinterpret(float_t, bits + two_to_the_24))));
92129
Expr inf_bits = make_const(u16_t, 0x7c00);
93130
Expr nan_bits = make_const(u16_t, 0x7fff);
94131

95132
// We want to round to nearest even, so we add either
96133
// 0.5 if the integer part is odd, or 0.4999999 if the
97134
// integer part is even, then truncate.
98-
bits += (bits >> 13) & 1;
99-
bits += 0xfff;
100-
bits = bits >> 13;
135+
const int float16_mantissa_bits = 10;
136+
const int input_mantissa_bits = src_bits == 32 ? 23 : 52;
137+
const int bits_lost = input_mantissa_bits - float16_mantissa_bits;
138+
bits += (bits >> bits_lost) & 1;
139+
bits += make_const(bits_t, ((uint64_t)1 << (bits_lost - 1)) - 1);
140+
bits = cast(u16_t, bits >> bits_lost);
141+
101142
// Rebias the exponent
102-
bits -= 0x1c000;
143+
bits -= 0x4000;
103144
// Truncate the top bits of the exponent
104145
bits = bits & 0x7fff;
105146
bits = select(is_denorm, denorm_bits,
106147
is_inf, inf_bits,
107148
is_nan, nan_bits,
108149
cast(u16_t, bits));
109150
// Recover the sign bit
110-
bits = bits | cast(u16_t, sign >> 16);
151+
bits = bits | cast(u16_t, sign >> (src_bits - 16));
111152
return common_subexpression_elimination(reinterpret(f16_t, bits));
112153
}
113154

@@ -157,7 +198,7 @@ Expr lower_float16_transcendental_to_float32_equivalent(const Call *op) {
157198
Expr e = Call::make(t, it->second, new_args, op->call_type,
158199
op->func, op->value_index, op->image, op->param);
159200
if (op->type.is_float()) {
160-
e = float32_to_float16(e);
201+
e = float_to_float16(e);
161202
}
162203
internal_assert(e.type() == op->type);
163204
return e;
@@ -171,6 +212,7 @@ Expr lower_float16_cast(const Cast *op) {
171212
Type src = op->value.type();
172213
Type dst = op->type;
173214
Type f32 = Float(32, dst.lanes());
215+
Type f64 = Float(64, dst.lanes());
174216
Expr val = op->value;
175217

176218
if (src.is_bfloat()) {
@@ -183,10 +225,20 @@ Expr lower_float16_cast(const Cast *op) {
183225

184226
if (dst.is_bfloat()) {
185227
internal_assert(dst.bits() == 16);
186-
val = float32_to_bfloat16(cast(f32, val));
228+
if (src.bits() > 32) {
229+
val = cast(f64, val);
230+
} else {
231+
val = cast(f32, val);
232+
}
233+
val = float_to_bfloat16(val);
187234
} else if (dst.is_float() && dst.bits() < 32) {
188235
internal_assert(dst.bits() == 16);
189-
val = float32_to_float16(cast(f32, val));
236+
if (src.bits() > 32) {
237+
val = cast(f64, val);
238+
} else {
239+
val = cast(f32, val);
240+
}
241+
val = float_to_float16(val);
190242
}
191243

192244
return cast(dst, val);

src/EmulateFloat16Math.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ Expr lower_float16_transcendental_to_float32_equivalent(const Call *);
1919

2020
/** Cast to/from float and bfloat using bitwise math. */
2121
//@{
22-
Expr float32_to_bfloat16(Expr e);
23-
Expr float32_to_float16(Expr e);
22+
Expr float_to_bfloat16(Expr e);
23+
Expr float_to_float16(Expr e);
2424
Expr float16_to_float32(Expr e);
2525
Expr bfloat16_to_float32(Expr e);
2626
Expr lower_float16_cast(const Cast *op);

src/Float16.cpp

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ namespace Internal {
99

1010
// Conversion routines to and from float cribbed from Christian Rau's
1111
// half library (half.sourceforge.net)
12-
uint16_t float_to_float16(float value) {
12+
template<typename T>
13+
uint16_t float_to_float16(T value) {
14+
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
15+
"float_to_float16 only supports float and double types");
1316
// Start by copying over the sign bit
1417
uint16_t bits = std::signbit(value) << 15;
1518

@@ -40,14 +43,14 @@ uint16_t float_to_float16(float value) {
4043

4144
// We've normalized value as much as possible. Put the integer
4245
// portion of it into the mantissa.
43-
float ival;
44-
float frac = std::modf(value, &ival);
46+
T ival;
47+
T frac = std::modf(value, &ival);
4548
bits += (uint16_t)(std::abs((int)ival));
4649

4750
// Now consider the fractional part. We round to nearest with ties
4851
// going to even.
4952
frac = std::abs(frac);
50-
bits += (frac > 0.5f) | ((frac == 0.5f) & bits);
53+
bits += (frac > T(0.5)) | ((frac == T(0.5)) & bits);
5154

5255
return bits;
5356
}
@@ -341,6 +344,19 @@ uint16_t float_to_bfloat16(float f) {
341344
return ret >> 16;
342345
}
343346

347+
uint16_t float_to_bfloat16(double f) {
348+
// Coming from double is a little tricker. We first narrow to float and
349+
// record if any magnitude was lost or gained in the process. If so we'll
350+
// use that to break ties instead of testing whether or not truncation would
351+
// return odd.
352+
float f32 = (float)f;
353+
const double err = std::abs(f) - (double)std::abs(f32);
354+
uint32_t ret;
355+
memcpy(&ret, &f32, sizeof(float));
356+
ret += 0x7fff + (((err >= 0) & ((ret >> 16) & 1)) | (err > 0));
357+
return ret >> 16;
358+
}
359+
344360
float bfloat16_to_float(uint16_t b) {
345361
// Assume little-endian floats
346362
uint16_t bits[2] = {0, b};
@@ -362,7 +378,17 @@ float16_t::float16_t(double value)
362378
}
363379

364380
float16_t::float16_t(int value)
365-
: data(float_to_float16(value)) {
381+
: data(float_to_float16((float)value)) {
382+
// integers of any size that map to finite float16s are all representable as
383+
// float, so we can go via the float conversion method.
384+
}
385+
386+
float16_t::float16_t(int64_t value)
387+
: data(float_to_float16((float)value)) {
388+
}
389+
390+
float16_t::float16_t(uint64_t value)
391+
: data(float_to_float16((float)value)) {
366392
}
367393

368394
float16_t::operator float() const {
@@ -464,7 +490,15 @@ bfloat16_t::bfloat16_t(double value)
464490
}
465491

466492
bfloat16_t::bfloat16_t(int value)
467-
: data(float_to_bfloat16(value)) {
493+
: data(float_to_bfloat16((double)value)) {
494+
}
495+
496+
bfloat16_t::bfloat16_t(int64_t value)
497+
: data(float_to_bfloat16((double)value)) {
498+
}
499+
500+
bfloat16_t::bfloat16_t(uint64_t value)
501+
: data(float_to_bfloat16((double)value)) {
468502
}
469503

470504
bfloat16_t::operator float() const {

src/Float16.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ struct float16_t {
3232
explicit float16_t(float value);
3333
explicit float16_t(double value);
3434
explicit float16_t(int value);
35+
explicit float16_t(int64_t value);
36+
explicit float16_t(uint64_t value);
3537
// @}
3638

3739
/** Construct a float16_t with the bits initialised to 0. This represents
@@ -175,6 +177,8 @@ struct bfloat16_t {
175177
explicit bfloat16_t(float value);
176178
explicit bfloat16_t(double value);
177179
explicit bfloat16_t(int value);
180+
explicit bfloat16_t(int64_t value);
181+
explicit bfloat16_t(uint64_t value);
178182
// @}
179183

180184
/** Construct a bfloat16_t with the bits initialised to 0. This represents

src/IR.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,7 @@ const char *const intrinsic_op_names[] = {
678678
"sliding_window_marker",
679679
"sorted_avg",
680680
"strict_add",
681+
"strict_cast",
681682
"strict_div",
682683
"strict_eq",
683684
"strict_le",

src/IR.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,7 @@ struct Call : public ExprNode<Call> {
626626
// them as reals and ignoring the existence of nan and inf. Using these
627627
// intrinsics instead prevents any such optimizations.
628628
strict_add,
629+
strict_cast,
629630
strict_div,
630631
strict_eq,
631632
strict_le,
@@ -792,6 +793,7 @@ struct Call : public ExprNode<Call> {
792793
bool is_strict_float_intrinsic() const {
793794
return is_intrinsic(
794795
{Call::strict_add,
796+
Call::strict_cast,
795797
Call::strict_div,
796798
Call::strict_max,
797799
Call::strict_min,

src/StrictifyFloat.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,16 @@ class Strictify : public IRMutator {
8383
return IRMutator::visit(op);
8484
}
8585
}
86+
87+
Expr visit(const Cast *op) override {
88+
if (op->value.type().is_float() &&
89+
op->type.is_float()) {
90+
return Call::make(op->type, Call::strict_cast,
91+
{mutate(op->value)}, Call::PureIntrinsic);
92+
} else {
93+
return IRMutator::visit(op);
94+
}
95+
}
8696
};
8797

8898
const std::set<std::string> strict_externs = {
@@ -142,6 +152,8 @@ Expr unstrictify_float(const Call *op) {
142152
return op->args[0] <= op->args[1];
143153
} else if (op->is_intrinsic(Call::strict_eq)) {
144154
return op->args[0] == op->args[1];
155+
} else if (op->is_intrinsic(Call::strict_cast)) {
156+
return cast(op->type, op->args[0]);
145157
} else {
146158
internal_error << "Missing lowering of strict float intrinsic: "
147159
<< Expr(op) << "\n";

0 commit comments

Comments
 (0)