forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconst_prop_pass.py
More file actions
84 lines (75 loc) · 2.98 KB
/
const_prop_pass.py
File metadata and controls
84 lines (75 loc) · 2.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import torch
from executorch.exir.dialects._ops import ops
from executorch.exir.pass_base import ExportPass, ProxyValue
from torch._subclasses.fake_tensor import FakeTensor
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
from torch.utils import _pytree as pytree
__all__ = [
"quantized_decomposed_lib",
]
class ConstPropPass(ExportPass):
"""
Performs constant folding and constant propagation.
"""
def __init__(self, propogate_quant: bool = False) -> None:
super().__init__()
self.propogate_quant = propogate_quant
# pyre-ignore
def call_operator(self, op, args, kwargs, meta):
# pyre-ignore
def is_const(arg) -> bool:
if isinstance(arg, FakeTensor):
return False
if isinstance(
arg,
(
float,
int,
bool,
str,
torch.Tensor,
torch.device,
torch.dtype,
torch.layout,
),
):
return True
if isinstance(arg, (tuple, list)):
return all(map(is_const, arg))
if isinstance(arg, dict):
return all(map(is_const, arg.values()))
return False
dequant_quant_ops = {
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_channel.default,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
ops.edge.quantized_decomposed.quantize_per_tensor.default,
ops.edge.quantized_decomposed.dequantize_per_tensor.default,
ops.edge.quantized_decomposed.quantize_per_channel.default,
ops.edge.quantized_decomposed.dequantize_per_channel.default,
}
op_is_q_dq = op in dequant_quant_ops
# XNOR relationship, if propogate_quant is true only const prop quant ops,
# if false propogate everything but quant ops
if (
(not op_is_q_dq and not self.propogate_quant)
or (op_is_q_dq and self.propogate_quant)
) and is_const([args, kwargs]):
guard = torch._C._DisableTorchDispatch() # noqa
try:
args_data, kwargs_data = pytree.tree_map_only(
ProxyValue, lambda x: x.data, (args, kwargs)
)
result = op(*args_data, **kwargs_data)
finally:
del guard
return result.to_tensor() if isinstance(result, ProxyValue) else result
else:
return super().call_operator(op, args, kwargs, meta)