Skip to content

Commit 472a5cd

Browse files
ethansfngfacebook-github-bot
authored andcommitted
Add FuseQATConvBN to fuse_ops (#19442)
Summary: Adds a FuseQATConvBN which folds the QAT Conv-BN simulation chain (`conv → q → dq → div(scale) → add(orig_bias) → batch_norm`) inserted by `prepare_qat_pt2e` into the conv's quantized bias and removes the chain. The pass runs in two steps inside a single `call()`: 1. Bias prep — for each conv, create a zero-filled quantized bias if missing, or quantize a float bias as per-tensor int32. Required so step 2 has a quantized bias slot to write the BN correction into. 2. Fold — for each matched chain, compute the BN correction C = (orig_bias - running_mean) * bn_weight / sqrt(running_var + eps) + bn_bias and absorb it into the conv's quantized bias in place. Erase the chain + batch_norm. Differential Revision: D104497938
1 parent a49171d commit 472a5cd

4 files changed

Lines changed: 552 additions & 4 deletions

File tree

backends/cadence/aot/BUCK

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ fbcode_target(_kind = runtime.python_library,
4545
":utils",
4646
"//caffe2:torch",
4747
"//executorch/backends/cadence/aot/quantizer:fusion_pass",
48+
"//executorch/backends/cadence/aot/quantizer/passes:fuse_ops",
4849
"//executorch/backends/cadence/aot/quantizer:quantizer",
4950
"//executorch/backends/transforms:decompose_sdpa",
5051
"//executorch/backends/transforms:remove_clone_ops",

backends/cadence/aot/compiler.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
print_memory_planning_info,
2323
)
2424
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
25+
from executorch.backends.cadence.aot.quantizer.passes.fuse_ops import FuseQATConvBN
2526
from executorch.backends.cadence.aot.quantizer.quantizer import (
2627
CadenceDefaultQuantizer,
2728
CadenceQuantizer,
@@ -37,9 +38,10 @@
3738
ExecutorchBackendConfig,
3839
ExecutorchProgramManager,
3940
)
41+
from executorch.exir.pass_manager import PassManager
4042
from executorch.exir.passes import ToOutVarPass
4143
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
42-
from executorch.exir.program._program import _transform, to_edge
44+
from executorch.exir.program._program import to_edge
4345
from torch.export.exported_program import ExportedProgram
4446
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e
4547

@@ -162,13 +164,17 @@ def apply_pre_edge_transform_passes(
162164
which will instantiate a default quantizer for you if needed.
163165
Returns an ExportedProgram with the fused model.
164166
"""
165-
# Get patterns and apply fusion of dq -> op -> q to qop
166167
# pyre-ignore[16]: no attribute
167168
patterns = [q.pattern for q in quantizer.quantizers]
168-
fused_program = _transform(converted_program, QuantFusion(patterns))
169+
PassManager(
170+
[
171+
FuseQATConvBN(converted_program),
172+
QuantFusion(patterns),
173+
]
174+
)(converted_program.graph_module)
169175

170176
# Apply torch ops passes (e.g., ReplaceMulTensorWithMulAndFullOpsPass)
171-
fused_program = apply_torch_ops_passes(fused_program)
177+
fused_program = apply_torch_ops_passes(converted_program)
172178

173179
return fused_program
174180

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target")
2+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
3+
4+
oncall("odai_jarvis")
5+
6+
fbcode_target(_kind = runtime.python_library,
7+
name = "fuse_ops",
8+
srcs = [
9+
"fuse_ops.py",
10+
],
11+
typing = True,
12+
deps = [
13+
"//caffe2:torch",
14+
"//executorch/backends/transforms:quantize_fused_convbn_bias_pass",
15+
],
16+
)

0 commit comments

Comments
 (0)