forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinsert_write_back_for_buffers_pass.py
More file actions
185 lines (160 loc) · 6.34 KB
/
insert_write_back_for_buffers_pass.py
File metadata and controls
185 lines (160 loc) · 6.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Optional, Tuple
import torch
from executorch.exir.operator.convert import is_inplace_variant
from torch.export.exported_program import (
ExportedProgram,
ExportGraphSignature,
InputKind,
OutputKind,
OutputSpec,
)
from torch.export.graph_signature import TensorArgument
from torch.utils import _pytree as pytree
from torchgen.model import SchemaKind
def _insert_copy(
gm: torch.fx.GraphModule,
mutated_outputs: List[Optional[str]],
input_name_to_node: Dict[str, torch.fx.Node],
):
"""
Find the all the buffers and inputs that were mutated and insert copy_
operators to reflect mutations.
"""
output_node = gm.graph.output_node()
assert output_node is not None
outputs = pytree.tree_flatten(output_node.args)[0]
assert len(outputs) == len(mutated_outputs)
user_output_nodes = []
buffer_output_nodes = []
for return_node, mutated_node_name in zip(outputs, mutated_outputs):
# User output, leave alone
if mutated_node_name is None:
user_output_nodes.append(return_node)
continue
# Mutable buffer grab the node
if mutated_node_name in input_name_to_node:
mutated_node = input_name_to_node[mutated_node_name]
else:
raise RuntimeError(
f"Could not find {mutated_node_name} in either buffer or input nodes"
)
# insert copy
with gm.graph.inserting_before(output_node):
buffer_output = gm.graph.call_function(
torch.ops.aten.copy_.default, (mutated_node, return_node)
)
# add output of copy to graph outputs
buffer_output_nodes.append(buffer_output)
with gm.graph.inserting_before(output_node):
buffer_output_nodes.extend(user_output_nodes)
# Remove old outputs
new_output = gm.graph.output(tuple(buffer_output_nodes))
output_node.replace_all_uses_with(new_output)
gm.graph.erase_node(output_node)
return buffer_output_nodes
def _is_inplace_node(node: torch.fx.Node) -> bool:
"""Check if a node is an inplace node."""
return (
node.op == "call_function"
and hasattr(node.target, "_schema")
and is_inplace_variant(
node.target._schema.name, node.target._schema.overload_name # pyre-ignore
)
)
def _inplace_lineage(
output_arg: torch.fx.Node,
gs: ExportGraphSignature,
kind: SchemaKind,
) -> bool:
"""
Walk the graph backwards to see if output_arg is ultimately the same as an input.
"""
if kind != OutputKind.BUFFER_MUTATION and kind != OutputKind.USER_INPUT_MUTATION:
return False
while output_arg.op != "placeholder":
if _is_inplace_node(output_arg):
# From looking at native_functions.yaml, inplace ops always have self as the first arg
output_arg = output_arg.args[0] # pyre-ignore
else:
return False
# If the output arg was a buffer then it needs to reach a buffer placeholder
if kind == OutputKind.BUFFER_MUTATION:
return output_arg.target in gs.inputs_to_buffers
# If the output arg was a user input then it needs to reach a user input placeholder
assert kind == OutputKind.USER_INPUT_MUTATION
return output_arg.target in gs.user_inputs
def insert_write_back_for_buffers_pass(
ep: ExportedProgram,
) -> Tuple[torch.fx.GraphModule, ExportGraphSignature]:
gm: torch.fx.GraphModule = ep.graph_module
lifted_inputs: List[Optional[str]] = []
for in_spec in ep.graph_signature.input_specs:
if in_spec.kind in (
InputKind.BUFFER,
InputKind.CONSTANT_TENSOR,
InputKind.PARAMETER,
InputKind.CUSTOM_OBJ,
):
lifted_inputs.append(in_spec.target)
elif in_spec.kind is InputKind.USER_INPUT and isinstance(
in_spec.arg, TensorArgument
):
lifted_inputs.append(in_spec.arg.name)
else:
lifted_inputs.append(None)
input_name_to_node: Dict[str, torch.fx.Node] = {}
placeholder_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
assert len(lifted_inputs) == len(placeholder_nodes)
# Grab the all the non user inputs
for input_node, lifted_node in zip(placeholder_nodes, lifted_inputs):
if lifted_node is not None:
input_name_to_node[lifted_node] = input_node
output_node = gm.graph.output_node()
# Grab the mutable buffer nodes in the outputs,
mutated_outputs: List[Optional[str]] = []
for i, out_spec in enumerate(ep.graph_signature.output_specs):
# if the output arg is the input value then all operations on it are in-place
# so there's no need to add a copy_ node
if (
out_spec.kind
in (OutputKind.BUFFER_MUTATION, OutputKind.USER_INPUT_MUTATION)
and
# explicitly check if target exists (it should always be there)
out_spec.target in input_name_to_node
and
# if the arg and target are not the same, we add a copy_ node.
not _inplace_lineage(
output_node.args[0][i],
ep.graph_signature,
ep.graph_signature.output_specs[i].kind,
)
):
mutated_outputs.append(out_spec.target)
else:
mutated_outputs.append(None)
# insert the copy ops and update the outputs
buffer_output_nodes = _insert_copy(gm, mutated_outputs, input_name_to_node)
gm.graph.lint()
gm.graph.eliminate_dead_code()
gm.recompile()
# patch the output signature to point to the new updated outputs
new_output_specs: List[OutputSpec] = []
i = 0
for output_spec in ep.graph_signature.output_specs:
if output_spec.kind in (
OutputKind.BUFFER_MUTATION,
OutputKind.USER_INPUT_MUTATION,
):
output_spec.arg.name = buffer_output_nodes[i].name
i += 1
new_output_specs.append(output_spec)
signature = ExportGraphSignature(
input_specs=ep.graph_signature.input_specs,
output_specs=new_output_specs,
)
return gm, signature