Skip to content

Commit 77df9b7

Browse files
authored
New exported program pass manager and exported program passes (#16986)
Differential Revision: D91725222 Pull Request resolved: #16986
1 parent fb420f3 commit 77df9b7

10 files changed

Lines changed: 671 additions & 153 deletions

File tree

backends/arm/test/tester/test_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from executorch.backends.arm.vgf.compile_spec import VgfCompileSpec
4949
from executorch.backends.test.harness.stages import StageType
5050
from executorch.exir.pass_base import ExportPass
51-
from torch._export.pass_base import PassType
51+
from executorch.exir.pass_manager import PassType
5252
from torch.export.graph_signature import InputKind, OutputKind
5353
from torchao.quantization.pt2e.quantizer import QuantizationSpec
5454

backends/qualcomm/_passes/recompose_pad_maxpool2d.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,8 @@
1313
from executorch.exir.pass_base import ExportPass, PassResult
1414
from executorch.exir.passes import dead_code_elimination_pass
1515

16-
from torch._subclasses.fake_tensor import FakeTensorMode
17-
18-
19-
def add_fake_tensor_to_node(padding_node, input_shape, padding_args, dtype):
20-
fake_mode = FakeTensorMode()
2116

17+
def add_fake_tensor_to_node(padding_node, input_shape, padding_args, dtype, fake_mode):
2218
with fake_mode:
2319
batch, channels, height, width = input_shape
2420
pad_left, pad_right, pad_top, pad_bottom = padding_args
@@ -114,6 +110,7 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa C901
114110
input_node.meta["val"].shape,
115111
padding,
116112
input_node.meta["val"].dtype,
113+
input_node.meta["val"].fake_mode,
117114
)
118115
if quant_attrs:
119116
padding_node.meta["quant_attrs"] = node.meta["quant_attrs"]

backends/qualcomm/_passes/utils.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,23 @@ def copy_nn_module_stack(src, target):
137137
target.meta["nn_module_stack"] = value
138138

139139

