Skip to content

Commit 22174fa

Browse files
authored
Implement scheme for custom arm_pass_manager insertions (#16396)
Differential Revision: D87120925 Pull Request resolved: #16396
1 parent a81ef44 commit 22174fa

6 files changed

Lines changed: 454 additions & 6 deletions

File tree

backends/arm/_passes/TARGETS

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,55 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
22

33
runtime.python_library(
4-
name = "passes",
5-
srcs = glob(["*.py"]),
4+
name = "core",
5+
srcs = [
6+
"arm_pass.py",
7+
"arm_pass_utils.py",
8+
"quant_args.py",
9+
],
610
deps = [
711
"//executorch/backends/arm:common",
812
"//executorch/backends/arm:constants",
913
"//executorch/backends/arm/tosa:utils",
1014
"//executorch/backends/arm/tosa/dialect:lib",
15+
"//executorch/backends/transforms:utils",
16+
"//executorch/exir:lib",
17+
],
18+
)
19+
20+
runtime.python_library(
21+
name = "arm_pass_manager_base",
22+
srcs = ["arm_pass_manager.py"],
23+
deps = [
24+
":core",
25+
"//executorch/backends/arm:common",
26+
"//executorch/backends/arm/tosa:utils",
27+
"//executorch/backends/arm/tosa/dialect:lib",
28+
"//executorch/exir:lib",
29+
],
30+
)
31+
32+
runtime.python_library(
33+
name = "arm_pass_manager_fb",
34+
srcs = [],
35+
deps = [
36+
":arm_pass_manager_base",
37+
# @oss-disable[end= ]: "//executorch/backends/arm/_passes/fb:fb",
38+
],
39+
)
40+
41+
runtime.python_library(
42+
name = "passes",
43+
srcs = glob(["*.py"], exclude = [
44+
"arm_pass.py",
45+
"arm_pass_utils.py",
46+
"quant_args.py",
47+
]),
48+
deps = [
49+
":core",
50+
":arm_pass_manager_base" if runtime.is_oss else ":arm_pass_manager_fb",
51+
"//executorch/backends/arm/tosa:utils",
52+
"//executorch/backends/arm/tosa/dialect:lib",
1153
"//executorch/backends/transforms:fuse_view_copy",
1254
"//executorch/backends/transforms:remove_getitem_op",
1355
"//executorch/backends/transforms:replace_scalar_with_tensor",

backends/arm/_passes/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,4 +158,13 @@
158158
from .control_flow_const_inline import ( # noqa # usort: skip
159159
ControlFlowConstInlinePass,
160160
)
161-
from .arm_pass_manager import ArmPassManager # noqa # usort: skip
161+
162+
# Import all subpackages to allow extensions to patch classes
163+
import importlib # noqa: E402
164+
import pkgutil # noqa: E402
165+
166+
for _, _modname, _ispkg in pkgutil.iter_modules(__path__, __name__ + "."):
167+
if _ispkg:
168+
importlib.import_module(_modname)
169+
170+
from .arm_pass_manager import ArmPassManager # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import logging
99
from collections import defaultdict
1010
from collections.abc import Sequence
11+
from dataclasses import dataclass, field
1112

1213
import executorch.backends.arm.tosa.dialect # noqa: unused
1314
from executorch.backends.arm._passes import (
@@ -157,11 +158,21 @@
157158
logger = logging.getLogger(__name__)
158159

159160

161+
@dataclass
162+
class PassInsertions:
163+
"""Holds lists of passes to be inserted before and after a target pass."""
164+
165+
before_passes: list = field(default_factory=list)
166+
after_passes: list = field(default_factory=list)
167+
168+
160169
class ArmPassManager(PassManager):
161170
def __init__(self, compile_spec: ArmCompileSpec) -> None:
162171
self.compile_spec = compile_spec
163172
self.tosa_spec = compile_spec.tosa_spec
164173
self._skip_pass_types: tuple[type, ...] = ()
174+
self._pass_insertions: dict[type, PassInsertions] = {}
175+
self._insertions_applied = False
165176
super().__init__()
166177
self.configure_skip_passes()
167178

@@ -223,6 +234,98 @@ def validate_constraints_mandatory(self):
223234

224235
raise RuntimeError(error_msg)
225236

237+
def insert_passes_before(
238+
self, target_pass_type: type, passes: list[ExportPass]
239+
) -> None:
240+
"""Register passes to be inserted before instances of target_pass_type.
241+
Insertions are deferred and applied via _apply_pass_insertions().
242+
243+
Args:
244+
target_pass_type: The pass class to insert before (e.g., InsertTableOpsPass)
245+
passes: List of pass instances to insert
246+
247+
"""
248+
self._pass_insertions.setdefault(
249+
target_pass_type, PassInsertions()
250+
).before_passes.extend(passes)
251+
252+
def insert_passes_after(
253+
self, target_pass_type: type, passes: list[ExportPass]
254+
) -> None:
255+
"""Register passes to be inserted after instances of target_pass_type.
256+
Insertions are deferred and applied via _apply_pass_insertions().
257+
258+
Args:
259+
target_pass_type: The pass class to insert after
260+
passes: List of pass instances to insert
261+
262+
"""
263+
self._pass_insertions.setdefault(
264+
target_pass_type, PassInsertions()
265+
).after_passes.extend(passes)
266+
267+
def _apply_pass_insertions(self) -> None:
268+
"""Apply all registered pass insertions to the collected passes.
269+
270+
Called ONCE after all add_passes() calls are complete, before execution.
271+
272+
Raises:
273+
ValueError: If any registered target pass type is not found in the pipeline.
274+
275+
"""
276+
if self._insertions_applied or not self._pass_insertions:
277+
return
278+
279+
# Fail fast if any target pass type is missing from the pipeline
280+
existing_pass_types = {type(p) for p in self.passes}
281+
for target_type in self._pass_insertions:
282+
if target_type not in existing_pass_types:
283+
available = [type(p).__name__ for p in self.passes]
284+
raise ValueError(
285+
f"Target pass {target_type.__name__} not found in the pass "
286+
f"pipeline. Available passes: {available}"
287+
)
288+
289+
# Build new pass list with insertions applied
290+
new_passes = []
291+
for pass_obj in self.passes:
292+
pass_type = type(pass_obj)
293+
294+
# Insert passes BEFORE this pass
295+
if pass_type in self._pass_insertions:
296+
insertions = self._pass_insertions[pass_type]
297+
for before_pass in insertions.before_passes:
298+
# Check if we should skip this inserted pass
299+
if type(before_pass) not in self._skip_pass_types:
300+
new_passes.append(before_pass)
301+
302+
# Add the original pass
303+
new_passes.append(pass_obj)
304+
305+
# Insert passes AFTER this pass
306+
if pass_type in self._pass_insertions:
307+
insertions = self._pass_insertions[pass_type]
308+
for after_pass in insertions.after_passes:
309+
# Check if we should skip this inserted pass
310+
if type(after_pass) not in self._skip_pass_types:
311+
new_passes.append(after_pass)
312+
313+
# Replace the passes list
314+
self.passes = new_passes
315+
self._insertions_applied = True
316+
317+
def _configure_pass_insertions(self, exported_program: ExportedProgram) -> None:
318+
"""Hook for subclasses to configure pass insertions. Called at the START
319+
of pipeline construction, before any passes are added.
320+
321+
Subclasses should override this to call insert_passes_before/after.
322+
323+
Args:
324+
exported_program: The exported program being transformed
325+
326+
"""
327+
pass
328+
226329
def add_passes(self, passes: Sequence[ExportPass | None]):
227330
for p in passes:
228331
if p is not None:
@@ -241,6 +344,9 @@ def add_pass(self, pipeline_pass):
241344
def _tosa_pipeline(
242345
self, exported_program: ExportedProgram, graph_module: GraphModule
243346
) -> GraphModule:
347+
# Allow subclasses to configure pass insertions before building pipeline
348+
self._configure_pass_insertions(exported_program)
349+
244350
# Preprocessing passes
245351
self.add_pass(AnnotateOutputDimOrderPass())
246352

@@ -398,6 +504,9 @@ def _tosa_pipeline(
398504
]
399505
)
400506

507+
# Apply all pass insertions once after all passes are collected
508+
self._apply_pass_insertions()
509+
401510
self.validate_constraints_mandatory()
402511
return self._transform(graph_module)
403512

backends/arm/test/ops/test_rsqrt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def test_rsqrt_tosa_INT_a16w8(test_tensor: torch.Tensor):
146146

147147
@common.parametrize("test_tensor", Rsqrt.test_parameters)
148148
@common.XfailIfNoCorstone300
149-
def test_rsqrt_16a8w_u55_INT16(test_tensor: torch.Tensor):
149+
def test_rsqrt_16a8w_u55_INT(test_tensor: torch.Tensor):
150150
"""Test rsqrt operation with int16 I/O quantization for U55."""
151151
# Use wider tolerances for int16 I/O quantization on U55
152152
pipeline = EthosU55PipelineINT[input_t1](

0 commit comments

Comments
 (0)