diff --git a/backends/cadence/aot/BUCK b/backends/cadence/aot/BUCK index ae884e29deb..7d8ff3cffd2 100644 --- a/backends/cadence/aot/BUCK +++ b/backends/cadence/aot/BUCK @@ -45,6 +45,7 @@ fbcode_target(_kind = runtime.python_library, ":utils", "//caffe2:torch", "//executorch/backends/cadence/aot/quantizer:fusion_pass", + "//executorch/backends/cadence/aot/quantizer/passes:fuse_ops", "//executorch/backends/cadence/aot/quantizer:quantizer", "//executorch/backends/transforms:decompose_sdpa", "//executorch/backends/transforms:remove_clone_ops", diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 7fa2ac6f224..5c66c9eb62b 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -22,6 +22,7 @@ print_memory_planning_info, ) from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion +from executorch.backends.cadence.aot.quantizer.passes.fuse_ops import FuseQATConvBN from executorch.backends.cadence.aot.quantizer.quantizer import ( CadenceDefaultQuantizer, CadenceQuantizer, @@ -37,9 +38,10 @@ ExecutorchBackendConfig, ExecutorchProgramManager, ) +from executorch.exir.pass_manager import PassManager from executorch.exir.passes import ToOutVarPass from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass -from executorch.exir.program._program import _transform, to_edge +from executorch.exir.program._program import to_edge from torch.export.exported_program import ExportedProgram from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e @@ -162,13 +164,17 @@ def apply_pre_edge_transform_passes( which will instantiate a default quantizer for you if needed. Returns an ExportedProgram with the fused model. """ - # Get patterns and apply fusion of dq -> op -> q to qop # pyre-ignore[16]: no attribute patterns = [q.pattern for q in quantizer.quantizers] - fused_program = _transform(converted_program, QuantFusion(patterns)) + PassManager( + [ + FuseQATConvBN(converted_program), + QuantFusion(patterns), + ] + )(converted_program.graph_module) # Apply torch ops passes (e.g., ReplaceMulTensorWithMulAndFullOpsPass) - fused_program = apply_torch_ops_passes(fused_program) + fused_program = apply_torch_ops_passes(converted_program) return fused_program diff --git a/backends/cadence/aot/quantizer/passes/BUCK b/backends/cadence/aot/quantizer/passes/BUCK new file mode 100644 index 00000000000..9b1e403d77c --- /dev/null +++ b/backends/cadence/aot/quantizer/passes/BUCK @@ -0,0 +1,16 @@ +load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target") +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("odai_jarvis") + +fbcode_target(_kind = runtime.python_library, + name = "fuse_ops", + srcs = [ + "fuse_ops.py", + ], + typing = True, + deps = [ + "//caffe2:torch", + "//executorch/backends/transforms:quantize_fused_convbn_bias_pass", + ], +) diff --git a/backends/cadence/aot/quantizer/passes/fuse_ops.py b/backends/cadence/aot/quantizer/passes/fuse_ops.py new file mode 100644 index 00000000000..6c02403e644 --- /dev/null +++ b/backends/cadence/aot/quantizer/passes/fuse_ops.py @@ -0,0 +1,530 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +""" +FuseQATConvBN: a pre-edge (aten dialect) Conv-BN fusion pass. + +Folds the QAT Conv-BN simulation chain that `prepare_qat_pt2e` inserts into +the conv's quantized bias, and removes the chain. Necessary because TorchAO's +`_fold_conv_bn_qat` matcher fails to fold the chain when Cadence's quantizer +annotates conv biases with INT32 quantization (the matcher is hardcoded to +expect INT8 bias quantization). + +Lives next to `fusion_pass.py:QuantFusion` — both are pre-edge quantization +passes that operate on Q/DQ patterns and run before `to_edge`. +""" + +import operator +from typing import Any, Optional + +import torch +from torch.fx.passes.infra.pass_base import PassBase, PassResult + + +_QAT_CONV_TARGETS: tuple[Any, ...] = ( + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.convolution.default, +) + +_BN_TARGETS: tuple[Any, ...] = ( + torch.ops.aten.batch_norm.default, + torch.ops.aten._native_batch_norm_legit_no_training.default, +) + + +def _build_placeholder_to_state_dict_key_map( + exported_program: "torch.export.ExportedProgram", +) -> dict[str, str]: + """Map placeholder node names (`p_module_weight`, `b_module_running_mean`) + to the corresponding dotted state_dict keys (`module.weight`).""" + result: dict[str, str] = {} + + def _add(dotted: str) -> None: + flat = dotted.replace(".", "_") + result[f"p_{flat}"] = dotted + result[f"b_{flat}"] = dotted + result[f"arg_{flat}"] = dotted + result[dotted] = dotted + + for key in exported_program.state_dict: + _add(key) + for name, _ in exported_program.named_buffers(): + _add(name) + for key in getattr(exported_program, "constants", None) or {}: + _add(key) + return result + + +def _resolve_param_tensor( + graph_module: torch.fx.GraphModule, + node: torch.fx.Node, + exported_program: Optional["torch.export.ExportedProgram"], + key_map: Optional[dict[str, str]], +) -> Optional[torch.Tensor]: + """Read a tensor from a get_attr node (GraphModule attribute) or a + placeholder node (ExportedProgram state_dict / buffers / constants).""" + if node.op == "get_attr": + target = node.target + if not isinstance(target, str): + return None + attr = graph_module + for part in target.split("."): + if not hasattr(attr, part): + return None + attr = getattr(attr, part) + if isinstance(attr, (torch.Tensor, torch.nn.Parameter)): + return attr.data if isinstance(attr, torch.nn.Parameter) else attr + return None + if ( + node.op == "placeholder" + and exported_program is not None + and key_map is not None + ): + dotted = key_map.get(node.name) + if dotted is None: + return None + sd = exported_program.state_dict + if dotted in sd: + val = sd[dotted] + return val.data if isinstance(val, torch.nn.Parameter) else val + constants = getattr(exported_program, "constants", None) or {} + if dotted in constants: + return constants[dotted] + for name, buf in exported_program.named_buffers(): + if name == dotted: + return buf + return None + + +def _set_param_tensor( + graph_module: torch.fx.GraphModule, + node: torch.fx.Node, + value: torch.Tensor, + exported_program: Optional["torch.export.ExportedProgram"], + key_map: Optional[dict[str, str]], +) -> bool: + """Write a tensor to the location backing a get_attr or placeholder node.""" + if node.op == "get_attr" and isinstance(node.target, str): + parts = node.target.split(".") + parent = graph_module + for part in parts[:-1]: + if not hasattr(parent, part): + return False + parent = getattr(parent, part) + setattr(parent, parts[-1], value) + return True + if ( + node.op == "placeholder" + and exported_program is not None + and key_map is not None + ): + dotted = key_map.get(node.name) + if dotted is None: + return False + sd = exported_program.state_dict + if dotted in sd: + sd[dotted] = ( + torch.nn.Parameter(value, requires_grad=False) + if isinstance(sd[dotted], torch.nn.Parameter) + else value + ) + return True + return False + + +class FuseQATConvBN(PassBase): + """ + Folds the QAT Conv-BN simulation chain (inserted by `prepare_qat_pt2e`) into + the conv's quantized bias. Cleans up `batch_norm` nodes and the surrounding + sqrt/div/add ops that TorchAO's `_fold_conv_bn_qat` matcher fails to fold + when Cadence's quantizer annotates conv biases with INT32 quantization. + + The chain looks like: + conv → q → dq → div(scale) → add(orig_bias) → batch_norm + where scale = bn_weight / sqrt(running_var + eps). + + Two-step `call()`: + 1. Bias prep — for each conv, create a zero-filled quantized bias if + missing, or quantize a float bias as per-tensor int32. Required so + step 2 has a quantized bias slot to write the BN correction into. + 2. Fold — for each matched chain, compute the BN correction + C = (orig_bias - running_mean) * bn_weight / sqrt(running_var + eps) + bn_bias + and absorb it into the conv's quantized bias in place. Erase the chain + + batch_norm node. + + Pass `exported_program` when the graph_module's params are placeholders + (post-convert_pt2e); omit it for plain GraphModules with get_attr params. + """ + + def __init__( + self, + exported_program: Optional["torch.export.ExportedProgram"] = None, + default_zero_bias: bool = True, + ) -> None: + super().__init__() + self.exported_program = exported_program + self.default_zero_bias = default_zero_bias + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + # Step 1: prep biases so step 2 has quantized bias slots to write into. + prep_modified = self._prep_conv_biases(graph_module) + + # Step 2: fold the BN correction into the (now-quantized) bias and + # delete the simulation chain + batch_norm. + fold_modified = self._fold_qat_chains(graph_module) + + if prep_modified or fold_modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, prep_modified or fold_modified) + + def _prep_conv_biases(self, graph_module: torch.fx.GraphModule) -> bool: + """Delegate bias prep to the shared helper. Creates zero biases for + biasless convs and quantizes any float biases.""" + from executorch.backends.transforms.quantize_fused_convbn_bias_pass import ( + _get_bias_tensor_ep, + _get_tensor_from_node, + _quantize_fused_conv_bias, + _set_param_ep, + _set_param_gm, + ) + from torch._export.utils import get_buffer + + ep = self.exported_program + if ep is not None: + + def get_bias(n: torch.fx.Node) -> Any: + return _get_bias_tensor_ep(ep, n) + + def set_param( + n: torch.fx.Node, + t: torch.Tensor, + insert_before: Optional[torch.fx.Node] = None, + ) -> Any: + return _set_param_ep(ep, n, t) + + def get_scale(n: torch.fx.Node) -> Any: + return get_buffer(ep, n) + + else: + + def get_bias(n: torch.fx.Node) -> Any: + return _get_tensor_from_node(graph_module, n) + + def set_param( + n: torch.fx.Node, + t: torch.Tensor, + insert_before: Optional[torch.fx.Node] = None, + ) -> Any: + return _set_param_gm(graph_module, n, t, insert_before) + + def get_scale(n: torch.fx.Node) -> Any: + return _get_tensor_from_node(graph_module, n) + + return _quantize_fused_conv_bias( + graph_module, + conv_targets=_QAT_CONV_TARGETS + (torch.ops.aten.conv_transpose2d.input,), + unsqueeze_targets=( + torch.ops.aten.unsqueeze_copy.default, + torch.ops.aten.unsqueeze.default, + ), + dq_per_tensor=torch.ops.quantized_decomposed.dequantize_per_tensor.default, + dq_per_channel=torch.ops.quantized_decomposed.dequantize_per_channel.default, + get_bias_tensor=get_bias, + set_param=set_param, + get_weight_scale_tensor=get_scale, + default_zero_bias=self.default_zero_bias, + ) + + def _fold_qat_chains(self, graph_module: torch.fx.GraphModule) -> bool: + """Walk batch_norm nodes, match the QAT simulation chain, and fold + the BN correction into the conv bias.""" + graph = graph_module.graph + key_map = ( + _build_placeholder_to_state_dict_key_map(self.exported_program) + if self.exported_program is not None + else None + ) + nodes_to_erase: list[torch.fx.Node] = [] + changed = False + + for bn_node in list(graph.nodes): + if bn_node.target not in _BN_TARGETS: + continue + match = self._match_qat_chain(bn_node) + if match is None: + continue + + tensors = self._read_param_tensors(graph_module, key_map, bn_node, match) + if tensors is None: + continue + + new_bias = self._compute_folded_bias(match, tensors) + if not _set_param_tensor( + graph_module, + match["conv_bias_param"], + new_bias, + self.exported_program, + key_map, + ): + continue + + self._rewire_and_collect_erase(bn_node, match, nodes_to_erase) + changed = True + + for node in reversed(nodes_to_erase): + if len(node.users) == 0: + graph.erase_node(node) + return changed + + @staticmethod + def _match_bn_post_input( + bn_input: torch.fx.Node, + ) -> Optional[tuple[torch.fx.Node, Optional[torch.fx.Node], Any]]: + """Return (div_output, add_orig_bias, reshape_orig_bias) or None. + reshape_orig_bias is validated later by _resolve_orig_bias_node.""" + if bn_input.target == torch.ops.aten.add.Tensor: + add_orig_bias = bn_input + div_output = add_orig_bias.args[0] + reshape_orig_bias = add_orig_bias.args[1] + if not isinstance(div_output, torch.fx.Node) or ( + div_output.target != torch.ops.aten.div.Tensor + ): + return None + return div_output, add_orig_bias, reshape_orig_bias + if bn_input.target == torch.ops.aten.div.Tensor: + return bn_input, None, None + return None + + @staticmethod + def _match_scale_chain( + div_output: torch.fx.Node, + ) -> Optional[ + tuple[torch.fx.Node, torch.fx.Node, torch.fx.Node, torch.fx.Node, torch.fx.Node] + ]: + """Return (dq_intermediate, reshape_scale, scale_node, sqrt_node, add_var_eps) or None.""" + dq_intermediate = div_output.args[0] + reshape_scale = div_output.args[1] + if not isinstance(dq_intermediate, torch.fx.Node) or not isinstance( + reshape_scale, torch.fx.Node + ): + return None + if ( + dq_intermediate.target + != torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): + return None + if reshape_scale.target not in ( + torch.ops.aten.reshape.default, + torch.ops.aten.view.default, + ): + return None + scale_node = reshape_scale.args[0] + if not isinstance(scale_node, torch.fx.Node) or ( + scale_node.target != torch.ops.aten.div.Tensor + ): + return None + sqrt_node = scale_node.args[1] + if not isinstance(sqrt_node, torch.fx.Node) or ( + sqrt_node.target != torch.ops.aten.sqrt.default + ): + return None + add_var_eps = sqrt_node.args[0] + if not isinstance(add_var_eps, torch.fx.Node) or ( + add_var_eps.target != torch.ops.aten.add.Tensor + ): + return None + return dq_intermediate, reshape_scale, scale_node, sqrt_node, add_var_eps + + @staticmethod + def _match_conv_chain( + dq_intermediate: torch.fx.Node, + ) -> Optional[torch.fx.Node]: + """Return conv_bias_dq or None.""" + q_intermediate = dq_intermediate.args[0] + if not isinstance(q_intermediate, torch.fx.Node) or ( + q_intermediate.target + != torch.ops.quantized_decomposed.quantize_per_tensor.default + ): + return None + conv_node = q_intermediate.args[0] + if not isinstance(conv_node, torch.fx.Node) or ( + conv_node.target not in _QAT_CONV_TARGETS + ): + return None + if len(conv_node.args) < 3: + return None + conv_bias_dq = conv_node.args[2] + if ( + not isinstance(conv_bias_dq, torch.fx.Node) + or conv_bias_dq.target + != torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): + return None + return conv_bias_dq + + @staticmethod + def _resolve_orig_bias_node(reshape_orig_bias: Any) -> tuple[bool, Any]: + """Return (ok, orig_bias_node).""" + if reshape_orig_bias is None: + return True, None + if not isinstance(reshape_orig_bias, torch.fx.Node) or ( + reshape_orig_bias.target + not in ( + torch.ops.aten.reshape.default, + torch.ops.aten.view.default, + ) + ): + return False, None + return True, reshape_orig_bias.args[0] + + @staticmethod + def _match_qat_chain(bn_node: torch.fx.Node) -> Optional[dict[str, Any]]: + """Walk back from a batch_norm node and return the matched chain + components, or None if the pattern doesn't match.""" + if bn_node.target == torch.ops.aten.batch_norm.default: + if bn_node.args[5] is not False: # training=False required + return None + eps = bn_node.args[7] + else: # _native_batch_norm_legit_no_training + eps = bn_node.args[6] + + bn_input = bn_node.args[0] + if not isinstance(bn_input, torch.fx.Node): + return None + post_input = FuseQATConvBN._match_bn_post_input(bn_input) + if post_input is None: + return None + div_output, add_orig_bias, reshape_orig_bias = post_input + + scale_chain = FuseQATConvBN._match_scale_chain(div_output) + if scale_chain is None: + return None + dq_intermediate, reshape_scale, scale_node, sqrt_node, add_var_eps = scale_chain + + conv_bias_dq = FuseQATConvBN._match_conv_chain(dq_intermediate) + if conv_bias_dq is None: + return None + + ok, orig_bias_node = FuseQATConvBN._resolve_orig_bias_node(reshape_orig_bias) + if not ok: + return None + + return { + "eps": eps, + "bn_weight": bn_node.args[1], + "bn_bias": bn_node.args[2], + "bn_mean": bn_node.args[3], + "bn_var": bn_node.args[4], + "div_output": div_output, + "reshape_scale": reshape_scale, + "scale_node": scale_node, + "sqrt_node": sqrt_node, + "add_var_eps": add_var_eps, + "dq_intermediate": dq_intermediate, + "add_orig_bias": add_orig_bias, + "reshape_orig_bias": reshape_orig_bias, + "orig_bias_node": orig_bias_node, + "conv_bias_param": conv_bias_dq.args[0], + "bias_scale": conv_bias_dq.args[1], + "bias_zp": conv_bias_dq.args[2], + "bias_qmin": conv_bias_dq.args[3], + "bias_qmax": conv_bias_dq.args[4], + } + + def _read_param_tensors( + self, + graph_module: torch.fx.GraphModule, + key_map: Optional[dict[str, str]], + bn_node: torch.fx.Node, + match: dict[str, Any], + ) -> Optional[dict[str, torch.Tensor]]: + def get(node: Optional[torch.fx.Node]) -> Optional[torch.Tensor]: + if node is None: + return None + return _resolve_param_tensor( + graph_module, node, self.exported_program, key_map + ) + + tensors = { + "bn_weight": get(match["bn_weight"]), + "bn_bias": get(match["bn_bias"]), + "bn_mean": get(match["bn_mean"]), + "bn_var": get(match["bn_var"]), + "conv_bias": get(match["conv_bias_param"]), + } + if match["orig_bias_node"] is not None: + tensors["orig_bias"] = get(match["orig_bias_node"]) + if tensors["orig_bias"] is None: + return None + if any(t is None for t in tensors.values()): + return None + # Narrow Optional[Tensor] -> Tensor for type-checker after the None-check above. + return {k: v for k, v in tensors.items() if v is not None} + + @staticmethod + def _compute_folded_bias( + match: dict[str, Any], tensors: dict[str, torch.Tensor] + ) -> torch.Tensor: + """Compute new int bias = round((bias_float + C) / scale) + zp, clamped.""" + scale = match["bias_scale"] + zp = match["bias_zp"] + qmin = match["bias_qmin"] + qmax = match["bias_qmax"] + + running_std = torch.sqrt(tensors["bn_var"] + match["eps"]) + if "orig_bias" in tensors: + correction = (tensors["orig_bias"] - tensors["bn_mean"]) * tensors[ + "bn_weight" + ] / running_std + tensors["bn_bias"] + else: + correction = ( + -tensors["bn_mean"] * tensors["bn_weight"] / running_std + + tensors["bn_bias"] + ) + bias_float = (tensors["conv_bias"].float() - zp) * scale + new_bias_float = bias_float + correction + return torch.clamp(torch.round(new_bias_float / scale) + zp, qmin, qmax).to( + tensors["conv_bias"].dtype + ) + + @staticmethod + def _rewire_and_collect_erase( + bn_node: torch.fx.Node, + match: dict[str, Any], + nodes_to_erase: list[torch.fx.Node], + ) -> None: + """Replace BN output with the dequantized conv output and queue the + intermediate ops for deletion.""" + if ( + bn_node.target + == torch.ops.aten._native_batch_norm_legit_no_training.default + ): + for user in list(bn_node.users): + if user.target == operator.getitem: + if user.args[1] == 0: + user.replace_all_uses_with(match["dq_intermediate"]) + nodes_to_erase.append(user) + else: + bn_node.replace_all_uses_with(match["dq_intermediate"]) + + nodes_to_erase.extend( + n + for n in [ + bn_node, + match["div_output"], + match["reshape_scale"], + match["scale_node"], + match["sqrt_node"], + match["add_var_eps"], + match["add_orig_bias"], + match["reshape_orig_bias"], + ] + if n is not None + )