140-
def merge_decomposed_graph(
140+
def _unify_fake_mode(node: torch.fx.Node, fake_mode) -> None:
141+
val = node.meta.get("val")
142+
if val is None:
143+
return
144+
if isinstance(val, FakeTensor) and val.fake_mode is not fake_mode:
145+
node.meta["val"] = fake_mode.from_tensor(val)
146+
elif isinstance(val, (list, tuple)):
147+
unified = []
148+
for v in val:
149+
if isinstance(v, FakeTensor) and v.fake_mode is not fake_mode:
150+
unified.append(fake_mode.from_tensor(v))
151+
else:
152+
unified.append(v)
153+
node.meta["val"] = type(val)(unified)
154+
155+
156+
def merge_decomposed_graph( # noqa: C901
141157
remap: Dict[str, torch.fx.Node],
142158
target_node: torch.fx.Node,
143159
target_graph: torch.fx.GraphModule,
@@ -148,6 +164,16 @@ def merge_decomposed_graph(
148164
[torch.fx.Node, torch.fx.Node, Dict[str, torch.fx.Node]], None
149165
] = None,
150166
) -> None:
167+
target_fake_mode = None
168+
target_val = target_node.meta.get("val")
169+
if isinstance(target_val, FakeTensor):
170+
target_fake_mode = target_val.fake_mode
171+
elif isinstance(target_val, (list, tuple)):
172+
for v in target_val:
173+
if isinstance(v, FakeTensor):
174+
target_fake_mode = v.fake_mode
175+
break
176+
151177
def default_output_process(node):
152178
for user in node.users.copy():
153179
# remap
@@ -170,10 +196,13 @@ def default_output_process(node):
170196
# replace node map from string to graph node
171197
remap[decomposed_node] = remap.pop(decomposed_node.name)
172198
else:
173-
remap[decomposed_node] = target_graph.node_copy(
199+
copied = target_graph.node_copy(
174200
decomposed_node,
175201
arg_transform=lambda x, remap=remap: remap[x],
176202
)
203+
if target_fake_mode is not None:
204+
_unify_fake_mode(copied, target_fake_mode)
205+
remap[decomposed_node] = copied
177206

178207

179208
def is_float_tensor(node: torch.fx.Node) -> bool:

exir/BUCK

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,14 +259,26 @@ fbcode_target(_kind = runtime.python_library,
259259
],
260260
)
261261

262+
fbcode_target(_kind = runtime.python_library,
263+
name = "_program_utils",
264+
srcs = [
265+
"_program_utils.py",
266+
],
267+
deps = [
268+
"//caffe2:torch",
269+
],
270+
)
271+
262272
fbcode_target(_kind = runtime.python_library,
263273
name = "pass_manager",
264274
srcs = [
265275
"pass_manager.py",
266276
],
267277
deps = [
268278
"fbsource//third-party/pypi/typing-extensions:typing-extensions",
279+
":_program_utils",
269280
":error",
281+
":pass_base",
270282
"//caffe2:torch",
271283
],
272284
)

exir/_program_utils.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import torch
10+
from torch.export.exported_program import (
11+
ConstantArgument,
12+
ExportGraphSignature,
13+
InputSpec,
14+
OutputSpec,
15+
)
16+
17+
18+
def _get_updated_range_constraints(gm):
19+
def get_shape_env(gm):
20+
vals = [
21+
node.meta["val"]
22+
for node in gm.graph.nodes
23+
if node.meta.get("val", None) is not None
24+
]
25+
from torch._guards import detect_fake_mode # type: ignore[21]
26+
27+
fake_mode = detect_fake_mode(vals)
28+
if fake_mode is not None:
29+
return fake_mode.shape_env
30+
for v in vals:
31+
if isinstance(v, torch.SymInt):
32+
return v.node.shape_env
33+
34+
shape_env = get_shape_env(gm)
35+
if shape_env is None:
36+
return {}
37+
range_constraints = {
38+
shape_env.replacements.get(k, k): v for k, v in shape_env.var_to_range.items()
39+
}
40+
# Only when we have an unbacked symint, and it's used as constructor inputs,
41+
# runtime_var_to_range will make a difference compated to var_to_range.
42+
# e.g. [2, oo) -> [0, oo)
43+
for k, v in shape_env.var_to_range.items():
44+
if k not in shape_env.replacements:
45+
range_constraints[k] = v
46+
return range_constraints
47+
48+
49+
def _get_updated_graph_signature(
50+
old_signature: ExportGraphSignature,
51+
new_gm: torch.fx.GraphModule,
52+
) -> ExportGraphSignature:
53+
"""
54+
Update the graph signature's user_input/user_outputs.
55+
"""
56+
new_input_specs = []
57+
i = 0
58+
for node in new_gm.graph.nodes:
59+
if node.op != "placeholder":
60+
continue
61+
62+
assert i < len(
63+
old_signature.input_specs
64+
), "Number of inputs changed after transformation"
65+
old_input_spec = old_signature.input_specs[i]
66+
arg = (
67+
old_input_spec.arg
68+
if isinstance(old_input_spec.arg, ConstantArgument)
69+
# pyre-fixme[20]: Argument `class_fqn` expected.
70+
else type(old_input_spec.arg)(node.name)
71+
)
72+
new_input_specs.append(
73+
InputSpec(
74+
old_input_spec.kind,
75+
arg,
76+
old_input_spec.target,
77+
persistent=old_input_spec.persistent,
78+
)
79+
)
80+
i += 1
81+
82+
output_node = new_gm.graph.output_node()
83+
assert output_node.op == "output"
84+
85+
new_output_specs = []
86+
for i, node in enumerate(output_node.args[0]):
87+
assert i < len(
88+
old_signature.output_specs
89+
), "Number of outputs changed after transformation"
90+
old_output_spec = old_signature.output_specs[i]
91+
arg = (
92+
old_output_spec.arg
93+
if isinstance(old_output_spec.arg, ConstantArgument)
94+
# pyre-fixme[20]: Argument `class_fqn` expected.
95+
else type(old_output_spec.arg)(node.name)
96+
)
97+
new_output_specs.append(
98+
OutputSpec(old_output_spec.kind, arg, old_output_spec.target)
99+
)
100+
101+
new_signature = ExportGraphSignature(
102+
input_specs=new_input_specs, output_specs=new_output_specs
103+
)
104+
return new_signature

exir/pass_base.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
# LICENSE file in the root directory of this source tree.
77

88
# pyre-strict
9-
109
import operator
1110
import traceback
11+
from abc import ABC, abstractmethod
1212
from contextlib import nullcontext
13+
from dataclasses import dataclass
1314
from typing import (
1415
Any,
1516
Callable,
@@ -27,16 +28,15 @@
2728

2829
import torch
2930
from executorch.exir import memory
30-
3131
from executorch.exir.delegate import executorch_call_delegate, is_lowered_module
32-
3332
from executorch.exir.dialects.edge._ops import EdgeOpOverload
3433
from executorch.exir.error import ExportError, ExportErrorType
3534
from torch import fx
3635
from torch._dispatch.python import enable_python_dispatcher
3736
from torch._subclasses import FakeTensorMode, UnsupportedFakeTensorException
3837
from torch._subclasses.fake_tensor import FakeTensor
3938
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
39+
from torch.export import ExportedProgram
4040
from torch.fx import traceback as fx_traceback
4141
from torch.fx.experimental.proxy_tensor import PythonKeyTracer
4242
from torch.fx.graph import CodeGen
@@ -182,6 +182,58 @@ class ExportPassBaseError(RuntimeError):
182182
pass
183183

184184

185+
@dataclass(frozen=True)
186+
class ExportedProgramPassResult:
187+
exported_program: ExportedProgram
188+
modified: bool
189+
190+
191+
class ExportedProgramPassBase(ABC):
192+
"""
193+
Base interface for implementing passes that operate on ExportedProgram.
194+
"""
195+
196+
def __call__(self, exported_program: ExportedProgram) -> ExportedProgramPassResult:
197+
"""
198+
Runs the precondition check, the pass itself, and the postcondition check.
199+
"""
200+
201+
self.requires(exported_program)
202+
res = self.call(exported_program)
203+
self.ensures(exported_program)
204+
return res
205+
206+
@abstractmethod
207+
def call(self, exported_program: ExportedProgram) -> ExportedProgramPassResult:
208+
"""
209+
The pass that is run through the given exported program. To implement a
210+
pass, it is required to implement this function.
211+
212+
Args:
213+
exported_program: The exported program we will run a pass on
214+
"""
215+
216+
def requires(self, exported_program: ExportedProgram) -> None: # noqa: B027
217+
"""
218+
This function will be called before the pass is run and will check that
219+
the given exported program contains the preconditions needed to run the
220+
pass. It is not required to implement this function.
221+
222+
Args:
223+
exported_program: The exported program we will run checks on
224+
"""
225+
226+
def ensures(self, exported_program: ExportedProgram) -> None: # noqa: B027
227+
"""
228+
This function will be called after the pass is run and will check that
229+
the given exported program contains the postconditions needed to run the
230+
pass. It is not required to implement this function.
231+
232+
Args:
233+
exported_program: The exported program we will run checks on
234+
"""
235+
236+
185237
class _ExportPassBase(PassBase):
186238
"""
187239
Interpreter-based pass class to help users maintain the IR spec while writing

0 commit comments

Comments
 (0)