Skip to content

Commit acb2f4b

Browse files
Andrew Pullinmeta-codesync[bot]
authored andcommitted
Add DecomposeGruPass for ARM backend
Summary: Adds a decomposition pass that transforms aten.gru.input into elementary ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat). GRU cell equations per timestep: r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr) z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz) n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn)) h_t = n_t + z_t * (h_{t-1} - n_t) Features: - Multi-layer GRU support - Bidirectional GRU support - With/without bias - batch_first support - Batched gate computation (2 mm ops per timestep instead of 6) Differential Revision: D92058313
1 parent 2a68e74 commit acb2f4b

File tree

4 files changed

+491
-0
lines changed

4 files changed

+491
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from .decompose_glu_pass import DecomposeGluPass # noqa
5353
from .decompose_grouped_conv_pass import DecomposeGroupedConvPass # noqa
5454
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
55+
from .decompose_gru_pass import DecomposeGruPass # noqa
5556
from .decompose_index_copy_pass import DecomposeIndexCopyPass # noqa
5657
from .decompose_index_select_to_gather_pass import ( # noqa
5758
DecomposeIndexSelectToGatherPass,

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
DecomposeGluPass,
6161
DecomposeGroupedConvPass,
6262
DecomposeGroupNormPass,
63+
DecomposeGruPass,
6364
DecomposeIndexCopyPass,
6465
DecomposeIndexSelectToGatherPass,
6566
DecomposeIndexTensorToGatherPass,
@@ -362,6 +363,7 @@ def _tosa_pipeline(
362363
ConvertToClampPass(),
363364
DecomposeTOSAUnsupportedClampPass(),
364365
DecomposeGroupNormPass(),
366+
DecomposeGruPass(),
365367
DecomposeLayerNormPass(),
366368
DecomposeVarPass(),
367369
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec),
@@ -579,6 +581,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
579581
self.add_passes(
580582
[
581583
NormalizeWhileInitialArgsPass(use_exir_clone=False, tfa_pass=True),
584+
DecomposeGruPass(tfa_pass=True),
582585
DecomposeNotEqualPass(tfa_pass=True),
583586
DecomposeCosineSimilarityPass(tfa_pass=True),
584587
DecomposeGluPass(tfa_pass=True),
Lines changed: 345 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,345 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import operator
7+
from typing import List, Set, Tuple, Type
8+
9+
import torch
10+
from executorch.backends.arm._passes.arm_pass import ArmPass
11+
from executorch.backends.arm._passes.arm_pass_utils import create_node
12+
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
15+
16+
class DecomposeGruPass(ArmPass):
17+
"""Decomposes aten.gru.input into elementary ops supported by TOSA.
18+
19+
GRU cell equations per timestep:
20+
r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
21+
z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
22+
n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
23+
h_t = n_t + z_t * (h_{t-1} - n_t)
24+
25+
The weights are batched: one mm computes all three gates at once, then the
26+
result is sliced into r/z/n components. This yields 2 mm ops per timestep
27+
instead of 6.
28+
29+
Supports multi-layer, bidirectional, with/without bias, and batch_first.
30+
31+
"""
32+
33+
_passes_required_after: Set[Type[ExportPass]] = {InsertTableOpsPass}
34+
35+
_TARGET = torch.ops.aten.gru.input
36+
37+
# Ops — always aten since GRU has no edge dialect variant
38+
_mm = torch.ops.aten.mm.default
39+
_t = torch.ops.aten.t.default
40+
_add = torch.ops.aten.add.Tensor
41+
_sub = torch.ops.aten.sub.Tensor
42+
_mul = torch.ops.aten.mul.Tensor
43+
_sigmoid = torch.ops.aten.sigmoid.default
44+
_tanh = torch.ops.aten.tanh.default
45+
_slice = torch.ops.aten.slice_copy.Tensor
46+
_unsqueeze = torch.ops.aten.unsqueeze.default
47+
_cat = torch.ops.aten.cat.default
48+
_select = torch.ops.aten.select_copy.int
49+
50+
def _build_direction(
51+
self,
52+
graph: torch.fx.Graph,
53+
node: torch.fx.Node,
54+
current_input: torch.fx.Node,
55+
h_prev: torch.fx.Node,
56+
weight_ih: torch.fx.Node,
57+
weight_hh: torch.fx.Node,
58+
bias_ih,
59+
bias_hh,
60+
hidden_size: int,
61+
seq_len: int,
62+
time_dim: int,
63+
reverse: bool,
64+
) -> Tuple[List[torch.fx.Node], torch.fx.Node]:
65+
"""Build GRU cell computation for one direction.
66+
67+
Returns (timestep_outputs, h_final) where timestep_outputs are
68+
unsqueezed hidden states in forward time order.
69+
70+
"""
71+
w_ih_t = create_node(graph, self._t, args=(weight_ih,), from_node=node)
72+
w_hh_t = create_node(graph, self._t, args=(weight_hh,), from_node=node)
73+
74+
time_indices = range(seq_len - 1, -1, -1) if reverse else range(seq_len)
75+
timestep_outputs = []
76+
77+
for t_idx in time_indices:
78+
x_t = create_node(
79+
graph,
80+
self._select,
81+
args=(current_input, time_dim, t_idx),
82+
from_node=node,
83+
)
84+
85+
gates_x = create_node(graph, self._mm, args=(x_t, w_ih_t), from_node=node)
86+
gates_h = create_node(
87+
graph, self._mm, args=(h_prev, w_hh_t), from_node=node
88+
)
89+
90+
if bias_ih is not None:
91+
gates_x = create_node(
92+
graph, self._add, args=(gates_x, bias_ih), from_node=node
93+
)
94+
if bias_hh is not None:
95+
gates_h = create_node(
96+
graph, self._add, args=(gates_h, bias_hh), from_node=node
97+
)
98+
99+
H = hidden_size
100+
r_x = create_node(
101+
graph, self._slice, args=(gates_x, 1, 0, H), from_node=node
102+
)
103+
z_x = create_node(
104+
graph, self._slice, args=(gates_x, 1, H, 2 * H), from_node=node
105+
)
106+
n_x = create_node(
107+
graph,
108+
self._slice,
109+
args=(gates_x, 1, 2 * H, 3 * H),
110+
from_node=node,
111+
)
112+
r_h = create_node(
113+
graph, self._slice, args=(gates_h, 1, 0, H), from_node=node
114+
)
115+
z_h = create_node(
116+
graph, self._slice, args=(gates_h, 1, H, 2 * H), from_node=node
117+
)
118+
n_h = create_node(
119+
graph,
120+
self._slice,
121+
args=(gates_h, 1, 2 * H, 3 * H),
122+
from_node=node,
123+
)
124+
125+
r_pre = create_node(graph, self._add, args=(r_x, r_h), from_node=node)
126+
r_t = create_node(graph, self._sigmoid, args=(r_pre,), from_node=node)
127+
128+
z_pre = create_node(graph, self._add, args=(z_x, z_h), from_node=node)
129+
z_t = create_node(graph, self._sigmoid, args=(z_pre,), from_node=node)
130+
131+
r_n_h = create_node(graph, self._mul, args=(r_t, n_h), from_node=node)
132+
n_pre = create_node(graph, self._add, args=(n_x, r_n_h), from_node=node)
133+
n_t = create_node(graph, self._tanh, args=(n_pre,), from_node=node)
134+
135+
diff = create_node(graph, self._sub, args=(h_prev, n_t), from_node=node)
136+
z_diff = create_node(graph, self._mul, args=(z_t, diff), from_node=node)
137+
h_t = create_node(graph, self._add, args=(n_t, z_diff), from_node=node)
138+
h_prev = h_t
139+
140+
h_t_expanded = create_node(
141+
graph, self._unsqueeze, args=(h_t, time_dim), from_node=node
142+
)
143+
timestep_outputs.append(h_t_expanded)
144+
145+
# Backward outputs were appended in reverse time order; flip to
146+
# forward order so they align with the forward direction for concat.
147+
if reverse:
148+
timestep_outputs.reverse()
149+
150+
return timestep_outputs, h_prev
151+
152+
def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
153+
graph = graph_module.graph
154+
made_changes = False
155+
156+
for node in list(graph.nodes):
157+
if (
158+
node.op != "call_function"
159+
or node.target != self._TARGET
160+
or not self.allowed_to_transform(node.meta)
161+
):
162+
continue
163+
164+
args = node.args
165+
input_node = args[0]
166+
hx_node = args[1]
167+
params = args[2]
168+
has_biases = args[3]
169+
num_layers = args[4]
170+
# dropout (args[5]) and train (args[6]) are unused at inference
171+
bidirectional = args[7]
172+
batch_first = args[8]
173+
174+
input_val = input_node.meta["val"]
175+
hx_val = hx_node.meta["val"]
176+
177+
if batch_first:
178+
seq_len = input_val.shape[1]
179+
time_dim = 1
180+
else:
181+
seq_len = input_val.shape[0]
182+
time_dim = 0
183+
184+
hidden_size = hx_val.shape[-1]
185+
num_directions = 2 if bidirectional else 1
186+
# Params per layer: (w_ih, w_hh[, b_ih, b_hh]) * num_directions
187+
dir_step = 4 if has_biases else 2
188+
layer_step = dir_step * num_directions
189+
190+
with graph.inserting_before(node):
191+
current_input = input_node
192+
layer_final_hiddens = []
193+
194+
for layer_idx in range(num_layers):
195+
layer_offset = layer_idx * layer_step
196+
197+
# Forward direction
198+
fw_off = layer_offset
199+
fw_w_ih = params[fw_off]
200+
fw_w_hh = params[fw_off + 1]
201+
fw_b_ih = params[fw_off + 2] if has_biases else None
202+
fw_b_hh = params[fw_off + 3] if has_biases else None
203+
204+
fw_h0 = create_node(
205+
graph,
206+
self._select,
207+
args=(hx_node, 0, num_directions * layer_idx),
208+
from_node=node,
209+
)
210+
211+
fw_outputs, fw_h_final = self._build_direction(
212+
graph,
213+
node,
214+
current_input,
215+
fw_h0,
216+
fw_w_ih,
217+
fw_w_hh,
218+
fw_b_ih,
219+
fw_b_hh,
220+
hidden_size,
221+
seq_len,
222+
time_dim,
223+
reverse=False,
224+
)
225+
226+
if bidirectional:
227+
bw_off = layer_offset + dir_step
228+
bw_w_ih = params[bw_off]
229+
bw_w_hh = params[bw_off + 1]
230+
bw_b_ih = params[bw_off + 2] if has_biases else None
231+
bw_b_hh = params[bw_off + 3] if has_biases else None
232+
233+
bw_h0 = create_node(
234+
graph,
235+
self._select,
236+
args=(hx_node, 0, 2 * layer_idx + 1),
237+
from_node=node,
238+
)
239+
240+
bw_outputs, bw_h_final = self._build_direction(
241+
graph,
242+
node,
243+
current_input,
244+
bw_h0,
245+
bw_w_ih,
246+
bw_w_hh,
247+
bw_b_ih,
248+
bw_b_hh,
249+
hidden_size,
250+
seq_len,
251+
time_dim,
252+
reverse=True,
253+
)
254+
255+
# Concatenate fw + bw at each timestep along feature dim
256+
merged = []
257+
for fw_out, bw_out in zip(fw_outputs, bw_outputs):
258+
merged.append(
259+
create_node(
260+
graph,
261+
self._cat,
262+
args=([fw_out, bw_out], -1),
263+
from_node=node,
264+
)
265+
)
266+
267+
layer_output = create_node(
268+
graph,
269+
self._cat,
270+
args=(merged, time_dim),
271+
from_node=node,
272+
)
273+
274+
layer_final_hiddens.append(
275+
create_node(
276+
graph,
277+
self._unsqueeze,
278+
args=(fw_h_final, 0),
279+
from_node=node,
280+
)
281+
)
282+
layer_final_hiddens.append(
283+
create_node(
284+
graph,
285+
self._unsqueeze,
286+
args=(bw_h_final, 0),
287+
from_node=node,
288+
)
289+
)
290+
else:
291+
layer_output = create_node(
292+
graph,
293+
self._cat,
294+
args=(fw_outputs, time_dim),
295+
from_node=node,
296+
)
297+
298+
layer_final_hiddens.append(
299+
create_node(
300+
graph,
301+
self._unsqueeze,
302+
args=(fw_h_final, 0),
303+
from_node=node,
304+
)
305+
)
306+
307+
current_input = layer_output
308+
309+
# Build h_n
310+
if len(layer_final_hiddens) == 1:
311+
h_n = layer_final_hiddens[0]
312+
else:
313+
h_n = create_node(
314+
graph,
315+
self._cat,
316+
args=(layer_final_hiddens, 0),
317+
from_node=node,
318+
)
319+
320+
output_node = current_input
321+
322+
# Replace getitem users: GRU returns (output, h_n)
323+
getitem_nodes = []
324+
for user in list(node.users.keys()):
325+
if user.target == operator.getitem:
326+
idx = user.args[1]
327+
if idx == 0:
328+
user.replace_all_uses_with(output_node)
329+
elif idx == 1:
330+
user.replace_all_uses_with(h_n)
331+
getitem_nodes.append(user)
332+
333+
# Erase getitem nodes then the GRU node explicitly;
334+
# eliminate_dead_code does not remove GRU because aten
335+
# considers it impure (may have dropout side-effects).
336+
for gi in getitem_nodes:
337+
graph.erase_node(gi)
338+
graph.erase_node(node)
339+
made_changes = True
340+
341+
if not made_changes:
342+
return PassResult(graph_module, False)
343+
344+
graph_module.recompile()
345+
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)