Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def convert_pt2(
def apply_pre_edge_transform_passes(
converted_program: ExportedProgram,
quantizer: CadenceQuantizer,
is_qat: bool = False,
) -> ExportedProgram:
"""
Apply pre-edge transform passes including QuantFusion and torch ops passes.
Expand All @@ -166,12 +167,11 @@ def apply_pre_edge_transform_passes(
"""
# pyre-ignore[16]: no attribute
patterns = [q.pattern for q in quantizer.quantizers]
PassManager(
[
FuseQATConvBN(converted_program),
QuantFusion(patterns),
]
)(converted_program.graph_module)
passes = []
if is_qat:
passes.append(FuseQATConvBN(converted_program))
passes.append(QuantFusion(patterns))
PassManager(passes)(converted_program.graph_module)

# Apply torch ops passes (e.g., ReplaceMulTensorWithMulAndFullOpsPass)
fused_program = apply_torch_ops_passes(converted_program)
Expand All @@ -187,19 +187,24 @@ def get_fake_quant_model(
quantizer: CadenceQuantizer,
calibration_data: Optional[list[tuple[object, ...]]] = None,
dump_graphs: bool = False,
is_qat: bool = False,
) -> torch.fx.GraphModule:
# Make the model inference mode by calling model.eval()
model.eval()

ops_to_keep = quantizer.get_ops_to_preserve_from_decomposition()
program = trace(model, inputs, dump_graphs=dump_graphs, ops_to_keep=ops_to_keep)
program = trace(
model, inputs, dump_graphs=dump_graphs, ops_to_keep=ops_to_keep, is_qat=is_qat
)

if dump_graphs:
logging.info("Graph after trace:")
logging.info(program.graph.print_tabular())

# Get prepared graph module
prepared_gm = prepare_traced_pt2(program, quantizer, dump_graphs=dump_graphs)
prepared_gm = prepare_traced_pt2(
program, quantizer, dump_graphs=dump_graphs, is_qat=is_qat
)

# Calibrate
# If no calibration data is provided, use the inputs
Expand All @@ -221,6 +226,7 @@ def quantize_pt2(
calibration_data: Optional[list[tuple[object, ...]]] = None,
dump_graphs: bool = False,
quant_input_args: Optional[list[str]] = None,
is_qat: bool = False,
) -> ExportedProgram:
"""
Trace, prepare, convert and fuse the model using the given quantizer.
Expand All @@ -242,6 +248,7 @@ def quantize_pt2(
quantizer=quantizer,
calibration_data=calibration_data,
dump_graphs=dump_graphs,
is_qat=is_qat,
)
# Wrap the model to handle quantized inputs if provided
if quant_input_args is not None:
Expand All @@ -254,7 +261,7 @@ def quantize_pt2(
if quant_input_args is not None:
QuantizedInputWrapper.sink_dequants(program)

fused_program = apply_pre_edge_transform_passes(program, quantizer)
fused_program = apply_pre_edge_transform_passes(program, quantizer, is_qat=is_qat)

if dump_graphs:
logging.info("Graph after quantization and fusion:")
Expand Down
Loading