Skip to content

Commit 62d8f21

Browse files
authored
Arm backend: Fix meta propagation in some call passes (#19154)
When a call pass creates a node, it won't get a meta["val"] field until the graph is re-traced, unless explicitly set. This is not usually a problem, but if a pass needs meta["val"] from an argument, and an op is chained multiple times, the pass will crash if the meta is not there/incorrectly set. Fix a couple of passes where the field was missing, or using the incorrect value.
1 parent 572f023 commit 62d8f21

10 files changed

Lines changed: 216 additions & 25 deletions

backends/arm/_passes/arm_pass_utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
# LICENSE file in the root directory of this source tree.
77

88

9+
import operator
910
import traceback
1011
from inspect import isclass
11-
from typing import List, Optional, Sequence, Tuple
12+
from typing import cast, List, Optional, Sequence, Tuple
1213

1314
import torch
1415
import torch.fx
@@ -59,6 +60,17 @@ def is_get_attr_node(node: torch.fx.Node) -> bool:
5960
)
6061

6162

63+
def get_getitem_users(
64+
source_node: torch.fx.Node, max_users: int
65+
) -> dict[int, torch.fx.Node | None]:
66+
getitem_users: dict[int, torch.fx.Node | None] = {i: None for i in range(max_users)}
67+
for user in source_node.users:
68+
if user.target == operator.getitem:
69+
getitem_users[cast(int, user.args[1])] = user
70+
71+
return getitem_users
72+
73+
6274
def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool:
6375
return (
6476
is_get_attr_node(node)

backends/arm/_passes/decompose_any_pass.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77

88
import torch
99
from executorch.backends.arm._passes import ArmPass
10-
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
10+
from executorch.backends.arm._passes.arm_pass_utils import (
11+
create_node,
12+
get_first_fake_tensor,
13+
)
1114
from executorch.exir.dialects._ops import ( # type: ignore[import-not-found]
1215
ops as exir_ops,
1316
)
@@ -91,16 +94,21 @@ def call(self, graph_module: torch.fx.GraphModule):
9194
with graph_module.graph.inserting_before(node):
9295
for dim in dims_to_reduce:
9396
args = (input_node, dim, True)
94-
input_node = graph_module.graph.create_node(
95-
"call_function", exir_ops.edge.aten.any.dim, args, node.kwargs
97+
input_node = create_node(
98+
graph_module.graph,
99+
exir_ops.edge.aten.any.dim,
100+
args=args,
101+
kwargs=node.kwargs,
102+
from_node=node,
96103
)
97104

98105
if not keepdim:
99106
output_shape = list(get_first_fake_tensor(node).shape)
100-
input_node = graph_module.graph.create_node(
101-
"call_function",
107+
input_node = create_node(
108+
graph_module.graph,
102109
exir_ops.edge.aten.view_copy.default,
103-
(input_node, output_shape),
110+
args=(input_node, output_shape),
111+
from_node=node,
104112
)
105113

106114
node.replace_all_uses_with(input_node)

backends/arm/_passes/decompose_gru_pass.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99

1010
import torch
1111
from executorch.backends.arm._passes.arm_pass import ArmPass
12-
from executorch.backends.arm._passes.arm_pass_utils import create_node
12+
from executorch.backends.arm._passes.arm_pass_utils import (
13+
create_node,
14+
get_getitem_users,
15+
)
1316
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
1417
from executorch.exir.pass_base import ExportPass, PassResult
1518

@@ -149,6 +152,7 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
149152
or not self.allowed_to_transform(node.meta)
150153
):
151154
continue
155+
getitem_users = get_getitem_users(node, 2)
152156

153157
args = node.args
154158
input_node = args[0]
@@ -257,7 +261,7 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
257261
graph,
258262
self._cat,
259263
args=(merged, time_dim),
260-
from_node=node,
264+
from_node=getitem_users.get(0),
261265
)
262266

263267
layer_final_hiddens.append(
@@ -281,15 +285,15 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
281285
graph,
282286
self._cat,
283287
args=(fw_outputs, time_dim),
284-
from_node=node,
288+
from_node=getitem_users.get(0),
285289
)
286290

287291
layer_final_hiddens.append(
288292
create_node(
289293
graph,
290294
self._unsqueeze,
291295
args=(fw_h_final, 0),
292-
from_node=node,
296+
from_node=getitem_users.get(1),
293297
)
294298
)
295299

@@ -303,7 +307,7 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
303307
graph,
304308
self._cat,
305309
args=(layer_final_hiddens, 0),
306-
from_node=node,
310+
from_node=getitem_users.get(1),
307311
)
308312

