Skip to content

Commit 513a4ea

Browse files
Arm backend: Avoid running passes with no matching target ops (pytorch#19839)
Add ArmPass.should_run_pass() as a reusable early-exit hook before call() starts the normal ExportPass retracing path. The default hook returns true, preserving existing behavior for ArmPass subclasses. Introduce ArmOpTargetedPass for passes that only transform a known set of operator targets. It implements should_run_pass() by scanning the current graph and nested GraphModules for matching target operators. If no matching target operator is found, the pass returns an unmodified PassResult. For passes that already gate transformations with allowed_to_transform(), allow the target pre-scan to apply the same check before deciding whether the pass needs to run. This avoids running TFA passes when all matching target nodes are marked as disallowed. The should_run_pass() hook and ArmOpTargetedPass pre-scan avoid rebuilding graphs for decomposition and rewrite passes that cannot affect the current graph. The speedup is most visible on large models. Single-run paired benchmarks on Arm backend model tests across FP32, INT, VGF no-quant, and VGF quant variants: | Model | E2E avg | Pass-manager avg | |-------------|--------:|-----------------:| | T5-small | +30.5% | +47.5% | | DeepLabV3 | +12.9% | +49.8% | | Wav2Letter | +16.9% | +51.2% | | InceptionV3 | +22.2% | +46.5% | | MobileNetV2 | +22.2% | +52.5% | | MobileNetV3 | +29.9% | +54.6% | Model rows are unweighted averages over successful variants. Unweighted average across 23 successful model/target variants: E2E speedup: +22.4% Pass-manager speedup: +50.5% Change-Id: Iaa09638473a1d6d1e2ce98f5a0e3fc3a14378143 cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani Signed-off-by: Yufeng Shi <yufeng.shi@arm.com> Co-authored-by: Erik Lundell <erik.lundell@arm.com>
1 parent 1494535 commit 513a4ea

78 files changed

Lines changed: 593 additions & 294 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
from . import arm_pass_utils # noqa
8-
from .arm_pass import ArmPass # noqa # usort: skip
8+
from .arm_pass import ArmOpTargetedPass, ArmPass # noqa # usort: skip
99
from .accumulate_index_put_pass import AccumulateIndexPutPass # noqa
1010
from .broadcast_args_pass import BroadcastArgsPass # noqa
1111
from .canonicalize_gather_pass import CanonicalizeGatherPass # noqa

backends/arm/_passes/accumulate_index_put_pass.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import torch
88

9-
from executorch.backends.arm._passes import ArmPass
9+
from executorch.backends.arm._passes import ArmOpTargetedPass
1010
from executorch.backends.arm._passes.decompose_index_tensor_to_gather_pass import (
1111
DecomposeIndexTensorToGatherPass,
1212
)
@@ -32,7 +32,7 @@ def get_ops(op):
3232
raise RuntimeError(f"Can't get index_put decomposition for op {op}")
3333

3434

35-
class AccumulateIndexPutPass(ArmPass):
35+
class AccumulateIndexPutPass(ArmOpTargetedPass):
3636
"""This pass adjusts the values arg when the accumulate arg is set to true
3737
for the index_put op.
3838
"""
@@ -41,9 +41,11 @@ class AccumulateIndexPutPass(ArmPass):
4141
DecomposeIndexTensorToGatherPass,
4242
RewriteIndexPutPass,
4343
}
44+
target_ops = aten_ops + edge_ops
45+
check_allowed_to_transform = True
4446

4547
def call_operator(self, op, args, kwargs, meta):
46-
if op not in (aten_ops + edge_ops) or not self.allowed_to_transform(meta):
48+
if op not in self.target_ops or not self.allowed_to_transform(meta):
4749
return super().call_operator(op, args, kwargs, meta)
4850

4951
source, indices, values = args[:3]

backends/arm/_passes/arm_pass.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
import copy
88
import traceback
99
from abc import abstractmethod
10+
from collections.abc import Collection
1011
from typing import Any, List, Optional, Set, Type
1112

1213
import torch
1314
from executorch.backends.arm.constants import DISALLOW_TFA_META_KEY
1415
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
1516
from executorch.exir.dialects._ops import ops as exir_ops
1617
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
17-
from torch.fx import GraphModule
18+
from torch.fx import GraphModule, Node
1819
from torch.fx.passes.infra.pass_base import PassResult
1920
from torch.utils import _pytree as pytree
2021

