-
Notifications
You must be signed in to change notification settings - Fork 996
Expand file tree
/
Copy pathinsert_table_ops.py
More file actions
315 lines (275 loc) · 13.4 KB
/
insert_table_ops.py
File metadata and controls
315 lines (275 loc) · 13.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
# Copyright 2024-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from itertools import chain
from typing import Callable, cast, Dict, Iterator, Set, Type
import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.backends.arm._passes.quant_args import QuantArgs
from executorch.backends.transforms.utils import create_constant_placeholder
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ExportPass, PassResult
from torch.export.graph_signature import InputKind
from torch.fx import GraphModule
from torch.fx.node import Node
class TableOps:
"""Helper class for finding the corresponding table operator for a given
Node.
"""
# Targets that follow a straigtforward one-to-one mapping to their table op
unary_table_ops: Dict[EdgeOpOverload, Callable[[torch.Tensor], torch.Tensor]] = {
exir_ops.edge.aten.ceil.default: torch.ceil,
exir_ops.edge.aten.erf.default: torch.erf,
exir_ops.edge.aten.exp.default: torch.exp,
exir_ops.edge.aten.expm1.default: torch.expm1,
exir_ops.edge.aten.erfinv.default: torch.erfinv,
exir_ops.edge.aten.floor.default: torch.floor,
exir_ops.edge.aten.log.default: torch.log,
exir_ops.edge.aten.log1p.default: torch.log1p,
exir_ops.edge.aten.log10.default: torch.log10,
exir_ops.edge.aten.reciprocal.default: torch.reciprocal,
exir_ops.edge.aten.rsqrt.default: torch.rsqrt,
exir_ops.edge.aten.sigmoid.default: torch.sigmoid,
exir_ops.edge.aten.cos.default: torch.cos,
exir_ops.edge.aten.sin.default: torch.sin,
exir_ops.edge.aten.tanh.default: torch.tanh,
exir_ops.edge.aten.atan.default: torch.atan,
exir_ops.edge.aten.atanh.default: torch.atanh,
exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid,
exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish,
exir_ops.edge.aten.sinh.default: torch.sinh,
exir_ops.edge.aten.acosh.default: torch.acosh,
exir_ops.edge.aten.asin.default: torch.asin,
exir_ops.edge.aten.asinh.default: torch.asinh,
exir_ops.edge.aten.cosh.default: torch.cosh,
exir_ops.edge.aten.acos.default: torch.acos,
exir_ops.edge.aten.tan.default: torch.tan,
exir_ops.edge.aten.silu.default: torch.nn.functional.silu,
}
# Targets that must be treated explicitly
special_table_ops: Set[EdgeOpOverload] = {
exir_ops.edge.aten.pow.Tensor_Scalar,
exir_ops.edge.aten.gelu.default,
exir_ops.edge.aten.elu.default,
exir_ops.edge.aten.remainder.Scalar,
}
def __init__(self, exported_program: ExportedProgram):
self.exported_program = exported_program
def __contains__(self, node: Node) -> bool:
return (
node.target in self.unary_table_ops or node.target in self.special_table_ops
)
def __getitem__(self, node: Node):
target = cast(EdgeOpOverload, node.target)
if target in self.unary_table_ops:
return self.unary_table_ops[target]
elif target in self.special_table_ops:
match target:
case exir_ops.edge.aten.pow.Tensor_Scalar:
# Exponent is a constant. Embed it into a lambda.
exp = cast(int, node.args[1])
return lambda x: torch.pow(x, exp).flatten()
case exir_ops.edge.aten.gelu.default:
# If kwargs not present it is default "none"
approximate = cast(
str,
(
node.kwargs["approximate"]
if "approximate" in node.kwargs
else "none"
),
)
return lambda x: torch.nn.functional.gelu(
x, approximate=approximate
).flatten()
case exir_ops.edge.aten.elu.default:
input_alpha = cast(int, node.kwargs["alpha"])
return lambda x: torch.nn.functional.elu(
x, alpha=input_alpha
).flatten()
case exir_ops.edge.aten.remainder.Scalar:
divisor = cast(float | int, node.args[1])
return lambda x: torch.remainder(x, divisor).flatten()
case _:
# Op must be handled if it's inside self.special_ops
raise AssertionError("Unhandled table operation")
else:
raise KeyError("Table op for {target} does not exist")
@staticmethod
def included_ops() -> Iterator[EdgeOpOverload]:
return chain(TableOps.unary_table_ops, TableOps.special_table_ops)
class InsertTableOpsPass(ArmPass):
"""For ops in self.table_ops they need to be serialized as a TOSA TABLE.
This pass replaces these edge ops with a tosa._table(input: Tensor,
target_str: str) where target_str == str(node.target). When lowering the
_table node target_str will be used to find the corresponding torch operator
which will be used to produce the table values in operators/op_table.py.
"""
_passes_required_after: Set[Type[ExportPass]] = set()
def __init__(self, exported_program: ExportedProgram) -> None:
super().__init__()
self.exported_program = exported_program
self.table_ops = TableOps(exported_program)
def register_buffer(self, buffer_name: str, buffer: torch.Tensor) -> None:
"""Add buffer to self.exported_program.state_dict."""
self.exported_program.state_dict[buffer_name] = buffer
def generate_8bit_table_values(
self,
torch_op: Callable[[torch.Tensor], torch.Tensor],
in_quantargs: QuantArgs,
out_quantargs: QuantArgs,
) -> tuple[torch.Tensor, int]:
"""Compute LUT values for a INT8 TOSA.TABLE.
Also returns 0 since no shifting is required after 8bit table. The INT8
table is a simple 256 value 1-1 LUT.
"""
def f(x: torch.Tensor) -> torch.Tensor:
x = in_quantargs.dequantize_value(x)
x = torch_op(x)
return out_quantargs.quantize_value(x)
return (
f(
torch.linspace(
start=in_quantargs.qmin,
end=in_quantargs.qmax,
steps=256,
dtype=torch.int8,
)
).to(dtype=torch.int8),
0,
)
def generate_16_bit_table_values(
self,
torch_op: Callable[[torch.Tensor], torch.Tensor],
in_quantargs: QuantArgs,
out_quantargs: QuantArgs,
) -> tuple[torch.Tensor, int]:
"""Compute LUT values for a INT16 TOSA.TABLE with 32 bit output.
In practice the output is 23 bits that should be interpreted as 16 'whole' bits and 7 fractional bits, see
the specification: https://www.mlplatform.org/tosa/tosa_spec.html#_table. This means that the output
will interpreted as 2**7=128 times too large unless accounted for by rescaling down the table output.
Quantization can be either int16 or int32 which means that the op output could be larger than the 23 bits from
the TOSA.TABLE output. In that case, we need to rescale up the output.
To handle this we need to:
1) Make sure that our table values fit within 16 bits.
2) Insert a rescale after the table to handle the x128 from the fractional bits and match the quantization.
The function returns rescale_lshift which says how much to rescale after the table. This value can negative.
"""
def f(x: torch.Tensor) -> torch.Tensor:
x = x.clamp(in_quantargs.qmin, in_quantargs.qmax).to(
dtype=in_quantargs.dtype
)
# Dont use the 7 LSBs.
x = in_quantargs.dequantize_value((x & ~0x7F))
x = torch_op(x)
return out_quantargs.quantize_value(x)
lut_values = f(
torch.linspace(
start=in_quantargs.qmin,
end=in_quantargs.qmax + 1,
steps=513,
# use torch.int32 to avoid overflow for end=in_quantargs.qmax + 1.
dtype=torch.int32,
)
)
# Calculate how much we need to shift table values to fit in 16 signed bits
# ceil(log2(max absolute table value)) + 1 bit for signedness - 16
# Example:
# Max value in the table is 70 000. We want to fit it in 16 signed bits.
# 70 000=0b10001000101110000 (17 digits) has ceil(log2(70 000)) = ceil(16.095) = 17 bits.
# If we shift it 17-16=1 bit, we do get 16 bits (0b1000100010111000),
# but due to signedness this is a negative number! So we need to shift it one more bit.
# Note: for out_quantargs.dtype=torch.int16, rshift == 0 and rescale_lshift = -7.
rshift = int(torch.ceil(torch.log2(lut_values.abs().max()))) + 1 - 16
# The 7 fractional bits are equivalent to a lshift of 7, so subtract 7 from the lshift we do.
rescale_lshift = rshift - 7
lut_values = lut_values >> rshift
return lut_values.to(dtype=torch.int16), rescale_lshift
def generate_table_values(
self,
torch_op: Callable[[torch.Tensor], torch.Tensor],
in_quantargs: QuantArgs,
out_quantargs: QuantArgs,
) -> tuple[torch.Tensor, int]:
match out_quantargs.dtype:
case torch.int8:
return self.generate_8bit_table_values(
torch_op, in_quantargs, out_quantargs
)
case torch.int16 | torch.int32:
return self.generate_16_bit_table_values(
torch_op, in_quantargs, out_quantargs
)
case _:
raise ValueError(
f"Unsupported output dtype for table: {out_quantargs.dtype}"
)
def call(self, graph_module: GraphModule) -> PassResult:
modified = False
for node in graph_module.graph.nodes:
if node.op != "call_function" or node not in self.table_ops:
continue
input_qparams = node.meta.get("input_qparams", {})
output_qparams = node.meta.get("output_qparams", {})
if len(input_qparams) == 0 or len(output_qparams) == 0:
# We only want to replace the node if it's quantized
continue
# Create table node
insert_pos = list(node.graph.nodes)[0]
with graph_module.graph.inserting_before(insert_pos):
# Expect exactly one quantization parameter for input and output
if len(input_qparams) != 1:
raise ValueError(
f"InsertTableOpsPass expected exactly one input quantization parameter, "
f"got {len(input_qparams)} for node {node.name}"
)
if len(output_qparams) != 1:
raise ValueError(
f"InsertTableOpsPass expected exactly one output quantization parameter, "
f"got {len(output_qparams)} for node {node.name}"
)
# Generate table buffer and how much to lshift the table output.
buffer, lshift = self.generate_table_values(
torch_op=self.table_ops[node],
in_quantargs=input_qparams[0],
out_quantargs=output_qparams[0],
)
# Register buffer in self.exported_program.state_dict
const_table_node = create_constant_placeholder(
exp_program=self.exported_program,
graph=node.graph,
kind=InputKind.BUFFER,
name=node.name + "_table_constant",
data=buffer,
persistent_buffer=True,
)
# Create table node
with graph_module.graph.inserting_before(node):
table_op_node = create_node(
graph=graph_module.graph,
op_target=exir_ops.backend.tosa.TABLE.default,
args=(node.args[0], const_table_node),
)
output_node = table_op_node
if lshift != 0:
scale = 2.0**lshift
rescale_node = create_node(
graph=graph_module.graph,
op_target=exir_ops.backend.tosa.RESCALE.default,
args=(table_op_node, output_qparams[0].dtype, [scale], 0, 0),
)
output_node = rescale_node
node.replace_all_uses_with(output_node)
graph_module.graph.erase_node(node)
table_op_node.meta["input_qparams"] = input_qparams
table_op_node.meta["output_qparams"] = output_qparams
modified = True
if modified:
# retrace the graph to update the fake tensor types
graph_module = super().call(graph_module).graph_module
graph_module.recompile()
return PassResult(graph_module, modified)