88import logging
99from collections import defaultdict
1010from collections .abc import Sequence
11+ from dataclasses import dataclass , field
1112
1213import executorch .backends .arm .tosa .dialect # noqa: unused
1314from executorch .backends .arm ._passes import (
157158logger = logging .getLogger (__name__ )
158159
159160
161+ @dataclass
162+ class PassInsertions :
163+ """Holds lists of passes to be inserted before and after a target pass."""
164+
165+ before_passes : list = field (default_factory = list )
166+ after_passes : list = field (default_factory = list )
167+
168+
160169class ArmPassManager (PassManager ):
161170 def __init__ (self , compile_spec : ArmCompileSpec ) -> None :
162171 self .compile_spec = compile_spec
163172 self .tosa_spec = compile_spec .tosa_spec
164173 self ._skip_pass_types : tuple [type , ...] = ()
174+ self ._pass_insertions : dict [type , PassInsertions ] = {}
175+ self ._insertions_applied = False
165176 super ().__init__ ()
166177 self .configure_skip_passes ()
167178
@@ -223,6 +234,98 @@ def validate_constraints_mandatory(self):
223234
224235 raise RuntimeError (error_msg )
225236
237+ def insert_passes_before (
238+ self , target_pass_type : type , passes : list [ExportPass ]
239+ ) -> None :
240+ """Register passes to be inserted before instances of target_pass_type.
241+ Insertions are deferred and applied via _apply_pass_insertions().
242+
243+ Args:
244+ target_pass_type: The pass class to insert before (e.g., InsertTableOpsPass)
245+ passes: List of pass instances to insert
246+
247+ """
248+ self ._pass_insertions .setdefault (
249+ target_pass_type , PassInsertions ()
250+ ).before_passes .extend (passes )
251+
252+ def insert_passes_after (
253+ self , target_pass_type : type , passes : list [ExportPass ]
254+ ) -> None :
255+ """Register passes to be inserted after instances of target_pass_type.
256+ Insertions are deferred and applied via _apply_pass_insertions().
257+
258+ Args:
259+ target_pass_type: The pass class to insert after
260+ passes: List of pass instances to insert
261+
262+ """
263+ self ._pass_insertions .setdefault (
264+ target_pass_type , PassInsertions ()
265+ ).after_passes .extend (passes )
266+
267+ def _apply_pass_insertions (self ) -> None :
268+ """Apply all registered pass insertions to the collected passes.
269+
270+ Called ONCE after all add_passes() calls are complete, before execution.
271+
272+ Raises:
273+ ValueError: If any registered target pass type is not found in the pipeline.
274+
275+ """
276+ if self ._insertions_applied or not self ._pass_insertions :
277+ return
278+
279+ # Fail fast if any target pass type is missing from the pipeline
280+ existing_pass_types = {type (p ) for p in self .passes }
281+ for target_type in self ._pass_insertions :
282+ if target_type not in existing_pass_types :
283+ available = [type (p ).__name__ for p in self .passes ]
284+ raise ValueError (
285+ f"Target pass { target_type .__name__ } not found in the pass "
286+ f"pipeline. Available passes: { available } "
287+ )
288+
289+ # Build new pass list with insertions applied
290+ new_passes = []
291+ for pass_obj in self .passes :
292+ pass_type = type (pass_obj )
293+
294+ # Insert passes BEFORE this pass
295+ if pass_type in self ._pass_insertions :
296+ insertions = self ._pass_insertions [pass_type ]
297+ for before_pass in insertions .before_passes :
298+ # Check if we should skip this inserted pass
299+ if type (before_pass ) not in self ._skip_pass_types :
300+ new_passes .append (before_pass )
301+
302+ # Add the original pass
303+ new_passes .append (pass_obj )
304+
305+ # Insert passes AFTER this pass
306+ if pass_type in self ._pass_insertions :
307+ insertions = self ._pass_insertions [pass_type ]
308+ for after_pass in insertions .after_passes :
309+ # Check if we should skip this inserted pass
310+ if type (after_pass ) not in self ._skip_pass_types :
311+ new_passes .append (after_pass )
312+
313+ # Replace the passes list
314+ self .passes = new_passes
315+ self ._insertions_applied = True
316+
317+ def _configure_pass_insertions (self , exported_program : ExportedProgram ) -> None :
318+ """Hook for subclasses to configure pass insertions. Called at the START
319+ of pipeline construction, before any passes are added.
320+
321+ Subclasses should override this to call insert_passes_before/after.
322+
323+ Args:
324+ exported_program: The exported program being transformed
325+
326+ """
327+ pass
328+
226329 def add_passes (self , passes : Sequence [ExportPass | None ]):
227330 for p in passes :
228331 if p is not None :
@@ -241,6 +344,9 @@ def add_pass(self, pipeline_pass):
241344 def _tosa_pipeline (
242345 self , exported_program : ExportedProgram , graph_module : GraphModule
243346 ) -> GraphModule :
347+ # Allow subclasses to configure pass insertions before building pipeline
348+ self ._configure_pass_insertions (exported_program )
349+
244350 # Preprocessing passes
245351 self .add_pass (AnnotateOutputDimOrderPass ())
246352
@@ -398,6 +504,9 @@ def _tosa_pipeline(
398504 ]
399505 )
400506
507+ # Apply all pass insertions once after all passes are collected
508+ self ._apply_pass_insertions ()
509+
401510 self .validate_constraints_mandatory ()
402511 return self ._transform (graph_module )
403512
0 commit comments