309313
output_node = current_input

backends/arm/_passes/decompose_lstm_pass.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99

1010
import torch
1111
from executorch.backends.arm._passes.arm_pass import ArmPass
12-
from executorch.backends.arm._passes.arm_pass_utils import create_node
12+
from executorch.backends.arm._passes.arm_pass_utils import (
13+
create_node,
14+
get_getitem_users,
15+
)
1316
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
1417
from executorch.exir.pass_base import ExportPass, PassResult
1518

@@ -142,6 +145,7 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
142145
or not self.allowed_to_transform(node.meta)
143146
):
144147
continue
148+
getitem_users = get_getitem_users(node, 3)
145149

146150
args = node.args
147151
input_node = args[0]
@@ -266,7 +270,7 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
266270
graph,
267271
self._cat,
268272
args=(merged, time_dim),
269-
from_node=node,
273+
from_node=(getitem_users.get(0)),
270274
)
271275

272276
layer_final_hiddens.append(
@@ -306,23 +310,23 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
306310
graph,
307311
self._cat,
308312
args=(fw_outputs, time_dim),
309-
from_node=node,
313+
from_node=(getitem_users.get(0)),
310314
)
311315

312316
layer_final_hiddens.append(
313317
create_node(
314318
graph,
315319
self._unsqueeze,
316320
args=(fw_h_final, 0),
317-
from_node=node,
321+
from_node=(getitem_users.get(1)),
318322
)
319323
)
320324
layer_final_cells.append(
321325
create_node(
322326
graph,
323327
self._unsqueeze,
324328
args=(fw_c_final, 0),
325-
from_node=node,
329+
from_node=(getitem_users.get(2)),
326330
)
327331
)
328332

@@ -336,7 +340,7 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
336340
graph,
337341
self._cat,
338342
args=(layer_final_hiddens, 0),
339-
from_node=node,
343+
from_node=getitem_users.get(1),
340344
)
341345
if len(layer_final_cells) == 1:
342346
c_n = layer_final_cells[0]
@@ -345,7 +349,7 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
345349
graph,
346350
self._cat,
347351
args=(layer_final_cells, 0),
348-
from_node=node,
352+
from_node=getitem_users.get(2),
349353
)
350354

351355
output_node = current_input

backends/arm/_passes/decompose_rnn_pass.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99

1010
import torch
1111
from executorch.backends.arm._passes.arm_pass import ArmPass
12-
from executorch.backends.arm._passes.arm_pass_utils import create_node
12+
from executorch.backends.arm._passes.arm_pass_utils import (
13+
create_node,
14+
get_getitem_users,
15+
)
1316
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
1417
from executorch.exir.pass_base import ExportPass, PassResult
1518

@@ -117,6 +120,7 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
117120

118121
is_relu = node.target == torch.ops.aten.rnn_relu.input
119122
activation = self._relu if is_relu else self._tanh
123+
getitem_users = get_getitem_users(node, 2)
120124

121125
args = node.args
122126
input_node = args[0]
@@ -223,7 +227,7 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
223227
graph,
224228
self._cat,
225229
args=([fw_combined, bw_combined], -1),
226-
from_node=node,
230+
from_node=(getitem_users.get(0)),
227231
)
228232

229233
layer_final_hiddens.append(
@@ -247,15 +251,15 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
247251
graph,
248252
self._cat,
249253
args=(fw_outputs, time_dim),
250-
from_node=node,
254+
from_node=(getitem_users.get(0)),
251255
)
252256

253257
layer_final_hiddens.append(
254258
create_node(
255259
graph,
256260
self._unsqueeze,
257261
args=(fw_h_final, 0),
258-
from_node=node,
262+
from_node=(getitem_users.get(1)),
259263
)
260264
)
261265

@@ -269,7 +273,7 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
269273
graph,
270274
self._cat,
271275
args=(layer_final_hiddens, 0),
272-
from_node=node,
276+
from_node=getitem_users.get(1),
273277
)
274278

275279
output_node = current_input

backends/arm/test/ops/test_any.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -41,6 +41,14 @@ def forward(self, x: torch.Tensor):
4141
return torch.any(x)
4242

4343

44+
class ChainedAny(torch.nn.Module):
45+
aten_op = "torch.ops.aten.any.default"
46+
exir_op = "executorch_exir_dialects_edge__ops_aten_any_default"
47+
48+
def forward(self, x: torch.Tensor):
49+
return torch.any(torch.any(x, dim=1, keepdim=True))
50+
51+
4452
input_t1 = Tuple[torch.Tensor] # Input x
4553

