From 41f946aaffa11925b7b1ba5b50cf7ec02bf3873f Mon Sep 17 00:00:00 2001 From: Baris Demir Date: Wed, 25 Mar 2026 15:27:18 +0000 Subject: [PATCH] Arm backend: Handle conv qparams from non-fused clamp The bug was caused by a pass-pipeline mismatch in the Arm quantized lowering flow. For a delegated conv -> relu/clamp -> cat branch under TOSA-1.0+INT+FP, the activation can remain unfused when its output quantization is non-fuseable (zp != qmin). In that case, quantization annotation and qparam folding legitimately leave input_qparams on the conv and output_qparams on the surviving clamp, but RewriteConvPass incorrectly assumed every quantized conv would always own local output_qparams and crashed when inserting the output rescale. The fix makes RewriteConvPass resolve the conv's effective output qparams from a following quantized clamp when needed, which matches the real quantized output domain of the unfused branch. The patch also fixes synthetic bias creation for biasless quantized convs by detecting quantization from input_qparams instead of relying on local output_qparams. Signed-off-by: Baris Demir Change-Id: I373e9eca8490148bc90dee487a2963612a494d8b --- backends/arm/_passes/rewrite_conv_pass.py | 53 +++++-- .../arm/test/passes/test_rewrite_conv_pass.py | 133 +++++++++++++++++- 2 files changed, 175 insertions(+), 11 deletions(-) diff --git a/backends/arm/_passes/rewrite_conv_pass.py b/backends/arm/_passes/rewrite_conv_pass.py index dfd7e065750..c477f9afbe5 100644 --- a/backends/arm/_passes/rewrite_conv_pass.py +++ b/backends/arm/_passes/rewrite_conv_pass.py @@ -167,8 +167,8 @@ def _add_bias( weight_node: torch.fx.Node, ) -> torch.fx.Node: output_channels = get_first_fake_tensor(node).shape[1] - # add a node containging zeros if quantized, use int32, otherwise use float32 - if "output_qparams" in node.meta and len(node.meta["output_qparams"]) > 0: + # add a node containing zeros if quantized, use int32, otherwise use float32 + if self._is_quantized_conv(node): bias_data = torch.zeros(size=(output_channels,), dtype=torch.int32) else: output_dtype = node.meta["val"].dtype @@ -188,9 +188,40 @@ def _add_bias( node.update_arg(2, bias_node) return bias_node - def insert_output_rescale(self, graph_module, node): - input_qparams = get_input_qparams(node) - output_qparams = get_output_qparams(node)[0] + def _is_quantized_conv(self, node: torch.fx.Node) -> bool: + return bool(node.meta.get("input_qparams", {})) + + def _get_effective_output_qparams(self, node: torch.fx.Node): + """Return the quantized output domain for a conv node. + + Quantization annotation may place output qparams on a following + activation instead of on the conv itself. If that activation is not + fuseable, it survives as a quantized ``clamp`` and still owns the + branch output qparams needed for the conv output rescale. + + """ + output_qparams = node.meta.get("output_qparams", {}) + if output_qparams: + return output_qparams + + users = list(node.users) + if len(users) != 1: + raise ValueError( + f"RewriteConvPass: No output quantization parameter found in node {node}\n" + f"original_aten={node.meta.get('original_aten', 'None')}" + ) + + activation = users[0] + if activation.target == exir_ops.edge.aten.clamp.default: + activation_output_qparams = activation.meta.get("output_qparams", {}) + if activation_output_qparams: + return activation_output_qparams + + return get_output_qparams(node) + + def insert_output_rescale(self, graph_module, source_node, conv_node): + input_qparams = get_input_qparams(source_node) + output_qparams = self._get_effective_output_qparams(source_node)[0] weight_qparams = input_qparams[1] input_qparams = input_qparams[0] is_per_channel = weight_qparams.per_channel @@ -207,18 +238,18 @@ def insert_output_rescale(self, graph_module, node): itertools.cycle([output_qparams.get_scale_per_tensor()]), ) ] - with graph_module.graph.inserting_after(node): + with graph_module.graph.inserting_after(conv_node): rescale_node = create_node( graph=graph_module.graph, op_target=exir_ops.backend.tosa.RESCALE.default, args=( - node, + conv_node, output_qparams.dtype, post_conv2d_scale, 0, output_qparams.get_zp_per_tensor(), ), - from_node=node, + from_node=source_node, ) return rescale_node @@ -347,7 +378,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 tosa_node_fake_tensor.dtype == torch.int32 and input_fake_tensor.dtype == torch.int8 ): - output_rescale = self.insert_output_rescale(graph_module, tosa_op) + output_rescale = self.insert_output_rescale(graph_module, node, tosa_op) node.replace_all_uses_with(output_rescale) elif ( tosa_node_fake_tensor.dtype == torch.int32 @@ -355,7 +386,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 ): has_bias = len(node.meta["input_qparams"]) > 2 if not has_bias: - output_rescale = self.insert_output_rescale(graph_module, tosa_op) + output_rescale = self.insert_output_rescale( + graph_module, node, tosa_op + ) node.replace_all_uses_with(output_rescale) else: node.replace_all_uses_with(tosa_op) diff --git a/backends/arm/test/passes/test_rewrite_conv_pass.py b/backends/arm/test/passes/test_rewrite_conv_pass.py index d59dbc90848..8bd98437cbc 100644 --- a/backends/arm/test/passes/test_rewrite_conv_pass.py +++ b/backends/arm/test/passes/test_rewrite_conv_pass.py @@ -1,13 +1,98 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import torch +import torch.nn as nn +import torch.nn.functional as F +from executorch.backends.arm._passes import ( + ConvertToClampPass, + FoldAndAnnotateQParamsPass, + FuseQuantizedActivationPass, + QuantizeClampArgumentsPass, +) from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_quantization_config, + VgfQuantizer, +) from executorch.backends.arm.test.misc.test_dw_convs_with_shared_weights import ( DWConvsModule, ) from executorch.backends.arm.test.tester.test_pipeline import PassPipeline +from executorch.backends.arm.tosa.specification import TosaLoweringContext +from executorch.backends.arm.vgf import VgfCompileSpec, VgfPartitioner +from executorch.exir import EdgeCompileConfig, to_edge, to_edge_transform_and_lower +from executorch.exir.dialects._ops import ops as exir_ops + + +class TinyConvReluCat(nn.Module): + def __init__(self, conv1_bias: bool = True) -> None: + super().__init__() + self.conv1 = nn.Conv2d(4, 4, 3, padding=1, bias=conv1_bias) + self.conv2 = nn.Conv2d(8, 4, 1) + with torch.no_grad(): + for param in self.parameters(): + param.uniform_(-0.1, 0.1) + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + relu_out = F.relu(self.conv1(x)) + merged = torch.cat((relu_out, y), dim=1) + return self.conv2(merged) + + +def _example_inputs() -> tuple[torch.Tensor, torch.Tensor]: + torch.manual_seed(0) + x = torch.rand(1, 4, 16, 16) + y = torch.rand(1, 4, 16, 16) - 0.065 + return x, y + + +def _compile_spec() -> VgfCompileSpec: + return VgfCompileSpec("TOSA-1.0+INT+FP") + + +def _quantizer() -> VgfQuantizer: + quantizer = VgfQuantizer(_compile_spec()) + quantizer.set_global( + get_symmetric_quantization_config( + is_per_channel=True, + act_qmin=-127, + act_qmax=127, + weight_qmin=-127, + weight_qmax=127, + ) + ) + return quantizer + + +def _export_quantized(model: nn.Module): + inputs = _example_inputs() + exported = torch.export.export(model.eval(), inputs).module(check_guards=False) + quantized = _quantizer()._quantize_with_submodules(exported, [inputs]) + return torch.export.export(quantized, inputs) + + +def _run_pre_rewrite_passes(exported_program: torch.export.ExportedProgram): + gm = exported_program.graph_module + for pass_ in ( + FuseQuantizedActivationPass(), + ConvertToClampPass(), + FoldAndAnnotateQParamsPass(exported_program), + QuantizeClampArgumentsPass(), + ): + result = pass_(gm) + assert result is not None + gm = result.graph_module + return gm + + +def _get_call_function_node(gm: torch.fx.GraphModule, target): + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == target: + return node + raise AssertionError(f"Node with target {target} not found") def test_rewrite_conv_tosa_FP(): @@ -18,3 +103,49 @@ def test_rewrite_conv_tosa_FP(): # We can't run TOSA backend dialect operators in eager mode pipeline.pop_stage("run_method_and_compare_outputs") pipeline.run() + + +def test_fold_and_annotate_q_params_vgf_quant_preserves_output_qparams_on_non_fuseable_clamp() -> ( + None +): + exported_program = _export_quantized(TinyConvReluCat()) + gm = _run_pre_rewrite_passes(to_edge(exported_program).exported_program()) + + conv = _get_call_function_node(gm, exir_ops.edge.aten.convolution.default) + clamp = _get_call_function_node(gm, exir_ops.edge.aten.clamp.default) + + assert conv.meta["input_qparams"] + assert not conv.meta["output_qparams"] + assert clamp.meta["output_qparams"] + + +def test_rewrite_conv_vgf_quant_handles_non_fuseable_conv_clamp_cat_branch() -> None: + exported_program = _export_quantized(TinyConvReluCat()) + compile_spec = _compile_spec() + + to_edge_transform_and_lower( + exported_program, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + partitioner=[VgfPartitioner(compile_spec)], + ) + + +def test_rewrite_conv_vgf_quant_infers_quantized_bias_dtype_from_inputs() -> None: + exported_program = _export_quantized(TinyConvReluCat(conv1_bias=False)) + edge_program = to_edge( + exported_program, compile_config=EdgeCompileConfig(_check_ir_validity=False) + ).exported_program() + gm = _run_pre_rewrite_passes(edge_program) + with TosaLoweringContext(_compile_spec().tosa_spec): + result = RewriteConvPass(edge_program)(gm) + assert result is not None + gm = result.graph_module + + bias_nodes = [ + node + for node in gm.graph.nodes + if node.op == "placeholder" and node.name.endswith("_bias") + ] + + assert len(bias_nodes) == 1 + assert bias_nodes[0].meta["val"].dtype == torch.int32