|
6 | 6 | # LICENSE file in the root directory of this source tree. |
7 | 7 |
|
8 | 8 | # pyre-strict |
9 | | - |
| 9 | +import enum |
10 | 10 | import operator |
11 | 11 | import traceback |
12 | 12 | from contextlib import nullcontext |
| 13 | +from dataclasses import dataclass |
13 | 14 | from typing import ( |
14 | 15 | Any, |
15 | 16 | Callable, |
16 | 17 | Dict, |
| 18 | + Final, |
17 | 19 | List, |
18 | 20 | MutableMapping, |
19 | 21 | Optional, |
|
27 | 29 |
|
28 | 30 | import torch |
29 | 31 | from executorch.exir import memory |
30 | | - |
31 | 32 | from executorch.exir.delegate import executorch_call_delegate, is_lowered_module |
32 | | - |
33 | 33 | from executorch.exir.dialects.edge._ops import EdgeOpOverload |
34 | 34 | from executorch.exir.error import ExportError, ExportErrorType |
35 | 35 | from torch import fx |
36 | 36 | from torch._dispatch.python import enable_python_dispatcher |
37 | 37 | from torch._subclasses import FakeTensorMode, UnsupportedFakeTensorException |
38 | 38 | from torch._subclasses.fake_tensor import FakeTensor |
39 | 39 | from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode |
| 40 | +from torch.export import ExportedProgram |
40 | 41 | from torch.fx import traceback as fx_traceback |
41 | 42 | from torch.fx.experimental.proxy_tensor import PythonKeyTracer |
42 | 43 | from torch.fx.graph import CodeGen |
@@ -157,12 +158,60 @@ class ExportPassBaseError(RuntimeError): |
157 | 158 | pass |
158 | 159 |
|
159 | 160 |
|
| 161 | +@dataclass(frozen=True) |
| 162 | +class ExportedProgramPassResult: |
| 163 | + exported_program: ExportedProgram |
| 164 | + modified: bool |
| 165 | + |
| 166 | + |
| 167 | +class ExportPassType(enum.Enum): |
| 168 | + EXPORTED_PROGRAM = enum.auto() |
| 169 | + GRAPH_MODULE = enum.auto() |
| 170 | + |
| 171 | + |
160 | 172 | class _ExportPassBase(PassBase): |
161 | 173 | """ |
162 | 174 | Interpreter-based pass class to help users maintain the IR spec while writing |
163 | 175 | transformations. |
164 | 176 | """ |
165 | 177 |
|
| 178 | + def __init_subclass__(cls, **kwargs: Any) -> None: |
| 179 | + # Required to ensure that subclasses of ExportPassBase only call |
| 180 | + # the correct call method based on the pass_type |
| 181 | + super().__init_subclass__(**kwargs) |
| 182 | + if "call" in cls.__dict__: |
| 183 | + original_call = cls.__dict__["call"] |
| 184 | + |
| 185 | + def wrapped_call( |
| 186 | + self: "_ExportPassBase", graph_module: torch.fx.GraphModule |
| 187 | + ) -> PassResult: |
| 188 | + if self.pass_type != ExportPassType.GRAPH_MODULE: |
| 189 | + raise ExportPassBaseError( |
| 190 | + f"Cannot call 'call' on a pass with pass_type={self.pass_type}. " |
| 191 | + f"Expected pass_type=ExportPassType.GRAPH_MODULE. " |
| 192 | + f"Use 'call_exported_program' for passes with ExportPassType.EXPORTED_PROGRAM." |
| 193 | + ) |
| 194 | + |
| 195 | + return original_call(self, graph_module) |
| 196 | + |
| 197 | + cls.call = wrapped_call |
| 198 | + |
| 199 | + if "call_exported_program" in cls.__dict__: |
| 200 | + original_call_ep = cls.__dict__["call_exported_program"] |
| 201 | + |
| 202 | + def wrapped_call_exported_program( |
| 203 | + self: "_ExportPassBase", exported_program: ExportedProgram |
| 204 | + ) -> ExportedProgramPassResult: |
| 205 | + if self.pass_type != ExportPassType.EXPORTED_PROGRAM: |
| 206 | + raise ExportPassBaseError( |
| 207 | + f"Cannot call 'call_exported_program' on a pass with pass_type={self.pass_type}. " |
| 208 | + f"Expected pass_type=ExportPassType.EXPORTED_PROGRAM. " |
| 209 | + f"Use 'call' for passes with ExportPassType.GRAPH_MODULE." |
| 210 | + ) |
| 211 | + return original_call_ep(self, exported_program) |
| 212 | + |
| 213 | + cls.call_exported_program = wrapped_call_exported_program |
| 214 | + |
166 | 215 | @staticmethod |
167 | 216 | def _create_dummy_node_metadata() -> NodeMetadata: |
168 | 217 | return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))}) |
@@ -394,14 +443,15 @@ def run_node(self, n: torch.fx.Node) -> Argument: |
394 | 443 | self.callback.node_debug_str = n.format_node() |
395 | 444 | return super().run_node(n) |
396 | 445 |
|
397 | | - def __init__(self) -> None: |
| 446 | + def __init__(self, pass_type: ExportPassType = ExportPassType.GRAPH_MODULE) -> None: |
398 | 447 | self.interpreter = torch.fx.Interpreter( |
399 | 448 | torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) |
400 | 449 | ) |
401 | 450 | self.tracer = self.ExportTracer(self, CodeGen()) # pyre-ignore |
402 | 451 | self.fake_tensor_mode: Optional[FakeTensorMode] = None |
403 | 452 | self._initialized = True |
404 | 453 | self.node_debug_str: Optional[str] = None |
| 454 | + self.pass_type: Final[ExportPassType] = pass_type |
405 | 455 |
|
406 | 456 | def _fx( |
407 | 457 | self, |
@@ -651,6 +701,11 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: |
651 | 701 |
|
652 | 702 | return result |
653 | 703 |
|
| 704 | + def call_exported_program( |
| 705 | + self, exported_program: ExportedProgram |
| 706 | + ) -> ExportedProgramPassResult: |
| 707 | + raise NotImplementedError("call_exported_program is not implemented.") |
| 708 | + |
654 | 709 |
|
655 | 710 | class ExportPass(_ExportPassBase): |
656 | 711 | class ExportTracer(_ExportPassBase.ExportTracer): |
|
0 commit comments