Skip to content

Commit 744f05f

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
ExportedProgram passes (#16986)
Summary: - Add ExportedProgramPassBase, which supports running passes on both `fx.GraphModule`s and `exir.ExportedProgram`s - Extend this class in `_ExportPassBase` so that any passes which import directly from `ExportPass` are immediately compatible with the current pass manager. - Create `LegacyPassWrapper` to auto-wrap passes not migrated to use ExportedProgramPassBase such that it works out of the box with EdgeProgramManager. Differential Revision: D91725222
1 parent 40d94b6 commit 744f05f

4 files changed

Lines changed: 218 additions & 41 deletions

File tree

exir/pass_base.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
# LICENSE file in the root directory of this source tree.
77

88
# pyre-strict
9-
109
import operator
1110
import traceback
11+
import warnings
1212
from contextlib import nullcontext
13+
from dataclasses import dataclass
1314
from typing import (
1415
Any,
1516
Callable,
@@ -27,16 +28,15 @@
2728

2829
import torch
2930
from executorch.exir import memory
30-
3131
from executorch.exir.delegate import executorch_call_delegate, is_lowered_module
32-
3332
from executorch.exir.dialects.edge._ops import EdgeOpOverload
3433
from executorch.exir.error import ExportError, ExportErrorType
3534
from torch import fx
3635
from torch._dispatch.python import enable_python_dispatcher
3736
from torch._subclasses import FakeTensorMode, UnsupportedFakeTensorException
3837
from torch._subclasses.fake_tensor import FakeTensor
3938
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
39+
from torch.export import ExportedProgram
4040
from torch.fx import traceback as fx_traceback
4141
from torch.fx.experimental.proxy_tensor import PythonKeyTracer
4242
from torch.fx.graph import CodeGen
@@ -157,7 +157,64 @@ class ExportPassBaseError(RuntimeError):
157157
pass
158158

159159

160-
class _ExportPassBase(PassBase):
160+
@dataclass(frozen=True)
161+
class ExportedProgramPassResult:
162+
exported_program: ExportedProgram
163+
modified: bool
164+
165+
166+
class ExportedProgramPassBase(PassBase):
167+
"""
168+
Base class which supports running passes on either `exir.ExportedProgram` or `fx.GraphModule`
169+
types.
170+
"""
171+
172+
def call_exported_program(
173+
self, exported_program: ExportedProgram
174+
) -> ExportedProgramPassResult:
175+
pass_result = self.call(exported_program.graph_module)
176+
modified = pass_result.modified if pass_result else False
177+
if modified:
178+
exported_program._graph_module = pass_result.graph_module
179+
return ExportedProgramPassResult(exported_program, modified)
180+
181+
def call(self, graph_module: fx.GraphModule) -> Optional[PassResult]:
182+
"""
183+
Overriden version of call which is effectively a no-op, but is required
184+
since the base class method is abstract.
185+
"""
186+
return PassResult(graph_module, False)
187+
188+
189+
class LegacyPassWrapper(ExportedProgramPassBase):
190+
"""
191+
Wraps a legacy callable pass (Callable[[GraphModule], Optional[PassResult]])
192+
to work with the new ExportedProgramPassBase infrastructure.
193+
194+
This provides backwards compatibility for passes that haven't been migrated yet.
195+
"""
196+
197+
def __init__(
198+
self, legacy_pass: Callable[[torch.fx.GraphModule], Optional[PassResult]]
199+
) -> None:
200+
self._legacy_pass = legacy_pass
201+
# Preserve the original name for debugging/logging
202+
self.__class__.__name__ = getattr(
203+
legacy_pass, "__name__", type(legacy_pass).__name__
204+
)
205+
206+
warnings.warn(
207+
f"Callable pass '{self.__class__.__name__}' is deprecated. "
208+
"Please migrate to ExportedProgramPassBase.",
209+
DeprecationWarning,
210+
stacklevel=3,
211+
)
212+
213+
def call(self, graph_module: fx.GraphModule) -> Optional[PassResult]:
214+
return self._legacy_pass(graph_module)
215+
216+
217+
class _ExportPassBase(ExportedProgramPassBase):
161218
"""
162219
Interpreter-based pass class to help users maintain the IR spec while writing
163220
transformations.

exir/pass_manager.py

Lines changed: 126 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,38 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-strict
8-
9-
from typing import Callable, List, Optional, Union
8+
import copy
9+
from typing import Callable, cast, List, Optional, Union
1010

1111
import torch
1212
import torch.fx.passes.infra.pass_manager as fx
1313
import torch.utils._pytree as pytree
1414
from executorch.exir.error import ExportError, ExportErrorType
15+
from executorch.exir.pass_base import (
16+
ExportedProgramPassBase,
17+
ExportedProgramPassResult,
18+
LegacyPassWrapper,
19+
)
20+
from torch.export import ExportedProgram
1521
from torch.fx.passes.infra.pass_base import PassResult
16-
from typing_extensions import TypeAlias
1722

18-
PassType: TypeAlias = Callable[[torch.fx.GraphModule], Optional[PassResult]]
23+
# Legacy type for backwards compatibility
24+
LegacyPassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
25+
26+
# New union that accepts both
27+
PassType = Union[ExportedProgramPassBase, LegacyPassType]
28+
29+
30+
def _normalize_pass(
31+
p: Union[ExportedProgramPassBase, Callable]
32+
) -> ExportedProgramPassBase:
33+
"""Normalize a pass to ExportedProgramPassBase, wrapping legacy callables."""
34+
if isinstance(p, ExportedProgramPassBase):
35+
return p
36+
elif callable(p):
37+
return LegacyPassWrapper(p)
38+
else:
39+
raise TypeError(f"Expected ExportedProgramPassBase or callable, got {type(p)}")
1940

2041

2142
class PassManager(fx.PassManager):
@@ -48,12 +69,21 @@ def __init__(
4869
fx.pass_result_wrapper(fn) for fn in pytree.tree_flatten(passes)[0]
4970
]
5071

72+
normalized: list[ExportedProgramPassBase] = []
73+
for p in flattened_passes:
74+
normalized.append(_normalize_pass(p))
75+
76+
flattened_passes = normalized
77+
5178
super().__init__(
5279
flattened_passes,
5380
run_checks_after_each_pass=run_checks_after_each_pass,
5481
suppress_check_failures=suppress_check_failures,
5582
)
5683

84+
# Improves type-checking in call_exported_program
85+
self.passes = cast(list[ExportedProgramPassBase], self.passes)
86+
5787
def check(self, module: torch.nn.Module) -> None:
5888
"""
5989
Runs various checks on the given graph module to make sure it contains
@@ -76,3 +106,95 @@ def check(self, module: torch.nn.Module) -> None:
76106
ExportErrorType.NOT_SUPPORTED,
77107
f"call_method `{node}` is not supported except for backend delegate.",
78108
)
109+
110+
def call(self, module: torch.nn.Module) -> PassResult:
111+
"""
112+
Runs the passes on the given graph module.
113+
114+
Args:
115+
module: The graph module to run the passes on
116+
117+
Returns:
118+
A PassResult object containing the modified graph module and a
119+
boolean indicating whether the module was modified
120+
"""
121+
# Check for passes that might have ExportedProgram-specific logic
122+
problematic_passes = []
123+
for p in self.passes:
124+
if (
125+
type(p).call_exported_program
126+
is not ExportedProgramPassBase.call_exported_program
127+
):
128+
problematic_passes.append(type(p).__name__)
129+
130+
if problematic_passes:
131+
fx.logger.warning(
132+
f"The following passes have overridden 'call_exported_program': {problematic_passes}. "
133+
"Calling PassManager.call() will only run the graph module logic via 'call()' "
134+
"and may miss ExportedProgram-specific transformations. "
135+
"Consider using PassManager.call_exported_program() instead.",
136+
UserWarning,
137+
stacklevel=2,
138+
)
139+
140+
return super()(module)
141+
142+
def call_exported_program(
143+
self, exported_program: ExportedProgram
144+
) -> ExportedProgramPassResult:
145+
"""
146+
Runs the passes on the given ExportedProgram.
147+
148+
Args:
149+
exported_program: The ExportedProgram to run the passes on
150+
151+
Returns:
152+
A PassResult object containing the modified ExportedProgram and a
153+
boolean indicating whether the ExportedProgram was modified
154+
"""
155+
# Order the passes based on the constraints
156+
if not self._validated:
157+
self.solve_constraints()
158+
159+
# Check graph invariants
160+
self.check(exported_program.graph_module)
161+
162+
# Run the set of passes `steps` number of times or until the graph stops
163+
# changing
164+
overall_modified = False
165+
166+
for _ in range(self.steps):
167+
modified = False
168+
169+
# Run the set of passes on the graph module
170+
for i, p in enumerate(self.passes):
171+
pass_name = type(p).__name__
172+
fx.logger.debug("Running pass '%s'", pass_name)
173+
174+
try:
175+
res = p.call_exported_program(exported_program)
176+
exported_program = res.exported_program
177+
modified = modified or res.modified
178+
179+
fx.logger.debug(
180+
"Graph after pass '%s': %s",
181+
pass_name,
182+
exported_program.graph_module.graph,
183+
)
184+
exported_program.graph_module.recompile()
185+
186+
# Check graph invariants
187+
if self.run_checks_after_each_pass:
188+
self.check(exported_program.graph_module)
189+
190+
except Exception as e:
191+
prev_pass_names = [type(p).__name__ for p in self.passes[:i]]
192+
msg = f"An error occurred when running the '{pass_name}' pass after the following passes: {prev_pass_names}"
193+
raise Exception(msg) from e # noqa: TRY002
194+
195+
# If the graph no longer changes, then we can stop running these passes
196+
overall_modified = overall_modified or modified
197+
if not modified:
198+
break
199+
200+
return ExportedProgramPassResult(exported_program, overall_modified)

0 commit comments

Comments
 (0)