diff --git a/backends/arm/_passes/rewrite_conv_pass.py b/backends/arm/_passes/rewrite_conv_pass.py index dfd7e065750..c477f9afbe5 100644 --- a/backends/arm/_passes/rewrite_conv_pass.py +++ b/backends/arm/_passes/rewrite_conv_pass.py @@ -167,8 +167,8 @@ def _add_bias( weight_node: torch.fx.Node, ) -> torch.fx.Node: output_channels = get_first_fake_tensor(node).shape[1] - # add a node containging zeros if quantized, use int32, otherwise use float32 - if "output_qparams" in node.meta and len(node.meta["output_qparams"]) > 0: + # add a node containing zeros if quantized, use int32, otherwise use float32 + if self._is_quantized_conv(node): bias_data = torch.zeros(size=(output_channels,), dtype=torch.int32) else: output_dtype = node.meta["val"].dtype @@ -188,9 +188,40 @@ def _add_bias( node.update_arg(2, bias_node) return bias_node - def insert_output_rescale(self, graph_module, node): - input_qparams = get_input_qparams(node) - output_qparams = get_output_qparams(node)[0] + def _is_quantized_conv(self, node: torch.fx.Node) -> bool: + return bool(node.meta.get("input_qparams", {})) + + def _get_effective_output_qparams(self, node: torch.fx.Node): + """Return the quantized output domain for a conv node. + + Quantization annotation may place output qparams on a following + activation instead of on the conv itself. If that activation is not + fuseable, it survives as a quantized ``clamp`` and still owns the + branch output qparams needed for the conv output rescale. + + """ + output_qparams = node.meta.get("output_qparams", {}) + if output_qparams: + return output_qparams + + users = list(node.users) + if len(users) != 1: + raise ValueError( + f"RewriteConvPass: No output quantization parameter found in node {node}\n" + f"original_aten={node.meta.get('original_aten', 'None')}" + ) + + activation = users[0] + if activation.target == exir_ops.edge.aten.clamp.default: + activation_output_qparams = activation.meta.get("output_qparams", {}) + if activation_output_qparams: + return activation_output_qparams + + return get_output_qparams(node) + + def insert_output_rescale(self, graph_module, source_node, conv_node): + input_qparams = get_input_qparams(source_node) + output_qparams = self._get_effective_output_qparams(source_node)[0] weight_qparams = input_qparams[1] input_qparams = input_qparams[0] is_per_channel = weight_qparams.per_channel @@ -207,18 +238,18 @@ def insert_output_rescale(self, graph_module, node): itertools.cycle([output_qparams.get_scale_per_tensor()]), ) ] - with graph_module.graph.inserting_after(node): + with graph_module.graph.inserting_after(conv_node): rescale_node = create_node( graph=graph_module.graph, op_target=exir_ops.backend.tosa.RESCALE.default, args=( - node, + conv_node, output_qparams.dtype, post_conv2d_scale, 0, output_qparams.get_zp_per_tensor(), ), - from_node=node, + from_node=source_node, ) return rescale_node @@ -347,7 +378,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 tosa_node_fake_tensor.dtype == torch.int32 and input_fake_tensor.dtype == torch.int8 ): - output_rescale = self.insert_output_rescale(graph_module, tosa_op) + output_rescale = self.insert_output_rescale(graph_module, node, tosa_op) node.replace_all_uses_with(output_rescale) elif ( tosa_node_fake_tensor.dtype == torch.int32 @@ -355,7 +386,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 ): has_bias = len(node.meta["input_qparams"]) > 2 if not has_bias: - output_rescale = self.insert_output_rescale(graph_module, tosa_op) + output_rescale = self.insert_output_rescale( + graph_module, node, tosa_op + ) node.replace_all_uses_with(output_rescale) else: node.replace_all_uses_with(tosa_op) diff --git a/backends/arm/test/passes/test_rewrite_conv_pass.py b/backends/arm/test/passes/test_rewrite_conv_pass.py index d59dbc90848..8bd98437cbc 100644 --- a/backends/arm/test/passes/test_rewrite_conv_pass.py +++ b/backends/arm/test/passes/test_rewrite_conv_pass.py @@ -1,13 +1,98 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import torch +import torch.nn as nn +import torch.nn.functional as F +from executorch.backends.arm._passes import ( + ConvertToClampPass, + FoldAndAnnotateQParamsPass, + FuseQuantizedActivationPass, + QuantizeClampArgumentsPass, +) from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_quantization_config, + VgfQuantizer, +) from executorch.backends.arm.test.misc.test_dw_convs_with_shared_weights import ( DWConvsModule, ) from executorch.backends.arm.test.tester.test_pipeline import PassPipeline +from executorch.backends.arm.tosa.specification import TosaLoweringContext +from executorch.backends.arm.vgf import VgfCompileSpec, VgfPartitioner +from executorch.exir import EdgeCompileConfig, to_edge, to_edge_transform_and_lower +from executorch.exir.dialects._ops import ops as exir_ops + + +class TinyConvReluCat(nn.Module): + def __init__(self, conv1_bias: bool = True) -> None: + super().__init__() + self.conv1 = nn.Conv2d(4, 4, 3, padding=1, bias=conv1_bias) + self.conv2 = nn.Conv2d(8, 4, 1) + with torch.no_grad(): + for param in self.parameters(): + param.uniform_(-0.1, 0.1) + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + relu_out = F.relu(self.conv1(x)) + merged = torch.cat((relu_out, y), dim=1) + return self.conv2(merged) + + +def _example_inputs() -> tuple[torch.Tensor, torch.Tensor]: + torch.manual_seed(0) + x = torch.rand(1, 4, 16, 16) + y = torch.rand(1, 4, 16, 16) - 0.065 + return x, y + + +def _compile_spec() -> VgfCompileSpec: + return VgfCompileSpec("TOSA-1.0+INT+FP") + + +def _quantizer() -> VgfQuantizer: + quantizer = VgfQuantizer(_compile_spec()) + quantizer.set_global( + get_symmetric_quantization_config( + is_per_channel=True, + act_qmin=-127, + act_qmax=127, + weight_qmin=-127, + weight_qmax=127, + ) + ) + return quantizer + + +def _export_quantized(model: nn.Module): + inputs = _example_inputs() + exported = torch.export.export(model.eval(), inputs).module(check_guards=False) + quantized = _quantizer()._quantize_with_submodules(exported, [inputs]) + return torch.export.export(quantized, inputs) + + +def _run_pre_rewrite_passes(exported_program: torch.export.ExportedProgram): + gm = exported_program.graph_module + for pass_ in ( + FuseQuantizedActivationPass(), + ConvertToClampPass(), + FoldAndAnnotateQParamsPass(exported_program), + QuantizeClampArgumentsPass(), + ): + result = pass_(gm) + assert result is not None + gm = result.graph_module + return gm + + +def _get_call_function_node(gm: torch.fx.GraphModule, target): + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == target: + return node + raise AssertionError(f"Node with target {target} not found") def test_rewrite_conv_tosa_FP(): @@ -18,3 +103,49 @@ def test_rewrite_conv_tosa_FP(): # We can't run TOSA backend dialect operators in eager mode pipeline.pop_stage("run_method_and_compare_outputs") pipeline.run() + + +def test_fold_and_annotate_q_params_vgf_quant_preserves_output_qparams_on_non_fuseable_clamp() -> ( + None +): + exported_program = _export_quantized(TinyConvReluCat()) + gm = _run_pre_rewrite_passes(to_edge(exported_program).exported_program()) + + conv = _get_call_function_node(gm, exir_ops.edge.aten.convolution.default) + clamp = _get_call_function_node(gm, exir_ops.edge.aten.clamp.default) + + assert conv.meta["input_qparams"] + assert not conv.meta["output_qparams"] + assert clamp.meta["output_qparams"] + + +def test_rewrite_conv_vgf_quant_handles_non_fuseable_conv_clamp_cat_branch() -> None: + exported_program = _export_quantized(TinyConvReluCat()) + compile_spec = _compile_spec() + + to_edge_transform_and_lower( + exported_program, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + partitioner=[VgfPartitioner(compile_spec)], + ) + + +def test_rewrite_conv_vgf_quant_infers_quantized_bias_dtype_from_inputs() -> None: + exported_program = _export_quantized(TinyConvReluCat(conv1_bias=False)) + edge_program = to_edge( + exported_program, compile_config=EdgeCompileConfig(_check_ir_validity=False) + ).exported_program() + gm = _run_pre_rewrite_passes(edge_program) + with TosaLoweringContext(_compile_spec().tosa_spec): + result = RewriteConvPass(edge_program)(gm) + assert result is not None + gm = result.graph_module + + bias_nodes = [ + node + for node in gm.graph.nodes + if node.op == "placeholder" and node.name.endswith("_bias") + ] + + assert len(bias_nodes) == 1 + assert bias_nodes[0].meta["val"].dtype == torch.int32