Skip to content

Commit 3d3a362

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
ExportedProgram passes (#16986)
Summary: - Adds support to _ExportPassBase to run passes on ExportedPrograms. Ensures that we can only run either call or call_exported_program, not both - Updates exir.PassManager to run on either exported programs or graph modules. Updates PassManagers which extend it to override the correct interfaces (minimal change, just a method renaming) - If we run the pass manager with an exported program, supports both graph module and exported program passes, but will always return an ExportedProgramPassResult - Updates transform to always call pass manager with an ExportedProgram, thus getting back an ExportedProgramPassResult. Differential Revision: D91725222
1 parent d05fe5e commit 3d3a362

7 files changed

Lines changed: 265 additions & 72 deletions

File tree

backends/arm/_passes/arm_pass_manager.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
import logging
99
from collections import defaultdict
1010
from collections.abc import Sequence
11+
from typing import Union
1112

1213
import executorch.backends.arm.tosa.dialect # noqa: unused
14+
import torch
1315
from executorch.backends.arm._passes import (
1416
AccumulateIndexPutPass,
1517
AnnotateOutputDimOrderPass,
@@ -133,11 +135,10 @@
133135
TosaSpecification,
134136
)
135137
from executorch.exir import ExportedProgram
136-
from executorch.exir.pass_base import ExportPass
138+
from executorch.exir.pass_base import ExportedProgramPassResult, ExportPass
137139
from executorch.exir.pass_manager import PassManager
138140
from torch.fx import GraphModule
139141
from torch.fx.passes.infra.pass_base import PassResult
140-
from torch.nn.modules import Module
141142

142143
logger = logging.getLogger(__name__)
143144

@@ -452,9 +453,11 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
452453

453454
return self._transform(graph_module)
454455

455-
def __call__(self, module: Module) -> PassResult:
456+
def __call__(
457+
self, module_or_program: Union[torch.fx.GraphModule, ExportedProgram]
458+
) -> Union[PassResult, ExportedProgramPassResult]:
456459
try:
457-
return super().__call__(module)
460+
return super().__call__(ExportedProgramPassResult)
458461
except Exception as e:
459462
first_exception = e.__cause__ or e.__context__ or e
460463
import re

backends/nxp/aten_passes/neutron_aten_pass_manager.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import Callable
77

88
import torch
9-
109
from executorch.backends.nxp.aten_passes.convert_unsqueeze_to_view import (
1110
ConvertUnsqueezeToViewPass,
1211
)
@@ -36,7 +35,7 @@
3635
)
3736
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
3837
from executorch.exir.pass_manager import PassManager
39-
from torch import nn
38+
from torch.fx import GraphModule
4039
from torch.fx.passes.infra.pass_base import PassResult
4140

