Skip to content

Commit c81126e

Browse files
authored
Arm backend: Refactor and bug-fix RewriteIndexPutPass (#18197)
The patch should hopefully make the pass easier to understand. Make explicit that we set N=1, handle explicit indexing by folding them in the K dimension, and handle full indexing (select all values) by folding them in the C dimension. Note that TOSA and torch has switched terminology regarding what the parameter 'values' means, instead, use a new naming: TOSA values_in == torch x/self tensor, call this 'destination'. TOSA input == torch values, call this 'data'. Additionally, the pass earlier didn't account for that 1) There are fully indexed dimensions 2) Index tensors can be broadcast 3) The data tensor can be smaller than (N, W, C), and require broadcasting first. 4) None index tensors were incorrectly handled. Regarding 1-3): Given destination of shape (N, K, C), TOSA.SCATTER semantics require the shape (N, W) of the index tensor, including possibly an implicit C dimension, to match the data shape (N, W_d, C_d). Torch can however broadcast both these inputs. We need to expand/reshape the data tensor correctly. Example (ignoring N, it's always 1): >>> destination = torch.ones(5, 2), K=5, C=2 >>> indices = (torch.tensor([0, 2]),) # Indexes K dim W=2 times, C is implicitly assumed to be C=2. >>> data = torch.tensor([10.0, 20.0]) # W_d = 1 !!, C_d=2 >>> torch.index_put(destination, indices, data) tensor([[10., 20.], [ 1., 1.], [10., 20.], [ 1., 1.], [ 1., 1.]]) Or even >>> [...] >>> data = torch.tensor([10.0]) # W_d = 1, C_d=1 !! >>> torch.index_put(destination, indices, data) tensor([[10., 10.], [ 1., 1.], [10., 10.], [ 1., 1.], [10., 10.]]) The patch generalizes this to multiple dimensions. Refer to docstring in patch for complete explaination. 4) Is handled by adding a normalization pass that rewrites None indice tensors to fully indexed tensors. Signed-off-by: Erik Lundell <erik.lundell@arm.com>
1 parent 22174fa commit c81126e

6 files changed

Lines changed: 374 additions & 167 deletions

File tree

