forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathspec_prop_pass.py
More file actions
132 lines (116 loc) · 5.34 KB
/
spec_prop_pass.py
File metadata and controls
132 lines (116 loc) · 5.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import operator
from typing import Optional
import torch
from executorch.exir.delegate import executorch_call_delegate
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, ProxyValue
from executorch.exir.tensor import TensorSpec
from torch.export.exported_program import ExportGraphSignature
from torch.fx.node import Node
from torch.fx.passes.infra.pass_base import PassResult
from torch.utils import _pytree as pytree
# register llama.fallback (optional — only needed for QNN/llama sharding paths)
try:
import executorch.extension.llm.custom_ops.op_fallback # noqa: F401
_llama_fallback_default = exir_ops.edge.llama.fallback.default
except (ImportError, AttributeError):
_llama_fallback_default = None
# pyre-ignore
def make_spec(x):
if isinstance(x, ProxyValue):
return make_spec(x.node.meta["val"])
elif isinstance(x, torch.Tensor):
return TensorSpec.from_tensor(x)
elif isinstance(x, (int, bool, float)):
return x
else:
return None
def _is_mutable_buffer(
node: Node, graph_signature: Optional[ExportGraphSignature] = None
) -> bool:
"""
Check if the node is mutable buffer according to the provided graph signature.
"""
# graph signature is None for memory planning passes not called from EdgeProgramManager, these paths are deprecated so mutable buffers are not supported on them.
if graph_signature is None:
return False
if node.op == "placeholder":
if isinstance(node.target, str):
if node.target in graph_signature.inputs_to_buffers:
fqn = graph_signature.inputs_to_buffers[node.target]
# if the buffer is mutated then record that
if fqn in graph_signature.buffers_to_mutate.values():
return True
return False
class SpecPropPass(ExportPass):
def __init__(self) -> None:
super().__init__()
def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
# Re-trace metadata to ensure it's up to date.
res = ExportPass()(graph_module)
assert res is not None
gm = res.graph_module
def get_spec(x):
if hasattr(x, "meta"):
return x.meta.get("spec", None)
else:
return None
for module in gm.modules():
if isinstance(module, torch.fx.GraphModule):
for node in module.graph.nodes:
meta_val = node.meta.get("val", None)
if node.op == "output":
node.meta["spec"] = pytree.tree_map(get_spec, node.args[0])
elif node.op == "call_function" and node.target == operator.getitem:
value_spec = pytree.tree_map(get_spec, node.args[0])
node.meta["spec"] = value_spec[node.args[1]]
elif node.op == "call_function" and node.target in (
executorch_call_delegate,
_llama_fallback_default,
):
# Note: We currently rely on delegate node specs not being regenerated,
# as the spec is set somewhat manually when adding the call delegate node.
# If we regenerate, it can change and break lowering (it becomes a tuple?).
# Ideally, we should figure out how to make the spec regeneration not break
# things.
#
# We do need to regenerate non-call-delegate node specs, as this pass is called
# multiple times in some lowering paths (backends can and do call it).
if "spec" not in node.meta:
node.meta["spec"] = pytree.tree_map(make_spec, meta_val)
else:
node.meta["spec"] = pytree.tree_map(make_spec, meta_val)
return res
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
return self(graph_module)
def update_placeholder_tensor_specs(
self,
exported_program: torch.export.ExportedProgram,
graph_module: torch.fx.GraphModule,
) -> None:
"""
Update the tensor specs for all placeholder nodes such that
placeholders that are parameters are marked as constant.
"""
for node in graph_module.graph.nodes:
if node.op != "placeholder":
continue
if "spec" not in node.meta:
raise RuntimeError(f"Placeholder node {node} missing meta['spec']")
spec = node.meta["spec"]
if isinstance(node.target, str) and (
node.target in exported_program.graph_signature.inputs_to_parameters
or (
node.target in exported_program.graph_signature.inputs_to_buffers
and not _is_mutable_buffer(node, exported_program.graph_signature)
)
or node.target
in exported_program.graph_signature.inputs_to_lifted_tensor_constants
):
spec.const = True