55# LICENSE file in the root directory of this source tree.
66
77# pyre-strict
8-
9- from typing import Callable , List , Optional , Union
10-
8+ import copy
9+ from typing import Callable , cast , List , Optional , Union
1110import torch
1211import torch .fx .passes .infra .pass_manager as fx
1312import torch .utils ._pytree as pytree
13+ from torch .export import ExportedProgram
1414from executorch .exir .error import ExportError , ExportErrorType
15+ from executorch .exir .pass_base import (
16+ ExportedProgramPassBase ,
17+ ExportedProgramPassResult ,
18+ LegacyPassWrapper ,
19+ )
1520from torch .fx .passes .infra .pass_base import PassResult
16- from typing_extensions import TypeAlias
1721
18- PassType : TypeAlias = Callable [[torch .fx .GraphModule ], Optional [PassResult ]]
22+ # Legacy type for backwards compatibility
23+ LegacyPassType = Callable [[torch .fx .GraphModule ], Optional [PassResult ]]
24+
25+ # New union that accepts both
26+ PassType = Union [ExportedProgramPassBase , LegacyPassType ]
27+
28+
29+ def _normalize_pass (
30+ p : Union [ExportedProgramPassBase , Callable ]
31+ ) -> ExportedProgramPassBase :
32+ """Normalize a pass to ExportedProgramPassBase, wrapping legacy callables."""
33+ if isinstance (p , ExportedProgramPassBase ):
34+ return p
35+ elif callable (p ):
36+ return LegacyPassWrapper (p )
37+ else :
38+ raise TypeError (f"Expected ExportedProgramPassBase or callable, got { type (p )} " )
1939
2040
2141class PassManager (fx .PassManager ):
@@ -48,12 +68,23 @@ def __init__(
4868 fx .pass_result_wrapper (fn ) for fn in pytree .tree_flatten (passes )[0 ]
4969 ]
5070
71+ normalized : list [ExportedProgramPassBase ] = []
72+ for p in flattened_passes :
73+ normalized .append (_normalize_pass (p ))
74+
75+ flattened_passes = normalized
76+
5177 super ().__init__ (
5278 flattened_passes ,
5379 run_checks_after_each_pass = run_checks_after_each_pass ,
5480 suppress_check_failures = suppress_check_failures ,
5581 )
5682
83+ # Improves type-checking in call_exported_program
84+ self .passes = cast (list [ExportedProgramPassBase ], self .passes )
85+
86+
87+
5788 def check (self , module : torch .nn .Module ) -> None :
5889 """
5990 Runs various checks on the given graph module to make sure it contains
@@ -76,3 +107,95 @@ def check(self, module: torch.nn.Module) -> None:
76107 ExportErrorType .NOT_SUPPORTED ,
77108 f"call_method `{ node } ` is not supported except for backend delegate." ,
78109 )
110+
111+ def call (self , module : torch .nn .Module ) -> PassResult :
112+ """
113+ Runs the passes on the given graph module.
114+
115+ Args:
116+ module: The graph module to run the passes on
117+
118+ Returns:
119+ A PassResult object containing the modified graph module and a
120+ boolean indicating whether the module was modified
121+ """
122+ # Check for passes that might have ExportedProgram-specific logic
123+ problematic_passes = []
124+ for p in self .passes :
125+ if type (p ).call_exported_program is not ExportedProgramPassBase .call_exported_program :
126+ problematic_passes .append (type (p ).__name__ )
127+
128+ if problematic_passes :
129+ fx .logger .warning (
130+ f"The following passes have overridden 'call_exported_program': { problematic_passes } . "
131+ "Calling PassManager.call() will only run the graph module logic via 'call()' "
132+ "and may miss ExportedProgram-specific transformations. "
133+ "Consider using PassManager.call_exported_program() instead." ,
134+ UserWarning ,
135+ stacklevel = 2 ,
136+ )
137+
138+ return super ()(module )
139+
140+ def call_exported_program (
141+ self , exported_program : ExportedProgram
142+ ) -> ExportedProgramPassResult :
143+ """
144+ Runs the passes on the given ExportedProgram.
145+
146+ Args:
147+ exported_program: The ExportedProgram to run the passes on
148+
149+ Returns:
150+ A PassResult object containing the modified ExportedProgram and a
151+ boolean indicating whether the ExportedProgram was modified
152+ """
153+ # Order the passes based on the constraints
154+ if not self ._validated :
155+ self .solve_constraints ()
156+
157+ # Check graph invariants
158+ self .check (exported_program .graph_module )
159+
160+ # Run the set of passes `steps` number of times or until the graph stops
161+ # changing
162+ overall_modified = False
163+
164+ # Done once at the beginning to make sure we don't modify the original
165+ # in place
166+ exported_program = copy .deepcopy (exported_program )
167+ for _ in range (self .steps ):
168+ modified = False
169+
170+ # Run the set of passes on the graph module
171+ for i , p in enumerate (self .passes ):
172+ pass_name = type (p ).__name__
173+ fx .logger .debug ("Running pass '%s'" , pass_name )
174+
175+ try :
176+ res = p .call_exported_program (exported_program )
177+ exported_program = res .exported_program
178+ modified = modified or res .modified
179+
180+ fx .logger .debug (
181+ "Graph after pass '%s': %s" ,
182+ pass_name ,
183+ exported_program .graph_module .graph ,
184+ )
185+ exported_program .graph_module .recompile ()
186+
187+ # Check graph invariants
188+ if self .run_checks_after_each_pass :
189+ self .check (exported_program .graph_module )
190+
191+ except Exception as e :
192+ prev_pass_names = [type (p ).__name__ for p in self .passes [:i ]]
193+ msg = f"An error occurred when running the '{ pass_name } ' pass after the following passes: { prev_pass_names } "
194+ raise Exception (msg ) from e # noqa: TRY002
195+
196+ # If the graph no longer changes, then we can stop running these passes
197+ overall_modified = overall_modified or modified
198+ if not modified :
199+ break
200+
201+ return ExportedProgramPassResult (exported_program , overall_modified )
0 commit comments