@@ -9,27 +9,44 @@ namespace Halide {
99namespace Internal {
1010
1111Expr bfloat16_to_float32 (Expr e) {
12+ const int lanes = e.type ().lanes ();
1213 if (e.type ().is_bfloat ()) {
1314 e = reinterpret (e.type ().with_code (Type::UInt), e);
1415 }
15- e = cast (UInt (32 , e. type (). lanes () ), e);
16+ e = cast (UInt (32 , lanes), e);
1617 e = e << 16 ;
17- e = reinterpret (Float (32 , e. type (). lanes () ), e);
18+ e = reinterpret (Float (32 , lanes), e);
1819 e = strict_float (e);
1920 return e;
2021}
2122
22- Expr float32_to_bfloat16 (Expr e) {
23- internal_assert ( e.type ().bits () == 32 );
23+ Expr float_to_bfloat16 (Expr e) {
24+ const int lanes = e.type ().lanes ( );
2425 e = strict_float (e);
25- e = reinterpret (UInt (32 , e.type ().lanes ()), e);
26- // We want to round ties to even, so before truncating either
27- // add 0x8000 (0.5) to odd numbers or 0x7fff (0.499999) to
28- // even numbers.
29- e += 0x7fff + ((e >> 16 ) & 1 );
26+
27+ Expr err;
28+ // First round to float and record any gain of loss of magnitude
29+ if (e.type ().bits () == 64 ) {
30+ Expr f = cast (Float (32 , lanes), e);
31+ err = abs (e) - abs (f);
32+ e = f;
33+ } else {
34+ internal_assert (e.type ().bits () == 32 );
35+ }
36+ e = reinterpret (UInt (32 , lanes), e);
37+
38+ // We want to round ties to even, so if we have no error recorded above,
39+ // before truncating either add 0x8000 (0.5) to odd numbers or 0x7fff
40+ // (0.499999) to even numbers. If we have error, break ties using that
41+ // instead.
42+ Expr tie_breaker = (e >> 16 ) & 1 ; // 1 when rounding down would go to odd
43+ if (err.defined ()) {
44+ tie_breaker = ((err == 0 ) & tie_breaker) | (err > 0 );
45+ }
46+ e += tie_breaker + 0x7fff ;
3047 e = (e >> 16 );
31- e = cast (UInt (16 , e. type (). lanes () ), e);
32- e = reinterpret (BFloat (16 , e. type (). lanes () ), e);
48+ e = cast (UInt (16 , lanes), e);
49+ e = reinterpret (BFloat (16 , lanes), e);
3350 return e;
3451}
3552
@@ -63,51 +80,75 @@ Expr float16_to_float32(Expr value) {
6380 return f32 ;
6481}
6582
66- Expr float32_to_float16 (Expr value) {
83+ Expr float_to_float16 (Expr value) {
6784 // We're about the sniff the bits of a float, so we should
6885 // guard it with strict float to ensure we don't do things
6986 // like assume it can't be denormal.
7087 value = strict_float (value);
7188
72- Type f32_t = Float (32 , value.type ().lanes ());
89+ const int src_bits = value.type ().bits ();
90+
91+ Type float_t = Float (src_bits, value.type ().lanes ());
7392 Type f16_t = Float (16 , value.type ().lanes ());
74- Type u32_t = UInt (32 , value.type ().lanes ());
93+ Type bits_t = UInt (src_bits , value.type ().lanes ());
7594 Type u16_t = UInt (16 , value.type ().lanes ());
7695
77- Expr bits = reinterpret (u32_t , value);
96+ Expr bits = reinterpret (bits_t , value);
7897
7998 // Extract the sign bit
80- Expr sign = bits & make_const (u32_t , 0x80000000 );
99+ Expr sign = bits & make_const (bits_t , ( uint64_t ) 1 << (src_bits - 1 ) );
81100 bits = bits ^ sign;
82101
83102 // Test the endpoints
84- Expr is_denorm = (bits < make_const (u32_t , 0x38800000 ));
85- Expr is_inf = (bits >= make_const (u32_t , 0x47800000 ));
86- Expr is_nan = (bits > make_const (u32_t , 0x7f800000 ));
103+
104+ // Smallest input representable as normal float16 (2^-14)
105+ Expr two_to_the_minus_14 = src_bits == 32 ?
106+ make_const (bits_t , 0x38800000 ) :
107+ make_const (bits_t , (uint64_t )0x3f10000000000000ULL );
108+ Expr is_denorm = bits < two_to_the_minus_14;
109+
110+ // Smallest input too big to represent as a float16 (2^16)
111+ Expr two_to_the_16 = src_bits == 32 ?
112+ make_const (bits_t , 0x47800000 ) :
113+ make_const (bits_t , (uint64_t )0x40f0000000000000ULL );
114+ Expr is_inf = bits >= two_to_the_16;
115+
116+ // Check if the input is a nan, which is anything bigger than an infinity bit pattern
117+ Expr input_inf_bits = src_bits == 32 ?
118+ make_const (bits_t , 0x7f800000 ) :
119+ make_const (bits_t , (uint64_t )0x7ff0000000000000ULL );
120+ Expr is_nan = bits > input_inf_bits;
87121
88122 // Denorms are linearly spaced, so we can handle them
89123 // by scaling up the input as a float and using the
90124 // existing int-conversion rounding instructions.
91- Expr denorm_bits = cast (u16_t , strict_float (round (strict_float (reinterpret (f32_t , bits + 0x0c000000 )))));
125+ Expr two_to_the_24 = src_bits == 32 ?
126+ make_const (bits_t , 0x0c000000 ) :
127+ make_const (bits_t , (uint64_t )0x0180000000000000ULL );
128+ Expr denorm_bits = cast (u16_t , strict_float (round (reinterpret (float_t , bits + two_to_the_24))));
92129 Expr inf_bits = make_const (u16_t , 0x7c00 );
93130 Expr nan_bits = make_const (u16_t , 0x7fff );
94131
95132 // We want to round to nearest even, so we add either
96133 // 0.5 if the integer part is odd, or 0.4999999 if the
97134 // integer part is even, then truncate.
98- bits += (bits >> 13 ) & 1 ;
99- bits += 0xfff ;
100- bits = bits >> 13 ;
135+ const int float16_mantissa_bits = 10 ;
136+ const int input_mantissa_bits = src_bits == 32 ? 23 : 52 ;
137+ const int bits_lost = input_mantissa_bits - float16_mantissa_bits;
138+ bits += (bits >> bits_lost) & 1 ;
139+ bits += make_const (bits_t , ((uint64_t )1 << (bits_lost - 1 )) - 1 );
140+ bits = cast (u16_t , bits >> bits_lost);
141+
101142 // Rebias the exponent
102- bits -= 0x1c000 ;
143+ bits -= 0x4000 ;
103144 // Truncate the top bits of the exponent
104145 bits = bits & 0x7fff ;
105146 bits = select (is_denorm, denorm_bits,
106147 is_inf, inf_bits,
107148 is_nan, nan_bits,
108149 cast (u16_t , bits));
109150 // Recover the sign bit
110- bits = bits | cast (u16_t , sign >> 16 );
151+ bits = bits | cast (u16_t , sign >> (src_bits - 16 ) );
111152 return common_subexpression_elimination (reinterpret (f16_t , bits));
112153}
113154
@@ -157,7 +198,7 @@ Expr lower_float16_transcendental_to_float32_equivalent(const Call *op) {
157198 Expr e = Call::make (t, it->second , new_args, op->call_type ,
158199 op->func , op->value_index , op->image , op->param );
159200 if (op->type .is_float ()) {
160- e = float32_to_float16 (e);
201+ e = float_to_float16 (e);
161202 }
162203 internal_assert (e.type () == op->type );
163204 return e;
@@ -171,6 +212,7 @@ Expr lower_float16_cast(const Cast *op) {
171212 Type src = op->value .type ();
172213 Type dst = op->type ;
173214 Type f32 = Float (32 , dst.lanes ());
215+ Type f64 = Float (64 , dst.lanes ());
174216 Expr val = op->value ;
175217
176218 if (src.is_bfloat ()) {
@@ -183,10 +225,20 @@ Expr lower_float16_cast(const Cast *op) {
183225
184226 if (dst.is_bfloat ()) {
185227 internal_assert (dst.bits () == 16 );
186- val = float32_to_bfloat16 (cast (f32 , val));
228+ if (src.bits () > 32 ) {
229+ val = cast (f64 , val);
230+ } else {
231+ val = cast (f32 , val);
232+ }
233+ val = float_to_bfloat16 (val);
187234 } else if (dst.is_float () && dst.bits () < 32 ) {
188235 internal_assert (dst.bits () == 16 );
189- val = float32_to_float16 (cast (f32 , val));
236+ if (src.bits () > 32 ) {
237+ val = cast (f64 , val);
238+ } else {
239+ val = cast (f32 , val);
240+ }
241+ val = float_to_float16 (val);
190242 }
191243
192244 return cast (dst, val);
0 commit comments