Skip to content

Commit 70fd62b

Browse files
authored
Arm backend: allow global register for ArmBackend pass insertions (#18937)
To make the use of custom ops easier (which rely on passes to handle rewrite and prevent decomposition) - provide a simple mechanism to insert them into the pass manager using existing functionality. Also a small commit which fixes missing doc updates. cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson --------- Signed-off-by: Rob Elliott <Robert.Elliott@arm.com>
1 parent cb30495 commit 70fd62b

3 files changed

Lines changed: 44 additions & 2 deletions

File tree

backends/arm/_passes/arm_pass_manager.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,32 @@ class PassInsertions:
175175
after_passes: list = field(default_factory=list)
176176

177177

178+
_registered_pass_insertions: dict[type, PassInsertions] = {}
179+
180+
181+
def register_pass_insertions_before(
182+
target_pass_type: type, passes: list[ExportPass]
183+
) -> None:
184+
"""Register passes to be inserted before a target pass for all pipelines."""
185+
if target_pass_type not in _registered_pass_insertions:
186+
_registered_pass_insertions[target_pass_type] = PassInsertions()
187+
_registered_pass_insertions[target_pass_type].before_passes.extend(passes)
188+
189+
190+
def register_pass_insertions_after(
191+
target_pass_type: type, passes: list[ExportPass]
192+
) -> None:
193+
"""Register passes to be inserted after a target pass for all pipelines."""
194+
if target_pass_type not in _registered_pass_insertions:
195+
_registered_pass_insertions[target_pass_type] = PassInsertions()
196+
_registered_pass_insertions[target_pass_type].after_passes.extend(passes)
197+
198+
199+
def clear_registered_pass_insertions() -> None:
200+
"""Clear all globally registered pass insertions."""
201+
_registered_pass_insertions.clear()
202+
203+
178204
class ArmPassManager(PassManager):
179205
def __init__(self, compile_spec: ArmCompileSpec) -> None:
180206
self.compile_spec = compile_spec
@@ -323,13 +349,17 @@ def _configure_pass_insertions(self, exported_program: ExportedProgram) -> None:
323349
"""Hook for subclasses to configure pass insertions. Called at the START
324350
of pipeline construction, before any passes are added.
325351
326-
Subclasses should override this to call insert_passes_before/after.
352+
Subclasses can override this to call insert_passes_before/after.
327353
328354
Args:
329355
exported_program: The exported program being transformed
330356
331357
"""
332-
pass
358+
for pass_type, insertions in _registered_pass_insertions.items():
359+
if insertions.before_passes:
360+
self.insert_passes_before(pass_type, list(insertions.before_passes))
361+
if insertions.after_passes:
362+
self.insert_passes_after(pass_type, list(insertions.after_passes))
333363

334364
def add_passes(self, passes: Sequence[ExportPass | None]):
335365
for p in passes:

docs/source/backends/arm-ethos-u/arm-ethos-u-partitioner.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,9 @@ Args:
4545
Returns:
4646
- **PartitionResult**: The input program with nodes tagged for delegation
4747
and a mapping of partition tags to delegation specs.
48+
49+
```python
50+
def EthosUPartitioner.register_custom_partition_op(self, op: torch._ops.OpOverload) -> None:
51+
```
52+
Register a custom op to be considered supported by this
53+
partitioner.

docs/source/backends/arm-vgf/arm-vgf-partitioner.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,9 @@ Args:
4545
Returns:
4646
- **PartitionResult**: The input program with nodes tagged for delegation
4747
and a mapping of partition tags to delegation specs.
48+
49+
```python
50+
def VgfPartitioner.register_custom_partition_op(self, op: torch._ops.OpOverload) -> None:
51+
```
52+
Register a custom op to be considered supported by this
53+
partitioner.

0 commit comments

Comments
 (0)