Skip to content

Commit 7df3e6d

Browse files
Ninja91meta-codesync[bot]
authored andcommitted
Add FuseConsecutiveRescalesPass to fuse redundant RESCALE pairs (#17830)
Summary: Pull Request resolved: #17830 TOSA requires INT32 arithmetic for add/sub/mul ops. `InsertRescaleInt32Pass` wraps each such op with input RESCALEs (INT8→INT32) and output RESCALE (INT32→INT8). When two such ops are chained, the output RESCALE of op1 feeds directly into the input RESCALE of op2, creating a redundant INT32→INT8→INT32 round-trip that wastes NPU cycles and loses precision. `FuseConsecutiveRescalesPass` detects these pairs and either: - Removes both if the composed scale is ~1.0 (identity) - Replaces both with a single INT32→INT32 RESCALE with composed scale Handles multi-user R1 nodes (e.g., residual connections, LayerNorm branching) by fusing each R1→R2 pair individually while preserving R1 for non-RESCALE users. ## Context Each unnecessary RESCALE is decomposed by Vela into Add+Mul NPU instructions (~1,130 cycles each on Ethos-U55-128). In meta-internal quantized models, RESCALE overhead accounts for 25-50% of total NPU cycles. This pass eliminates consecutive pairs at op boundaries, with multi-user handling catching additional pairs from branching patterns (LayerNorm's sub feeding both mul_square and mul_normalize). This diff also adds a `ResidualConvBlock` toy model and pass-level unit tests. Reviewed By: 3l1 Differential Revision: D94483331
1 parent 1b46fb0 commit 7df3e6d

6 files changed

Lines changed: 719 additions & 6 deletions

File tree

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
QuantizeClampArgumentsPass,
9797
)
9898
from .fuse_batch_norm2d_pass import FuseBatchNorm2dPass # noqa
99+
from .fuse_consecutive_rescales_pass import FuseConsecutiveRescalesPass # noqa
99100
from .fuse_constant_ops_pass import ( # noqa
100101
ComputeConstantOpsAOTPass,
101102
FuseConstantArgsPass,

backends/arm/_passes/arm_pass_manager.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
DecorateFp32toInt32CastingPass,
9494
FoldAndAnnotateQParamsPass,
9595
FuseBatchNorm2dPass,
96+
FuseConsecutiveRescalesPass,
9697
FuseConstantArgsPass,
9798
FuseDuplicateUsersPass,
9899
FuseEqualPlaceholdersPass,
@@ -161,8 +162,8 @@ def configure_skip_passes(
161162
self,
162163
override_config: ArmPassPipelineConfig | None = None,
163164
) -> tuple[type, ...]:
164-
"""Configures the pass manager to skip certain passes based on the
165-
ArmPassPipelineConfig class found in the compile spec.
165+
"""Configures the pass manager to skip certain passes based on
166+
the ArmPassPipelineConfig class found in the compile spec.
166167
"""
167168
skip_set: set[type] = set()
168169

@@ -189,11 +190,11 @@ def configure_skip_passes(
189190
return self._skip_pass_types
190191

191192
def validate_constraints_mandatory(self):
192-
"""Validates that necessary passes have run before transforming to
193-
backend.
193+
"""Validates that necessary passes have run before
194+
transforming to backend.
194195
195-
Note that this differs from the original validate_constraints function,
196-
which only checks the order of passes.
196+
Note that this differs from the original validate_constraints
197+
function, which only checks the order of passes.
197198
198199
"""
199200
passes_to_run = defaultdict(list)
@@ -264,6 +265,7 @@ def _tosa_pipeline(
264265
# Ticket: MLETORCH-1539
265266
DecomposeLinearPass(),
266267
InsertRescaleInt32Pass(),
268+
FuseConsecutiveRescalesPass(),
267269
InsertControlFlowRescalesPass(),
268270
DecomposeQuantNodesPass(),
269271
]
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import cast, Set, Type
7+
8+
import torch
9+
from executorch.backends.arm._passes.arm_pass import ArmPass
10+
from executorch.backends.arm._passes.arm_pass_utils import create_node
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass
13+
from torch.fx import GraphModule, Node
14+
from torch.fx.passes.infra.pass_base import PassResult
15+
16+
17+
class FuseConsecutiveRescalesPass(ArmPass):
18+
"""Fuse consecutive RESCALE(INT32->INT8/INT16) ->
19+
RESCALE(INT8/INT16->INT32) pairs.
20+
21+
InsertRescaleInt32Pass wraps each add/mul/sub with input rescales
22+
(INT8/INT16->INT32) and an output rescale (INT32->INT8/INT16). When
23+
two such ops are chained (e.g., add1 -> add2), the output rescale
24+
of add1 feeds directly into an input rescale of add2, creating a
25+
redundant INT32->INT8/INT16->INT32 round-trip that loses precision.
26+
27+
This pass detects such pairs and either:
28+
- Removes both if the composed scale is ~1.0 and zero points match
29+
- Replaces both with a single INT32->INT32 RESCALE with composed
30+
scale
31+
32+
Handles multi-user R1 nodes: when R1 feeds both RESCALE and
33+
non-RESCALE users, each R1->R2 RESCALE pair is fused individually
34+
while preserving R1 for its non-RESCALE users.
35+
36+
"""
37+
38+
_passes_required_after: Set[Type[ExportPass]] = set()
39+
40+
def call(self, graph_module: GraphModule) -> PassResult:
41+
graph = graph_module.graph
42+
modified = False
43+
nodes_to_erase = []
44+
45+
for node in list(graph.nodes):
46+
node = cast(Node, node)
47+
if not _is_rescale(node):
48+
continue
49+
50+
# R1 = node: output rescale (INT32 -> INT8/INT16)
51+
r1_output_dtype = node.args[1]
52+
if r1_output_dtype not in (torch.int8, torch.int16):
53+
continue
54+
55+
r1_input = node.args[0]
56+
r1_input_zp = node.args[3]
57+
r1_output_zp = node.args[4]
58+
r1_scale = float(node.args[2][0])
59+
60+
# Check each user individually (handles multi-user R1)
61+
for user in list(node.users):
62+
if not _is_rescale(user):
63+
continue
64+
65+
# R2 = user: input rescale (INT8/INT16 -> INT32)
66+
r2_output_dtype = user.args[1]
67+
if r2_output_dtype != torch.int32:
68+
continue
69+
70+
r2_input_zp = user.args[3]
71+
72+
# Guard: intermediate zero points must match for correct
73+
# composition. Without this, the offset term
74+
# (r1_output_zp - r2_input_zp) * r2_scale is silently lost.
75+
if r1_output_zp != r2_input_zp:
76+
continue
77+
78+
r2_scale = float(user.args[2][0])
79+
composed_scale = r1_scale * r2_scale
80+
r2_output_zp = user.args[4]
81+
82+
if abs(composed_scale - 1.0) < 1e-6 and r1_input_zp == r2_output_zp:
83+
# Identity: wire R1's input directly to R2's users
84+
user.replace_all_uses_with(r1_input)
85+
nodes_to_erase.append(user)
86+
else:
87+
# Non-identity: replace with single INT32->INT32 RESCALE
88+
with graph.inserting_before(user):
89+
composed_node = create_node(
90+
graph,
91+
exir_ops.backend.tosa.RESCALE.default,
92+
(
93+
r1_input,
94+
r2_output_dtype,
95+
[composed_scale],
96+
r1_input_zp,
97+
r2_output_zp,
98+
),
99+
from_node=user,
100+
)
101+
user.replace_all_uses_with(composed_node)
102+
nodes_to_erase.append(user)
103+
104+
modified = True
105+
106+
# Always consider R1 for removal; actual erasure is guarded below
107+
nodes_to_erase.append(node)
108+
109+
for node in nodes_to_erase:
110+
if len(node.users) == 0:
111+
graph.erase_node(node)
112+
113+
if modified:
114+
graph_module = super().call(graph_module).graph_module
115+
graph_module.recompile()
116+
117+
return PassResult(graph_module, modified)
118+
119+
120+
def _is_rescale(node: Node) -> bool:
121+
return (
122+
node.op == "call_function"
123+
and node.target == exir_ops.backend.tosa.RESCALE.default
124+
)
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""Residual conv block model test for ARM TOSA backend.
7+
8+
Tests a minimal residual architecture with conv->batchnorm->relu->add
9+
blocks and permute operations, representative of quantized signal
10+
processing models where FuseConsecutiveRescalesPass eliminates
11+
redundant RESCALE pairs.
12+
13+
"""
14+
15+
from typing import Tuple
16+
17+
import torch
18+
from executorch.backends.arm.test import common
19+
from executorch.backends.arm.test.tester.test_pipeline import (
20+
EthosU55PipelineINT,
21+
EthosU85PipelineINT,
22+
TosaPipelineFP,
23+
TosaPipelineINT,
24+
VgfPipeline,
25+
)
26+
27+
28+
class ResidualConvBlock(torch.nn.Module):
29+
"""Residual conv block with batchnorm and permute operations.
30+
31+
Architecture: conv->bn->relu->add (residual) -> permute ->
32+
conv->bn->relu->add. When quantized, each residual add is
33+
wrapped with INT32 RESCALEs by InsertRescaleInt32Pass. Stacked
34+
blocks create consecutive RESCALE pairs (INT32->INT8->INT32)
35+
between adjacent adds that FuseConsecutiveRescalesPass
36+
eliminates.
37+
38+
"""
39+
40+
def __init__(self):
41+
super().__init__()
42+
self.conv1 = torch.nn.Conv2d(3, 3, 3, padding=1)
43+
self.bn1 = torch.nn.BatchNorm2d(3)
44+
self.relu1 = torch.nn.ReLU()
45+
self.conv2 = torch.nn.Conv2d(3, 3, 3, padding=1)
46+
self.bn2 = torch.nn.BatchNorm2d(3)
47+
self.relu2 = torch.nn.ReLU()
48+
49+
def forward(self, x):
50+
# Block 1: conv → batchnorm → relu → residual add
51+
out = self.relu1(self.bn1(self.conv1(x)))
52+
out = out + x # residual add 1
53+
54+
# Channel reordering (common in signal processing models)
55+
out = out.permute(0, 1, 3, 2)
56+
57+
# Block 2: conv → batchnorm → relu → residual add
58+
out2 = self.relu2(self.bn2(self.conv2(out)))
59+
out2 = out2 + out # residual add 2
60+
return out2
61+
62+
63+
model = ResidualConvBlock().eval()
64+
model_inputs = (torch.randn(1, 3, 8, 8),)
65+
input_t = Tuple[torch.Tensor]
66+
67+
68+
def test_residual_conv_block_tosa_FP():
69+
pipeline = TosaPipelineFP[input_t](
70+
model,
71+
model_inputs,
72+
aten_op=[],
73+
exir_op=[],
74+
use_to_edge_transform_and_lower=True,
75+
)
76+
pipeline.run()
77+
78+
79+
def test_residual_conv_block_tosa_INT():
80+
pipeline = TosaPipelineINT[input_t](
81+
model,
82+
model_inputs,
83+
aten_op=[],
84+
exir_op=[],
85+
use_to_edge_transform_and_lower=True,
86+
atol=0.25,
87+
qtol=1,
88+
frobenius_threshold=None,
89+
cosine_threshold=None,
90+
)
91+
pipeline.run()
92+
93+
94+
@common.XfailIfNoCorstone300
95+
def test_residual_conv_block_u55_INT():
96+
pipeline = EthosU55PipelineINT[input_t](
97+
model,
98+
model_inputs,
99+
aten_ops=[],
100+
exir_ops=[],
101+
use_to_edge_transform_and_lower=True,
102+
)
103+
pipeline.run()
104+
105+
106+
@common.XfailIfNoCorstone320
107+
def test_residual_conv_block_u85_INT():
108+
pipeline = EthosU85PipelineINT[input_t](
109+
model,
110+
model_inputs,
111+
aten_ops=[],
112+
exir_ops=[],
113+
use_to_edge_transform_and_lower=True,
114+
)
115+
pipeline.run()
116+
117+
118+
@common.SkipIfNoModelConverter
119+
def test_residual_conv_block_vgf_quant():
120+
pipeline = VgfPipeline[input_t](
121+
model,
122+
model_inputs,
123+
aten_op=[],
124+
exir_op=[],
125+
use_to_edge_transform_and_lower=True,
126+
quantize=True,
127+
)
128+
pipeline.run()
129+
130+
131+
@common.SkipIfNoModelConverter
132+
def test_residual_conv_block_vgf_no_quant():
133+
pipeline = VgfPipeline[input_t](
134+
model,
135+
model_inputs,
136+
aten_op=[],
137+
exir_op=[],
138+
use_to_edge_transform_and_lower=True,
139+
quantize=False,
140+
)
141+
pipeline.run()
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2026 Arm Limited and/or its affiliates.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Tuple
8+
9+
import torch
10+
from executorch.backends.arm._passes.fuse_quantized_activation_pass import (
11+
FuseQuantizedActivationPass,
12+
)
13+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
14+
15+
input_t = Tuple[torch.Tensor]
16+
17+
18+
class ConvRelu(torch.nn.Module):
19+
"""Conv2d followed by ReLU — existing fuseable behavior."""
20+
21+
def __init__(self) -> None:
22+
super().__init__()
23+
self.conv = torch.nn.Conv2d(3, 3, 3, padding=1)
24+
self.relu = torch.nn.ReLU()
25+
26+
def get_inputs(self) -> input_t:
27+
return (torch.randn(1, 3, 8, 8),)
28+
29+
def forward(self, x: torch.Tensor) -> torch.Tensor:
30+
return self.relu(self.conv(x))
31+
32+
33+
def test_fuse_relu_after_conv_quantized() -> None:
34+
"""Existing behavior: ReLU after conv is fused in quantized graph."""
35+
module = ConvRelu()
36+
pipeline = PassPipeline[input_t](
37+
module,
38+
module.get_inputs(),
39+
quantize=True,
40+
ops_before_pass={
41+
"executorch_exir_dialects_edge__ops_aten_relu_default": 1,
42+
},
43+
ops_not_after_pass=[
44+
"executorch_exir_dialects_edge__ops_aten_relu_default",
45+
],
46+
pass_list=[FuseQuantizedActivationPass],
47+
)
48+
pipeline.pop_stage("run_method_and_compare_outputs")
49+
pipeline.run()

0 commit comments

Comments
 (0)