Skip to content

Commit 4227c90

Browse files
Arm backend: Support dynamic select (pytorch#19973)
Make sure negative indices are handled correctly when dimensions are symbolic. If index is negative and dimension is symbolic, express adjusted index as symbolic_dim - index rather than index % symbolic_dim. cc @digantdesai @freddan80 @per @zingo @mansnils @Sebastian-Larsson @robell @rascani Signed-off-by: Oscar Andersson <oscar.andersson@arm.com>
1 parent 7e4253a commit 4227c90

2 files changed

Lines changed: 76 additions & 1 deletion

File tree

backends/arm/_passes/decompose_select.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ def call(self, graph_module: torch.fx.GraphModule):
4848
rank = len(input_tensor.size())
4949
shape = input_tensor.shape
5050
dim = dim % rank if dim < 0 else dim
51-
index = index % shape[dim] if index < 0 else index
51+
if index < 0:
52+
size_at_dim = shape[dim]
53+
index = size_at_dim - abs(index)
5254

5355
with graph_module.graph.inserting_before(node):
5456
slice_node = create_node(
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import sympy # type: ignore
7+
import torch
8+
from executorch.backends.arm._passes import DecomposeSelectPass
9+
from executorch.backends.test.program_builder import ProgramBuilder
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
from torch._subclasses.fake_tensor import FakeTensorMode
12+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
13+
14+
15+
def _make_symint(
16+
shape_env: ShapeEnv, symbol: str, hint: int, min: int = 1, max: int = 64
17+
) -> torch.SymInt:
18+
symint = shape_env.create_symintnode(sympy.Symbol(symbol), hint=hint)
19+
assert isinstance(symint, torch.SymInt)
20+
shape_env.constrain_symbol_range(
21+
symint.node.expr, compiler_min=min, compiler_max=max
22+
)
23+
return symint
24+
25+
26+
def test_decompose_select_negative_symbolic_index_uses_symbolic_sub() -> None:
27+
shape_env = ShapeEnv()
28+
seq = _make_symint(shape_env, "seq", hint=4)
29+
30+
with FakeTensorMode(shape_env=shape_env) as mode:
31+
builder = ProgramBuilder(fake_tensor_mode=mode)
32+
x = builder.placeholder("x", mode.from_tensor(torch.empty(size=(1, seq, 576))))
33+
h = builder.call_operator(exir_ops.edge.aten.add.Tensor, (x, x))
34+
select = builder.call_operator(exir_ops.edge.aten.select_copy.int, (h, 1, -1))
35+
builder.output([select])
36+
37+
result = DecomposeSelectPass()(builder.get_program().graph_module)
38+
39+
assert result is not None
40+
41+
select_nodes = [
42+
node
43+
for node in result.graph_module.graph.nodes
44+
if node.op == "call_function"
45+
and node.target == exir_ops.edge.aten.select_copy.int
46+
]
47+
slice_nodes = [
48+
node
49+
for node in result.graph_module.graph.nodes
50+
if node.op == "call_function"
51+
and node.target == exir_ops.edge.aten.slice_copy.Tensor
52+
]
53+
squeeze_nodes = [
54+
node
55+
for node in result.graph_module.graph.nodes
56+
if node.op == "call_function"
57+
and node.target == exir_ops.edge.aten.squeeze_copy.dims
58+
]
59+
60+
assert not select_nodes
61+
assert len(slice_nodes) == 1
62+
assert len(squeeze_nodes) == 1
63+
64+
slice_node = slice_nodes[0]
65+
assert slice_node.args[1] == 1
66+
assert slice_node.args[2] != -1
67+
assert isinstance(slice_node.args[2], torch.SymInt)
68+
assert isinstance(slice_node.args[3], torch.SymInt)
69+
assert str(slice_node.args[2]).endswith(" - 1")
70+
assert str(slice_node.args[3]) in str(slice_node.args[2])
71+
assert squeeze_nodes[0].args == (slice_node, [1])
72+
73+
result.graph_module.graph.lint()

0 commit comments

Comments
 (0)