forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgraph_module.py
More file actions
153 lines (129 loc) · 5.17 KB
/
graph_module.py
File metadata and controls
153 lines (129 loc) · 5.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from types import FunctionType as function
from typing import Callable, Dict, List, Tuple, Union
import torch
from torch._ops import HigherOrderOperator
LeafValue = Union[
torch.Tensor,
str,
int,
float,
bool,
complex,
torch.dtype,
torch.device,
torch.memory_format,
torch.layout,
None,
]
# We maintain a global cache of op lookups as this significantly speeds up
# deserialization because hasattr(torch.ops, name) is an expensive call.
_cache_ops_dict: Dict[
Tuple[str, str], Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket]
] = {}
_cache_fake_ops_dict: Dict[Tuple[str, str], function] = {}
def _get_submodule(
graph_module: torch.fx.GraphModule, node: torch.fx.Node, arg_index: int
) -> Tuple[str, torch.nn.Module, torch.fx.Node]:
submod_node = node.args[arg_index]
assert isinstance(submod_node, torch.fx.Node)
assert submod_node.op == "get_attr"
assert isinstance(submod_node.target, str)
submodule = graph_module.get_submodule(submod_node.target)
# pyre-ignore
return submod_node.target, submodule, node
def _get_control_flow_submodules(
graph_module: torch.fx.GraphModule,
op_to_submodule_arg_index: dict[HigherOrderOperator, list[int]],
) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]:
"""
Returns a list of submodules used for control flow operations
that are in the given toplevel graph (does not look
into submodules). Specifically, the returned value is a list containing
tuples of (name of the submodule that's stored in the graph module, the
submodule itself, and the fx node that uses this submodule).
"""
control_flow_submodules = []
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue
for op in op_to_submodule_arg_index:
if node.target is not op:
continue
for i in op_to_submodule_arg_index[op]:
control_flow_submodules.append(_get_submodule(graph_module, node, i))
return control_flow_submodules
def get_control_flow_submodules(
graph_module: torch.fx.GraphModule,
) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]:
"""
Returns a list of submodules used for control flow operations
(torch.ops.higher_order.cond/map/scan) that are in the given toplevel graph (does not look
into submodules). Specifically, the returned value is a list containing
tuples of (name of the submodule that's stored in the graph module, the
submodule itself, and the fx node that uses this submodule).
"""
return _get_control_flow_submodules(
graph_module,
{
torch.ops.higher_order.cond: [1, 2],
torch.ops.higher_order.map_impl: [0],
torch.ops.higher_order.scan: [0], # combine_fn is at arg index 0
},
)
def get_cond_while_submodules(
graph_module: torch.fx.GraphModule,
) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]:
"""
Returns a list of submodules used for control flow operations
(torch.ops.higher_order.cond/while_loop) that are in the given toplevel graph (does not look
into submodules). Specifically, the returned value is a list containing
tuples of (name of the submodule that's stored in the graph module, the
submodule itself, and the fx node that uses this submodule).
"""
return _get_control_flow_submodules(
graph_module,
{
torch.ops.higher_order.cond: [1, 2],
torch.ops.higher_order.while_loop: [0, 1],
},
)
def get_scan_submodules(
graph_module: torch.fx.GraphModule,
) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]:
"""
Returns a list of submodules used for scan operations
(torch.ops.higher_order.scan) that are in the given toplevel graph (does not look
into submodules). Specifically, the returned value is a list containing
tuples of (name of the submodule that's stored in the graph module, the
submodule itself, and the fx node that uses this submodule).
For scan, the combine_fn submodule is at argument index 0.
The scan operator signature is: scan(combine_fn, init, xs, additional_inputs)
"""
return _get_control_flow_submodules(
graph_module,
{
torch.ops.higher_order.scan: [0],
},
)
def bfs_trace_with_node_process(
gm: torch.fx.GraphModule, node_op: Callable[[torch.fx.Node], None]
) -> None:
"""Traverse the graph module and apply node_op to each node."""
assert isinstance(gm, torch.fx.GraphModule), f"Expected GraphModule, got {type(gm)}"
queue = [gm]
while queue:
current_graph_module = queue.pop(0)
for node in current_graph_module.graph.nodes:
node_op(node)
control_flow_submodules = [
submodule
for _, submodule, _ in get_control_flow_submodules(current_graph_module)
]
queue.extend(control_flow_submodules)