-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Expand file tree
/
Copy pathSimplify_Cast.cpp
More file actions
155 lines (147 loc) · 6.64 KB
/
Simplify_Cast.cpp
File metadata and controls
155 lines (147 loc) · 6.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#include "Simplify_Internal.h"
#include "IRPrinter.h"
namespace Halide {
namespace Internal {
Expr Simplify::visit(const Cast *op, ExprInfo *info) {
ExprInfo value_info;
Expr value = mutate(op->value, &value_info);
if (info && no_overflow(op->type) && !op->type.can_represent(value_info.bounds)) {
// If there's overflow in a no-overflow type (e.g. due to casting
// from a UInt(64) to an Int(32)), then forget everything we know
// about the Expr. The expression may or may not overflow. We don't
// know.
*info = ExprInfo{};
} else {
int64_t old_min = value_info.bounds.min;
bool old_min_defined = value_info.bounds.min_defined;
value_info.cast_to(op->type);
if (op->type.is_uint() && op->type.bits() == 64 && old_min_defined && old_min > 0) {
// It's impossible for a cast *to* a uint64 in Halide to lower the
// min. Casts to uint64_t don't overflow for any source type.
value_info.bounds.min_defined = true;
value_info.bounds.min = old_min;
}
value_info.trim_bounds_using_alignment();
if (info) {
*info = value_info;
}
// It's possible we just reduced to a constant. E.g. if we cast an
// even number to uint1 we get zero.
if (value_info.bounds.is_single_point()) {
return make_const(op->type, value_info.bounds.min, info);
}
}
const Cast *cast = value.as<Cast>();
const Broadcast *broadcast_value = value.as<Broadcast>();
const Ramp *ramp_value = value.as<Ramp>();
std::optional<double> f;
std::optional<int64_t> i;
std::optional<uint64_t> u;
if (Call::as_intrinsic(value, {Call::signed_integer_overflow})) {
clear_expr_info(info);
return make_signed_integer_overflow(op->type);
} else if (value.type() == op->type) {
if (info) {
*info = value_info;
}
return value;
} else if (op->type.is_int() &&
(f = as_const_float(value)) &&
std::isfinite(*f)) {
// float -> int
return make_const(op->type, safe_numeric_cast<int64_t>(*f), info);
} else if (op->type.is_uint() &&
(f = as_const_float(value)) &&
std::isfinite(*f)) {
// float -> uint
return make_const(op->type, safe_numeric_cast<uint64_t>(*f), info);
} else if (op->type.is_float() &&
(f = as_const_float(value))) {
// float -> float
return make_const(op->type, *f, info);
} else if (op->type.is_int() &&
(i = as_const_int(value))) {
// int -> int
return make_const(op->type, *i, info);
} else if (op->type.is_uint() &&
(i = as_const_int(value))) {
// int -> uint
return make_const(op->type, safe_numeric_cast<uint64_t>(*i), info);
} else if (op->type.is_float() &&
(i = as_const_int(value))) {
// int -> float
return make_const(op->type, safe_numeric_cast<double>(*i), info);
} else if (op->type.is_int() &&
(u = as_const_uint(value))) {
// uint -> int.
return make_const(op->type, safe_numeric_cast<int64_t>(*u), info);
} else if (op->type.is_uint() &&
(u = as_const_uint(value))) {
// uint -> uint
return make_const(op->type, *u, info);
} else if (op->type.is_float() &&
(u = as_const_uint(value))) {
// uint -> float
return make_const(op->type, safe_numeric_cast<double>(*u), info);
} else if (cast &&
op->type.code() == cast->type.code() &&
op->type.bits() < cast->type.bits()) {
// If this is a cast of a cast of the same type, where the
// outer cast is narrower, the inner cast can be
// eliminated.
return mutate(Cast::make(op->type, cast->value), info);
} else if (cast &&
op->type.is_int_or_uint() &&
cast->type.is_int() &&
cast->value.type().is_int() &&
op->type.bits() >= cast->type.bits() &&
cast->type.bits() >= cast->value.type().bits()) {
// Casting from a signed type always sign-extends, so widening
// partway to a signed type and the rest of the way to some other
// integer type is the same as just widening to that integer type
// directly.
return mutate(Cast::make(op->type, cast->value), info);
} else if (cast &&
op->type.is_int_or_uint() &&
cast->type.is_int_or_uint() &&
cast->value.type().is_int_or_uint() &&
op->type.bits() <= cast->type.bits() &&
op->type.bits() <= op->value.type().bits()) {
// If this is a cast between integer types, where the
// outer cast is narrower than the inner cast and the
// inner cast's argument, the inner cast can be
// eliminated. The inner cast is either a sign extend
// or a zero extend, and the outer cast truncates the extended bits.
// The requirement that cast->value is itself int-or-uint is crucial:
// a float source makes `cast` an fp-to-int conversion, whose low
// bits are not the same as an fp-to-int conversion of a narrower
// type. For example, int32(uint64(float64(-21))) evaluates to 0
// (float-to-uint of a negative value saturates to 0 in Halide),
// while the stripped form int32(float64(-21)) evaluates to -21.
if (op->type == cast->value.type()) {
return mutate(cast->value, info);
} else {
return mutate(Cast::make(op->type, cast->value), info);
}
} else if (broadcast_value) {
// cast(broadcast(x)) -> broadcast(cast(x))
return mutate(Broadcast::make(Cast::make(op->type.with_lanes(broadcast_value->value.type().lanes()), broadcast_value->value), broadcast_value->lanes), info);
} else if (ramp_value &&
op->type.element_of() == Int(64) &&
op->value.type().element_of() == Int(32)) {
// cast(ramp(a, b, w)) -> ramp(cast(a), cast(b), w)
return mutate(Ramp::make(Cast::make(op->type.with_lanes(ramp_value->base.type().lanes()),
ramp_value->base),
Cast::make(op->type.with_lanes(ramp_value->stride.type().lanes()),
ramp_value->stride),
ramp_value->lanes),
info);
}
if (value.same_as(op->value)) {
return op;
} else {
return Cast::make(op->type, value);
}
}
} // namespace Internal
} // namespace Halide