Skip to content

Commit ccef69e

Browse files
Andrew Pullinmeta-codesync[bot]
authored andcommitted
Intermediate commit for 1774374772
Differential Revision: D92059152
1 parent 76bf600 commit ccef69e

4 files changed

Lines changed: 510 additions & 0 deletions

File tree

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
from .decompose_ne_pass import DecomposeNotEqualPass # noqa
7979
from .decompose_quant_nodes import DecomposeQuantNodesPass # noqa
8080
from .decompose_remainder_pass import DecomposeRemainderPass # noqa
81+
from .decompose_rnn_pass import DecomposeRnnPass # noqa
8182
from .decompose_round_pass import DecomposeRoundPass # noqa
8283
from .decompose_sdpa_pass import DecomposeScaledDotProductAttentionPass # noqa
8384
from .decompose_select import DecomposeSelectPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
DecomposeNotEqualPass,
8080
DecomposeQuantNodesPass,
8181
DecomposeRemainderPass,
82+
DecomposeRnnPass,
8283
DecomposeRoundPass,
8384
DecomposeScaledDotProductAttentionPass,
8485
DecomposeSelectPass,
@@ -362,6 +363,7 @@ def _tosa_pipeline(
362363
DecomposeTOSAUnsupportedClampPass(),
363364
DecomposeGroupNormPass(),
364365
DecomposeGruPass(),
366+
DecomposeRnnPass(),
365367
DecomposeLayerNormPass(),
366368
DecomposeVarPass(),
367369
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec),
@@ -581,6 +583,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
581583
[
582584
NormalizeWhileInitialArgsPass(use_exir_clone=False, tfa_pass=True),
583585
DecomposeGruPass(tfa_pass=True),
586+
DecomposeRnnPass(tfa_pass=True),
584587
DecomposeNotEqualPass(tfa_pass=True),
585588
DecomposeCosineSimilarityPass(tfa_pass=True),
586589
DecomposeGluPass(tfa_pass=True),
Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import operator
8+
from typing import List, Set, Tuple, Type
9+
10+
import torch
11+
from executorch.backends.arm._passes.arm_pass import ArmPass
12+
from executorch.backends.arm._passes.arm_pass_utils import create_node
13+
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
14+
from executorch.exir.pass_base import ExportPass, PassResult
15+
16+
17+
class DecomposeRnnPass(ArmPass):
18+
"""Decomposes aten.rnn_tanh.input and aten.rnn_relu.input into elementary
19+
ops supported by TOSA.
20+
21+
RNN cell equation per timestep:
22+
h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)
23+
24+
where activation is tanh (rnn_tanh) or relu (rnn_relu).
25+
26+
Supports multi-layer, bidirectional, with/without bias, and batch_first.
27+
28+
"""
29+
30+
_passes_required_after: Set[Type[ExportPass]] = {InsertTableOpsPass}
31+
32+
_TARGETS = {
33+
torch.ops.aten.rnn_tanh.input,
34+
torch.ops.aten.rnn_relu.input,
35+
}
36+
37+
_mm = torch.ops.aten.mm.default
38+
_t = torch.ops.aten.t.default
39+
_add = torch.ops.aten.add.Tensor
40+
_tanh = torch.ops.aten.tanh.default
41+
_relu = torch.ops.aten.relu.default
42+
_unsqueeze = torch.ops.aten.unsqueeze.default
43+
_cat = torch.ops.aten.cat.default
44+
_select = torch.ops.aten.select_copy.int
45+
46+
def _build_direction(
47+
self,
48+
graph: torch.fx.Graph,
49+
node: torch.fx.Node,
50+
current_input: torch.fx.Node,
51+
h_prev: torch.fx.Node,
52+
weight_ih: torch.fx.Node,
53+
weight_hh: torch.fx.Node,
54+
bias_ih,
55+
bias_hh,
56+
seq_len: int,
57+
time_dim: int,
58+
reverse: bool,
59+
activation,
60+
) -> Tuple[List[torch.fx.Node], torch.fx.Node]:
61+
"""Build RNN cell computation for one direction.
62+
63+
Returns (timestep_outputs, h_final) where timestep_outputs are
64+
unsqueezed hidden states in forward time order.
65+
66+
"""
67+
w_ih_t = create_node(graph, self._t, args=(weight_ih,), from_node=node)
68+
w_hh_t = create_node(graph, self._t, args=(weight_hh,), from_node=node)
69+
70+
time_indices = range(seq_len - 1, -1, -1) if reverse else range(seq_len)
71+
timestep_outputs = []
72+
73+
for t_idx in time_indices:
74+
x_t = create_node(
75+
graph,
76+
self._select,
77+
args=(current_input, time_dim, t_idx),
78+
from_node=node,
79+
)
80+
81+
out_ih = create_node(graph, self._mm, args=(x_t, w_ih_t), from_node=node)
82+
out_hh = create_node(graph, self._mm, args=(h_prev, w_hh_t), from_node=node)
83+
84+
if bias_ih is not None:
85+
out_ih = create_node(
86+
graph, self._add, args=(out_ih, bias_ih), from_node=node
87+
)
88+
if bias_hh is not None:
89+
out_hh = create_node(
90+
graph, self._add, args=(out_hh, bias_hh), from_node=node
91+
)
92+
93+
pre_act = create_node(
94+
graph, self._add, args=(out_ih, out_hh), from_node=node
95+
)
96+
h_t = create_node(graph, activation, args=(pre_act,), from_node=node)
97+
h_prev = h_t
98+
99+
h_t_expanded = create_node(
100+
graph, self._unsqueeze, args=(h_t, time_dim), from_node=node
101+
)
102+
timestep_outputs.append(h_t_expanded)
103+
104+
if reverse:
105+
timestep_outputs.reverse()
106+
107+
return timestep_outputs, h_prev
108+
109+
def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
110+
graph = graph_module.graph
111+
made_changes = False
112+
113+
for node in list(graph.nodes):
114+
if (
115+
node.op != "call_function"
116+
or node.target not in self._TARGETS
117+
or not self.allowed_to_transform(node.meta)
118+
):
119+
continue
120+
121+
is_relu = node.target == torch.ops.aten.rnn_relu.input
122+
activation = self._relu if is_relu else self._tanh
123+
124+
args = node.args
125+
input_node = args[0]
126+
hx_node = args[1]
127+
params = args[2]
128+
has_biases = args[3]
129+
num_layers = args[4]
130+
# dropout (args[5]) and train (args[6]) are unused at inference
131+
bidirectional = args[7]
132+
batch_first = args[8]
133+
134+
input_val = input_node.meta["val"]
135+
136+
if batch_first:
137+
seq_len = input_val.shape[1]
138+
time_dim = 1
139+
else:
140+
seq_len = input_val.shape[0]
141+
time_dim = 0
142+
143+
num_directions = 2 if bidirectional else 1
144+
# Params per layer: (w_ih, w_hh[, b_ih, b_hh]) * num_directions
145+
dir_step = 4 if has_biases else 2
146+
layer_step = dir_step * num_directions
147+
148+
with graph.inserting_before(node):
149+
current_input = input_node
150+
layer_final_hiddens = []
151+
152+
for layer_idx in range(num_layers):
153+
layer_offset = layer_idx * layer_step
154+
155+
# Forward direction
156+
fw_off = layer_offset
157+
fw_w_ih = params[fw_off]
158+
fw_w_hh = params[fw_off + 1]
159+
fw_b_ih = params[fw_off + 2] if has_biases else None
160+
fw_b_hh = params[fw_off + 3] if has_biases else None
161+
162+
fw_h0 = create_node(
163+
graph,
164+
self._select,
165+
args=(hx_node, 0, num_directions * layer_idx),
166+
from_node=node,
167+
)
168+
169+
fw_outputs, fw_h_final = self._build_direction(
170+
graph,
171+
node,
172+
current_input,
173+
fw_h0,
174+
fw_w_ih,
175+
fw_w_hh,
176+
fw_b_ih,
177+
fw_b_hh,
178+
seq_len,
179+
time_dim,
180+
reverse=False,
181+
activation=activation,
182+
)
183+
184+
if bidirectional:
185+
bw_off = layer_offset + dir_step
186+
bw_w_ih = params[bw_off]
187+
bw_w_hh = params[bw_off + 1]
188+
bw_b_ih = params[bw_off + 2] if has_biases else None
189+
bw_b_hh = params[bw_off + 3] if has_biases else None
190+
191+
bw_h0 = create_node(
192+
graph,
193+
self._select,
194+
args=(hx_node, 0, 2 * layer_idx + 1),
195+
from_node=node,
196+
)
197+
198+
bw_outputs, bw_h_final = self._build_direction(
199+
graph,
200+
node,
201+
current_input,
202+
bw_h0,
203+
bw_w_ih,
204+
bw_w_hh,
205+
bw_b_ih,
206+
bw_b_hh,
207+
seq_len,
208+
time_dim,
209+
reverse=True,
210+
activation=activation,
211+
)
212+
213+
fw_combined = create_node(
214+
graph,
215+
self._cat,
216+
args=(fw_outputs, time_dim),
217+
from_node=node,
218+
)
219+
bw_combined = create_node(
220+
graph,
221+
self._cat,
222+
args=(bw_outputs, time_dim),
223+
from_node=node,
224+
)
225+
layer_output = create_node(
226+
graph,
227+
self._cat,
228+
args=([fw_combined, bw_combined], -1),
229+
from_node=node,
230+
)
231+
232+
layer_final_hiddens.append(
233+
create_node(
234+
graph,
235+
self._unsqueeze,
236+
args=(fw_h_final, 0),
237+
from_node=node,
238+
)
239+
)
240+
layer_final_hiddens.append(
241+
create_node(
242+
graph,
243+
self._unsqueeze,
244+
args=(bw_h_final, 0),
245+
from_node=node,
246+
)
247+
)
248+
else:
249+
layer_output = create_node(
250+
graph,
251+
self._cat,
252+
args=(fw_outputs, time_dim),
253+
from_node=node,
254+
)
255+
256+
layer_final_hiddens.append(
257+
create_node(
258+
graph,
259+
self._unsqueeze,
260+
args=(fw_h_final, 0),
261+
from_node=node,
262+
)
263+
)
264+
265+
current_input = layer_output
266+
267+
# Build h_n
268+
if len(layer_final_hiddens) == 1:
269+
h_n = layer_final_hiddens[0]
270+
else:
271+
h_n = create_node(
272+
graph,
273+
self._cat,
274+
args=(layer_final_hiddens, 0),
275+
from_node=node,
276+
)
277+
278+
output_node = current_input
279+
280+
# Replace getitem users: RNN returns (output, h_n)
281+
getitem_nodes = []
282+
for user in list(node.users.keys()):
283+
if user.target == operator.getitem:
284+
idx = user.args[1]
285+
if idx == 0:
286+
user.replace_all_uses_with(output_node)
287+
elif idx == 1:
288+
user.replace_all_uses_with(h_n)
289+
getitem_nodes.append(user)
290+
291+
for gi in getitem_nodes:
292+
graph.erase_node(gi)
293+
graph.erase_node(node)
294+
made_changes = True
295+
296+
if not made_changes:
297+
return PassResult(graph_module, False)
298+
299+
graph_module.recompile()
300+
graph_module = super().call(graph_module).graph_module
301+
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)