Commit 3c1ea80
Add FoldQATConvBNPass to fold QAT Conv-BN simulation chains into conv bias (#19315)
Summary:
Add FoldQATConvBNPass to the Cadence AOT compiler pipeline to handle QAT Conv-BN simulated fusion patterns that survive into the exported graph.
When a model is exported after QAT training, the Conv-BN simulation chain (add(var+eps) -> sqrt -> div(bn_weight) -> div(conv_out/scale) -> add(orig_bias) -> batch_norm) may not be folded by TorchAO `_fold_conv_bn_qat` due to pattern mismatch. This leaves non-quantized add/div/sqrt nodes in the graph that cause QuantFusion to crash when it tries to fuse them as quantized add ops.
The fix has three parts:
1. Add `conv1d.default` to `QuantizeFusedConvBnBiasAtenPass` conv_targets so it matches conv1d ops and can create zero biases for convs without one (mirrors the existing conv2d support).
2. Add `FoldQATConvBNPass` which matches the QAT simulation chain, computes the BN correction constant C = (orig_bias - running_mean) * bn_weight / sqrt(running_var + eps) + bn_bias, folds C into the conv quantized bias tensor, and removes the simulation chain + batch_norm nodes. No new graph nodes are created.
3. Apply these passes in the correct order in both the `get_fake_quant_model` (pre-export, on GraphModule) and `apply_pre_edge_transform_passes` (post-export, on ExportedProgram) pipelines: first `QuantizeFusedConvBnBiasAtenPass` to create zero biases for convs that lack one, then `FoldQATConvBNPass` to fold the simulation chain into those biases.
Differential Revision: D1039495731 parent bf8abb6 commit 3c1ea80
5 files changed
Lines changed: 580 additions & 1 deletion
File tree
- backends
- cadence/aot
- tests
- transforms
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
31 | 31 | | |
32 | 32 | | |
33 | 33 | | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
34 | 44 | | |
35 | 45 | | |
36 | 46 | | |
37 | 47 | | |
38 | 48 | | |
39 | 49 | | |
| 50 | + | |
40 | 51 | | |
41 | 52 | | |
42 | 53 | | |
| |||
46 | 57 | | |
47 | 58 | | |
48 | 59 | | |
| 60 | + | |
49 | 61 | | |
50 | 62 | | |
51 | 63 | | |
| |||
512 | 524 | | |
513 | 525 | | |
514 | 526 | | |
| 527 | + | |
| 528 | + | |
| 529 | + | |
| 530 | + | |
| 531 | + | |
| 532 | + | |
| 533 | + | |
| 534 | + | |
| 535 | + | |
| 536 | + | |
| 537 | + | |
| 538 | + | |
| 539 | + | |
| 540 | + | |
| 541 | + | |
| 542 | + | |
| 543 | + | |
515 | 544 | | |
516 | 545 | | |
517 | 546 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
17 | 17 | | |
18 | 18 | | |
19 | 19 | | |
| 20 | + | |
20 | 21 | | |
21 | 22 | | |
22 | 23 | | |
| |||
30 | 31 | | |
31 | 32 | | |
32 | 33 | | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
33 | 37 | | |
34 | 38 | | |
35 | 39 | | |
| |||
162 | 166 | | |
163 | 167 | | |
164 | 168 | | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
165 | 181 | | |
166 | 182 | | |
167 | 183 | | |
| |||
205 | 221 | | |
206 | 222 | | |
207 | 223 | | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
208 | 231 | | |
209 | 232 | | |
210 | 233 | | |
| |||
0 commit comments