Skip to content

Commit 8003c8a

Browse files
Andrew Pullinfacebook-github-bot
authored andcommitted
Add DecomposeRnnPass for ARM backend (#17139)
Summary: Adds a decomposition pass that transforms aten.rnn_tanh.input and aten.rnn_relu.input into elementary ops supported by TOSA. RNN cell equation per timestep: h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh) where activation is tanh (rnn_tanh) or relu (rnn_relu). Features: - Multi-layer RNN support - Bidirectional RNN support - With/without bias - batch_first support - Both tanh and relu nonlinearities --- > Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/) [Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace) Differential Revision: D92059152
1 parent 0e9529a commit 8003c8a

File tree

4 files changed

+496
-0
lines changed

4 files changed

+496
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
from .decompose_ne_pass import DecomposeNotEqualPass # noqa
7272
from .decompose_quant_nodes import DecomposeQuantNodesPass # noqa
7373
from .decompose_remainder_pass import DecomposeRemainderPass # noqa
74+
from .decompose_rnn_pass import DecomposeRnnPass # noqa
7475
from .decompose_round_pass import DecomposeRoundPass # noqa
7576
from .decompose_sdpa_pass import DecomposeScaledDotProductAttentionPass # noqa
7677
from .decompose_select import DecomposeSelectPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
DecomposeNotEqualPass,
7373
DecomposeQuantNodesPass,
7474
DecomposeRemainderPass,
75+
DecomposeRnnPass,
7576
DecomposeRoundPass,
7677
DecomposeScaledDotProductAttentionPass,
7778
DecomposeSelectPass,
@@ -238,6 +239,7 @@ def _tosa_pipeline(
238239
DecomposeTOSAUnsupportedClampPass(),
239240
DecomposeGroupNormPass(),
240241
DecomposeGruPass(),
242+
DecomposeRnnPass(),
241243
DecomposeLayerNormPass(),
242244
DecomposeVarPass(),
243245
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec),
@@ -427,6 +429,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
427429
DecomposeAddSubAlphaPass(tfa_pass=True),
428430
DecomposeGroupNormPass(tfa_pass=True),
429431
DecomposeGruPass(tfa_pass=True),
432+
DecomposeRnnPass(tfa_pass=True),
430433
DecomposeLayerNormPass(tfa_pass=True),
431434
DecomposeVarPass(tfa_pass=True),
432435
DecomposeMeanDimPass(graph_module, self.tosa_spec, tfa_pass=True),
Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import operator
7+
from typing import List, Set, Tuple, Type
8+
9+
import torch
10+
from executorch.backends.arm._passes.arm_pass import ArmPass
11+
from executorch.backends.arm._passes.arm_pass_utils import create_node
12+
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
15+
16+
class DecomposeRnnPass(ArmPass):
17+
"""Decomposes aten.rnn_tanh.input and aten.rnn_relu.input into
18+
elementary ops supported by TOSA.
19+
20+
RNN cell equation per timestep:
21+
h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)
22+
23+
where activation is tanh (rnn_tanh) or relu (rnn_relu).
24+
25+
Supports multi-layer, bidirectional, with/without bias, and batch_first.
26+
"""
27+
28+
_passes_required_after: Set[Type[ExportPass]] = {InsertTableOpsPass}
29+
30+
_TARGETS = {
31+
torch.ops.aten.rnn_tanh.input,
32+
torch.ops.aten.rnn_relu.input,
33+
}
34+
35+
_mm = torch.ops.aten.mm.default
36+
_t = torch.ops.aten.t.default
37+
_add = torch.ops.aten.add.Tensor
38+
_tanh = torch.ops.aten.tanh.default
39+
_relu = torch.ops.aten.relu.default
40+
_unsqueeze = torch.ops.aten.unsqueeze.default
41+
_cat = torch.ops.aten.cat.default
42+
_select = torch.ops.aten.select_copy.int
43+
44+
def _build_direction(
45+
self,
46+
graph: torch.fx.Graph,
47+
node: torch.fx.Node,
48+
current_input: torch.fx.Node,
49+
h_prev: torch.fx.Node,
50+
weight_ih: torch.fx.Node,
51+
weight_hh: torch.fx.Node,
52+
bias_ih,
53+
bias_hh,
54+
seq_len: int,
55+
time_dim: int,
56+
reverse: bool,
57+
activation,
58+
) -> Tuple[List[torch.fx.Node], torch.fx.Node]:
59+
"""Build RNN cell computation for one direction.
60+
61+
Returns (timestep_outputs, h_final) where timestep_outputs are
62+
unsqueezed hidden states in forward time order.
63+
"""
64+
w_ih_t = create_node(graph, self._t, args=(weight_ih,), from_node=node)
65+
w_hh_t = create_node(graph, self._t, args=(weight_hh,), from_node=node)
66+
67+
time_indices = range(seq_len - 1, -1, -1) if reverse else range(seq_len)
68+
timestep_outputs = []
69+
70+
for t_idx in time_indices:
71+
x_t = create_node(
72+
graph,
73+
self._select,
74+
args=(current_input, time_dim, t_idx),
75+
from_node=node,
76+
)
77+
78+
out_ih = create_node(graph, self._mm, args=(x_t, w_ih_t), from_node=node)
79+
out_hh = create_node(graph, self._mm, args=(h_prev, w_hh_t), from_node=node)
80+
81+
if bias_ih is not None:
82+
out_ih = create_node(
83+
graph, self._add, args=(out_ih, bias_ih), from_node=node
84+
)
85+
if bias_hh is not None:
86+
out_hh = create_node(
87+
graph, self._add, args=(out_hh, bias_hh), from_node=node
88+
)
89+
90+
pre_act = create_node(
91+
graph, self._add, args=(out_ih, out_hh), from_node=node
92+
)
93+
h_t = create_node(graph, activation, args=(pre_act,), from_node=node)
94+
h_prev = h_t
95+
96+
h_t_expanded = create_node(
97+
graph, self._unsqueeze, args=(h_t, time_dim), from_node=node
98+
)
99+
timestep_outputs.append(h_t_expanded)
100+
101+
if reverse:
102+
timestep_outputs.reverse()
103+
104+
return timestep_outputs, h_prev
105+
106+
def call(self, graph_module: torch.fx.GraphModule):
107+
graph = graph_module.graph
108+
made_changes = False
109+
110+
for node in list(graph.nodes):
111+
if (
112+
node.op != "call_function"
113+
or node.target not in self._TARGETS
114+
or not self.allowed_to_transform(node.meta)
115+
):
116+
continue
117+
118+
is_relu = node.target == torch.ops.aten.rnn_relu.input
119+
activation = self._relu if is_relu else self._tanh
120+
121+
args = node.args
122+
input_node = args[0]
123+
hx_node = args[1]
124+
params = args[2]
125+
has_biases = args[3]
126+
num_layers = args[4]
127+
# dropout (args[5]) and train (args[6]) are unused at inference
128+
bidirectional = args[7]
129+
batch_first = args[8]
130+
131+
input_val = input_node.meta["val"]
132+
133+
if batch_first:
134+
seq_len = input_val.shape[1]
135+
time_dim = 1
136+
else:
137+
seq_len = input_val.shape[0]
138+
time_dim = 0
139+
140+
num_directions = 2 if bidirectional else 1
141+
# Params per layer: (w_ih, w_hh[, b_ih, b_hh]) * num_directions
142+
dir_step = 4 if has_biases else 2
143+
layer_step = dir_step * num_directions
144+
145+
with graph.inserting_before(node):
146+
current_input = input_node
147+
layer_final_hiddens = []
148+
149+
for layer_idx in range(num_layers):
150+
layer_offset = layer_idx * layer_step
151+
152+
# Forward direction
153+
fw_off = layer_offset
154+
fw_w_ih = params[fw_off]
155+
fw_w_hh = params[fw_off + 1]
156+
fw_b_ih = params[fw_off + 2] if has_biases else None
157+
fw_b_hh = params[fw_off + 3] if has_biases else None
158+
159+
fw_h0 = create_node(
160+
graph,
161+
self._select,
162+
args=(hx_node, 0, num_directions * layer_idx),
163+
from_node=node,
164+
)
165+
166+
fw_outputs, fw_h_final = self._build_direction(
167+
graph,
168+
node,
169+
current_input,
170+
fw_h0,
171+
fw_w_ih,
172+
fw_w_hh,
173+
fw_b_ih,
174+
fw_b_hh,
175+
seq_len,
176+
time_dim,
177+
reverse=False,
178+
activation=activation,
179+
)
180+
181+
if bidirectional:
182+
bw_off = layer_offset + dir_step
183+
bw_w_ih = params[bw_off]
184+
bw_w_hh = params[bw_off + 1]
185+
bw_b_ih = params[bw_off + 2] if has_biases else None
186+
bw_b_hh = params[bw_off + 3] if has_biases else None
187+
188+
bw_h0 = create_node(
189+
graph,
190+
self._select,
191+
args=(hx_node, 0, 2 * layer_idx + 1),
192+
from_node=node,
193+
)
194+
195+
bw_outputs, bw_h_final = self._build_direction(
196+
graph,
197+
node,
198+
current_input,
199+
bw_h0,
200+
bw_w_ih,
201+
bw_w_hh,
202+
bw_b_ih,
203+
bw_b_hh,
204+
seq_len,
205+
time_dim,
206+
reverse=True,
207+
activation=activation,
208+
)
209+
210+
merged = []
211+
for fw_out, bw_out in zip(fw_outputs, bw_outputs):
212+
merged.append(
213+
create_node(
214+
graph,
215+
self._cat,
216+
args=([fw_out, bw_out], -1),
217+
from_node=node,
218+
)
219+
)
220+
221+
layer_output = create_node(
222+
graph,
223+
self._cat,
224+
args=(merged, time_dim),
225+
from_node=node,
226+
)
227+
228+
layer_final_hiddens.append(
229+
create_node(
230+
graph,
231+
self._unsqueeze,
232+
args=(fw_h_final, 0),
233+
from_node=node,
234+
)
235+
)
236+
layer_final_hiddens.append(
237+
create_node(
238+
graph,
239+
self._unsqueeze,
240+
args=(bw_h_final, 0),
241+
from_node=node,
242+
)
243+
)
244+
else:
245+
layer_output = create_node(
246+
graph,
247+
self._cat,
248+
args=(fw_outputs, time_dim),
249+
from_node=node,
250+
)
251+
252+
layer_final_hiddens.append(
253+
create_node(
254+
graph,
255+
self._unsqueeze,
256+
args=(fw_h_final, 0),
257+
from_node=node,
258+
)
259+
)
260+
261+
current_input = layer_output
262+
263+
# Build h_n
264+
if len(layer_final_hiddens) == 1:
265+
h_n = layer_final_hiddens[0]
266+
else:
267+
h_n = create_node(
268+
graph,
269+
self._cat,
270+
args=(layer_final_hiddens, 0),
271+
from_node=node,
272+
)
273+
274+
output_node = current_input
275+
276+
# Replace getitem users: RNN returns (output, h_n)
277+
getitem_nodes = []
278+
for user in list(node.users.keys()):
279+
if user.target == operator.getitem:
280+
idx = user.args[1]
281+
if idx == 0:
282+
user.replace_all_uses_with(output_node)
283+
elif idx == 1:
284+
user.replace_all_uses_with(h_n)
285+
getitem_nodes.append(user)
286+
287+
for gi in getitem_nodes:
288+
graph.erase_node(gi)
289+
graph.erase_node(node)
290+
made_changes = True
291+
292+
if not made_changes:
293+
return PassResult(graph_module, False)
294+
295+
graph_module.recompile()
296+
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)