Skip to content

Commit 38b40bc

Browse files
authored
Add FuseConsecutiveRescalesPass to fuse redundant RESCALE pairs (#17830)
Differential Revision: D94483331 Pull Request resolved: #17830
1 parent af35006 commit 38b40bc

6 files changed

Lines changed: 1020 additions & 0 deletions

File tree

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@
102102
QuantizeClampArgumentsPass,
103103
)
104104
from .fuse_batch_norm2d_pass import FuseBatchNorm2dPass # noqa
105+
from .fuse_consecutive_rescales_pass import FuseConsecutiveRescalesPass # noqa
105106
from .fuse_constant_ops_pass import ( # noqa
106107
ComputeConstantOpsAOTPass,
107108
FuseConstantArgsPass,

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
DecorateFp32toInt32CastingPass,
9999
FoldAndAnnotateQParamsPass,
100100
FuseBatchNorm2dPass,
101+
FuseConsecutiveRescalesPass,
101102
FuseConstantArgsPass,
102103
FuseDuplicateUsersPass,
103104
FuseEqualPlaceholdersPass,
@@ -380,6 +381,7 @@ def _tosa_pipeline(
380381
# Ticket: MLETORCH-1539
381382
DecomposeLinearPass(),
382383
InsertRescaleInt32Pass(),
384+
FuseConsecutiveRescalesPass(),
383385
InsertControlFlowRescalesPass(),
384386
DecomposeQuantNodesPass(),
385387
]
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
from typing import cast, Set, Type
8+
9+
import torch
10+
from executorch.backends.arm._passes.arm_pass import ArmPass
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass, PassResult
13+
from torch.fx import GraphModule, Node
14+
15+
logger: logging.Logger = logging.getLogger(__name__)
16+
17+
# TOSA RESCALE argument positions:
18+
# args[0] = input tensor (Node)
19+
# args[1] = output dtype (e.g., torch.int8, torch.int32)
20+
# args[2] = scale list (List[float]; per-tensor when len == 1)
21+
# args[3] = input zero point (int)
22+
# args[4] = output zero point (int)
23+
_ARG_INPUT = 0
24+
_ARG_OUTPUT_DTYPE = 1
25+
_ARG_SCALE = 2
26+
_ARG_INPUT_ZP = 3
27+
_ARG_OUTPUT_ZP = 4
28+
29+
30+
class FuseConsecutiveRescalesPass(ArmPass):
31+
"""Fuse consecutive RESCALE(INT32->INT8/INT16) -> RESCALE(INT8/INT16->INT32)
32+
pairs.
33+
34+
InsertRescaleInt32Pass wraps each quantized arithmetic and comparison
35+
operator (add, sub, mul, abs, eq, ge, gt, le, lt, max, min, sum) with
36+
input rescales (INT8/INT16->INT32) and an output rescale
37+
(INT32->INT8/INT16). When two such ops are chained (e.g., add1 -> add2),
38+
the output rescale of add1 feeds directly into an input rescale of add2,
39+
creating a redundant INT32->INT8/INT16->INT32 round-trip that loses
40+
precision.
41+
42+
This pass detects such pairs and handles two cases:
43+
44+
- **Identity** (composed scale ~1.0, matching zero points): Removes both
45+
RESCALEs and directly wires R1's input to R2's users. This eliminates
46+
the entire round-trip. Bypassing the intermediate INT8/INT16 clamp can
47+
in theory cause up to ~120 INT8 steps of output difference when all
48+
inputs are near the clamp boundary; in practice, observed differences
49+
are 0-1 steps for typical distributions. Tests use qtol=1.
50+
51+
- **Non-identity**: Leaves the pair unchanged. The Vela NPU compiler
52+
cannot correctly process INT32->INT32 RESCALE (produces all-zero NPU
53+
outputs), so non-identity pairs retain their INT8/INT16 intermediate.
54+
55+
Handles multi-user R1 nodes: when R1 feeds both RESCALE and
56+
non-RESCALE users, each R1->R2 RESCALE pair is fused individually
57+
while preserving R1 for its non-RESCALE users.
58+
59+
"""
60+
61+
_passes_required_after: Set[Type[ExportPass]] = set()
62+
63+
def call(self, graph_module: GraphModule) -> PassResult:
64+
graph = graph_module.graph
65+
modified = False
66+
rescale_before = sum(1 for n in graph.nodes if _is_rescale(n))
67+
identity_pairs_fused = 0
68+
69+
for node in list(graph.nodes):
70+
node = cast(Node, node)
71+
if not _is_fuseable_r1(node):
72+
continue
73+
74+
r1_input = node.args[_ARG_INPUT]
75+
r1_input_zp = node.args[_ARG_INPUT_ZP]
76+
r1_scale = float(node.args[_ARG_SCALE][0]) # type: ignore[arg-type]
77+
78+
node_fused = False
79+
for user in list(node.users):
80+
if _try_fuse_identity_pair(node, user, r1_input, r1_input_zp, r1_scale):
81+
node_fused = True
82+
identity_pairs_fused += 1
83+
84+
if node_fused:
85+
modified = True
86+
87+
if modified:
88+
graph.eliminate_dead_code()
89+
rescale_after = sum(1 for n in graph.nodes if _is_rescale(n))
90+
removed = rescale_before - rescale_after
91+
logger.info(
92+
"FuseConsecutiveRescalesPass: removed %d identity pairs "
93+
"(%d RESCALEs: %d -> %d)",
94+
identity_pairs_fused,
95+
removed,
96+
rescale_before,
97+
rescale_after,
98+
)
99+
graph_module.recompile()
100+
graph.lint()
101+
# Note: we deliberately skip super().call() — retracing is
102+
# unnecessary since this pass only rewires edges and removes
103+
# nodes without introducing new operations.
104+
105+
return PassResult(graph_module, modified)
106+
107+
108+
def _is_rescale(node: Node) -> bool:
109+
return (
110+
node.op == "call_function"
111+
and node.target == exir_ops.backend.tosa.RESCALE.default
112+
)
113+
114+
115+
def _is_fuseable_r1(node: Node) -> bool:
116+
"""Check if node is an R1 candidate.
117+
118+
R1 is RESCALE(INT32 -> INT8/INT16) with per-tensor scale.
119+
120+
"""
121+
if not _is_rescale(node):
122+
return False
123+
if node.args[_ARG_OUTPUT_DTYPE] not in (torch.int8, torch.int16):
124+
return False
125+
if len(node.args[_ARG_SCALE]) != 1: # type: ignore[arg-type]
126+
return False
127+
r1_input = node.args[_ARG_INPUT]
128+
if not isinstance(r1_input, Node) or "val" not in r1_input.meta:
129+
return False
130+
if r1_input.meta["val"].dtype != torch.int32:
131+
return False
132+
return True
133+
134+
135+
def _try_fuse_identity_pair(
136+
r1: Node,
137+
r2: Node,
138+
r1_input: Node,
139+
r1_input_zp: int,
140+
r1_scale: float,
141+
) -> bool:
142+
"""Try to fuse an R1->R2 identity pair.
143+
144+
Returns True if fused.
145+
146+
"""
147+
if not _is_rescale(r2):
148+
return False
149+
if r2.args[_ARG_OUTPUT_DTYPE] != torch.int32:
150+
return False
151+
if r1.args[_ARG_OUTPUT_ZP] != r2.args[_ARG_INPUT_ZP]:
152+
return False
153+
if len(r2.args[_ARG_SCALE]) != 1: # type: ignore[arg-type]
154+
return False
155+
156+
r2_scale = float(r2.args[_ARG_SCALE][0]) # type: ignore[arg-type, index]
157+
composed_scale = r1_scale * r2_scale
158+
r2_output_zp = r2.args[_ARG_OUTPUT_ZP]
159+
160+
if abs(composed_scale - 1.0) < 1e-6 and r1_input_zp == r2_output_zp:
161+
# Identity case: remove both RESCALEs and directly wire
162+
# R1's input (INT32) to R2's users. The composed scale
163+
# is ~1.0 so the round-trip is a no-op modulo the INT8
164+
# clamp. Bypassing the clamp can in theory cause up to
165+
# ~120 INT8 steps of difference near clamp boundaries;
166+
# observed differences are 0-1 steps. Tests use qtol=1.
167+
r2.replace_all_uses_with(r1_input)
168+
return True
169+
170+
# Non-identity: leave the pair unchanged. Creating a
171+
# single INT32->INT32 RESCALE with the composed scale would
172+
# be semantically correct (and the TOSA ref model handles
173+
# it), but the Vela NPU compiler produces all-zero outputs
174+
# for INT32->INT32 RESCALE operations.
175+
return False
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
"""Residual conv block model test for ARM TOSA backend.
6+
7+
Tests a minimal residual architecture with conv->batchnorm->relu->add blocks and
8+
permute operations, representative of quantized signal processing models where
9+
FuseConsecutiveRescalesPass eliminates redundant RESCALE pairs.
10+
11+
"""
12+
13+
from typing import Tuple
14+
15+
import torch
16+
from executorch.backends.arm.test import common
17+
from executorch.backends.arm.test.tester.test_pipeline import (
18+
EthosU55PipelineINT,
19+
EthosU85PipelineINT,
20+
TosaPipelineFP,
21+
TosaPipelineINT,
22+
VgfPipeline,
23+
)
24+
25+
26+
class ResidualConvBlock(torch.nn.Module):
27+
"""Residual conv block with batchnorm and permute operations.
28+
29+
Architecture: conv->bn->relu->add (residual) -> permute ->
30+
conv->bn->relu->add. When quantized, each residual add is
31+
wrapped with INT32 RESCALEs by InsertRescaleInt32Pass. Stacked
32+
blocks create consecutive RESCALE pairs (INT32->INT8->INT32)
33+
between adjacent adds that FuseConsecutiveRescalesPass
34+
eliminates.
35+
36+
"""
37+
38+
def __init__(self):
39+
super().__init__()
40+
self.conv1 = torch.nn.Conv2d(3, 3, 3, padding=1)
41+
self.bn1 = torch.nn.BatchNorm2d(3)
42+
self.relu1 = torch.nn.ReLU()
43+
self.conv2 = torch.nn.Conv2d(3, 3, 3, padding=1)
44+
self.bn2 = torch.nn.BatchNorm2d(3)
45+
self.relu2 = torch.nn.ReLU()
46+
47+
def forward(self, x):
48+
# Block 1: conv → batchnorm → relu → residual add
49+
out = self.relu1(self.bn1(self.conv1(x)))
50+
out = out + x # residual add 1
51+
52+
# Channel reordering (common in signal processing models)
53+
out = out.permute(0, 1, 3, 2)
54+
55+
# Block 2: conv → batchnorm → relu → residual add
56+
out2 = self.relu2(self.bn2(self.conv2(out)))
57+
out2 = out2 + out # residual add 2
58+
return out2
59+
60+
61+
model = ResidualConvBlock().eval()
62+
model_inputs = (torch.randn(1, 3, 8, 8),)
63+
input_t = Tuple[torch.Tensor]
64+
65+
66+
def test_residual_conv_block_tosa_FP():
67+
pipeline = TosaPipelineFP[input_t](
68+
model,
69+
model_inputs,
70+
aten_op=[],
71+
exir_op=[],
72+
use_to_edge_transform_and_lower=True,
73+
)
74+
pipeline.run()
75+
76+
77+
def test_residual_conv_block_tosa_INT():
78+
pipeline = TosaPipelineINT[input_t](
79+
model,
80+
model_inputs,
81+
aten_op=[],
82+
exir_op=[],
83+
use_to_edge_transform_and_lower=True,
84+
atol=0.25,
85+
qtol=1,
86+
frobenius_threshold=None,
87+
cosine_threshold=None,
88+
)
89+
pipeline.run()
90+
91+
92+
@common.XfailIfNoCorstone300
93+
def test_residual_conv_block_u55_INT():
94+
pipeline = EthosU55PipelineINT[input_t](
95+
model,
96+
model_inputs,
97+
aten_ops=[],
98+
exir_ops=[],
99+
use_to_edge_transform_and_lower=True,
100+
atol=0.25,
101+
qtol=1,
102+
)
103+
pipeline.run()
104+
105+
106+
@common.XfailIfNoCorstone320
107+
def test_residual_conv_block_u85_INT():
108+
pipeline = EthosU85PipelineINT[input_t](
109+
model,
110+
model_inputs,
111+
aten_ops=[],
112+
exir_ops=[],
113+
use_to_edge_transform_and_lower=True,
114+
atol=0.25,
115+
qtol=1,
116+
)
117+
pipeline.run()
118+
119+
120+
@common.SkipIfNoModelConverter
121+
def test_residual_conv_block_vgf_quant():
122+
pipeline = VgfPipeline[input_t](
123+
model,
124+
model_inputs,
125+
aten_op=[],
126+
exir_op=[],
127+
use_to_edge_transform_and_lower=True,
128+
quantize=True,
129+
)
130+
pipeline.run()
131+
132+
133+
@common.SkipIfNoModelConverter
134+
def test_residual_conv_block_vgf_no_quant():
135+
pipeline = VgfPipeline[input_t](
136+
model,
137+
model_inputs,
138+
aten_op=[],
139+
exir_op=[],
140+
use_to_edge_transform_and_lower=True,
141+
quantize=False,
142+
)
143+
pipeline.run()
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2026 Arm Limited and/or its affiliates.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Tuple
8+
9+
import torch
10+
from executorch.backends.arm._passes.fuse_quantized_activation_pass import (
11+
FuseQuantizedActivationPass,
12+
)
13+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
14+
15+
input_t = Tuple[torch.Tensor]
16+
17+
18+
class ConvRelu(torch.nn.Module):
19+
"""Conv2d followed by ReLU — existing fuseable behavior."""
20+
21+
def __init__(self) -> None:
22+
super().__init__()
23+
self.conv = torch.nn.Conv2d(3, 3, 3, padding=1)
24+
self.relu = torch.nn.ReLU()
25+
26+
def get_inputs(self) -> input_t:
27+
return (torch.randn(1, 3, 8, 8),)
28+
29+
def forward(self, x: torch.Tensor) -> torch.Tensor:
30+
return self.relu(self.conv(x))
31+
32+
33+
def test_fuse_relu_after_conv_quantized() -> None:
34+
"""Existing behavior: ReLU after conv is fused in quantized graph."""
35+
module = ConvRelu()
36+
pipeline = PassPipeline[input_t](
37+
module,
38+
module.get_inputs(),
39+
quantize=True,
40+
ops_before_pass={
41+
"executorch_exir_dialects_edge__ops_aten_relu_default": 1,
42+
},
43+
ops_not_after_pass=[
44+
"executorch_exir_dialects_edge__ops_aten_relu_default",
45+
],
46+
pass_list=[FuseQuantizedActivationPass],
47+
)
48+
pipeline.run()

0 commit comments

Comments
 (0)