Skip to content

Commit 288e8ef

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
ExportedProgram passes (#16986)
Summary: - Adds support to _ExportPassBase to run passes on ExportedPrograms. Ensures that we can only run either call or call_exported_program, not both - Updates exir.PassManager to add a new pass manager which operates on exported programs. This is done to ensure backwards compatibility, while allowing _program transformations to use either pass manager Differential Revision: D91725222
1 parent fdb386c commit 288e8ef

4 files changed

Lines changed: 547 additions & 64 deletions

File tree

exir/pass_base.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@
66
# LICENSE file in the root directory of this source tree.
77

88
# pyre-strict
9-
9+
import enum
1010
import operator
1111
import traceback
1212
from contextlib import nullcontext
13+
from dataclasses import dataclass
1314
from typing import (
1415
Any,
1516
Callable,
1617
Dict,
18+
Final,
1719
List,
1820
MutableMapping,
1921
Optional,
@@ -27,16 +29,15 @@
2729

2830
import torch
2931
from executorch.exir import memory
30-
3132
from executorch.exir.delegate import executorch_call_delegate, is_lowered_module
32-
3333
from executorch.exir.dialects.edge._ops import EdgeOpOverload
3434
from executorch.exir.error import ExportError, ExportErrorType
3535
from torch import fx
3636
from torch._dispatch.python import enable_python_dispatcher
3737
from torch._subclasses import FakeTensorMode, UnsupportedFakeTensorException
3838
from torch._subclasses.fake_tensor import FakeTensor
3939
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
40+
from torch.export import ExportedProgram
4041
from torch.fx import traceback as fx_traceback
4142
from torch.fx.experimental.proxy_tensor import PythonKeyTracer
4243
from torch.fx.graph import CodeGen
@@ -157,12 +158,60 @@ class ExportPassBaseError(RuntimeError):
157158
pass
158159

159160

161+
@dataclass(frozen=True)
162+
class ExportedProgramPassResult:
163+
exported_program: ExportedProgram
164+
modified: bool
165+
166+
167+
class ExportPassType(enum.Enum):
168+
EXPORTED_PROGRAM = enum.auto()
169+
GRAPH_MODULE = enum.auto()
170+
171+
160172
class _ExportPassBase(PassBase):
161173
"""
162174
Interpreter-based pass class to help users maintain the IR spec while writing
163175
transformations.
164176
"""
165177

178+
def __init_subclass__(cls, **kwargs: Any) -> None:
179+
# Required to ensure that subclasses of ExportPassBase only call
180+
# the correct call method based on the pass_type
181+
super().__init_subclass__(**kwargs)
182+
if "call" in cls.__dict__:
183+
original_call = cls.__dict__["call"]
184+
185+
def wrapped_call(
186+
self: "_ExportPassBase", graph_module: torch.fx.GraphModule
187+
) -> PassResult:
188+
if self.pass_type != ExportPassType.GRAPH_MODULE:
189+
raise ExportPassBaseError(
190+
f"Cannot call 'call' on a pass with pass_type={self.pass_type}. "
191+
f"Expected pass_type=ExportPassType.GRAPH_MODULE. "
192+
f"Use 'call_exported_program' for passes with ExportPassType.EXPORTED_PROGRAM."
193+
)
194+
195+
return original_call(self, graph_module)
196+
197+
cls.call = wrapped_call
198+
199+
if "call_exported_program" in cls.__dict__:
200+
original_call_ep = cls.__dict__["call_exported_program"]
201+
202+
def wrapped_call_exported_program(
203+
self: "_ExportPassBase", exported_program: ExportedProgram
204+
) -> ExportedProgramPassResult:
205+
if self.pass_type != ExportPassType.EXPORTED_PROGRAM:
206+
raise ExportPassBaseError(
207+
f"Cannot call 'call_exported_program' on a pass with pass_type={self.pass_type}. "
208+
f"Expected pass_type=ExportPassType.EXPORTED_PROGRAM. "
209+
f"Use 'call' for passes with ExportPassType.GRAPH_MODULE."
210+
)
211+
return original_call_ep(self, exported_program)
212+
213+
cls.call_exported_program = wrapped_call_exported_program
214+
166215
@staticmethod
167216
def _create_dummy_node_metadata() -> NodeMetadata:
168217
return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))})
@@ -394,14 +443,15 @@ def run_node(self, n: torch.fx.Node) -> Argument:
394443
self.callback.node_debug_str = n.format_node()
395444
return super().run_node(n)
396445

397-
def __init__(self) -> None:
446+
def __init__(self, pass_type: ExportPassType = ExportPassType.GRAPH_MODULE) -> None:
398447
self.interpreter = torch.fx.Interpreter(
399448
torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
400449
)
401450
self.tracer = self.ExportTracer(self, CodeGen()) # pyre-ignore
402451
self.fake_tensor_mode: Optional[FakeTensorMode] = None
403452
self._initialized = True
404453
self.node_debug_str: Optional[str] = None
454+
self.pass_type: Final[ExportPassType] = pass_type
405455

406456
def _fx(
407457
self,
@@ -651,6 +701,11 @@ def call(self, graph_module: fx.GraphModule) -> PassResult:
651701

652702
return result
653703

704+
def call_exported_program(
705+
self, exported_program: ExportedProgram
706+
) -> ExportedProgramPassResult:
707+
raise NotImplementedError("call_exported_program is not implemented.")
708+
654709

655710
class ExportPass(_ExportPassBase):
656711
class ExportTracer(_ExportPassBase.ExportTracer):

0 commit comments

Comments
 (0)