@@ -191,3 +192,99 @@ def call_scalar(self, value: int | float, meta: NodeMetadata | dict[str, Any]):
191192
meta=meta,
192193
updated=True,
193194
)
195+
196+
def should_run_pass(self, graph_module: GraphModule) -> bool:
197+
"""Return whether this pass should run on the graph module.
198+
199+
Subclasses can override this to cheaply skip the pass before
200+
``call()`` starts the normal ``ExportPass`` retracing path.
201+
202+
Args:
203+
graph_module (GraphModule): The graph module to inspect.
204+
205+
Returns:
206+
bool: True when the pass should run.
207+
208+
"""
209+
return True
210+
211+
def __call__(self, graph_module: GraphModule) -> PassResult | None:
212+
self.requires(graph_module)
213+
if not self.should_run_pass(graph_module):
214+
self.ensures(graph_module)
215+
return PassResult(graph_module, False)
216+
res = self.call(graph_module)
217+
self.ensures(graph_module)
218+
return res
219+
220+
221+
class ArmOpTargetedPass(ArmPass):
222+
"""Base class for passes that only transform selected operators.
223+
224+
Subclasses set ``target_ops`` to the call_function targets they can
225+
transform. If the current graph and nested control-flow subgraphs do not
226+
contain any target, the pass returns immediately without paying the default
227+
ExportPass retracing cost.
228+
229+
Set ``check_allowed_to_transform`` to ``True`` when the target pre-scan
230+
should also apply ``allowed_to_transform()`` to matching target nodes. This
231+
is useful for TFA passes whose ``call_operator()`` leaves disallowed target
232+
nodes unchanged. If all matching targets are disallowed, the pass can
233+
return before entering the normal ``ExportPass`` path.
234+
235+
"""
236+
237+
target_ops: Collection[Any] = ()
238+
check_allowed_to_transform = False
239+
240+
def has_target_node(self, graph_module: GraphModule) -> bool:
241+
"""Return whether the graph module tree contains a target node.
242+
243+
Args:
244+
graph_module (GraphModule): The graph module tree to inspect.
245+
246+
Returns:
247+
bool: True if a matching call_function node is present.
248+
249+
"""
250+
visited_graph_modules = set()
251+
252+
def target_node_can_trigger_pass(node: Node) -> bool:
253+
if not self.check_allowed_to_transform:
254+
return True
255+
if self.allowed_to_transform(node.meta):
256+
return True
257+
return False
258+
259+
def graph_has_target(module: GraphModule) -> bool:
260+
if id(module) in visited_graph_modules:
261+
return False
262+
visited_graph_modules.add(id(module))
263+
264+
for target in self.target_ops:
265+
for node in module.graph.find_nodes(
266+
op="call_function",
267+
target=target,
268+
sort=False,
269+
):
270+
if target_node_can_trigger_pass(node):
271+
return True
272+
273+
return any(
274+
isinstance(child, GraphModule) and graph_has_target(child)
275+
for child in module.children()
276+
)
277+
278+
return graph_has_target(graph_module)
279+
280+
def should_run_pass(self, graph_module: GraphModule) -> bool:
281+
"""Return whether this pass has a target node to transform.
282+
283+
Args:
284+
graph_module (GraphModule): The graph module tree to inspect.
285+
286+
Returns:
287+
bool: True when a matching target node is present.
288+
289+
"""
290+
return self.has_target_node(graph_module)

backends/arm/_passes/canonicalize_gather_pass.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
from typing import Set, Type
77

88
import torch
9-
from executorch.backends.arm._passes import ArmPass
9+
from executorch.backends.arm._passes import ArmOpTargetedPass
1010
from executorch.exir.dialects._ops import ops as exir_ops
1111
from executorch.exir.pass_base import ExportPass
1212

1313

