Skip to content

Commit 4b0a9c3

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
EdgeProgramManager passes (#16986)
Summary: - Adds support to run to run passes on ExportedPrograms and EdgeProgramManager - EdgeProgramManager transform behaves basically like a pass manager Reviewed By: larryliu0820, ethansfng Differential Revision: D91725222
1 parent 8b30cfe commit 4b0a9c3

7 files changed

Lines changed: 797 additions & 55 deletions

File tree

exir/BUCK

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,17 @@ fbcode_target(_kind = runtime.python_library,
259259
],
260260
)
261261

262+
fbcode_target(_kind = runtime.python_library,
263+
name = "edge_program_manager_pass_base",
264+
srcs = [
265+
"edge_program_manager_pass_base.py",
266+
],
267+
deps = [
268+
"//caffe2:torch",
269+
"//executorch/exir:pass_base",
270+
],
271+
)
272+
262273
fbcode_target(_kind = runtime.python_library,
263274
name = "pass_manager",
264275
srcs = [
@@ -267,6 +278,7 @@ fbcode_target(_kind = runtime.python_library,
267278
deps = [
268279
"fbsource//third-party/pypi/typing-extensions:typing-extensions",
269280
":error",
281+
":pass_base",
270282
"//caffe2:torch",
271283
],
272284
)
Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import copy
10+
from abc import ABC, abstractmethod
11+
from dataclasses import dataclass
12+
from typing import Any, Callable, Dict, Optional, Sequence, TYPE_CHECKING, Union
13+
14+
import torch
15+
from torch.export import ExportedProgram
16+
from torch.fx.passes.infra.pass_base import PassResult
17+
18+
if TYPE_CHECKING:
19+
from executorch.exir.program._program import EdgeProgramManager
20+
21+
22+
@dataclass(frozen=True)
23+
class ExportedProgramPassResult:
24+
"""Result of running a pass on an ExportedProgram."""
25+
26+
exported_program: ExportedProgram
27+
modified: bool
28+
29+
30+
class ExportedProgramPassBase(ABC):
31+
"""
32+
Base interface for implementing passes that operate on ExportedProgram.
33+
"""
34+
35+
def __call__(self, exported_program: ExportedProgram) -> ExportedProgramPassResult:
36+
"""
37+
Runs the precondition check, the pass itself, and the postcondition check.
38+
"""
39+
40+
self.requires(exported_program)
41+
res = self.call(exported_program)
42+
self.ensures(exported_program)
43+
return res
44+
45+
@abstractmethod
46+
def call(self, exported_program: ExportedProgram) -> ExportedProgramPassResult:
47+
"""
48+
The pass that is run through the given exported program. To implement a
49+
pass, it is required to implement this function.
50+
51+
Args:
52+
exported_program: The exported program we will run a pass on
53+
"""
54+
55+
def requires(self, exported_program: ExportedProgram) -> None: # noqa: B027
56+
"""
57+
This function will be called before the pass is run and will check that
58+
the given exported program contains the preconditions needed to run the
59+
pass. It is not required to implement this function.
60+
61+
Args:
62+
exported_program: The exported program we will run checks on
63+
"""
64+
65+
def ensures(self, exported_program: ExportedProgram) -> None: # noqa: B027
66+
"""
67+
This function will be called after the pass is run and will check that
68+
the given exported program contains the postconditions needed to run the
69+
pass. It is not required to implement this function.
70+
71+
Args:
72+
exported_program: The exported program we will run checks on
73+
"""
74+
75+
76+
@dataclass(frozen=True)
77+
class EdgeProgramManagerPassResult:
78+
"""Result of running a pass on an EdgeProgramManager."""
79+
80+
edge_program_manager: "EdgeProgramManager"
81+
modified: bool
82+
83+
84+
class EdgeProgramManagerPassBase(ABC):
85+
"""
86+
Base interface for implementing passes that operate on EdgeProgramManager.
87+
88+
This is the highest-level pass abstraction. Passes at this level can:
89+
- Transform individual ExportedPrograms within the manager
90+
- Modify constant methods
91+
- Split one program into multiple programs
92+
- Add or remove programs from the manager
93+
94+
Lower-level passes (ExportedProgramPassBase, GraphModule callables) can be
95+
lifted to this level using the provided wrapper classes.
96+
"""
97+
98+
def __call__(
99+
self, epm: "EdgeProgramManager"
100+
) -> EdgeProgramManagerPassResult:
101+
"""
102+
Runs the precondition check, the pass itself, and the postcondition check.
103+
"""
104+
self.requires(epm)
105+
res = self.call(epm)
106+
self.ensures(res.edge_program_manager)
107+
return res
108+
109+
@abstractmethod
110+
def call(
111+
self, epm: "EdgeProgramManager"
112+
) -> EdgeProgramManagerPassResult:
113+
"""
114+
The pass that is run on the given EdgeProgramManager. To implement a
115+
pass, it is required to implement this function.
116+
117+
Args:
118+
epm: The EdgeProgramManager to transform
119+
"""
120+
121+
def requires(self, epm: "EdgeProgramManager") -> None: # noqa: B027
122+
"""
123+
This function will be called before the pass is run and will check that
124+
the given EdgeProgramManager contains the preconditions needed to run the
125+
pass. It is not required to implement this function.
126+
127+
Args:
128+
epm: The EdgeProgramManager we will run checks on
129+
"""
130+
131+
def ensures(self, epm: "EdgeProgramManager") -> None: # noqa: B027
132+
"""
133+
This function will be called after the pass is run and will check that
134+
the given EdgeProgramManager contains the postconditions needed to run the
135+
pass. It is not required to implement this function.
136+
137+
Args:
138+
epm: The EdgeProgramManager we will run checks on
139+
"""
140+
141+
142+
class GraphModuleBackedExportedProgramPassWrapper(ExportedProgramPassBase):
143+
"""
144+
Wrapper that adapts a GraphModule pass to work as an ExportedProgramPassBase.
145+
146+
This wrapper takes a pass that operates on GraphModule and makes it compatible
147+
with ExportedProgramPassBase by extracting the graph module, running the pass,
148+
and updating the ExportedProgram in-place.
149+
"""
150+
151+
def __init__(
152+
self,
153+
graph_module_pass: Callable[[torch.fx.GraphModule], PassResult],
154+
) -> None:
155+
super().__init__()
156+
self._pass = graph_module_pass
157+
158+
def call(self, exported_program: ExportedProgram) -> ExportedProgramPassResult:
159+
from executorch.exir.program._program import (
160+
_get_updated_graph_signature,
161+
_get_updated_range_constraints,
162+
)
163+
164+
result = self._pass(exported_program.graph_module)
165+
166+
if result.modified:
167+
# Cannot use _update_exported_program_graph_module because it
168+
# runs verification, and it is not the responsibility of the
169+
# pass to run verification. EdgeProgram manager can
170+
# optionally run verification after a pass.
171+
result.graph_module.recompile()
172+
exported_program = copy.copy(exported_program) # bypasses __init__ and _validate()
173+
174+
exported_program._graph_module = result.graph_module
175+
exported_program._graph_signature = _get_updated_graph_signature(
176+
exported_program.graph_signature, result.graph_module
177+
)
178+
exported_program._range_constraints = _get_updated_range_constraints(
179+
result.graph_module
180+
)
181+
exported_program._module_call_graph = copy.deepcopy(
182+
exported_program._module_call_graph
183+
)
184+
exported_program._graph_module.meta.update(exported_program.graph_module.meta)
185+
186+
187+
return ExportedProgramPassResult(exported_program, result.modified)
188+
189+
190+
class ExportedProgramToEdgeProgramManagerPassWrapper(EdgeProgramManagerPassBase):
191+
"""
192+
Adapts an ExportedProgramPassBase to run on every method in an EdgeProgramManager.
193+
194+
This wrapper takes a pass that operates on a single ExportedProgram and applies it
195+
to every method in the EdgeProgramManager, collecting results into a new EPM.
196+
This is where the iteration over methods lives -- not in the pass manager, and not
197+
in EdgeProgramManager.transform().
198+
"""
199+
200+
def __init__(self, ep_pass: ExportedProgramPassBase) -> None:
201+
super().__init__()
202+
self._pass = ep_pass
203+
204+
def call(
205+
self, epm: "EdgeProgramManager"
206+
) -> EdgeProgramManagerPassResult:
207+
new_epm = copy.copy(epm)
208+
new_epm._edge_programs = dict(epm._edge_programs)
209+
210+
overall_modified = False
211+
for name, program in epm._edge_programs.items():
212+
result = self._pass(program)
213+
new_epm._edge_programs[name] = result.exported_program
214+
overall_modified = overall_modified or result.modified
215+
216+
new_epm._config_methods = epm._config_methods
217+
return EdgeProgramManagerPassResult(new_epm, overall_modified)
218+
219+
220+
PassType = Union[
221+
EdgeProgramManagerPassBase,
222+
ExportedProgramPassBase,
223+
Callable[[torch.fx.GraphModule], Optional[PassResult]],
224+
]
225+
226+
227+
def _get_pass_name(fn: PassType) -> str:
228+
"""Unwraps wrapper chain to get the underlying pass name."""
229+
import inspect
230+
231+
if isinstance(fn, ExportedProgramToEdgeProgramManagerPassWrapper):
232+
return _get_pass_name(fn._pass)
233+
if isinstance(fn, GraphModuleBackedExportedProgramPassWrapper):
234+
return _get_pass_name(fn._pass)
235+
return fn.__name__ if inspect.isfunction(fn) else type(fn).__name__
236+
237+
238+
def wrap_passes(
239+
passes: Sequence[PassType],
240+
) -> list[EdgeProgramManagerPassBase]:
241+
"""
242+
Wraps a list of mixed-level passes up to the EdgeProgramManager level.
243+
244+
Accepts passes at three levels:
245+
- EdgeProgramManagerPassBase: used as-is
246+
- ExportedProgramPassBase: wrapped with ExportedProgramToEdgeProgramManagerPassWrapper
247+
- GraphModule callables: wrapped with GraphModuleBackedExportedProgramPassWrapper
248+
then ExportedProgramToEdgeProgramManagerPassWrapper
249+
250+
Args:
251+
passes: A sequence of passes at any level.
252+
253+
Returns:
254+
A list of EdgeProgramManagerPassBase passes.
255+
"""
256+
from torch.fx.passes.infra.pass_manager import pass_result_wrapper
257+
258+
wrapped: list[EdgeProgramManagerPassBase] = []
259+
for fn in passes:
260+
if isinstance(fn, EdgeProgramManagerPassBase):
261+
wrapped.append(fn)
262+
elif isinstance(fn, ExportedProgramPassBase):
263+
wrapped.append(
264+
ExportedProgramToEdgeProgramManagerPassWrapper(fn)
265+
)
266+
else:
267+
assert callable(fn)
268+
ep_pass = GraphModuleBackedExportedProgramPassWrapper(
269+
pass_result_wrapper(fn)
270+
)
271+
wrapped.append(
272+
ExportedProgramToEdgeProgramManagerPassWrapper(ep_pass)
273+
)
274+
return wrapped
275+
276+
277+
class MethodFilteredEdgeProgramManagerPass(EdgeProgramManagerPassBase):
278+
"""
279+
Applies different passes to different methods in an EdgeProgramManager.
280+
281+
Converts the Dict[str, Sequence[PassType]] pattern (previously handled inline
282+
in EdgeProgramManager.transform) into a proper pass. Used by
283+
to_edge_transform_and_lower to handle the dict case.
284+
"""
285+
286+
def __init__(self, passes_dict: Dict[str, Sequence[Any]]) -> None:
287+
super().__init__()
288+
self._passes_dict = passes_dict
289+
290+
def call(
291+
self, epm: "EdgeProgramManager"
292+
) -> EdgeProgramManagerPassResult:
293+
from executorch.exir.program._program import _transform
294+
295+
new_epm = copy.copy(epm)
296+
new_epm._edge_programs = dict(epm._edge_programs)
297+
298+
overall_modified = False
299+
for name, program in epm._edge_programs.items():
300+
if name in self._passes_dict:
301+
new_program = _transform(program, *self._passes_dict[name])
302+
new_epm._edge_programs[name] = new_program
303+
overall_modified = True
304+
305+
return EdgeProgramManagerPassResult(new_epm, overall_modified)

exir/pass_manager.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,20 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-strict
8-
8+
import logging
99
from typing import Callable, List, Optional, Union
1010

1111
import torch
1212
import torch.fx.passes.infra.pass_manager as fx
1313
import torch.utils._pytree as pytree
1414
from executorch.exir.error import ExportError, ExportErrorType
15+
1516
from torch.fx.passes.infra.pass_base import PassResult
1617
from typing_extensions import TypeAlias
1718

19+
logger = logging.getLogger(__name__)
20+
logger.setLevel(logging.WARNING)
21+
1822
PassType: TypeAlias = Callable[[torch.fx.GraphModule], Optional[PassResult]]
1923

2024

@@ -27,6 +31,8 @@ class PassManager(fx.PassManager):
2731
* **passes**: A list of callable passes
2832
* **params**: An instance of PassManagerParams containing the result of the
2933
flags set in the constructor.
34+
35+
Note: This class is deprecated. Please use EdgeProgramManager.transform() instead.
3036
"""
3137

3238
def __init__(
@@ -41,6 +47,9 @@ def __init__(
4147
enable_debug_pass: set to true to enable the debug passes
4248
run_checks_after_each_pass: whether to run checks and linting after each pass
4349
"""
50+
logger.warning(
51+
"PassManager is deprecated. Please use EdgeProgramManager.transform() instead."
52+
)
4453

4554
# Flatten the passes to a list of callables
4655
passes = passes if passes else []

exir/program/BUCK

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target")
22
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
3-
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
43

54
oncall("executorch")
65

@@ -47,6 +46,7 @@ fbcode_target(_kind = runtime.python_library,
4746
"//executorch/exir/passes:spec_prop_pass",
4847
"//executorch/exir/passes:weights_to_outputs_pass",
4948
"//executorch/exir/passes:convert_constant_dim_order_pass",
49+
"//executorch/exir:edge_program_manager_pass_base",
5050
"//executorch/exir/verification:verifier",
5151
"//executorch/extension/flat_tensor/serialize:serialize",
5252
] + (["//executorch/exir/program/fb:logger"] if not runtime.is_oss else [])

0 commit comments

Comments
 (0)