55# LICENSE file in the root directory of this source tree.
66
77# pyre-strict
8-
9- from typing import Callable , List , Optional , Union
8+ import copy
9+ import inspect
10+ import logging
11+ from typing import Callable , List , Optional , TypeAlias , Union
1012
1113import torch
1214import torch .fx .passes .infra .pass_manager as fx
1315import torch .utils ._pytree as pytree
1416from executorch .exir .error import ExportError , ExportErrorType
17+ from executorch .exir .pass_base import ExportedProgramPassBase , ExportedProgramPassResult
18+ from torch .export import ExportedProgram
1519from torch .fx .passes .infra .pass_base import PassResult
16- from typing_extensions import TypeAlias
20+ from torch .fx .passes .infra .pass_manager import pass_result_wrapper
21+
22+ logger = logging .getLogger (__name__ )
23+ logger .setLevel (logging .WARNING )
24+
25+ PassType : TypeAlias = Union [
26+ ExportedProgramPassBase , Callable [[torch .fx .GraphModule ], Optional [PassResult ]]
27+ ]
28+
1729
18- PassType : TypeAlias = Callable [[torch .fx .GraphModule ], Optional [PassResult ]]
30+ def _get_pass_name (fn : PassType ) -> str :
31+ return fn .__name__ if inspect .isfunction (fn ) else type (fn ).__name__
1932
2033
2134class PassManager (fx .PassManager ):
@@ -27,21 +40,27 @@ class PassManager(fx.PassManager):
2740 * **passes**: A list of callable passes
2841 * **params**: An instance of PassManagerParams containing the result of the
2942 flags set in the constructor.
43+
44+ Note: This class is deprecated. Please use ExportedProgramPassManager instead.
3045 """
3146
3247 def __init__ (
3348 self ,
3449 passes : Optional [Union [List [PassType ], List [List [PassType ]]]] = None ,
3550 run_checks_after_each_pass : bool = False ,
3651 suppress_check_failures : bool = False ,
52+ steps : int = 1 ,
3753 ) -> None :
3854 r"""
3955 Args:
4056 passes: A list of passes
41- enable_debug_pass: set to true to enable the debug passes
42- run_checks_after_each_pass: whether to run checks and linting after each pass
57+ run_checks_after_each_pass: Whether to run checks and linting after each pass
58+ suppress_check_failures: Whether to raise errors when running checks
59+ steps: Number of times we wish to run passes iteratively.
4360 """
44-
61+ logger .warning (
62+ "PassManager is deprecated. Please use ExportedProgramPassManager instead."
63+ )
4564 # Flatten the passes to a list of callables
4665 passes = passes if passes else []
4766 flattened_passes = [
@@ -52,6 +71,7 @@ def __init__(
5271 flattened_passes ,
5372 run_checks_after_each_pass = run_checks_after_each_pass ,
5473 suppress_check_failures = suppress_check_failures ,
74+ steps = steps ,
5575 )
5676
5777 def check (self , module : torch .nn .Module ) -> None :
@@ -65,7 +85,7 @@ def check(self, module: torch.nn.Module) -> None:
6585 node's spec field is a tuple)
6686 - Ensure that the graph module has type torch.fx.GraphModule
6787 """
68- assert isinstance (module , fx .GraphModule )
88+ assert isinstance (module , torch . fx .GraphModule )
6989 module .recompile ()
7090 module .graph .lint ()
7191 # TODO(qihan): use verifier.check_is_exir
@@ -76,3 +96,151 @@ def check(self, module: torch.nn.Module) -> None:
7696 ExportErrorType .NOT_SUPPORTED ,
7797 f"call_method `{ node } ` is not supported except for backend delegate." ,
7898 )
99+
100+
101+ class ExportedProgramPassManager (fx .PassManager ):
102+ """
103+ Runs multiple passes on an ExportedProgram.
104+
105+ This PassManager is specifically designed for ExportedProgram and supports
106+ both GraphModule-only passes and ExportedProgram-aware passes.
107+
108+ For running passes on GraphModule directly, use PassManager instead.
109+ """
110+
111+ def __init__ (
112+ self ,
113+ passes : Optional [Union [List [PassType ], List [List [PassType ]]]] = None ,
114+ constraints : Optional [List [Callable [[Callable , Callable ], bool ]]] = None ,
115+ run_checks_after_each_pass : bool = False ,
116+ # Setting default to True since many tests are failing pre-verification.
117+ suppress_exported_program_pre_verification : bool = True ,
118+ steps : int = 1 ,
119+ override_verifiers : bool = False ,
120+ ) -> None :
121+ wrapped_passes = (
122+ [
123+ (
124+ fn
125+ if isinstance (fn , ExportedProgramPassBase )
126+ else pass_result_wrapper (fn )
127+ )
128+ for fn in pytree .tree_flatten (passes )[0 ]
129+ ]
130+ if passes
131+ else []
132+ )
133+
134+ if suppress_exported_program_pre_verification :
135+ logger .warning (
136+ "Pre-verification of exported program is suppressed. This means that the exported program may pass validation prior to running the pass manager."
137+ )
138+
139+ super ().__init__ (
140+ wrapped_passes ,
141+ constraints = constraints ,
142+ run_checks_after_each_pass = run_checks_after_each_pass ,
143+ suppress_check_failures = suppress_exported_program_pre_verification ,
144+ steps = steps ,
145+ )
146+ self ._override_verifiers = override_verifiers
147+
148+ def check (self , exported_program : ExportedProgram ) -> None :
149+ """
150+ Runs exported program validation.
151+ """
152+ if not self .suppress_check_failures :
153+ exported_program .validate ()
154+
155+ module = exported_program .graph_module
156+ module .recompile ()
157+ module .graph .lint ()
158+
159+ for node in module .graph .nodes :
160+ if node .op == "call_method" :
161+ raise ExportError (
162+ ExportErrorType .NOT_SUPPORTED ,
163+ f"call_method `{ node } ` is not supported except for backend delegate." ,
164+ )
165+
166+ # pyre-ignore[14]: Intentionally overriding with different signature for ExportedProgram
167+ def __call__ (self , exported_program : ExportedProgram ) -> ExportedProgramPassResult :
168+ """
169+ Runs passes on an ExportedProgram.
170+
171+ Handles both GraphModule-only passes and ExportedProgram-aware passes.
172+
173+ Args:
174+ exported_program: The exported program to transform.
175+
176+ Returns:
177+ ExportedProgramPassResult containing the transformed program and whether or
178+ not the program was modified.
179+ """
180+ # Lazy import to avoid circular dependency
181+ from executorch .exir .program ._program import (
182+ _update_exported_program_graph_module ,
183+ )
184+
185+ if not self ._validated :
186+ self .solve_constraints ()
187+
188+ exported_program = copy .copy (exported_program )
189+
190+ # Check graph invariants before running passes
191+ self .check (exported_program )
192+
193+ overall_modified = False
194+
195+ for _ in range (self .steps ):
196+ step_modified = False
197+
198+ for i , fn in enumerate (self .passes ):
199+ try :
200+ if not isinstance (fn , ExportedProgramPassBase ):
201+ result = fn (exported_program .graph_module )
202+ if result .modified :
203+ logger .debug (
204+ "Graph after pass '%s': %s" ,
205+ _get_pass_name (fn ),
206+ result .graph_module .graph ,
207+ )
208+ result .graph_module .recompile ()
209+
210+ exported_program = _update_exported_program_graph_module (
211+ exported_program ,
212+ result .graph_module ,
213+ self ._override_verifiers ,
214+ )
215+ step_modified = step_modified or result .modified
216+
217+ if self .run_checks_after_each_pass :
218+ self .check (exported_program )
219+ else :
220+ assert isinstance (fn , ExportedProgramPassBase )
221+ result = fn (exported_program )
222+ if result .modified :
223+ logger .debug (
224+ "Graph after pass '%s': %s" ,
225+ _get_pass_name (fn ),
226+ result .exported_program .graph_module .graph ,
227+ )
228+ result .exported_program .graph_module .recompile ()
229+
230+ exported_program = result .exported_program
231+ step_modified = step_modified or result .modified
232+
233+ if self .run_checks_after_each_pass :
234+ self .check (exported_program )
235+
236+ except Exception as e :
237+ prev_names = [_get_pass_name (p ) for p in self .passes [:i ]]
238+ msg = f"An error occurred when running the '{ _get_pass_name (fn )} ' pass after the following passes: { prev_names } "
239+ e .add_note (msg )
240+ raise
241+
242+ overall_modified = overall_modified or step_modified
243+ if not step_modified :
244+ break
245+
246+ return ExportedProgramPassResult (exported_program , overall_modified )
0 commit comments