14-
class CanonicalizeGatherPass(ArmPass):
14+
class CanonicalizeGatherPass(ArmOpTargetedPass):
1515
"""Canonicalize gather so it can be lowered to TOSA.GATHER via the backend
1616
dialect.
1717
@@ -40,10 +40,10 @@ class CanonicalizeGatherPass(ArmPass):
4040

4141
_passes_required_after: Set[Type[ExportPass]] = set()
4242

43-
_TARGET_OPS = {exir_ops.edge.aten.gather.default}
43+
target_ops = {exir_ops.edge.aten.gather.default}
4444

4545
def call_operator(self, op, args, kwargs, meta):
46-
if op not in self._TARGET_OPS:
46+
if op not in self.target_ops:
4747
return super().call_operator(op, args, kwargs, meta)
4848

4949
# edge.aten.gather.default: (x, dim, index) with kw-only sparse_grad

backends/arm/_passes/conv1d_unsqueeze_pass.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from typing import Set, Type
1010

11-
from executorch.backends.arm._passes import ArmPass
11+
from executorch.backends.arm._passes import ArmOpTargetedPass
1212

1313
from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass
1414
from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass
@@ -17,7 +17,7 @@
1717
from executorch.exir.pass_base import ExportPass
1818

1919

20-
class Conv1dUnsqueezePass(ArmPass):
20+
class Conv1dUnsqueezePass(ArmOpTargetedPass):
2121
"""This pass is used to change conv1d ops into conv2d since TOSA only
2222
supports 2d and 3d convolution.
2323
@@ -34,9 +34,10 @@ class Conv1dUnsqueezePass(ArmPass):
3434
RewriteConvPass,
3535
SizeAdjustInputPass,
3636
}
37+
target_ops = (exir_ops.edge.aten.convolution.default,)
3738

3839
def call_operator(self, op, args, kwargs, meta):
39-
if op != exir_ops.edge.aten.convolution.default:
40+
if op not in self.target_ops:
4041
return super().call_operator(op, args, kwargs, meta)
4142
stride = list(args[3])
4243
if len(stride) != 1:

backends/arm/_passes/convert_expand_copy_to_repeat.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111

12-
from executorch.backends.arm._passes.arm_pass import ArmPass
12+
from executorch.backends.arm._passes.arm_pass import ArmOpTargetedPass
1313
from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import (
1414
UnsqueezeBeforeRepeatPass,
1515
)
@@ -51,7 +51,7 @@ def calculate_multiples(args):
5151
return multiples, expanded_rank != len(input_shape)
5252

5353

54-
class ConvertExpandCopyToRepeatPass(ArmPass):
54+
class ConvertExpandCopyToRepeatPass(ArmOpTargetedPass):
5555
"""Replace expand copy with repeat since it is a repeat that can only repeat
5656
singleton dimensions.
5757
"""
@@ -60,9 +60,10 @@ class ConvertExpandCopyToRepeatPass(ArmPass):
6060

6161
expand_copy = exir_ops.edge.aten.expand_copy.default
6262
repeat = exir_ops.edge.aten.repeat.default
63+
target_ops = (expand_copy,)
6364

6465
def call_operator(self, op, args, kwargs, meta):
65-
if op != self.expand_copy:
66+
if op not in self.target_ops:
6667
return super().call_operator(op, args, kwargs, meta)
6768

6869
multiples, changes_rank = calculate_multiples(args)

backends/arm/_passes/convert_full_like_to_full_pass.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from typing import Set, Type
77

8-
from executorch.backends.arm._passes.arm_pass import ArmPass
8+
from executorch.backends.arm._passes.arm_pass import ArmOpTargetedPass
99
from executorch.backends.arm._passes.fuse_constant_ops_pass import (
1010
ComputeConstantOpsAOTPass,
1111
)
@@ -14,7 +14,7 @@
1414
from executorch.exir.pass_base import ExportPass
1515

1616

17-
class ConvertFullLikeToFullPass(ArmPass):
17+
class ConvertFullLikeToFullPass(ArmOpTargetedPass):
1818
"""Convert edge aten full_like to full.
1919
2020
As per the full_like PyTorch documentation, `torch.full_like(input,
@@ -35,11 +35,10 @@ class ConvertFullLikeToFullPass(ArmPass):
3535
"""
3636

3737
_passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass}
38+
target_ops = (exir_ops.edge.aten.full_like.default,)
3839