4654

@@ -152,6 +160,19 @@ def test_any_tosa_INT(test_data: input_t1):
152160
pipeline.run()
153161

154162

163+
def test_any_tosa_FP_chained_reduce_all() -> None:
164+
pipeline = TosaPipelineFP[Tuple[torch.Tensor]](
165+
ChainedAny(),
166+
(torch.randint(0, 2, (2, 3, 4), dtype=torch.bool),),
167+
ChainedAny.aten_op,
168+
ChainedAny.exir_op,
169+
atol=0,
170+
rtol=0,
171+
qtol=0,
172+
)
173+
pipeline.run()
174+
175+
155176
@common.parametrize("test_data", test_data)
156177
def test_any_u55_INT(test_data: input_t1):
157178
# Tests that we don't delegate these ops since they are not supported on U55.

backends/arm/test/ops/test_upsample_nearest2d.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,16 @@ def forward(self, x):
135135
return self.upsample(x)
136136

137137

138+
class ChainedBilinearInterpolate(torch.nn.Module):
139+
def forward(self, x):
140+
x = torch.nn.functional.interpolate(
141+
x, scale_factor=2.0, mode="bilinear", align_corners=False
142+
)
143+
return torch.nn.functional.interpolate(
144+
x, scale_factor=2.0, mode="bilinear", align_corners=False
145+
)
146+
147+
138148
@common.parametrize(
139149
"test_data", test_data_suite | test_data_suite_bf16 | test_data_suite_fp16
140150
)
@@ -233,6 +243,16 @@ def test_upsample_nearest2d_vec_tosa_INT_interpolate(test_data: torch.Tensor):
233243
pipeline.run()
234244

235245

246+
def test_upsample_nearest2d_vec_tosa_INT_chained_bilinear_interpolate():
247+
pipeline = TosaPipelineINT[input_t1](
248+
ChainedBilinearInterpolate(),
249+
(torch.rand(1, 3, 4, 4),),
250+
"torch.ops.aten.upsample_bilinear2d.vec",
251+
exir_op=[],
252+
)
253+
pipeline.run()
254+
255+
236256
@common.parametrize("test_data", test_data_suite)
237257
def test_upsample_nearest2d_vec_tosa_INT_a16w8(test_data: torch.Tensor):
238258
"""Test upsample_nearest2d vector op with int16 I/O quantization for TOSA

backends/arm/test/passes/test_decompose_gru_pass.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
from typing import Tuple
99

1010
import torch
11+
from executorch.backends.arm._passes import DecomposeGruPass
1112
from executorch.backends.arm.test.tester.test_pipeline import (
13+
PassPipeline,
1214
TosaPipelineFP,
1315
TosaPipelineINT,
1416
)
@@ -62,6 +64,30 @@ def get_inputs(self) -> input_t:
6264
return (x, h)
6365

6466

67+
chain_input_t = Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
68+
69+
70+
class SequentialGRU(torch.nn.Module):
71+
def __init__(self) -> None:
72+
super().__init__()
73+
self.gru1 = torch.nn.GRU(10, 12, batch_first=True)
74+
self.gru2 = torch.nn.GRU(12, 8, batch_first=True)
75+
76+
def forward(
77+
self, x: torch.Tensor, h1: torch.Tensor, h2: torch.Tensor
78+
) -> Tuple[torch.Tensor, torch.Tensor]:
79+
y, _ = self.gru1(x, h1)
80+
z, h_n = self.gru2(y, h2)
81+
return z, h_n
82+
83+
def get_inputs(self) -> chain_input_t:
84+
return (
85+
torch.randn(2, 5, 10),
86+
torch.randn(1, 2, 12),
87+
torch.randn(1, 2, 8),
88+
)
89+
90+
6591
def _make_gru_fp_pipeline(module: GRU) -> TosaPipelineFP:
6692
pipeline = TosaPipelineFP[input_t](
6793
module,
@@ -139,3 +165,14 @@ def test_decompose_gru_tosa_FP_multilayer():
139165
def test_decompose_gru_tosa_INT_multilayer():
140166
"""Test multi-layer GRU through quantized pipeline."""
141167
_make_gru_int_pipeline(GRU(num_layers=2)).run()
168+
169+
170+
def test_decompose_gru_pass_handles_chained_grus() -> None:
171+
module = SequentialGRU()
172+
pipeline = PassPipeline(
173+
module,
174+
module.get_inputs(),
175+
quantize=True,
176+
pass_list=[DecomposeGruPass],
177+
)
178+
pipeline.run()

0 commit comments

Comments
 (0)