-
-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Expand file tree
/
Copy pathconstant_fold.py
More file actions
185 lines (157 loc) · 6.55 KB
/
constant_fold.py
File metadata and controls
185 lines (157 loc) · 6.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
"""Constant folding of IR values.
For example, 3 + 5 can be constant folded into 8.
This is mostly like mypy.constant_fold, but we can bind some additional
NameExpr and MemberExpr references here, since we have more knowledge
about which definitions can be trusted -- we constant fold only references
to other compiled modules in the same compilation unit.
"""
from __future__ import annotations
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Final, TypeVar
from mypy.constant_fold import constant_fold_binary_op, constant_fold_unary_op
from mypy.nodes import (
BytesExpr,
ComplexExpr,
Expression,
FloatExpr,
IndexExpr,
IntExpr,
MemberExpr,
NameExpr,
OpExpr,
SliceExpr,
StrExpr,
UnaryExpr,
Var,
)
from mypyc.ir.ops import Value
from mypyc.irbuild.util import bytes_from_str
if TYPE_CHECKING:
from mypyc.irbuild.builder import IRBuilder
# All possible result types of constant folding
ConstantValue = int | float | complex | str | bytes
CONST_TYPES: Final = (int, float, complex, str, bytes)
Expr = TypeVar("Expr", bound=Expression)
def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | None:
"""Return the constant value of an expression for supported operations.
Return None otherwise.
"""
if isinstance(expr, IntExpr):
return expr.value
if isinstance(expr, FloatExpr):
return expr.value
if isinstance(expr, StrExpr):
return expr.value
if isinstance(expr, BytesExpr):
return bytes_from_str(expr.value)
if isinstance(expr, ComplexExpr):
return expr.value
elif isinstance(expr, NameExpr):
node = expr.node
if isinstance(node, Var) and node.is_final:
final_value = node.final_value
if isinstance(final_value, (CONST_TYPES)):
return final_value
elif isinstance(expr, MemberExpr):
final = builder.get_final_ref(expr)
if final is not None:
fn, final_var, native = final
if final_var.is_final:
final_value = final_var.final_value
if isinstance(final_value, (CONST_TYPES)):
return final_value
elif isinstance(expr, OpExpr):
left = constant_fold_expr(builder, expr.left)
right = constant_fold_expr(builder, expr.right)
if left is not None and right is not None:
return constant_fold_binary_op_extended(expr.op, left, right)
elif isinstance(expr, UnaryExpr):
value = constant_fold_expr(builder, expr.expr)
if value is not None and not isinstance(value, bytes):
return constant_fold_unary_op(expr.op, value)
elif isinstance(expr, IndexExpr):
base = constant_fold_expr(builder, expr.base)
if base is not None:
assert isinstance(base, (Sequence, dict)), base
index_expr = expr.index
if isinstance(index_expr, SliceExpr):
if index_expr.begin_index is None:
begin_index = None
else:
begin_index = constant_fold_expr(builder, index_expr.begin_index)
if begin_index is None:
return None
if index_expr.end_index is None:
end_index = None
else:
end_index = constant_fold_expr(builder, index_expr.end_index)
if end_index is None:
return None
if index_expr.stride is None:
stride = None
else:
stride = constant_fold_expr(builder, index_expr.stride)
if stride is None:
return None
# this branching just keeps mypy happy, non-functional
if isinstance(base, Sequence):
assert isinstance(begin_index, int) or begin_index is None
assert isinstance(end_index, int) or end_index is None
assert isinstance(stride, int) or stride is None
try:
return base[begin_index:end_index:stride]
except Exception:
return None
try: # type: ignore [unreachable]
return base[begin_index:end_index:stride]
except Exception:
return None
index = constant_fold_expr(builder, index_expr)
# this branching just keeps mypy happy, non-functional
if isinstance(base, Sequence):
if isinstance(index, int):
try:
return base[index]
except Exception:
return None
else:
try: # type: ignore [unreachable]
return base[index]
except Exception:
return None
return None
def constant_fold_binary_op_extended(
op: str, left: ConstantValue, right: ConstantValue
) -> ConstantValue | None:
"""Like mypy's constant_fold_binary_op(), but includes bytes support.
mypy cannot use constant folded bytes easily so it's simpler to only support them in mypyc.
"""
if not isinstance(left, bytes) and not isinstance(right, bytes):
return constant_fold_binary_op(op, left, right)
if op == "+" and isinstance(left, bytes) and isinstance(right, bytes):
return left + right
elif op == "*" and isinstance(left, bytes) and isinstance(right, int):
return left * right
elif op == "*" and isinstance(left, int) and isinstance(right, bytes):
return left * right
return None
def try_constant_fold(builder: IRBuilder, expr: Expression) -> Value | None:
"""Return the constant value of an expression if possible.
Return None otherwise.
"""
value = constant_fold_expr(builder, expr)
if value is not None:
return builder.load_literal_value(value)
return None
def folding_candidate(
transform: Callable[[IRBuilder, Expr], Value],
) -> Callable[[IRBuilder, Expr], Value]:
"""Mark a transform function as a candidate for constant folding.
Candidate functions will attempt to short-circuit the transformation
by constant folding the expression and will only proceed to transform
the expression if folding is not possible.
"""
def constant_fold_wrap(builder: IRBuilder, expr: Expr) -> Value:
folded = try_constant_fold(builder, expr)
return folded if folded is not None else transform(builder, expr)
return constant_fold_wrap