forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
executable file
·315 lines (269 loc) · 10.6 KB
/
utils.py
File metadata and controls
executable file
·315 lines (269 loc) · 10.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, Dict, List
import torch
from executorch.backends.qualcomm.builders.utils import get_parameter
from executorch.backends.qualcomm.utils.constants import QCOM_DTYPE, QCOM_ENCODING
from executorch.exir.dialects._ops import ops as exir_ops
from torch._subclasses import FakeTensor
def copy_meta(meta: Dict, callback=None):
copied = {}
for k, v in meta.items():
copied[k] = v
if callback:
copied = callback(copied)
return copied
def get_quant_attrs(
edge_program: torch.export.ExportedProgram, quant_node: torch.fx.Node
):
quant_attr_keys = [arg.name for arg in quant_node.target._schema.arguments][1:]
quant_attrs = dict.fromkeys(quant_attr_keys)
for i in range(1, len(quant_node.args)):
attr_n = quant_node.args[i]
value = attr_n
if isinstance(attr_n, torch.fx.node.Node):
# could be a commonly shared attribute between q & dq
if attr_n.target == exir_ops.edge.aten._to_copy.default:
value = get_parameter(attr_n.args[0], edge_program)
else:
value = get_parameter(attr_n, edge_program)
quant_attrs[quant_attr_keys[i - 1]] = value
# remap key for compatibility - block quantization only
if dtype := quant_attrs.get("input_dtype", None):
quant_attrs[QCOM_DTYPE] = dtype
quant_attrs[QCOM_ENCODING] = quant_node.target
return quant_attrs
def get_passes_dependency_for_capture_program():
"""
This function records the dependencies for passes used in the to_edge_transform_and_lower_to_qnn.
It returns a dictionary where the keys are pass classes and the values are lists of
dependencies required by each pass. This helps in managing and organizing the sequence
of passes needed for the to_edge_transform_and_lower_to_qnn to function correctly.
Returns:
dict: A dictionary mapping each pass to its corresponding list of dependencies.
"""
from executorch.backends.qualcomm._passes import (
AnnotateAdaptiveAvgPool1D,
AnnotateQuantAttrs,
AnnotateStack,
AnnotateUnbind,
CanonicalizeConv,
ConvertBmmToMatmul,
DecomposeAcos,
DecomposeAny,
DecomposeColIm,
DecomposeLinalgVectorNorm,
DecomposeLogVariants,
DecomposeMaxPool3d,
DecomposeTrunc,
ExpandBroadcastTensorShape,
FixedLinearKeepDim,
FoldQDQ,
I64toI32,
LayoutTransform,
RecomposePadMaxPool2d,
RecomposePixelUnshuffle,
RecomposeRmsNorm,
RemoveRedundancy,
ResolveDebugHandle,
TagQuantIO,
)
return {
AnnotateAdaptiveAvgPool1D: [RemoveRedundancy],
AnnotateQuantAttrs: [
ConvertBmmToMatmul,
RecomposePixelUnshuffle,
RemoveRedundancy,
],
AnnotateStack: [RemoveRedundancy],
AnnotateUnbind: [RemoveRedundancy],
ConvertBmmToMatmul: [RecomposePixelUnshuffle],
DecomposeAcos: [RemoveRedundancy],
DecomposeAny: [RemoveRedundancy],
DecomposeColIm: [FoldQDQ],
DecomposeLinalgVectorNorm: [RemoveRedundancy],
DecomposeLogVariants: [RemoveRedundancy],
DecomposeMaxPool3d: [RemoveRedundancy],
DecomposeTrunc: [RemoveRedundancy],
ExpandBroadcastTensorShape: [FoldQDQ],
FixedLinearKeepDim: [FoldQDQ],
FoldQDQ: [AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind],
I64toI32: [RemoveRedundancy],
LayoutTransform: [
AnnotateQuantAttrs,
CanonicalizeConv,
ExpandBroadcastTensorShape,
FixedLinearKeepDim,
],
RecomposePadMaxPool2d: [DecomposeMaxPool3d, FoldQDQ],
RecomposePixelUnshuffle: [RemoveRedundancy],
RecomposeRmsNorm: [RemoveRedundancy],
TagQuantIO: [LayoutTransform],
ResolveDebugHandle: [
TagQuantIO
], # IMPORTANT: Please always ensure ResolveDebugHandle is the last executed pass.
}
def copy_nn_module_stack(src, target):
"""
Copy meta["nn_module_stack"] from src node to target node if existing.
"""
if value := src.meta.get("nn_module_stack"):
target.meta["nn_module_stack"] = value
def merge_decomposed_graph(
remap: Dict[str, torch.fx.Node],
target_node: torch.fx.Node,
target_graph: torch.fx.GraphModule,
decomposed_graph_module: torch.fx.GraphModule,
predicate: Callable[[torch.fx.Node], None] = None,
# target_node, decomposed_output_node, remap
output_processor: Callable[
[torch.fx.Node, torch.fx.Node, Dict[str, torch.fx.Node]], None
] = None,
) -> None:
def default_output_process(node):
for user in node.users.copy():
# remap
user.replace_input_with(
node,
remap[decomposed_node.args[0][0]],
)
for decomposed_node in decomposed_graph_module.graph.nodes:
copy_nn_module_stack(target_node, decomposed_node)
if predicate is None or predicate(decomposed_node):
# no need to copy existent 'output'
if decomposed_node.op == "output":
if output_processor is None:
default_output_process(target_node)
else:
output_processor(target_node, decomposed_node, remap)
# no need to copy existent placeholders
elif decomposed_node.op == "placeholder":
# replace node map from string to graph node
remap[decomposed_node] = remap.pop(decomposed_node.name)
else:
remap[decomposed_node] = target_graph.node_copy(
decomposed_node,
arg_transform=lambda x, remap=remap: remap[x],
)
def is_float_tensor(node: torch.fx.Node) -> bool:
if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor):
return False
return node.meta["val"].dtype == torch.float32
def _is_node(node):
return isinstance(node, torch.fx.Node)
def _pred(node, pat):
return isinstance(pat, Callable) and pat(node)
def _next(node, from_args=True):
if from_args:
yield from [i for i in node.args if _is_node(i)]
else:
yield from list(node.users)
def find_pattern(
node: torch.fx.Node,
pattern: List[Callable[[torch.fx.Node], bool] | str],
from_args: bool = True,
max_wildcard_life: int = 3,
verbose: bool = False,
):
"""
Implement wildcard pattern matching
- node: fx.Node
- pattern: predicate list, can contain followings
Callable(fx.node): predicate
'*': wildcard
'?': any single node
- from_args: if True find from node.args, otherwise from node.users
- max_wildcard_life: max number of skips for wildcard
If not matched, return None.
Otherwise, return list of matched node list, which is the same length as pattern
"""
asterisk, question = "*", "?"
def _probe(
cur, hist, pat_idx, asterisk_life_count=max_wildcard_life, verbose=verbose
):
if pat_idx == len(pattern):
# Expected len(hist) is equal to pat_idx
assert len(hist) == len(pattern)
if list(hist) not in matched:
matched.append(list(hist))
return
if verbose:
print(
f"cur:{cur}, idx:{pat_idx}, life={asterisk_life_count}, pattern:{pattern[pat_idx]} hist={hist}"
)
if pattern[pat_idx] == question or _pred(cur, pattern[pat_idx]):
hist.append(cur)
for child in _next(cur, from_args):
_probe(child, hist, pat_idx + 1)
hist.pop()
elif pattern[pat_idx] == asterisk and asterisk_life_count > 0:
# 3 cases: ignore/consume/keep asterisk
# 1, Ignore asterisk
hist.append(None)
_probe(cur, hist, pat_idx + 1)
hist.pop()
# 2. Consume asterisk
hist.append(None)
for child in _next(cur, from_args):
_probe(child, hist, pat_idx + 1)
hist.pop()
# 3. keep asterisk and skip to next node
for child in _next(cur, from_args):
_probe(child, hist, pat_idx, asterisk_life_count - 1)
# Check if pattern is valid
assert all(
isinstance(i, Callable) or (isinstance(i, str) and (i == "*" or i == "?"))
for i in pattern
), f"Invalid pattern: {pattern}"
# Start probing
matched = []
_probe(node, [], 0)
return matched if matched else None
def find_patterns(node, patterns, **kwargs):
assert isinstance(patterns, list) and isinstance(patterns[0], list)
results = []
for pattern in patterns:
result = find_pattern(node, pattern, **kwargs)
results.append(result)
return results
def append_qdq(
graph_module: torch.fx.GraphModule,
node: torch.fx.Node,
qdq_node: torch.fx.Node,
):
q_op = torch.ops.quantized_decomposed.quantize_per_tensor.default
dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default
if qdq_node.target not in {q_op, dq_op}:
return node
with graph_module.graph.inserting_after(node):
q_args = (node, *qdq_node.args[1:])
q_node = graph_module.graph.create_node("call_function", q_op, q_args)
q_node.meta = copy_meta(node.meta)
q_node.meta["val"] = q_node.meta["val"].to(q_args[-1])
with graph_module.graph.inserting_after(q_node):
dq_args = (q_node, *qdq_node.args[1:])
dq_node = graph_module.graph.create_node("call_function", dq_op, dq_args)
dq_node.meta = copy_meta(node.meta)
return dq_node
def get_const_node(
graph: torch.fx.Graph,
graph_module: torch.fx.GraphModule,
attr_name: str,
value,
source_node: torch.fx.Node,
) -> torch.fx.Node:
"""
Register a scalar constant as a named buffer on the graph module and return a get_attr node referencing it.
Used in edge dialect op decomposition passes where raw scalar arguments are not accepted by QNN op builders which need the inputs to be graph nodes.
"""
dtype = source_node.meta["val"].dtype
tensor = torch.tensor(value, dtype=dtype)
graph_module.register_buffer(attr_name, tensor)
fake_mode = source_node.meta["val"].fake_mode
with graph.inserting_before(next(iter(graph.nodes))):
const_node = graph.get_attr(attr_name)
const_node.meta["val"] = fake_mode.from_tensor(tensor)
return const_node