Skip to content

Commit 3c17a1d

Browse files
ethansfngfacebook-github-bot
authored andcommitted
Gate FuseQATConvBN behind is_qat=True; opt in from QAT deployments
Summary: The FuseQATConvBN pass added in D104497938 ran unconditionally inside `apply_pre_edge_transform_passes`. Its `_prep_conv_biases` step delegates to the shared `_quantize_fused_conv_bias` helper, which iterates every conv in the graph and asserts each conv input is `dequantize_per_tensor` — an invariant that only holds inside the conv-BN simulation chain `prepare_qat_pt2e` inserts. PTQ graphs trip the assert (T271158088). Two failure modes seen in the wild: - `test_quantized_w8a32_conv1d_out_2` uses `CadenceW8A32MixedQuantizer` so activations stay float32; the conv input is the placeholder, not a dequant. - `test_conv2d_out_7` is `channel_last=True`, so the conv input is `aten.permute`, not a dequant; the helper only unwraps `unsqueeze` variants. Add an `is_qat: bool = False` parameter to `apply_pre_edge_transform_passes` and only include `FuseQATConvBN` when True. Plumb through `quantize_pt2`/`get_fake_quant_model` and forward from the modai recipe lambda so `ar_*_qat_et_recipe` factories actually opt in. QAT-trained models lowered via blobgen need a way to reach the QAT recipe. Add `is_qat: bool` to `Packaging` and have `Rt700Hifi4Deployment` pass `train=self.packaging.is_qat` to `get_recipe_with_custom_settings`. Models like `activity_classification_artemis` should set `"is_qat": True` in their `defs.bzl` packaging block. Differential Revision: D105061752
1 parent 7cd209d commit 3c17a1d

1 file changed

Lines changed: 16 additions & 9 deletions

File tree

backends/cadence/aot/compiler.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def convert_pt2(
152152
def apply_pre_edge_transform_passes(
153153
converted_program: ExportedProgram,
154154
quantizer: CadenceQuantizer,
155+
is_qat: bool = False,
155156
) -> ExportedProgram:
156157
"""
157158
Apply pre-edge transform passes including QuantFusion and torch ops passes.
@@ -166,12 +167,11 @@ def apply_pre_edge_transform_passes(
166167
"""
167168
# pyre-ignore[16]: no attribute
168169
patterns = [q.pattern for q in quantizer.quantizers]
169-
PassManager(
170-
[
171-
FuseQATConvBN(converted_program),
172-
QuantFusion(patterns),
173-
]
174-
)(converted_program.graph_module)
170+
passes = []
171+
if is_qat:
172+
passes.append(FuseQATConvBN(converted_program))
173+
passes.append(QuantFusion(patterns))
174+
PassManager(passes)(converted_program.graph_module)
175175

176176
# Apply torch ops passes (e.g., ReplaceMulTensorWithMulAndFullOpsPass)
177177
fused_program = apply_torch_ops_passes(converted_program)
@@ -187,19 +187,24 @@ def get_fake_quant_model(
187187
quantizer: CadenceQuantizer,
188188
calibration_data: Optional[list[tuple[object, ...]]] = None,
189189
dump_graphs: bool = False,
190+
is_qat: bool = False,
190191
) -> torch.fx.GraphModule:
191192
# Make the model inference mode by calling model.eval()
192193
model.eval()
193194

194195
ops_to_keep = quantizer.get_ops_to_preserve_from_decomposition()
195-
program = trace(model, inputs, dump_graphs=dump_graphs, ops_to_keep=ops_to_keep)
196+
program = trace(
197+
model, inputs, dump_graphs=dump_graphs, ops_to_keep=ops_to_keep, is_qat=is_qat
198+
)
196199

197200
if dump_graphs:
198201
logging.info("Graph after trace:")
199202
logging.info(program.graph.print_tabular())
200203

201204
# Get prepared graph module
202-
prepared_gm = prepare_traced_pt2(program, quantizer, dump_graphs=dump_graphs)
205+
prepared_gm = prepare_traced_pt2(
206+
program, quantizer, dump_graphs=dump_graphs, is_qat=is_qat
207+
)
203208

204209
# Calibrate
205210
# If no calibration data is provided, use the inputs
@@ -221,6 +226,7 @@ def quantize_pt2(
221226
calibration_data: Optional[list[tuple[object, ...]]] = None,
222227
dump_graphs: bool = False,
223228
quant_input_args: Optional[list[str]] = None,
229+
is_qat: bool = False,
224230
) -> ExportedProgram:
225231
"""
226232
Trace, prepare, convert and fuse the model using the given quantizer.
@@ -242,6 +248,7 @@ def quantize_pt2(
242248
quantizer=quantizer,
243249
calibration_data=calibration_data,
244250
dump_graphs=dump_graphs,
251+
is_qat=is_qat,
245252
)
246253
# Wrap the model to handle quantized inputs if provided
247254
if quant_input_args is not None:
@@ -254,7 +261,7 @@ def quantize_pt2(
254261
if quant_input_args is not None:
255262
QuantizedInputWrapper.sink_dequants(program)
256263

257-
fused_program = apply_pre_edge_transform_passes(program, quantizer)
264+
fused_program = apply_pre_edge_transform_passes(program, quantizer, is_qat=is_qat)
258265

259266
if dump_graphs:
260267
logging.info("Graph after quantization and fusion:")

0 commit comments

Comments
 (0)