Skip to content

Commit c2b44ea

Browse files
Andrew Pullinfacebook-github-bot
authored andcommitted
Add DecomposeRnnPass for ARM backend (#17139)
Summary: Adds a decomposition pass that transforms aten.rnn_tanh.input and aten.rnn_relu.input into elementary ops supported by TOSA. RNN cell equation per timestep: h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh) where activation is tanh (rnn_tanh) or relu (rnn_relu). Features: - Multi-layer RNN support - Bidirectional RNN support - With/without bias - batch_first support - Both tanh and relu nonlinearities Differential Revision: D92059152
1 parent 6583c97 commit c2b44ea

File tree

4 files changed

+503
-0
lines changed

4 files changed

+503
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
from .decompose_ne_pass import DecomposeNotEqualPass # noqa
7878
from .decompose_quant_nodes import DecomposeQuantNodesPass # noqa
7979
from .decompose_remainder_pass import DecomposeRemainderPass # noqa
80+
from .decompose_rnn_pass import DecomposeRnnPass # noqa
8081
from .decompose_round_pass import DecomposeRoundPass # noqa
8182
from .decompose_sdpa_pass import DecomposeScaledDotProductAttentionPass # noqa
8283
from .decompose_select import DecomposeSelectPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
DecomposeNotEqualPass,
7979
DecomposeQuantNodesPass,
8080
DecomposeRemainderPass,
81+
DecomposeRnnPass,
8182
DecomposeRoundPass,
8283
DecomposeScaledDotProductAttentionPass,
8384
DecomposeSelectPass,
@@ -364,6 +365,7 @@ def _tosa_pipeline(
364365
DecomposeTOSAUnsupportedClampPass(),
365366
DecomposeGroupNormPass(),
366367
DecomposeGruPass(),
368+
DecomposeRnnPass(),
367369
DecomposeLayerNormPass(),
368370
DecomposeVarPass(),
369371
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec),
@@ -582,6 +584,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
582584
[
583585
NormalizeWhileInitialArgsPass(use_exir_clone=False, tfa_pass=True),
584586
DecomposeGruPass(tfa_pass=True),
587+
DecomposeRnnPass(tfa_pass=True),
585588
DecomposeNotEqualPass(tfa_pass=True),
586589
DecomposeCosineSimilarityPass(tfa_pass=True),
587590
DecomposeGluPass(tfa_pass=True),
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import operator
7+
from typing import List, Set, Tuple, Type
8+
9+
import torch
10+
from executorch.backends.arm._passes.arm_pass import ArmPass
11+
from executorch.backends.arm._passes.arm_pass_utils import create_node
12+
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
15+
16+
class DecomposeRnnPass(ArmPass):
17+
"""Decomposes aten.rnn_tanh.input and aten.rnn_relu.input into elementary
18+
ops supported by TOSA.
19+
20+
RNN cell equation per timestep:
21+
h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)
22+
23+
where activation is tanh (rnn_tanh) or relu (rnn_relu).
24+
25+
Supports multi-layer, bidirectional, with/without bias, and batch_first.
26+
27+
"""
28+
29+
_passes_required_after: Set[Type[ExportPass]] = {InsertTableOpsPass}
30+
31+
_TARGETS = {
32+
torch.ops.aten.rnn_tanh.input,
33+
torch.ops.aten.rnn_relu.input,
34+
}
35+
36+
_mm = torch.ops.aten.mm.default
37+
_t = torch.ops.aten.t.default
38+
_add = torch.ops.aten.add.Tensor
39+
_tanh = torch.ops.aten.tanh.default
40+
_relu = torch.ops.aten.relu.default
41+
_unsqueeze = torch.ops.aten.unsqueeze.default
42+
_cat = torch.ops.aten.cat.default
43+
_select = torch.ops.aten.select_copy.int
44+
45+
def _build_direction(
46+
self,
47+
graph: torch.fx.Graph,
48+
node: torch.fx.Node,
49+
current_input: torch.fx.Node,
50+
h_prev: torch.fx.Node,
51+
weight_ih: torch.fx.Node,
52+
weight_hh: torch.fx.Node,
53+
bias_ih,
54+
bias_hh,
55+
seq_len: int,
56+
time_dim: int,
57+
reverse: bool,
58+
activation,
59+
) -> Tuple[List[torch.fx.Node], torch.fx.Node]:
60+
"""Build RNN cell computation for one direction.
61+
62+
Returns (timestep_outputs, h_final) where timestep_outputs are
63+
unsqueezed hidden states in forward time order.
64+
65+
"""
66+
w_ih_t = create_node(graph, self._t, args=(weight_ih,), from_node=node)
67+
w_hh_t = create_node(graph, self._t, args=(weight_hh,), from_node=node)
68+
69+
time_indices = range(seq_len - 1, -1, -1) if reverse else range(seq_len)
70+
timestep_outputs = []
71+
72+
for t_idx in time_indices:
73+
x_t = create_node(
74+
graph,
75+
self._select,
76+
args=(current_input, time_dim, t_idx),
77+
from_node=node,
78+
)
79+
80+
out_ih = create_node(graph, self._mm, args=(x_t, w_ih_t), from_node=node)
81+
out_hh = create_node(graph, self._mm, args=(h_prev, w_hh_t), from_node=node)
82+
83+
if bias_ih is not None:
84+
out_ih = create_node(
85+
graph, self._add, args=(out_ih, bias_ih), from_node=node
86+
)
87+
if bias_hh is not None:
88+
out_hh = create_node(
89+
graph, self._add, args=(out_hh, bias_hh), from_node=node
90+
)
91+
92+
pre_act = create_node(
93+
graph, self._add, args=(out_ih, out_hh), from_node=node
94+
)
95+
h_t = create_node(graph, activation, args=(pre_act,), from_node=node)
96+
h_prev = h_t
97+
98+
h_t_expanded = create_node(
99+
graph, self._unsqueeze, args=(h_t, time_dim), from_node=node
100+
)
101+
timestep_outputs.append(h_t_expanded)
102+
103+
if reverse:
104+
timestep_outputs.reverse()
105+
106+
return timestep_outputs, h_prev
107+
108+
def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
109+
graph = graph_module.graph
110+
made_changes = False
111+
112+
for node in list(graph.nodes):
113+
if (
114+
node.op != "call_function"
115+
or node.target not in self._TARGETS
116+
or not self.allowed_to_transform(node.meta)
117+
):
118+
continue
119+
120+
is_relu = node.target == torch.ops.aten.rnn_relu.input
121+
activation = self._relu if is_relu else self._tanh
122+
123+
args = node.args
124+
input_node = args[0]
125+
hx_node = args[1]
126+
params = args[2]
127+
has_biases = args[3]
128+
num_layers = args[4]
129+
# dropout (args[5]) and train (args[6]) are unused at inference
130+
bidirectional = args[7]
131+
batch_first = args[8]
132+
133+
input_val = input_node.meta["val"]
134+
135+
if batch_first:
136+
seq_len = input_val.shape[1]
137+
time_dim = 1
138+
else:
139+
seq_len = input_val.shape[0]
140+
time_dim = 0
141+
142+
num_directions = 2 if bidirectional else 1
143+
# Params per layer: (w_ih, w_hh[, b_ih, b_hh]) * num_directions
144+
dir_step = 4 if has_biases else 2
145+
layer_step = dir_step * num_directions
146+
147+
with graph.inserting_before(node):
148+
current_input = input_node
149+
layer_final_hiddens = []
150+
151+
for layer_idx in range(num_layers):
152+
layer_offset = layer_idx * layer_step
153+
154+
# Forward direction
155+
fw_off = layer_offset
156+
fw_w_ih = params[fw_off]
157+
fw_w_hh = params[fw_off + 1]
158+
fw_b_ih = params[fw_off + 2] if has_biases else None
159+
fw_b_hh = params[fw_off + 3] if has_biases else None
160+
161+
fw_h0 = create_node(
162+
graph,
163+
self._select,
164+
args=(hx_node, 0, num_directions * layer_idx),
165+
from_node=node,
166+
)
167+
168+
fw_outputs, fw_h_final = self._build_direction(
169+
graph,
170+
node,
171+
current_input,
172+
fw_h0,
173+
fw_w_ih,
174+
fw_w_hh,
175+
fw_b_ih,
176+
fw_b_hh,
177+
seq_len,
178+
time_dim,
179+
reverse=False,
180+
activation=activation,
181+
)
182+
183+
if bidirectional:
184+
bw_off = layer_offset + dir_step
185+
bw_w_ih = params[bw_off]
186+
bw_w_hh = params[bw_off + 1]
187+
bw_b_ih = params[bw_off + 2] if has_biases else None
188+
bw_b_hh = params[bw_off + 3] if has_biases else None
189+
190+
bw_h0 = create_node(
191+
graph,
192+
self._select,
193+
args=(hx_node, 0, 2 * layer_idx + 1),
194+
from_node=node,
195+
)
196+
197+
bw_outputs, bw_h_final = self._build_direction(
198+
graph,
199+
node,
200+
current_input,
201+
bw_h0,
202+
bw_w_ih,
203+
bw_w_hh,
204+
bw_b_ih,
205+
bw_b_hh,
206+
seq_len,
207+
time_dim,
208+
reverse=True,
209+
activation=activation,
210+
)
211+
212+
fw_combined = create_node(
213+
graph,
214+
self._cat,
215+
args=(fw_outputs, time_dim),
216+
from_node=node,
217+
)
218+
bw_combined = create_node(
219+
graph,
220+
self._cat,
221+
args=(bw_outputs, time_dim),
222+
from_node=node,
223+
)
224+
layer_output = create_node(
225+
graph,
226+
self._cat,
227+
args=([fw_combined, bw_combined], -1),
228+
from_node=node,
229+
)
230+
231+
layer_final_hiddens.append(
232+
create_node(
233+
graph,
234+
self._unsqueeze,
235+
args=(fw_h_final, 0),
236+
from_node=node,
237+
)
238+
)
239+
layer_final_hiddens.append(
240+
create_node(
241+
graph,
242+
self._unsqueeze,
243+
args=(bw_h_final, 0),
244+
from_node=node,
245+
)
246+
)
247+
else:
248+
layer_output = create_node(
249+
graph,
250+
self._cat,
251+
args=(fw_outputs, time_dim),
252+
from_node=node,
253+
)
254+
255+
layer_final_hiddens.append(
256+
create_node(
257+
graph,
258+
self._unsqueeze,
259+
args=(fw_h_final, 0),
260+
from_node=node,
261+
)
262+
)
263+
264+
current_input = layer_output
265+
266+
# Build h_n
267+
if len(layer_final_hiddens) == 1:
268+
h_n = layer_final_hiddens[0]
269+
else:
270+
h_n = create_node(
271+
graph,
272+
self._cat,
273+
args=(layer_final_hiddens, 0),
274+
from_node=node,
275+
)
276+
277+
output_node = current_input
278+
279+
# Replace getitem users: RNN returns (output, h_n)
280+
getitem_nodes = []
281+
for user in list(node.users.keys()):
282+
if user.target == operator.getitem:
283+
idx = user.args[1]
284+
if idx == 0:
285+
user.replace_all_uses_with(output_node)
286+
elif idx == 1:
287+
user.replace_all_uses_with(h_n)
288+
getitem_nodes.append(user)
289+
290+
for gi in getitem_nodes:
291+
graph.erase_node(gi)
292+
graph.erase_node(node)
293+
made_changes = True
294+
295+
if not made_changes:
296+
return PassResult(graph_module, False)
297+
298+
graph_module.recompile()
299+
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)