1010from collections .abc import Sequence
1111from dataclasses import dataclass , field
1212
13- import executorch .backends .arm .tosa .dialect # noqa: unused
1413from executorch .backends .arm ._passes import (
1514 AccumulateIndexPutPass ,
1615 AnnotateOutputDimOrderPass ,
126125 RemoveNoopPass ,
127126 ReplaceInfAndLimitValuesPass ,
128127 ReplaceScalarWithTensorByProfilePass ,
128+ RewriteAvgPool2dPass ,
129129 RewriteBoolBitwiseToLogicalPass ,
130130 RewriteBoolToFp32CastViaInt8Pass ,
131131 RewriteConvPass ,
134134 RewriteInplaceArithmeticPass ,
135135 RewriteLeLtToGeGtPass ,
136136 RewriteMatmulPass ,
137+ RewriteMaxPool2dPass ,
137138 RewritePadPass ,
138139 RewriteSlicePass ,
139140 RewriteUpsamplePass ,
143144 UnsqueezeBeforeRepeatPass ,
144145 UnsqueezeScalarPlaceholdersPass ,
145146)
146-
147147from executorch .backends .arm ._passes .arm_pass import ArmPass
148148from executorch .backends .arm .common .arm_compile_spec import ArmCompileSpec
149149from executorch .backends .arm .common .pipeline_config import (
@@ -175,6 +175,32 @@ class PassInsertions:
175175 after_passes : list = field (default_factory = list )
176176
177177
178+ _registered_pass_insertions : dict [type , PassInsertions ] = {}
179+
180+
181+ def register_pass_insertions_before (
182+ target_pass_type : type , passes : list [ExportPass ]
183+ ) -> None :
184+ """Register passes to be inserted before a target pass for all pipelines."""
185+ if target_pass_type not in _registered_pass_insertions :
186+ _registered_pass_insertions [target_pass_type ] = PassInsertions ()
187+ _registered_pass_insertions [target_pass_type ].before_passes .extend (passes )
188+
189+
190+ def register_pass_insertions_after (
191+ target_pass_type : type , passes : list [ExportPass ]
192+ ) -> None :
193+ """Register passes to be inserted after a target pass for all pipelines."""
194+ if target_pass_type not in _registered_pass_insertions :
195+ _registered_pass_insertions [target_pass_type ] = PassInsertions ()
196+ _registered_pass_insertions [target_pass_type ].after_passes .extend (passes )
197+
198+
199+ def clear_registered_pass_insertions () -> None :
200+ """Clear all globally registered pass insertions."""
201+ _registered_pass_insertions .clear ()
202+
203+
178204class ArmPassManager (PassManager ):
179205 def __init__ (self , compile_spec : ArmCompileSpec ) -> None :
180206 self .compile_spec = compile_spec
@@ -323,13 +349,17 @@ def _configure_pass_insertions(self, exported_program: ExportedProgram) -> None:
323349 """Hook for subclasses to configure pass insertions. Called at the START
324350 of pipeline construction, before any passes are added.
325351
326- Subclasses should override this to call insert_passes_before/after.
352+ Subclasses can override this to call insert_passes_before/after.
327353
328354 Args:
329355 exported_program: The exported program being transformed
330356
331357 """
332- pass
358+ for pass_type , insertions in _registered_pass_insertions .items ():
359+ if insertions .before_passes :
360+ self .insert_passes_before (pass_type , list (insertions .before_passes ))
361+ if insertions .after_passes :
362+ self .insert_passes_after (pass_type , list (insertions .after_passes ))
333363
334364 def add_passes (self , passes : Sequence [ExportPass | None ]):
335365 for p in passes :
@@ -432,6 +462,8 @@ def _tosa_pipeline(
432462 DecomposeSliceScatterPass (),
433463 AccumulateIndexPutPass (),
434464 DecomposeIndexTensorToGatherPass (),
465+ DecomposeAdaptiveAvgPool2dPass (),
466+ DecomposeAvgPool2dPass (),
435467 Conv1dUnsqueezePass (),
436468 ]
437469 )
@@ -468,17 +500,16 @@ def _tosa_pipeline(
468500 DecomposeSoftmaxPass (),
469501 ConvertMinMaxPass (),
470502 DecomposeAnyPass (),
471- DecomposeAdaptiveAvgPool2dPass (),
472- DecomposeAvgPool2dPass (),
473503 DecorateFp32toInt32CastingPass (),
474- ComputeConstantOpsAOTPass (exported_program ),
475- FuseConstantArgsPass (exported_program ),
476504 ConvertExpandCopyToRepeatPass (),
477505 UnsqueezeBeforeRepeatPass (),
478506 DecomposeCumsumPass (exported_program ),
479507 DecomposeAsStridedCopyPass (),
480508 DecomposeMaxPool2dPass (),
481509 SizeAdjustInputPass (),
510+ RewriteAvgPool2dPass (),
511+ ComputeConstantOpsAOTPass (exported_program ),
512+ FuseConstantArgsPass (exported_program ),
482513 DecomposeSelectPass (),
483514 ConvertSqueezesToViewPass (),
484515 CastToInt32Pass (),
@@ -496,6 +527,7 @@ def _tosa_pipeline(
496527 self .add_passes (
497528 [
498529 RewriteUpsamplePass (),
530+ RewriteMaxPool2dPass (),
499531 RewriteConvPass (exported_program ),
500532 RewriteMatmulPass (),
501533 RewritePadPass (),
@@ -573,6 +605,8 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
573605 DecomposeLayerNormPass (tfa_pass = True ),
574606 DecomposeVarPass (tfa_pass = True ),
575607 DecomposeMeanDimPass (graph_module , self .tosa_spec , tfa_pass = True ),
608+ DecomposeAdaptiveAvgPool2dPass (tfa_pass = True ),
609+ DecomposeAvgPool2dPass (tfa_pass = True ),
576610 ]
577611 )
578612
@@ -598,8 +632,6 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
598632 DecomposeDivPass (tfa_pass = True ),
599633 DecomposeLinalgVectorNormPass (tfa_pass = True ),
600634 DecomposeSqrtPass (tfa_pass = True ),
601- DecomposeAdaptiveAvgPool2dPass (tfa_pass = True ),
602- DecomposeAvgPool2dPass (tfa_pass = True ),
603635 DecomposeSoftmaxPass (
604636 tfa_pass = True ,
605637 ),
0 commit comments