Skip to content

Commit 570d2e9

Browse files
authored
Add quantize fused convbn bias pass (#17348)
Summary: When performing QAT with a model that has a conv layer with no bias followed by batch norm, the fusion process creates a bias. This is done *after* observers are attached so the resulting bias is kept as float. This diff adds a pass which grabs the proper qparams and applies them to the non-quantized bias. Differential Revision: D92733079 cc @robert-kalmar @digantdesai
1 parent 1003453 commit 570d2e9

3 files changed

Lines changed: 566 additions & 0 deletions

File tree

Lines changed: 355 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,355 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import torch
9+
from torch import fx
10+
from torch._export.utils import (
11+
get_buffer,
12+
get_lifted_tensor_constant,
13+
get_param,
14+
is_lifted_tensor_constant,
15+
is_param,
16+
)
17+
from torch._guards import detect_fake_mode
18+
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
19+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
20+
21+
22+
# --- ExportedProgram param helpers ---
23+
24+
25+
def _set_param_ep(exported_program, node_or_name, tensor, insert_before=None):
26+
"""Set or create a parameter in an exported program.
27+
28+
If node_or_name is a Node, updates the existing parameter or constant value.
29+
If node_or_name is a string, creates a new parameter placeholder.
30+
"""
31+
fake_mode = detect_fake_mode(
32+
tuple(
33+
node.meta["val"]
34+
for node in exported_program.graph.nodes
35+
if node.op == "placeholder"
36+
)
37+
)
38+
39+
if isinstance(node_or_name, fx.Node):
40+
node = node_or_name
41+
if node.name in exported_program.graph_signature.inputs_to_parameters:
42+
name = exported_program.graph_signature.inputs_to_parameters[node.name]
43+
exported_program.state_dict[name] = torch.nn.Parameter(
44+
tensor, requires_grad=False
45+
)
46+
elif (
47+
node.name
48+
in exported_program.graph_signature.inputs_to_lifted_tensor_constants
49+
):
50+
name = exported_program.graph_signature.inputs_to_lifted_tensor_constants[
51+
node.name
52+
]
53+
exported_program.constants[name] = tensor
54+
else:
55+
raise ValueError(
56+
f"Node {node.name} is not a parameter or lifted tensor constant"
57+
)
58+
node.meta["val"] = fake_mode.from_tensor(tensor, static_shapes=True)
59+
node.meta["val"].constant = tensor
60+
return node
61+
62+
# Create a new parameter from string name
63+
name = node_or_name
64+
graph = exported_program.graph_module.graph
65+
placeholders = [n for n in graph.nodes if n.op == "placeholder"]
66+
input_name = f"arg_{name}"
67+
with graph.inserting_before(placeholders[0]):
68+
new_placeholder = graph.placeholder(input_name)
69+
exported_program.graph_signature.input_specs.insert(
70+
0,
71+
InputSpec(
72+
kind=InputKind.PARAMETER,
73+
arg=TensorArgument(name=input_name),
74+
target=name,
75+
persistent=None,
76+
),
77+
)
78+
exported_program.state_dict[name] = torch.nn.Parameter(tensor, requires_grad=False)
79+
new_placeholder.meta["val"] = fake_mode.from_tensor(tensor, static_shapes=True)
80+
new_placeholder.meta["val"].constant = tensor
81+
return new_placeholder
82+
83+
84+
def _get_bias_tensor_ep(exported_program, bias_node):
85+
"""Extract bias tensor from parameter or lifted constant in an ExportedProgram."""
86+
if is_param(exported_program, bias_node):
87+
return get_param(exported_program, bias_node)
88+
elif is_lifted_tensor_constant(exported_program, bias_node):
89+
return get_lifted_tensor_constant(exported_program, bias_node)
90+
return None
91+
92+
93+
# --- GraphModule param helpers ---
94+
95+
96+
def _get_tensor_from_node(graph_module, node):
97+
"""Get tensor from a get_attr node on a GraphModule."""
98+
if node is None or node.op != "get_attr":
99+
return None
100+
target_atoms = node.target.split(".")
101+
attr = graph_module
102+
for atom in target_atoms:
103+
if not hasattr(attr, atom):
104+
return None
105+
attr = getattr(attr, atom)
106+
return attr
107+
108+
109+
def _set_param_gm(graph_module, node_or_name, tensor, insert_before=None):
110+
"""Set or create a parameter on a GraphModule using get_attr nodes.
111+
112+
If node_or_name is a Node, updates the existing parameter tensor.
113+
If node_or_name is a string, creates a new get_attr node.
114+
"""
115+
if isinstance(node_or_name, fx.Node):
116+
node = node_or_name
117+
target_atoms = node.target.split(".")
118+
parent = graph_module
119+
for atom in target_atoms[:-1]:
120+
parent = getattr(parent, atom)
121+
setattr(
122+
parent,
123+
target_atoms[-1],
124+
torch.nn.Parameter(tensor, requires_grad=False),
125+
)
126+
if "val" in node.meta:
127+
fake_mode = detect_fake_mode(
128+
tuple(
129+
n.meta["val"]
130+
for n in graph_module.graph.nodes
131+
if n.op == "placeholder" and "val" in n.meta
132+
)
133+
)
134+
if fake_mode is not None:
135+
node.meta["val"] = fake_mode.from_tensor(tensor, static_shapes=True)
136+
else:
137+
node.meta["val"] = tensor
138+
return node
139+
140+
# Create new get_attr node
141+
name = node_or_name
142+
graph_module.register_parameter(
143+
name, torch.nn.Parameter(tensor, requires_grad=False)
144+
)
145+
with graph_module.graph.inserting_before(insert_before):
146+
new_node = graph_module.graph.get_attr(name)
147+
fake_mode = detect_fake_mode(
148+
tuple(
149+
n.meta["val"]
150+
for n in graph_module.graph.nodes
151+
if n.op == "placeholder" and "val" in n.meta
152+
)
153+
)
154+
if fake_mode is not None:
155+
new_node.meta["val"] = fake_mode.from_tensor(tensor, static_shapes=True)
156+
else:
157+
new_node.meta["val"] = tensor
158+
return new_node
159+
160+
161+
# --- Shared core logic ---
162+
163+
164+
def _quantize_fused_conv_bias(
165+
graph_module,
166+
conv_targets,
167+
unsqueeze_targets,
168+
dq_per_tensor,
169+
dq_per_channel,
170+
get_bias_tensor,
171+
set_param,
172+
get_weight_scale_tensor,
173+
default_zero_bias=False,
174+
):
175+
"""Core logic for quantizing biases introduced by BatchNorm fusion/QAT.
176+
177+
BatchNorm fusion or QAT introduces a bias to conv layers that originally had
178+
bias=False. Since the bias is added after the quantizer runs, it lacks proper
179+
quantize->dequantize nodes. This function adds them.
180+
181+
Args:
182+
graph_module: The graph module to transform.
183+
conv_targets: Tuple of conv op targets to match.
184+
unsqueeze_targets: Tuple of unsqueeze op targets to unwrap.
185+
dq_per_tensor: The dequantize_per_tensor op for this dialect.
186+
dq_per_channel: The dequantize_per_channel op for this dialect.
187+
get_bias_tensor: Callable(node) -> Optional[Tensor].
188+
set_param: Callable(node_or_name, tensor, insert_before=None) -> Node.
189+
get_weight_scale_tensor: Callable(node) -> Tensor.
190+
default_zero_bias: If True, create zero bias for conv nodes without bias.
191+
192+
Returns:
193+
True if any modifications were made.
194+
"""
195+
modified = False
196+
for node in graph_module.graph.nodes:
197+
if node.target not in conv_targets:
198+
continue
199+
200+
input_dequant = node.args[0]
201+
weight_dequant = node.args[1]
202+
bias_node = node.args[2] if len(node.args) > 2 else None
203+
204+
if bias_node is None:
205+
if default_zero_bias:
206+
channel = node.meta["val"].shape[1]
207+
bias_node = set_param(
208+
node.name + "_default_zero_bias",
209+
torch.zeros(channel),
210+
insert_before=node,
211+
)
212+
args = list(node.args)
213+
if len(args) < 3:
214+
args.append(bias_node)
215+
else:
216+
args[2] = bias_node
217+
node.args = tuple(args)
218+
else:
219+
continue
220+
221+
bias = get_bias_tensor(bias_node)
222+
if bias is None or bias.dtype == torch.int32:
223+
continue
224+
225+
if input_dequant.target in unsqueeze_targets:
226+
input_dequant = input_dequant.args[0]
227+
228+
assert (
229+
input_dequant.target == dq_per_tensor
230+
), f"Expected dequantize_per_tensor, got {input_dequant.target}"
231+
232+
bias_val = bias_node.meta.get("val")
233+
dequant_val = (
234+
bias_val.to(torch.float32)
235+
if bias_val is not None
236+
else torch.empty(bias.shape, dtype=torch.float32)
237+
)
238+
239+
if isinstance(weight_dequant.args[1], torch.fx.node.Node):
240+
weight_scale = get_weight_scale_tensor(weight_dequant.args[1])
241+
bias_scale = input_dequant.args[1] * weight_scale
242+
243+
bias_zp = torch.zeros(bias_scale.shape, dtype=torch.int32)
244+
qbias = torch.ops.quantized_decomposed.quantize_per_channel.default(
245+
bias,
246+
bias_scale,
247+
bias_zp,
248+
0,
249+
-(2**31),
250+
2**31 - 1,
251+
torch.int32,
252+
)
253+
set_param(bias_node, qbias)
254+
255+
scale_node = set_param(
256+
node.name + "_bias_scale", bias_scale, insert_before=node
257+
)
258+
zp_node = set_param(
259+
node.name + "_bias_zero_point", bias_zp, insert_before=node
260+
)
261+
262+
with graph_module.graph.inserting_before(node):
263+
bias_dequant = graph_module.graph.call_function(
264+
dq_per_channel,
265+
(
266+
bias_node,
267+
scale_node,
268+
zp_node,
269+
0,
270+
-(2**31),
271+
2**31 - 1,
272+
torch.int32,
273+
),
274+
)
275+
bias_dequant.meta["val"] = dequant_val
276+
node.replace_input_with(bias_node, bias_dequant)
277+
else:
278+
weight_scale = weight_dequant.args[1]
279+
bias_scale = input_dequant.args[1] * weight_scale
280+
281+
qbias = torch.ops.quantized_decomposed.quantize_per_tensor.default(
282+
bias, bias_scale, 0, -(2**31), 2**31 - 1, torch.int32
283+
)
284+
set_param(bias_node, qbias)
285+
286+
with graph_module.graph.inserting_before(node):
287+
bias_dequant = graph_module.graph.call_function(
288+
dq_per_tensor,
289+
(bias_node, bias_scale, 0, -(2**31), 2**31 - 1, torch.int32),
290+
)
291+
bias_dequant.meta["val"] = dequant_val
292+
node.replace_input_with(bias_node, bias_dequant)
293+
294+
modified = True
295+
296+
graph_module.recompile()
297+
return modified
298+
299+
300+
class QuantizeFusedConvBnBiasAtenPass(PassBase):
301+
"""Quantize biases introduced by BatchNorm fusion/QAT on aten dialect graphs.
302+
303+
Operates on a GraphModule. If the graph_module came from an ExportedProgram
304+
(params are placeholder nodes), pass the exported_program so params can be
305+
resolved. If operating on a plain GraphModule (params are get_attr nodes),
306+
exported_program can be omitted.
307+
"""
308+
309+
def __init__(self, exported_program=None, default_zero_bias=False) -> None:
310+
self.exported_program = exported_program
311+
self.default_zero_bias = default_zero_bias
312+
313+
def call(self, graph_module: fx.GraphModule) -> PassResult:
314+
ep = self.exported_program
315+
if ep is not None:
316+
317+
def get_bias(node):
318+
return _get_bias_tensor_ep(ep, node)
319+
320+
def set_param(n, t, insert_before=None):
321+
return _set_param_ep(ep, n, t)
322+
323+
def get_scale(node):
324+
return get_buffer(ep, node)
325+
326+
else:
327+
328+
def get_bias(node):
329+
return _get_tensor_from_node(graph_module, node)
330+
331+
def set_param(n, t, insert_before=None):
332+
return _set_param_gm(graph_module, n, t, insert_before)
333+
334+
def get_scale(node):
335+
return _get_tensor_from_node(graph_module, node)
336+
337+
modified = _quantize_fused_conv_bias(
338+
graph_module,
339+
conv_targets=(
340+
torch.ops.aten.convolution.default,
341+
torch.ops.aten.conv2d.default,
342+
torch.ops.aten.conv_transpose2d.input,
343+
),
344+
unsqueeze_targets=(
345+
torch.ops.aten.unsqueeze_copy.default,
346+
torch.ops.aten.unsqueeze.default,
347+
),
348+
dq_per_tensor=torch.ops.quantized_decomposed.dequantize_per_tensor.default,
349+
dq_per_channel=torch.ops.quantized_decomposed.dequantize_per_channel.default,
350+
get_bias_tensor=get_bias,
351+
set_param=set_param,
352+
get_weight_scale_tensor=get_scale,
353+
default_zero_bias=self.default_zero_bias,
354+
)
355+
return PassResult(graph_module, modified)

backends/transforms/targets.bzl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,31 @@ def define_common_targets():
204204
],
205205
)
206206

207+
runtime.python_library(
208+
name = "quantize_fused_convbn_bias_pass",
209+
srcs = ["quantize_fused_convbn_bias_pass.py"],
210+
visibility = ["PUBLIC"],
211+
deps = [
212+
"//caffe2:torch",
213+
],
214+
)
215+
216+
runtime.python_test(
217+
name = "test_quantize_fused_convbn_bias_pass",
218+
srcs = [
219+
"test/test_quantize_fused_convbn_bias_pass.py",
220+
],
221+
deps = [
222+
"//caffe2:torch",
223+
":quantize_fused_convbn_bias_pass",
224+
"//executorch/backends/arm/quantizer:lib",
225+
"//executorch/backends/arm/test:common",
226+
"//executorch/backends/arm/tosa:tosa",
227+
"//executorch/kernels/quantized:custom_ops_generated_lib",
228+
"fbsource//third-party/pypi/pytest:pytest",
229+
],
230+
)
231+
207232
runtime.python_test(
208233
name = "test_duplicate_dynamic_quant_chain",
209234
srcs = [

0 commit comments

Comments
 (0)