backends/arm/_passes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@
123123
from .match_arg_dtype_pass import MatchArgDtypePass # noqa
124124
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
125125
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
126+
from .normalize_index_put_none_indices_pass import ( # noqa
127+
NormalizeIndexPutNoneIndicesPass,
128+
)
126129
from .normalize_while_initial_args_pass import NormalizeWhileInitialArgsPass # noqa
127130
from .promote_bool_operands_pass import PromoteBoolOperandsPass # noqa
128131
from .remove_getitem_pass import RemoveGetItemPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@
111111
InsertTableOpsPass,
112112
MatchArgDtypePass,
113113
MatchArgRanksPass,
114+
NormalizeIndexPutNoneIndicesPass,
114115
NormalizeWhileInitialArgsPass,
115116
PromoteBoolOperandsPass,
116117
QuantizeClampArgumentsPass,
@@ -444,6 +445,7 @@ def _tosa_pipeline(
444445
# Node transformation passes (post scalar-removal)
445446
self.add_passes(
446447
[
448+
NormalizeIndexPutNoneIndicesPass(),
447449
RewriteIndexPutPass(),
448450
RewriteBoolBitwiseToLogicalPass(),
449451
DecomposeRemainderPass(),
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import Set, Type
6+
7+
from executorch.backends.arm._passes import ArmPass
8+
from executorch.backends.arm._passes.rewrite_index_put_pass import RewriteIndexPutPass
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
from executorch.exir.pass_base import ExportPass
11+
12+
13+
class NormalizeIndexPutNoneIndicesPass(ArmPass):
14+
"""Normalize index_put with None:s in the indices_tensor list by moving
15+
None-indexed dims to the channel dimensions (*C_j in RewriteIndexPutPass
16+
teminology) by permutating the destination and data tensors. A None-index
17+
corresponds to selecting the entire dim, which is equivalent with being a
18+
channel dimension.
19+
20+
Example:
21+
out = index_put(destination, [None, idx1, None, idx2], data)
22+
becomes
23+
destination_permuted = permute(destination, destination_dim_order)
24+
data_front_padded = reshape(data, front_padded_data_shape)
25+
data_permuted = permute(data, data_dim_order)
26+
out_permuted = index_put(destination_permuted, [idx1, idx2], data_permuted)
27+
out = permute(out_permuted, inverse_destination_dim_order)
28+
29+
Where the permutations of destination and data are decided by how the indexes move.
30+
31+
Note that None tensors are handled differently in pytorch depending on how many indices tensors there are,
32+
causing the data tensor to require different shapes, which will require different data permutation.
33+
Many: all explicit dims are broadcast to a single dim and put in front of data tensor
34+
destination shape (5,3,4,3) with indices (None, [1,0], None, [0,2]) -> data shape (2, 5, 4)
35+
Note that this is the behaviour we want! No permutation of data is neccessary.
36+
One: The explicit dim is kept in place
37+
destination shape (5,3,4,3) with indices (None, [1,0], None, None) -> data shape (5, 2, 4, 3)
38+
dim 1 needs to be moved to the front: dim_order = (1,0,2,3).
39+
This is the same dim order as for the destination tensor.
40+
41+
"""
42+
43+
_passes_required_after: Set[Type[ExportPass]] = {RewriteIndexPutPass}
44+
45+
def __init__(self):
46+
super().__init__()
47+
self.permute_op = exir_ops.edge.aten.permute_copy.default
48+
self.reshape_op = exir_ops.edge.aten.view_copy.default
49+
50+
def _get_data_dim_order(
51+
self,
52+
explicit_dims: list[int],
53+
destination_dim_order: list[int],
54+
) -> list[int]:
55+
"""Return dim_order of data tensor."""
56+
57+
normalized_non_index_dims = destination_dim_order[len(explicit_dims) :]
58+
data_dim_order = list(range(len(normalized_non_index_dims)))
59+
60+
if not explicit_dims:
61+
raise RuntimeError("Expected at least one non-None index tensor.")
62+
elif len(explicit_dims) > 1:
63+
# For multiple explicit index tensors, data is already in the order we want.
64+
return data_dim_order
65+
else:
66+
# For single explicit index tensor, use same dim_order as destination
67+
return destination_dim_order
68+
69+
def call_operator(self, op, args, kwargs, meta, updated: bool | None = False):
70+
if op not in (exir_ops.edge.aten.index_put.default,):
71+
return super().call_operator(op, args, kwargs, meta)
72+
73+
destination, indices_tensor_list, data = args[:3]
74+
indices_tensor_list = list(indices_tensor_list)
75+
if not any(indices_tensor is None for indices_tensor in indices_tensor_list):
76+
return super().call_operator(op, args, kwargs, meta)
77+
78+
destination_shape = destination.data.shape
79+
explicit_dims = [
80+
dim_idx
81+
for dim_idx, index_tensor in enumerate(indices_tensor_list)
82+
if index_tensor is not None
83+
]
84+
85+
none_dims = [
86+
dim_idx
87+
for dim_idx, index_tensor in enumerate(indices_tensor_list)
88+
if index_tensor is None
89+
]
90+
trailing_dims = list(range(len(indices_tensor_list), len(destination_shape)))
91+
92+
# Handle None indexing of destination tensor.
93+
destination_dim_order = explicit_dims + none_dims + trailing_dims
94+
needs_destination_permute = destination_dim_order != list(
95+
range(len(destination_shape))
96+
)
97+
if needs_destination_permute:
98+
destination = super().call_operator(
99+
self.permute_op,
100+
(destination, destination_dim_order),
101+
{},
102+
meta,
103+
updated=True,
104+
)
105+
106+
# Handle None indexing of data tensor.
107+
data_dim_order = self._get_data_dim_order(
108+
explicit_dims=explicit_dims,
109+
destination_dim_order=destination_dim_order,
110+
)
111+
needs_data_permute = data_dim_order != list(range(len(data_dim_order)))
112+
113+
if needs_data_permute:
114+
data_shape = list(data.data.shape)
115+
aligned_rank = len(data_dim_order)
116+
if len(data_shape) < aligned_rank:
117+
# We add dims to data when we move none dims, front pad data with unit dims to match.
118+
padded_shape = [1] * (aligned_rank - len(data_shape)) + data_shape
119+
data = super().call_operator(
120+
self.reshape_op, (data, padded_shape), {}, meta, updated=True
121+
)
122+
data = super().call_operator(
123+
self.permute_op, (data, data_dim_order), {}, meta, updated=True
124+
)
125+
126+
# Call index_put op.
127+
explicit_indices_tensors = [
128+
indices_tensor_list[dim_idx] for dim_idx in explicit_dims
129+
]
130+
normalized_args = (destination, explicit_indices_tensors, data, *args[3:])
131+
out = super().call_operator(op, normalized_args, kwargs, meta, updated=True)
132+
133+
if not needs_destination_permute:
134+
return out
135+
136+
# If needed, reverse permutation of destination tensor.
137+
inv_dim_order = [0] * len(destination_dim_order)
138+
for new_dim, original_dim in enumerate(destination_dim_order):
139+
inv_dim_order[original_dim] = new_dim
140+
141+
return super().call_operator(
142+
self.permute_op, (out, inv_dim_order), {}, meta, updated=True
143+
)

0 commit comments

Comments
 (0)