Skip to content

Commit e70ee91

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
ExportedProgram passes
Summary: - Add ExportedProgramPassBase, which supports running passes on both `fx.GraphModule`s and `exir.ExportedProgram`s - Extend this class in `_ExportPassBase` so that any passes which import directly from `ExportPass` are immediately compatible with the current pass manager. - Create `LegacyPassWrapper` to auto-wrap passes not migrated to use ExportedProgramPassBase such that it works out of the box with EdgeProgramManager. Differential Revision: D91725222
1 parent ac760fc commit e70ee91

4 files changed

Lines changed: 218 additions & 34 deletions

File tree

exir/pass_base.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
# LICENSE file in the root directory of this source tree.
77

88
# pyre-strict
9-
109
import operator
1110
import traceback
11+
import warnings
1212
from contextlib import nullcontext
13+
from dataclasses import dataclass
1314
from typing import (
1415
Any,
1516
Callable,
@@ -26,10 +27,9 @@
2627
)
2728

2829
import torch
30+
from torch.export import ExportedProgram
2931
from executorch.exir import memory
30-
3132
from executorch.exir.delegate import executorch_call_delegate, is_lowered_module
32-
3333
from executorch.exir.dialects.edge._ops import EdgeOpOverload
3434
from executorch.exir.error import ExportError, ExportErrorType
3535
from torch import fx
@@ -157,7 +157,64 @@ class ExportPassBaseError(RuntimeError):
157157
pass
158158

159159