4241
PassType = type[Callable[[torch.fx.GraphModule], PassResult]]
@@ -70,7 +69,7 @@ def __init__(
7069
passes: list[PassType] = passes or _get_default_passes(neutron_target_spec)
7170
super().__init__(passes)
7271

73-
def __call__(self, module: nn.Module) -> PassResult:
72+
def run_passes(self, module: GraphModule) -> PassResult:
7473
pass_result: PassResult = super().__call__(module)
7574

7675
graph_module = pass_result.graph_module

backends/nxp/edge_passes/neutron_edge_pass_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from executorch.backends.nxp.edge_passes.remove_as_strided_copy_nodes import (
1212
RemoveUselessAsStridedCopyNodes,
1313
)
14-
from torch.fx.passes.infra.pass_manager import PassManager
14+
from executorch.exir.pass_manager import PassManager
1515

1616

1717
class NeutronEdgePassManager(PassManager):

exir/pass_base.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
# LICENSE file in the root directory of this source tree.
77

88
# pyre-strict
9-
9+
import enum
1010
import operator
1111
import traceback
1212
from contextlib import nullcontext
13+
from dataclasses import dataclass
1314
from typing import (
1415
Any,
1516
Callable,
@@ -27,16 +28,15 @@
2728

2829
import torch
2930
from executorch.exir import memory
30-
3131
from executorch.exir.delegate import executorch_call_delegate, is_lowered_module
32-
3332
from executorch.exir.dialects.edge._ops import EdgeOpOverload
3433
from executorch.exir.error import ExportError, ExportErrorType
3534
from torch import fx
3635
from torch._dispatch.python import enable_python_dispatcher
3736
from torch._subclasses import FakeTensorMode, UnsupportedFakeTensorException
3837
from torch._subclasses.fake_tensor import FakeTensor
3938
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
39+
from torch.export import ExportedProgram
4040
from torch.fx import traceback as fx_traceback
4141
from torch.fx.experimental.proxy_tensor import PythonKeyTracer
4242
from torch.fx.graph import CodeGen
@@ -157,12 +157,60 @@ class ExportPassBaseError(RuntimeError):
157157
pass
158158

159159

160+
@dataclass(frozen=True)
161+
class ExportedProgramPassResult:
162+
exported_program: ExportedProgram
163+
modified: bool
164+
165+
166+
class ExportPassType(enum.Enum):
167+
EXPORTED_PROGRAM = enum.auto()
168+
GRAPH_MODULE = enum.auto()
169+
170+
160171
class _ExportPassBase(PassBase):
161172
"""
162173
Interpreter-based pass class to help users maintain the IR spec while writing
163174
transformations.
164175
"""
165176

177+
def __init_subclass__(cls, **kwargs: Any) -> None:
178+
# Required to ensure that subclasses of ExportPassBase only call
179+
# the correct call method based on the pass_type
180+
super().__init_subclass__(**kwargs)
181+
if "call" in cls.__dict__:
182+
original_call = cls.__dict__["call"]
183+
184+
def wrapped_call(
185+
self: "_ExportPassBase", graph_module: torch.fx.GraphModule
186+
) -> PassResult:
187+
if self.pass_type != ExportPassType.GRAPH_MODULE:
188+
raise ExportPassBaseError(
189+
f"Cannot call 'call' on a pass with pass_type={self.pass_type}. "
190+
f"Expected pass_type=ExportPassType.GRAPH_MODULE. "
191+
f"Use 'call_exported_program' for passes with ExportPassType.EXPORTED_PROGRAM."
192+
)
193+
194+
return original_call(self, graph_module)
195+
196+
cls.call = wrapped_call
197+
198+
if "call_exported_program" in cls.__dict__:
199+
original_call_ep = cls.__dict__["call_exported_program"]
200+
201+
def wrapped_call_exported_program(
202+
self: "_ExportPassBase", exported_program: ExportedProgram
203+
) -> ExportedProgramPassResult:
204+
if self.pass_type != ExportPassType.EXPORTED_PROGRAM:
205+
raise ExportPassBaseError(
206+
f"Cannot call 'call_exported_program' on a pass with pass_type={self.pass_type}. "
207+
f"Expected pass_type=ExportPassType.EXPORTED_PROGRAM. "
208+
f"Use 'call' for passes with ExportPassType.GRAPH_MODULE."
209+
)
210+
return original_call_ep(self, exported_program)
211+
212+
cls.call_exported_program = wrapped_call_exported_program
213+
166214
@staticmethod
167215
def _create_dummy_node_metadata() -> NodeMetadata:
168216
return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))})
@@ -394,14 +442,15 @@ def run_node(self, n: torch.fx.Node) -> Argument:
394442
self.callback.node_debug_str = n.format_node()
395443
return super().run_node(n)
396444

397-
def __init__(self) -> None:
445+
def __init__(self, pass_type: ExportPassType = ExportPassType.GRAPH_MODULE) -> None:
398446
self.interpreter = torch.fx.Interpreter(
399447
torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
400448
)
401449
self.tracer = self.ExportTracer(self, CodeGen()) # pyre-ignore
402450
self.fake_tensor_mode: Optional[FakeTensorMode] = None
403451
self._initialized = True
404452
self.node_debug_str: Optional[str] = None
453+
self.pass_type = pass_type
405454

406455
def _fx(
407456
self,
@@ -651,6 +700,11 @@ def call(self, graph_module: fx.GraphModule) -> PassResult:
651700

652701
return result
653702

703+
def call_exported_program(
704+
self, exported_program: ExportedProgram
705+
) -> ExportedProgramPassResult:
706+
raise NotImplementedError("call_exported_program is not implemented.")
707+
654708

655709
class ExportPass(_ExportPassBase):
656710
class ExportTracer(_ExportPassBase.ExportTracer):

0 commit comments

Comments
 (0)