forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathquantize_fused_convbn_bias_pass.py
More file actions
355 lines (302 loc) · 12.2 KB
/
quantize_fused_convbn_bias_pass.py
File metadata and controls
355 lines (302 loc) · 12.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import fx
from torch._export.utils import (
get_buffer,
get_lifted_tensor_constant,
get_param,
is_lifted_tensor_constant,
is_param,
)
from torch._guards import detect_fake_mode
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
from torch.fx.passes.infra.pass_base import PassBase, PassResult
# --- ExportedProgram param helpers ---
def _set_param_ep(exported_program, node_or_name, tensor, insert_before=None):
"""Set or create a parameter in an exported program.
If node_or_name is a Node, updates the existing parameter or constant value.
If node_or_name is a string, creates a new parameter placeholder.
"""
fake_mode = detect_fake_mode(
tuple(
node.meta["val"]
for node in exported_program.graph.nodes
if node.op == "placeholder"
)
)
if isinstance(node_or_name, fx.Node):
node = node_or_name
if node.name in exported_program.graph_signature.inputs_to_parameters:
name = exported_program.graph_signature.inputs_to_parameters[node.name]
exported_program.state_dict[name] = torch.nn.Parameter(
tensor, requires_grad=False
)
elif (
node.name
in exported_program.graph_signature.inputs_to_lifted_tensor_constants
):
name = exported_program.graph_signature.inputs_to_lifted_tensor_constants[
node.name
]
exported_program.constants[name] = tensor
else:
raise ValueError(
f"Node {node.name} is not a parameter or lifted tensor constant"
)
node.meta["val"] = fake_mode.from_tensor(tensor, static_shapes=True)
node.meta["val"].constant = tensor
return node
# Create a new parameter from string name
name = node_or_name
graph = exported_program.graph_module.graph
placeholders = [n for n in graph.nodes if n.op == "placeholder"]
input_name = f"arg_{name}"
with graph.inserting_before(placeholders[0]):
new_placeholder = graph.placeholder(input_name)
exported_program.graph_signature.input_specs.insert(
0,
InputSpec(
kind=InputKind.PARAMETER,
arg=TensorArgument(name=input_name),
target=name,
persistent=None,
),
)
exported_program.state_dict[name] = torch.nn.Parameter(tensor, requires_grad=False)
new_placeholder.meta["val"] = fake_mode.from_tensor(tensor, static_shapes=True)
new_placeholder.meta["val"].constant = tensor
return new_placeholder
def _get_bias_tensor_ep(exported_program, bias_node):
"""Extract bias tensor from parameter or lifted constant in an ExportedProgram."""
if is_param(exported_program, bias_node):
return get_param(exported_program, bias_node)
elif is_lifted_tensor_constant(exported_program, bias_node):
return get_lifted_tensor_constant(exported_program, bias_node)
return None
# --- GraphModule param helpers ---
def _get_tensor_from_node(graph_module, node):
"""Get tensor from a get_attr node on a GraphModule."""
if node is None or node.op != "get_attr":
return None
target_atoms = node.target.split(".")
attr = graph_module
for atom in target_atoms:
if not hasattr(attr, atom):
return None
attr = getattr(attr, atom)
return attr
def _set_param_gm(graph_module, node_or_name, tensor, insert_before=None):
"""Set or create a parameter on a GraphModule using get_attr nodes.
If node_or_name is a Node, updates the existing parameter tensor.
If node_or_name is a string, creates a new get_attr node.
"""
if isinstance(node_or_name, fx.Node):
node = node_or_name
target_atoms = node.target.split(".")
parent = graph_module
for atom in target_atoms[:-1]:
parent = getattr(parent, atom)
setattr(
parent,
target_atoms[-1],
torch.nn.Parameter(tensor, requires_grad=False),
)
if "val" in node.meta:
fake_mode = detect_fake_mode(
tuple(
n.meta["val"]
for n in graph_module.graph.nodes
if n.op == "placeholder" and "val" in n.meta
)
)
if fake_mode is not None:
node.meta["val"] = fake_mode.from_tensor(tensor, static_shapes=True)
else:
node.meta["val"] = tensor
return node
# Create new get_attr node
name = node_or_name
graph_module.register_parameter(
name, torch.nn.Parameter(tensor, requires_grad=False)
)
with graph_module.graph.inserting_before(insert_before):
new_node = graph_module.graph.get_attr(name)
fake_mode = detect_fake_mode(
tuple(
n.meta["val"]
for n in graph_module.graph.nodes
if n.op == "placeholder" and "val" in n.meta
)
)
if fake_mode is not None:
new_node.meta["val"] = fake_mode.from_tensor(tensor, static_shapes=True)
else:
new_node.meta["val"] = tensor
return new_node
# --- Shared core logic ---
def _quantize_fused_conv_bias(
graph_module,
conv_targets,
unsqueeze_targets,
dq_per_tensor,
dq_per_channel,
get_bias_tensor,
set_param,
get_weight_scale_tensor,
default_zero_bias=False,
):
"""Core logic for quantizing biases introduced by BatchNorm fusion/QAT.
BatchNorm fusion or QAT introduces a bias to conv layers that originally had
bias=False. Since the bias is added after the quantizer runs, it lacks proper
quantize->dequantize nodes. This function adds them.
Args:
graph_module: The graph module to transform.
conv_targets: Tuple of conv op targets to match.
unsqueeze_targets: Tuple of unsqueeze op targets to unwrap.
dq_per_tensor: The dequantize_per_tensor op for this dialect.
dq_per_channel: The dequantize_per_channel op for this dialect.
get_bias_tensor: Callable(node) -> Optional[Tensor].
set_param: Callable(node_or_name, tensor, insert_before=None) -> Node.
get_weight_scale_tensor: Callable(node) -> Tensor.
default_zero_bias: If True, create zero bias for conv nodes without bias.
Returns:
True if any modifications were made.
"""
modified = False
for node in graph_module.graph.nodes:
if node.target not in conv_targets:
continue
input_dequant = node.args[0]
weight_dequant = node.args[1]
bias_node = node.args[2] if len(node.args) > 2 else None
if bias_node is None:
if default_zero_bias:
channel = node.meta["val"].shape[1]
bias_node = set_param(
node.name + "_default_zero_bias",
torch.zeros(channel),
insert_before=node,
)
args = list(node.args)
if len(args) < 3:
args.append(bias_node)
else:
args[2] = bias_node
node.args = tuple(args)
else:
continue
bias = get_bias_tensor(bias_node)
if bias is None or bias.dtype == torch.int32:
continue
if input_dequant.target in unsqueeze_targets:
input_dequant = input_dequant.args[0]
assert (
input_dequant.target == dq_per_tensor
), f"Expected dequantize_per_tensor, got {input_dequant.target}"
bias_val = bias_node.meta.get("val")
dequant_val = (
bias_val.to(torch.float32)
if bias_val is not None
else torch.empty(bias.shape, dtype=torch.float32)
)
if isinstance(weight_dequant.args[1], torch.fx.node.Node):
weight_scale = get_weight_scale_tensor(weight_dequant.args[1])
bias_scale = input_dequant.args[1] * weight_scale
bias_zp = torch.zeros(bias_scale.shape, dtype=torch.int32)
qbias = torch.ops.quantized_decomposed.quantize_per_channel.default(
bias,
bias_scale,
bias_zp,
0,
-(2**31),
2**31 - 1,
torch.int32,
)
set_param(bias_node, qbias)
scale_node = set_param(
node.name + "_bias_scale", bias_scale, insert_before=node
)
zp_node = set_param(
node.name + "_bias_zero_point", bias_zp, insert_before=node
)
with graph_module.graph.inserting_before(node):
bias_dequant = graph_module.graph.call_function(
dq_per_channel,
(
bias_node,
scale_node,
zp_node,
0,
-(2**31),
2**31 - 1,
torch.int32,
),
)
bias_dequant.meta["val"] = dequant_val
node.replace_input_with(bias_node, bias_dequant)
else:
weight_scale = weight_dequant.args[1]
bias_scale = input_dequant.args[1] * weight_scale
qbias = torch.ops.quantized_decomposed.quantize_per_tensor.default(
bias, bias_scale, 0, -(2**31), 2**31 - 1, torch.int32
)
set_param(bias_node, qbias)
with graph_module.graph.inserting_before(node):
bias_dequant = graph_module.graph.call_function(
dq_per_tensor,
(bias_node, bias_scale, 0, -(2**31), 2**31 - 1, torch.int32),
)
bias_dequant.meta["val"] = dequant_val
node.replace_input_with(bias_node, bias_dequant)
modified = True
graph_module.recompile()
return modified
class QuantizeFusedConvBnBiasAtenPass(PassBase):
"""Quantize biases introduced by BatchNorm fusion/QAT on aten dialect graphs.
Operates on a GraphModule. If the graph_module came from an ExportedProgram
(params are placeholder nodes), pass the exported_program so params can be
resolved. If operating on a plain GraphModule (params are get_attr nodes),
exported_program can be omitted.
"""
def __init__(self, exported_program=None, default_zero_bias=False) -> None:
self.exported_program = exported_program
self.default_zero_bias = default_zero_bias
def call(self, graph_module: fx.GraphModule) -> PassResult:
ep = self.exported_program
if ep is not None:
def get_bias(node):
return _get_bias_tensor_ep(ep, node)
def set_param(n, t, insert_before=None):
return _set_param_ep(ep, n, t)
def get_scale(node):
return get_buffer(ep, node)
else:
def get_bias(node):
return _get_tensor_from_node(graph_module, node)
def set_param(n, t, insert_before=None):
return _set_param_gm(graph_module, n, t, insert_before)
def get_scale(node):
return _get_tensor_from_node(graph_module, node)
modified = _quantize_fused_conv_bias(
graph_module,
conv_targets=(
torch.ops.aten.convolution.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.conv_transpose2d.input,
),
unsqueeze_targets=(
torch.ops.aten.unsqueeze_copy.default,
torch.ops.aten.unsqueeze.default,
),
dq_per_tensor=torch.ops.quantized_decomposed.dequantize_per_tensor.default,
dq_per_channel=torch.ops.quantized_decomposed.dequantize_per_channel.default,
get_bias_tensor=get_bias,
set_param=set_param,
get_weight_scale_tensor=get_scale,
default_zero_bias=self.default_zero_bias,
)
return PassResult(graph_module, modified)