55# LICENSE file in the root directory of this source tree.
66
77# pyre-strict
8-
9- from typing import Callable , List , Optional , Union
8+ import copy
9+ from typing import Callable , cast , List , Optional , Union
1010
1111import torch
1212import torch .fx .passes .infra .pass_manager as fx
1313import torch .utils ._pytree as pytree
1414from executorch .exir .error import ExportError , ExportErrorType
15+ from executorch .exir .pass_base import (
16+ ExportedProgramPassBase ,
17+ ExportedProgramPassResult ,
18+ LegacyPassWrapper ,
19+ )
20+ from torch .export import ExportedProgram
1521from torch .fx .passes .infra .pass_base import PassResult
16- from typing_extensions import TypeAlias
1722
18- PassType : TypeAlias = Callable [[torch .fx .GraphModule ], Optional [PassResult ]]
23+ # Legacy type for backwards compatibility
24+ LegacyPassType = Callable [[torch .fx .GraphModule ], Optional [PassResult ]]
25+
26+ # New union that accepts both
27+ PassType = Union [ExportedProgramPassBase , LegacyPassType ]
28+
29+
30+ def _normalize_pass (
31+ p : Union [ExportedProgramPassBase , Callable ]
32+ ) -> ExportedProgramPassBase :
33+ """Normalize a pass to ExportedProgramPassBase, wrapping legacy callables."""
34+ if isinstance (p , ExportedProgramPassBase ):
35+ return p
36+ elif callable (p ):
37+ return LegacyPassWrapper (p )
38+ else :
39+ raise TypeError (f"Expected ExportedProgramPassBase or callable, got { type (p )} " )
1940
2041
2142class PassManager (fx .PassManager ):
@@ -48,12 +69,21 @@ def __init__(
4869 fx .pass_result_wrapper (fn ) for fn in pytree .tree_flatten (passes )[0 ]
4970 ]
5071
72+ normalized : list [ExportedProgramPassBase ] = []
73+ for p in flattened_passes :
74+ normalized .append (_normalize_pass (p ))
75+
76+ flattened_passes = normalized
77+
5178 super ().__init__ (
5279 flattened_passes ,
5380 run_checks_after_each_pass = run_checks_after_each_pass ,
5481 suppress_check_failures = suppress_check_failures ,
5582 )
5683
84+ # Improves type-checking in call_exported_program
85+ self .passes = cast (list [ExportedProgramPassBase ], self .passes )
86+
5787 def check (self , module : torch .nn .Module ) -> None :
5888 """
5989 Runs various checks on the given graph module to make sure it contains
@@ -76,3 +106,95 @@ def check(self, module: torch.nn.Module) -> None:
76106 ExportErrorType .NOT_SUPPORTED ,
77107 f"call_method `{ node } ` is not supported except for backend delegate." ,
78108 )
109+
110+ def call (self , module : torch .nn .Module ) -> PassResult :
111+ """
112+ Runs the passes on the given graph module.
113+
114+ Args:
115+ module: The graph module to run the passes on
116+
117+ Returns:
118+ A PassResult object containing the modified graph module and a
119+ boolean indicating whether the module was modified
120+ """
121+ # Check for passes that might have ExportedProgram-specific logic
122+ problematic_passes = []
123+ for p in self .passes :
124+ if (
125+ type (p ).call_exported_program
126+ is not ExportedProgramPassBase .call_exported_program
127+ ):
128+ problematic_passes .append (type (p ).__name__ )
129+
130+ if problematic_passes :
131+ fx .logger .warning (
132+ f"The following passes have overridden 'call_exported_program': { problematic_passes } . "
133+ "Calling PassManager.call() will only run the graph module logic via 'call()' "
134+ "and may miss ExportedProgram-specific transformations. "
135+ "Consider using PassManager.call_exported_program() instead." ,
136+ UserWarning ,
137+ stacklevel = 2 ,
138+ )
139+
140+ return super ()(module )
141+
142+ def call_exported_program (
143+ self , exported_program : ExportedProgram
144+ ) -> ExportedProgramPassResult :
145+ """
146+ Runs the passes on the given ExportedProgram.
147+
148+ Args:
149+ exported_program: The ExportedProgram to run the passes on
150+
151+ Returns:
152+ A PassResult object containing the modified ExportedProgram and a
153+ boolean indicating whether the ExportedProgram was modified
154+ """
155+ # Order the passes based on the constraints
156+ if not self ._validated :
157+ self .solve_constraints ()
158+
159+ # Check graph invariants
160+ self .check (exported_program .graph_module )
161+
162+ # Run the set of passes `steps` number of times or until the graph stops
163+ # changing
164+ overall_modified = False
165+
166+ for _ in range (self .steps ):
167+ modified = False
168+
169+ # Run the set of passes on the graph module
170+ for i , p in enumerate (self .passes ):
171+ pass_name = type (p ).__name__
172+ fx .logger .debug ("Running pass '%s'" , pass_name )
173+
174+ try :
175+ res = p .call_exported_program (exported_program )
176+ exported_program = res .exported_program
177+ modified = modified or res .modified
178+
179+ fx .logger .debug (
180+ "Graph after pass '%s': %s" ,
181+ pass_name ,
182+ exported_program .graph_module .graph ,
183+ )
184+ exported_program .graph_module .recompile ()
185+
186+ # Check graph invariants
187+ if self .run_checks_after_each_pass :
188+ self .check (exported_program .graph_module )
189+
190+ except Exception as e :
191+ prev_pass_names = [type (p ).__name__ for p in self .passes [:i ]]
192+ msg = f"An error occurred when running the '{ pass_name } ' pass after the following passes: { prev_pass_names } "
193+ raise Exception (msg ) from e # noqa: TRY002
194+
195+ # If the graph no longer changes, then we can stop running these passes
196+ overall_modified = overall_modified or modified
197+ if not modified :
198+ break
199+
200+ return ExportedProgramPassResult (exported_program , overall_modified )
0 commit comments