-
Notifications
You must be signed in to change notification settings - Fork 993
Expand file tree
/
Copy pathutils.py
More file actions
224 lines (184 loc) · 7.26 KB
/
utils.py
File metadata and controls
224 lines (184 loc) · 7.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024-2026 NXP
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
import itertools
from collections import OrderedDict
from collections.abc import Iterable
from typing import Any, Callable, Dict, List, Tuple, Type
import torch
from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_linear_pass import (
FuseBatchNormWithLinearPass,
)
from executorch.backends.nxp.aten_passes.simulated_linear_bn_fusion_passes import (
AddSimulatedLinearBatchNormFusionQATPass,
RemoveSimulatedLinearBatchNormFusionQATPass,
)
from executorch.backends.transforms.quantize_fused_convbn_bias_pass import (
QuantizeFusedConvBnBiasAtenPass,
)
from torch import fx
from torch._ops import OpOverload
from torch.export import ExportedProgram
from torch.fx.passes.utils.source_matcher_utils import (
check_subgraphs_connected,
SourcePartition,
)
from torchao.quantization.pt2e import (
move_exported_model_to_eval,
move_exported_model_to_train,
ObserverOrFakeQuantize,
)
from torchao.quantization.pt2e.quantize_pt2e import (
convert_pt2e,
prepare_pt2e,
prepare_qat_pt2e,
)
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY, Quantizer
def is_annotated(nodes: List[fx.Node]) -> bool:
annotated = False
for node in nodes:
annotated = annotated or (
Q_ANNOTATION_KEY in node.meta and node.meta[Q_ANNOTATION_KEY]._annotated
)
return annotated
def no_outside_users(fused_partition) -> bool:
"""
Checks if each partition other than the last does not have any outside users.
"""
for source_partition in fused_partition[:-1]:
if len(source_partition.output_nodes) != 1:
return False
if len(source_partition.output_nodes[0].users) != 1:
return False
return True
def get_bias_qparams(
obs_or_fqs: List[ObserverOrFakeQuantize],
) -> Tuple[torch.Tensor, torch.Tensor]:
act_scale, _ = obs_or_fqs[0].calculate_qparams()
weight_scale, _ = obs_or_fqs[1].calculate_qparams()
bias_scale = act_scale * weight_scale
bias_zero_point = torch.zeros_like(bias_scale, dtype=torch.int64)
return bias_scale, bias_zero_point
def get_aten_node_target_partitions(
graph: torch.fx.Graph,
wanted_original_aten_op: List[OpOverload],
):
"""
Args:
graph: The graph we want to partition
wanted_original_aten_op: List of original_aten ops (OpOverload)
Returns:
Dictionary mapping aten ops that were given to a list of SourcePartitions
that correspond to the list of nodes that were decomposed from the given
aten ops.
"""
modules: Dict[Type, Dict[str, List[torch.fx.Node]]] = {}
for node in graph.nodes:
# The metadata source_fn should contain a tuple of a unique name for the
# source, and the source function if the node is decomposed from a
# function, or the type of module if the node is decomposed from a leaf
# module
# TODO(matthiascremon): look into ways to avoid using source_fn_stack
if (source_fn_st := node.meta.get("source_fn_stack")) is None:
continue
source_fn = source_fn_st[-1]
if node.target not in wanted_original_aten_op:
continue
diff_modules = modules.setdefault(source_fn[1], {})
partition = diff_modules.setdefault(node.name, [])
partition.append(node)
def make_partition(
nodes: List[torch.fx.Node], module_type: Type
) -> SourcePartition:
input_nodes = set()
output_nodes = set()
params = set()
for node in nodes:
for arg in node.args:
if isinstance(arg, torch.fx.Node) and arg not in nodes:
input_nodes.add(arg)
if node.op == "get_attr":
params.add(node)
for user in node.users.keys():
if user not in nodes:
output_nodes.add(node)
return SourcePartition(
nodes,
module_type,
list(input_nodes),
list(output_nodes),
list(params), # type: ignore[arg-type]
)
ret: Dict[Type[Any], List[SourcePartition]] = {}
for k, v in modules.items():
ret[k] = [make_partition(partition, k) for partition in v.values()]
return ret
def _partitions_sequential(partitions: Tuple[SourcePartition]) -> bool:
prev_partition = None
for partition in partitions:
if prev_partition is not None and not check_subgraphs_connected(
prev_partition, partition
):
return False
prev_partition = partition
return True
def find_sequential_partitions_aten(
gm: torch.fx.GraphModule,
partition_types: List[Any],
):
typed_partitions: OrderedDict[Any, List[SourcePartition]] = OrderedDict()
for partition_type in partition_types:
partitions = get_aten_node_target_partitions(gm.graph, [partition_type])
typed_partitions[partition_type] = list(
itertools.chain.from_iterable(partitions.values())
)
typed_partitions_list = list(typed_partitions.values())
fusion_candidates = itertools.product(*typed_partitions_list)
fused_partitions = []
for candidate in fusion_candidates:
if _partitions_sequential(candidate):
fused_partitions.append(candidate)
return fused_partitions
def calibrate_and_quantize(
model: ExportedProgram | fx.GraphModule,
calibration_inputs: Iterable[tuple[torch.Tensor, ...]],
quantizer: Quantizer,
is_qat: bool = False,
train_fn: Callable[[torch.fx.GraphModule], None] | None = None,
) -> fx.GraphModule:
"""Quantize the provided model.
:param model: Aten model (or it's GraphModule representation) to quantize.
:param calibration_inputs: An iterator over tuples of calibration input tensors where each tensor corresponds to a
model input.
:param quantizer: Quantizer to use.
:param is_qat: Whether quantization is done using Quantization Aware Training (QAT) or not.
Note: In QAT mode, training is not performed. Only calibration (in eval mode) is done.
:param train_fn: Optional training function to be called during QAT.
:return: Quantized GraphModule.
"""
if isinstance(model, ExportedProgram):
model = model.module()
if is_qat:
m = prepare_qat_pt2e(model, quantizer)
m = AddSimulatedLinearBatchNormFusionQATPass()(m).graph_module
if train_fn:
m = move_exported_model_to_train(m)
train_fn(m)
m = move_exported_model_to_eval(m)
m = RemoveSimulatedLinearBatchNormFusionQATPass()(m).graph_module
m = FuseBatchNormWithLinearPass()(m).graph_module
else:
m = prepare_pt2e(model, quantizer)
if not is_qat or (is_qat and not train_fn):
for data in calibration_inputs:
m(*data)
if is_qat:
m = RemoveSimulatedLinearBatchNormFusionQATPass()(m).graph_module
m = FuseBatchNormWithLinearPass()(m).graph_module
m = convert_pt2e(m)
m = QuantizeFusedConvBnBiasAtenPass(default_zero_bias=True)(m).graph_module
return m