Skip to content

Commit 9e5e0e7

Browse files
Reza Sajadianyfacebook-github-bot
authored andcommitted
bypass add and addRelu that have inputs from non-dq nodes (pytorch#19315)
Summary: Add FoldQATConvBNPass to the Cadence AOT compiler pipeline to handle QAT Conv-BN simulated fusion patterns that survive into the exported graph. When a model is exported after QAT training, the Conv-BN simulation chain (add(var+eps) -> sqrt -> div(bn_weight) -> div(conv_out/scale) -> add(orig_bias) -> batch_norm) may not be folded by TorchAO `_fold_conv_bn_qat` due to pattern mismatch. This leaves non-quantized add/div/sqrt nodes in the graph that cause QuantFusion to crash when it tries to fuse them as quantized add ops. The fix has three parts: 1. Add `conv1d.default` to `QuantizeFusedConvBnBiasAtenPass` conv_targets so it matches conv1d ops and can create zero biases for convs without one (mirrors the existing conv2d support). 2. Add `FoldQATConvBNPass` which matches the QAT simulation chain, computes the BN correction constant C = (orig_bias - running_mean) * bn_weight / sqrt(running_var + eps) + bn_bias, folds C into the conv quantized bias tensor, and removes the simulation chain + batch_norm nodes. No new graph nodes are created. 3. Apply these passes in the correct order in both the `get_fake_quant_model` (pre-export, on GraphModule) and `apply_pre_edge_transform_passes` (post-export, on ExportedProgram) pipelines: first `QuantizeFusedConvBnBiasAtenPass` to create zero biases for convs that lack one, then `FoldQATConvBNPass` to fold the simulation chain into those biases. Differential Revision: D103949573
1 parent 1643611 commit 9e5e0e7

5 files changed

Lines changed: 569 additions & 0 deletions

File tree

backends/cadence/aot/BUCK

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,23 @@ fbcode_target(_kind = runtime.python_library,
3131
],
3232
)
3333

34+
fbcode_target(_kind = runtime.python_library,
35+
name = "fold_qat_conv_bn",
36+
srcs = [
37+
"fold_qat_conv_bn.py",
38+
],
39+
deps = [
40+
"//caffe2:torch",
41+
],
42+
)
43+
3444
fbcode_target(_kind = runtime.python_library,
3545
name = "compiler",
3646
srcs = [
3747
"compiler.py",
3848
],
3949
deps = [
50+
":fold_qat_conv_bn",
4051
":memory_planning",
4152
":ops_registrations",
4253
":passes",
@@ -46,6 +57,7 @@ fbcode_target(_kind = runtime.python_library,
4657
"//caffe2:torch",
4758
"//executorch/backends/cadence/aot/quantizer:fusion_pass",
4859
"//executorch/backends/cadence/aot/quantizer:quantizer",
60+
"//executorch/backends/transforms:quantize_fused_convbn_bias_pass",
4961
"//executorch/backends/transforms:decompose_sdpa",
5062
"//executorch/backends/transforms:remove_clone_ops",
5163
"//executorch/devtools:lib",
@@ -512,6 +524,23 @@ fbcode_target(_kind = python_unittest,
512524
],
513525
)
514526

527+
fbcode_target(_kind = python_unittest,
528+
name = "test_fold_qat_conv_bn",
529+
srcs = [
530+
"tests/test_fold_qat_conv_bn.py",
531+
],
532+
supports_static_listing = False,
533+
typing = True,
534+
deps = [
535+
":compiler",
536+
":fold_qat_conv_bn",
537+
"//caffe2:torch",
538+
"//executorch/backends/cadence/aot:ops_registrations",
539+
"//executorch/backends/cadence/aot/quantizer:quantizer",
540+
"//executorch/backends/transforms:quantize_fused_convbn_bias_pass",
541+
],
542+
)
543+
515544
fbcode_target(_kind = python_unittest,
516545
name = "test_remove_ops_passes",
517546
srcs = [

backends/cadence/aot/compiler.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121
CadenceMemoryPlanning,
2222
print_memory_planning_info,
2323
)
24+
from executorch.backends.cadence.aot.fold_qat_conv_bn import FoldQATConvBNPass
2425
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
26+
from executorch.backends.transforms.quantize_fused_convbn_bias_pass import (
27+
QuantizeFusedConvBnBiasAtenPass,
28+
)
2529
from executorch.backends.cadence.aot.quantizer.quantizer import (
2630
CadenceDefaultQuantizer,
2731
CadenceQuantizer,
@@ -162,6 +166,17 @@ def apply_pre_edge_transform_passes(
162166
which will instantiate a default quantizer for you if needed.
163167
Returns an ExportedProgram with the fused model.
164168
"""
169+
# Create zero biases for convs without one, quantize any float biases if exists
170+
converted_program = _transform(
171+
converted_program, QuantizeFusedConvBnBiasAtenPass(
172+
exported_program=converted_program, default_zero_bias=True
173+
)
174+
)
175+
176+
# Fold QAT Conv-BN simulated fusion patterns
177+
# Removes (div(scale) → add(bias) → batch_norm chain and absorbs the correction into the conv bias
178+
FoldQATConvBNPass(converted_program)(converted_program.graph_module)
179+
165180
# Get patterns and apply fusion of dq -> op -> q to qop
166181
# pyre-ignore[16]: no attribute
167182
patterns = [q.pattern for q in quantizer.quantizers]
@@ -205,6 +220,13 @@ def get_fake_quant_model(
205220

206221
# Get converted graph module
207222
converted_gm = convert_pt2(prepared_gm, dump_graphs=dump_graphs)
223+
224+
# Create zero biases for convs without one, quantize any float biases
225+
QuantizeFusedConvBnBiasAtenPass(default_zero_bias=True)(converted_gm)
226+
227+
# Fold QAT Conv-BN simulated fusion patterns (now all convs have a bias to fold into)
228+
FoldQATConvBNPass()(converted_gm)
229+
208230
return converted_gm
209231

210232

0 commit comments

Comments
 (0)