160-
class _ExportPassBase(PassBase):
160+
@dataclass(frozen=True)
161+
class ExportedProgramPassResult:
162+
exported_program: ExportedProgram
163+
modified: bool
164+
165+
166+
class ExportedProgramPassBase(PassBase):
167+
"""
168+
Base class which supports running passes on either `exir.ExportedProgram` or `fx.GraphModule`
169+
types.
170+
"""
171+
172+
def call_exported_program(
173+
self, exported_program: ExportedProgram
174+
) -> ExportedProgramPassResult:
175+
pass_result = self.call(exported_program.graph_module)
176+
modified = pass_result.modified if pass_result else False
177+
if modified:
178+
exported_program._graph_module = pass_result.graph_module
179+
return ExportedProgramPassResult(exported_program, modified)
180+
181+
def call(self, graph_module: fx.GraphModule) -> Optional[PassResult]:
182+
"""
183+
Overriden version of call which is effectively a no-op, but is required
184+
since the base class method is abstract.
185+
"""
186+
return PassResult(graph_module, False)
187+
188+
189+
class LegacyPassWrapper(ExportedProgramPassBase):
190+
"""
191+
Wraps a legacy callable pass (Callable[[GraphModule], Optional[PassResult]])
192+
to work with the new ExportedProgramPassBase infrastructure.
193+
194+
This provides backwards compatibility for passes that haven't been migrated yet.
195+
"""
196+
197+
def __init__(
198+
self, legacy_pass: Callable[[torch.fx.GraphModule], Optional[PassResult]]
199+
) -> None:
200+
self._legacy_pass = legacy_pass
201+
# Preserve the original name for debugging/logging
202+
self.__class__.__name__ = getattr(
203+
legacy_pass, "__name__", type(legacy_pass).__name__
204+
)
205+
206+
warnings.warn(
207+
f"Callable pass '{self.__class__.__name__}' is deprecated. "
208+
"Please migrate to ExportedProgramPassBase.",
209+
DeprecationWarning,
210+
stacklevel=3,
211+
)
212+
213+
def call(self, graph_module: fx.GraphModule) -> Optional[PassResult]:
214+
return self._legacy_pass(graph_module)
215+
216+
217+
class _ExportPassBase(ExportedProgramPassBase):
161218
"""
162219
Interpreter-based pass class to help users maintain the IR spec while writing
163220
transformations.

exir/pass_manager.py

Lines changed: 128 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,37 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-strict
8-
9-
from typing import Callable, List, Optional, Union
10-
8+
import copy
9+
from typing import Callable, cast, List, Optional, Union
1110
import torch
1211
import torch.fx.passes.infra.pass_manager as fx
1312
import torch.utils._pytree as pytree
13+
from torch.export import ExportedProgram
1414
from executorch.exir.error import ExportError, ExportErrorType
15+
from executorch.exir.pass_base import (
16+
ExportedProgramPassBase,
17+
ExportedProgramPassResult,
18+
LegacyPassWrapper,
19+
)
1520
from torch.fx.passes.infra.pass_base import PassResult
16-
from typing_extensions import TypeAlias
1721

18-
PassType: TypeAlias = Callable[[torch.fx.GraphModule], Optional[PassResult]]
22+
# Legacy type for backwards compatibility
23+
LegacyPassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
24+
25+
# New union that accepts both
26+
PassType = Union[ExportedProgramPassBase, LegacyPassType]
27+
28+
29+
def _normalize_pass(
30+
p: Union[ExportedProgramPassBase, Callable]
31+
) -> ExportedProgramPassBase:
32+
"""Normalize a pass to ExportedProgramPassBase, wrapping legacy callables."""
33+
if isinstance(p, ExportedProgramPassBase):
34+
return p
35+
elif callable(p):
36+
return LegacyPassWrapper(p)
37+
else:
38+
raise TypeError(f"Expected ExportedProgramPassBase or callable, got {type(p)}")
1939

2040

2141
class PassManager(fx.PassManager):
@@ -48,12 +68,23 @@ def __init__(
4868
fx.pass_result_wrapper(fn) for fn in pytree.tree_flatten(passes)[0]
4969
]
5070

71+
normalized: list[ExportedProgramPassBase] = []
72+
for p in flattened_passes:
73+
normalized.append(_normalize_pass(p))
74+
75+
flattened_passes = normalized
76+
5177
super().__init__(
5278
flattened_passes,
5379
run_checks_after_each_pass=run_checks_after_each_pass,
5480
suppress_check_failures=suppress_check_failures,
5581
)
5682

83+
# Improves type-checking in call_exported_program
84+
self.passes = cast(list[ExportedProgramPassBase], self.passes)
85+
86+
87+
5788
def check(self, module: torch.nn.Module) -> None:
5889
"""
5990
Runs various checks on the given graph module to make sure it contains
@@ -76,3 +107,95 @@ def check(self, module: torch.nn.Module) -> None:
76107
ExportErrorType.NOT_SUPPORTED,
77108
f"call_method `{node}` is not supported except for backend delegate.",
78109
)
110+
111+
def call(self, module: torch.nn.Module) -> PassResult:
112+
"""
113+
Runs the passes on the given graph module.
114+
115+
Args:
116+
module: The graph module to run the passes on
117+
118+
Returns:
119+
A PassResult object containing the modified graph module and a
120+
boolean indicating whether the module was modified
121+
"""
122+
# Check for passes that might have ExportedProgram-specific logic
123+
problematic_passes = []
124+
for p in self.passes:
125+
if type(p).call_exported_program is not ExportedProgramPassBase.call_exported_program:
126+
problematic_passes.append(type(p).__name__)
127+
128+
if problematic_passes:
129+
fx.logger.warning(
130+
f"The following passes have overridden 'call_exported_program': {problematic_passes}. "
131+
"Calling PassManager.call() will only run the graph module logic via 'call()' "
132+
"and may miss ExportedProgram-specific transformations. "
133+
"Consider using PassManager.call_exported_program() instead.",
134+
UserWarning,
135+
stacklevel=2,
136+
)
137+
138+
return super()(module)
139+
140+
def call_exported_program(
141+
self, exported_program: ExportedProgram
142+
) -> ExportedProgramPassResult:
143+
"""
144+
Runs the passes on the given ExportedProgram.
145+
146+
Args:
147+
exported_program: The ExportedProgram to run the passes on
148+
149+
Returns:
150+
A PassResult object containing the modified ExportedProgram and a
151+
boolean indicating whether the ExportedProgram was modified
152+
"""
153+
# Order the passes based on the constraints
154+
if not self._validated:
155+
self.solve_constraints()
156+
157+
# Check graph invariants
158+
self.check(exported_program.graph_module)
159+
160+
# Run the set of passes `steps` number of times or until the graph stops
161+
# changing
162+
overall_modified = False
163+
164+
# Done once at the beginning to make sure we don't modify the original
165+
# in place
166+
exported_program = copy.deepcopy(exported_program)
167+
for _ in range(self.steps):
168+
modified = False
169+
170+
# Run the set of passes on the graph module
171+
for i, p in enumerate(self.passes):
172+
pass_name = type(p).__name__
173+
fx.logger.debug("Running pass '%s'", pass_name)
174+
175+
try:
176+
res = p.call_exported_program(exported_program)
177+
exported_program = res.exported_program
178+
modified = modified or res.modified
179+
180+
fx.logger.debug(
181+
"Graph after pass '%s': %s",
182+
pass_name,
183+
exported_program.graph_module.graph,
184+
)
185+
exported_program.graph_module.recompile()
186+
187+
# Check graph invariants
188+
if self.run_checks_after_each_pass:
189+
self.check(exported_program.graph_module)
190+
191+
except Exception as e:
192+
prev_pass_names = [type(p).__name__ for p in self.passes[:i]]
193+
msg = f"An error occurred when running the '{pass_name}' pass after the following passes: {prev_pass_names}"
194+
raise Exception(msg) from e # noqa: TRY002
195+
196+
# If the graph no longer changes, then we can stop running these passes
197+
overall_modified = overall_modified or modified
198+
if not modified:
199+
break
200+
201+
return ExportedProgramPassResult(exported_program, overall_modified)

0 commit comments

Comments
 (0)