Skip to content

Commit bb5f715

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
ExportedProgram passes (#16986)
Summary: - Adds support to _ExportPassBase to run passes on ExportedPrograms. Ensures that we can only run either call or call_exported_program, not both - Updates exir.PassManager to add a new pass manager which operates on exported programs. This is done to ensure backwards compatibility, while allowing _program transformations to use either pass manager Differential Revision: D91725222
1 parent 6c1dc31 commit bb5f715

4 files changed

Lines changed: 525 additions & 37 deletions

File tree

exir/pass_base.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
# LICENSE file in the root directory of this source tree.
77

88
# pyre-strict
9-
109
import operator
1110
import traceback
11+
from abc import ABC, abstractmethod
1212
from contextlib import nullcontext
13+
from dataclasses import dataclass
1314
from typing import (
1415
Any,
1516
Callable,
@@ -27,16 +28,15 @@
2728

2829
import torch
2930
from executorch.exir import memory
30-
3131
from executorch.exir.delegate import executorch_call_delegate, is_lowered_module
32-
3332
from executorch.exir.dialects.edge._ops import EdgeOpOverload
3433
from executorch.exir.error import ExportError, ExportErrorType
3534
from torch import fx
3635
from torch._dispatch.python import enable_python_dispatcher
3736
from torch._subclasses import FakeTensorMode, UnsupportedFakeTensorException
3837
from torch._subclasses.fake_tensor import FakeTensor
3938
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
39+
from torch.export import ExportedProgram
4040
from torch.fx import traceback as fx_traceback
4141
from torch.fx.experimental.proxy_tensor import PythonKeyTracer
4242
from torch.fx.graph import CodeGen
@@ -157,6 +157,58 @@ class ExportPassBaseError(RuntimeError):
157157
pass
158158

159159

160+
@dataclass(frozen=True)
161+
class ExportedProgramPassResult:
162+
exported_program: ExportedProgram
163+
modified: bool
164+
165+
166+
class ExportedProgramPassBase(ABC):
167+
"""
168+
Base interface for implementing passes that operate on ExportedProgram.
169+
"""
170+
171+
def __call__(self, exported_program: ExportedProgram) -> ExportedProgramPassResult:
172+
"""
173+
Runs the precondition check, the pass itself, and the postcondition check.
174+
"""
175+
176+
self.requires(exported_program)
177+
res = self.call(exported_program)
178+
self.ensures(exported_program)
179+
return res
180+
181+
@abstractmethod
182+
def call(self, exported_program: ExportedProgram) -> ExportedProgramPassResult:
183+
"""
184+
The pass that is run through the given exported program. To implement a
185+
pass, it is required to implement this function.
186+
187+
Args:
188+
exported_program: The exported program we will run a pass on
189+
"""
190+
191+
def requires(self, exported_program: ExportedProgram) -> None: # noqa: B027
192+
"""
193+
This function will be called before the pass is run and will check that
194+
the given exported program contains the preconditions needed to run the
195+
pass. It is not required to implement this function.
196+
197+
Args:
198+
exported_program: The exported program we will run checks on
199+
"""
200+
201+
def ensures(self, exported_program: ExportedProgram) -> None: # noqa: B027
202+
"""
203+
This function will be called after the pass is run and will check that
204+
the given exported program contains the postconditions needed to run the
205+
pass. It is not required to implement this function.
206+
207+
Args:
208+
exported_program: The exported program we will run checks on
209+
"""
210+
211+
160212
class _ExportPassBase(PassBase):
161213
"""
162214
Interpreter-based pass class to help users maintain the IR spec while writing

exir/pass_manager.py

Lines changed: 176 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,30 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-strict
8-
9-
from typing import Callable, List, Optional, Union
8+
import copy
9+
import inspect
10+
import logging
11+
from typing import Callable, List, Optional, TypeAlias, Union
1012

1113
import torch
1214
import torch.fx.passes.infra.pass_manager as fx
1315
import torch.utils._pytree as pytree
1416
from executorch.exir.error import ExportError, ExportErrorType
17+
from executorch.exir.pass_base import ExportedProgramPassBase, ExportedProgramPassResult
18+
from torch.export import ExportedProgram
1519
from torch.fx.passes.infra.pass_base import PassResult
16-
from typing_extensions import TypeAlias
20+
from torch.fx.passes.infra.pass_manager import pass_result_wrapper
21+
22+
logger = logging.getLogger(__name__)
23+
logger.setLevel(logging.WARNING)
24+
25+
PassType: TypeAlias = Union[
26+
ExportedProgramPassBase, Callable[[torch.fx.GraphModule], Optional[PassResult]]
27+
]
28+
1729

18-
PassType: TypeAlias = Callable[[torch.fx.GraphModule], Optional[PassResult]]
30+
def _get_pass_name(fn: PassType) -> str:
31+
return fn.__name__ if inspect.isfunction(fn) else type(fn).__name__
1932

2033

2134
class PassManager(fx.PassManager):
@@ -27,21 +40,27 @@ class PassManager(fx.PassManager):
2740
* **passes**: A list of callable passes
2841
* **params**: An instance of PassManagerParams containing the result of the
2942
flags set in the constructor.
43+
44+
Note: This class is deprecated. Please use ExportedProgramPassManager instead.
3045
"""
3146

3247
def __init__(
3348
self,
3449
passes: Optional[Union[List[PassType], List[List[PassType]]]] = None,
3550
run_checks_after_each_pass: bool = False,
3651
suppress_check_failures: bool = False,
52+
steps: int = 1,
3753
) -> None:
3854
r"""
3955
Args:
4056
passes: A list of passes
41-
enable_debug_pass: set to true to enable the debug passes
42-
run_checks_after_each_pass: whether to run checks and linting after each pass
57+
run_checks_after_each_pass: Whether to run checks and linting after each pass
58+
suppress_check_failures: Whether to raise errors when running checks
59+
steps: Number of times we wish to run passes iteratively.
4360
"""
44-
61+
logger.warning(
62+
"PassManager is deprecated. Please use ExportedProgramPassManager instead."
63+
)
4564
# Flatten the passes to a list of callables
4665
passes = passes if passes else []
4766
flattened_passes = [
@@ -52,6 +71,7 @@ def __init__(
5271
flattened_passes,
5372
run_checks_after_each_pass=run_checks_after_each_pass,
5473
suppress_check_failures=suppress_check_failures,
74+
steps=steps,
5575
)
5676

5777
def check(self, module: torch.nn.Module) -> None:
@@ -65,7 +85,7 @@ def check(self, module: torch.nn.Module) -> None:
6585
node's spec field is a tuple)
6686
- Ensure that the graph module has type torch.fx.GraphModule
6787
"""
68-
assert isinstance(module, fx.GraphModule)
88+
assert isinstance(module, torch.fx.GraphModule)
6989
module.recompile()
7090
module.graph.lint()
7191
# TODO(qihan): use verifier.check_is_exir
@@ -76,3 +96,151 @@ def check(self, module: torch.nn.Module) -> None:
7696
ExportErrorType.NOT_SUPPORTED,
7797
f"call_method `{node}` is not supported except for backend delegate.",
7898
)
99+
100+
101+
class ExportedProgramPassManager(fx.PassManager):
102+
"""
103+
Runs multiple passes on an ExportedProgram.
104+
105+
This PassManager is specifically designed for ExportedProgram and supports
106+
both GraphModule-only passes and ExportedProgram-aware passes.
107+
108+
For running passes on GraphModule directly, use PassManager instead.
109+
"""
110+
111+
def __init__(
112+
self,
113+
passes: Optional[Union[List[PassType], List[List[PassType]]]] = None,
114+
constraints: Optional[List[Callable[[Callable, Callable], bool]]] = None,
115+
run_checks_after_each_pass: bool = False,
116+
# Setting default to True since many tests are failing pre-verification.
117+
suppress_exported_program_pre_verification: bool = True,
118+
steps: int = 1,
119+
override_verifiers: bool = False,
120+
) -> None:
121+
wrapped_passes = (
122+
[
123+
(
124+
fn
125+
if isinstance(fn, ExportedProgramPassBase)
126+
else pass_result_wrapper(fn)
127+
)
128+
for fn in pytree.tree_flatten(passes)[0]
129+
]
130+
if passes
131+
else []
132+
)
133+
134+
if suppress_exported_program_pre_verification:
135+
logger.warning(
136+
"Pre-verification of exported program is suppressed. This means that the exported program may pass validation prior to running the pass manager."
137+
)
138+
139+
super().__init__(
140+
wrapped_passes,
141+
constraints=constraints,
142+
run_checks_after_each_pass=run_checks_after_each_pass,
143+
suppress_check_failures=suppress_exported_program_pre_verification,
144+
steps=steps,
145+
)
146+
self._override_verifiers = override_verifiers
147+
148+
def check(self, exported_program: ExportedProgram) -> None:
149+
"""
150+
Runs exported program validation.
151+
"""
152+
if not self.suppress_check_failures:
153+
exported_program.validate()
154+
155+
module = exported_program.graph_module
156+
module.recompile()
157+
module.graph.lint()
158+
159+
for node in module.graph.nodes:
160+
if node.op == "call_method":
161+
raise ExportError(
162+
ExportErrorType.NOT_SUPPORTED,
163+
f"call_method `{node}` is not supported except for backend delegate.",
164+
)
165+
166+
# pyre-ignore[14]: Intentionally overriding with different signature for ExportedProgram
167+
def __call__(self, exported_program: ExportedProgram) -> ExportedProgramPassResult:
168+
"""
169+
Runs passes on an ExportedProgram.
170+
171+
Handles both GraphModule-only passes and ExportedProgram-aware passes.
172+
173+
Args:
174+
exported_program: The exported program to transform.
175+
176+
Returns:
177+
ExportedProgramPassResult containing the transformed program and whether or
178+
not the program was modified.
179+
"""
180+
# Lazy import to avoid circular dependency
181+
from executorch.exir.program._program import (
182+
_update_exported_program_graph_module,
183+
)
184+
185+
if not self._validated:
186+
self.solve_constraints()
187+
188+
exported_program = copy.copy(exported_program)
189+
190+
# Check graph invariants before running passes
191+
self.check(exported_program)
192+
193+
overall_modified = False
194+
195+
for _ in range(self.steps):
196+
step_modified = False
197+
198+
for i, fn in enumerate(self.passes):
199+
try:
200+
if not isinstance(fn, ExportedProgramPassBase):
201+
result = fn(exported_program.graph_module)
202+
if result.modified:
203+
logger.debug(
204+
"Graph after pass '%s': %s",
205+
_get_pass_name(fn),
206+
result.graph_module.graph,
207+
)
208+
result.graph_module.recompile()
209+
210+
exported_program = _update_exported_program_graph_module(
211+
exported_program,
212+
result.graph_module,
213+
self._override_verifiers,
214+
)
215+
step_modified = step_modified or result.modified
216+
217+
if self.run_checks_after_each_pass:
218+
self.check(exported_program)
219+
else:
220+
assert isinstance(fn, ExportedProgramPassBase)
221+
result = fn(exported_program)
222+
if result.modified:
223+
logger.debug(
224+
"Graph after pass '%s': %s",
225+
_get_pass_name(fn),
226+
result.exported_program.graph_module.graph,
227+
)
228+
result.exported_program.graph_module.recompile()
229+
230+
exported_program = result.exported_program
231+
step_modified = step_modified or result.modified
232+
233+
if self.run_checks_after_each_pass:
234+
self.check(exported_program)
235+
236+
except Exception as e:
237+
prev_names = [_get_pass_name(p) for p in self.passes[:i]]
238+
msg = f"An error occurred when running the '{_get_pass_name(fn)}' pass after the following passes: {prev_names}"
239+
e.add_note(msg)
240+
raise
241+
242+
overall_modified = overall_modified or step_modified
243+
if not step_modified:
244+
break
245+
246+
return ExportedProgramPassResult(exported_program, overall_modified)

0 commit comments

Comments
 (0)