3940
def call_operator(self, op, args, kwargs, meta):
40-
if op not in [
41-
exir_ops.edge.aten.full_like.default,
42-
]:
41+
if op not in self.target_ops:
4342
return super().call_operator(op, args, kwargs, meta)
4443

4544
tensor = args[0].data

backends/arm/_passes/convert_permute_singleton_to_view_pass.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from typing import Sequence, Set, Tuple, Type
88

9-
from executorch.backends.arm._passes.arm_pass import ArmPass
9+
from executorch.backends.arm._passes.arm_pass import ArmOpTargetedPass
1010

1111
from executorch.exir.dialects._ops import ops as exir_ops
1212
from executorch.exir.pass_base import ExportPass
@@ -20,7 +20,7 @@
2020
)
2121

2222

23-
class ConvertPermuteSingletonToViewPass(ArmPass):
23+
class ConvertPermuteSingletonToViewPass(ArmOpTargetedPass):
2424
"""Replace permutations that only move singleton axes with a reshape.
2525
2626
Examples:
@@ -34,9 +34,10 @@ class ConvertPermuteSingletonToViewPass(ArmPass):
3434
"""
3535

3636
_passes_required_after: Set[Type[ExportPass]] = set()
37+
target_ops = _PERMUTE_TARGETS
3738

3839
def call_operator(self, op, args, kwargs, meta):
39-
if op not in _PERMUTE_TARGETS:
40+
if op not in self.target_ops:
4041
return super().call_operator(op, args, kwargs, meta)
4142

4243
input_tensor = args[0].data

backends/arm/_passes/convert_squeezes_to_view.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66

77
from typing import Set, Type
88

9-
from executorch.backends.arm._passes import ArmPass
9+
from executorch.backends.arm._passes import ArmOpTargetedPass
1010
from executorch.backends.arm._passes.fuse_view_copy_transform_pass import (
1111
FuseViewCopyTransformPass,
1212
)
1313
from executorch.exir.dialects._ops import ops as exir_ops
1414
from executorch.exir.pass_base import ExportPass
1515

1616

17-
class ConvertSqueezesToViewPass(ArmPass):
17+
class ConvertSqueezesToViewPass(ArmOpTargetedPass):
1818
"""Replaces squeeze/unsqueeze operators with view.
1919
2020
These are simply special cases of the view op, so removing them gives us
@@ -23,12 +23,13 @@ class ConvertSqueezesToViewPass(ArmPass):
2323
"""
2424

2525
_passes_required_after: Set[Type[ExportPass]] = {FuseViewCopyTransformPass}
26+
target_ops = (
27+
exir_ops.edge.aten.squeeze_copy.dims,
28+
exir_ops.edge.aten.unsqueeze_copy.default,
29+
)
2630

2731
def call_operator(self, op, args, kwargs, meta):
28-
if op not in [
29-
exir_ops.edge.aten.squeeze_copy.dims,
30-
exir_ops.edge.aten.unsqueeze_copy.default,
31-
]:
32+
if op not in self.target_ops:
3233
return super().call_operator(op, args, kwargs, meta)
3334

3435
x = args[0]

backends/arm/_passes/convert_to_clamp_pass.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

66
from typing import Set, Tuple, Type
77

8-
from executorch.backends.arm._passes import ArmPass
8+
from executorch.backends.arm._passes import ArmOpTargetedPass
99

1010
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1111
QuantizeClampArgumentsPass,
@@ -29,11 +29,13 @@ def get_clamp_params(op, args) -> Tuple[float | None, float | None]:
2929
raise ValueError(f"Getting clamp parameters for op {op} is not implemented.")
3030

3131

32-
class ConvertToClampPass(ArmPass):
32+
class ConvertToClampPass(ArmOpTargetedPass):
3333
_passes_required_after: Set[Type[ExportPass]] = {QuantizeClampArgumentsPass}
34+
target_ops = edge_operators
35+
check_allowed_to_transform = True
3436

3537
def call_operator(self, op, args, kwargs, meta):
36-
if op not in edge_operators or not self.allowed_to_transform(meta):
38+
if op not in self.target_ops or not self.allowed_to_transform(meta):
3739
return super().call_operator(op, args, kwargs, meta)
3840

3941
return super().call_operator(

0 commit comments

Comments
 (0)