forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathspec_prop_pass.py
More file actions
123 lines (109 loc) · 4.97 KB
/
spec_prop_pass.py
File metadata and controls
123 lines (109 loc) · 4.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import operator
from typing import Optional
import torch
from executorch.exir.delegate import executorch_call_delegate
from executorch.exir.pass_base import ExportPass, ProxyValue
from executorch.exir.tensor import TensorSpec
from torch.export.exported_program import ExportGraphSignature
from torch.fx.node import Node
from torch.fx.passes.infra.pass_base import PassResult
from torch.utils import _pytree as pytree
# pyre-ignore
def make_spec(x):
if isinstance(x, ProxyValue):
return make_spec(x.node.meta["val"])
elif isinstance(x, torch.Tensor):
return TensorSpec.from_tensor(x)
elif isinstance(x, (int, bool, float)):
return x
else:
return None
def _is_mutable_buffer(
node: Node, graph_signature: Optional[ExportGraphSignature] = None
) -> bool:
"""
Check if the node is mutable buffer according to the provided graph signature.
"""
# graph signature is None for memory planning passes not called from EdgeProgramManager, these paths are deprecated so mutable buffers are not supported on them.
if graph_signature is None:
return False
if node.op == "placeholder":
if isinstance(node.target, str):
if node.target in graph_signature.inputs_to_buffers:
fqn = graph_signature.inputs_to_buffers[node.target]
# if the buffer is mutated then record that
if fqn in graph_signature.buffers_to_mutate.values():
return True
return False
class SpecPropPass(ExportPass):
def __init__(self) -> None:
super().__init__()
def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
# Re-trace metadata to ensure it's up to date.
res = ExportPass()(graph_module)
assert res is not None
gm = res.graph_module
def get_spec(x):
if hasattr(x, "meta"):
return x.meta.get("spec", None)
else:
return None
for module in gm.modules():
if isinstance(module, torch.fx.GraphModule):
for node in module.graph.nodes:
meta_val = node.meta.get("val", None)
if node.op == "output":
node.meta["spec"] = pytree.tree_map(get_spec, node.args[0])
elif node.op == "call_function" and node.target == operator.getitem:
value_spec = pytree.tree_map(get_spec, node.args[0])
node.meta["spec"] = value_spec[node.args[1]]
elif (
node.op == "call_function"
and node.target == executorch_call_delegate
):
# Note: We currently rely on delegate node specs not being regenerated,
# as the spec is set somewhat manually when adding the call delegate node.
# If we regenerate, it can change and break lowering (it becomes a tuple?).
# Ideally, we should figure out how to make the spec regeneration not break
# things.
#
# We do need to regenerate non-call-delegate node specs, as this pass is called
# multiple times in some lowering paths (backends can and do call it).
if "spec" not in node.meta:
node.meta["spec"] = pytree.tree_map(make_spec, meta_val)
else:
node.meta["spec"] = pytree.tree_map(make_spec, meta_val)
return res
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
return self(graph_module)
def update_placeholder_tensor_specs(
self,
exported_program: torch.export.ExportedProgram,
graph_module: torch.fx.GraphModule,
) -> None:
"""
Update the tensor specs for all placeholder nodes such that
placeholders that are parameters are marked as constant.
"""
for node in graph_module.graph.nodes:
if node.op != "placeholder":
continue
if "spec" not in node.meta:
raise RuntimeError(f"Placeholder node {node} missing meta['spec']")
spec = node.meta["spec"]
if isinstance(node.target, str) and (
node.target in exported_program.graph_signature.inputs_to_parameters
or (
node.target in exported_program.graph_signature.inputs_to_buffers
and not _is_mutable_buffer(node, exported_program.graph_signature)
)
or node.target
in exported_program.graph_signature.inputs_to_lifted_tensor_constants
):
spec.const = True