Skip to content

Commit 4e430fc

Browse files
Arm backend: Support high-rank index.Tensor indices
- Drop the rank>=4 rejection from index.Tensor TOSA support checks - Add rank-4 and rank-5 index tensor test cases - Note: rank>=4 support has been covered by the index.Tensor refactor, removing the need for special handling in ToTosaMemoryFormatPass Change-Id: Ief40942a94040c02e54c7f276eecd660d571e46d Signed-off-by: Yufeng Shi <yufeng.shi@arm.com>
1 parent 1bf23bf commit 4e430fc

2 files changed

Lines changed: 28 additions & 23 deletions

File tree

backends/arm/operator_support/index_tensor_support.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55
"""Provide TOSA support checks for ``aten.index.Tensor``.
66
7-
Reject unsupported patterns such as high-rank index tensors, front-positioned
8-
slice/ellipsis/None markers, and cases that exceed ``int32`` element limits.
7+
Reject unsupported patterns such as front-positioned slice/ellipsis/None
8+
markers and cases that exceed ``int32`` element limits.
99
1010
"""
1111

@@ -30,18 +30,12 @@ class IndexTensorSupported(SupportedTOSAOperatorCheck):
3030
This support check is intended to prevent the partitioning of
3131
currently unsupported usages of the index.Tensor operator.
3232
33-
1. Usages where indexing tensors are of rank 4 or higher.
34-
This is due to the AnnotateChannelsLastDimOrder pass and
35-
the rarity of such operation.
36-
Support is possible but would require further changes to the above
37-
pass which can be added at such a time as is necessary.
38-
39-
2. Usages where slice, ellipsis or None are present before an indexing tensor:
33+
1. Usages where slice, ellipsis or None are present before an indexing tensor:
4034
t[{start}:{end}, indexTensor] - slicing
4135
t[None, indexTensor] - unsqueeze
4236
t[..., indexTensor] - ellipsis
4337
44-
3. Usages where the value tensor contains more than int32.max elements
38+
2. Usages where the value tensor contains more than int32.max elements
4539
This is due to int32 TOSA limitation and the fact that we flatten out
4640
and accumulate all index tensors.
4741
As such to avoid overflow we reject lowering of this operator if it is
@@ -115,13 +109,12 @@ def is_node_tosa_supported(
115109
116110
Enforces the following constraints:
117111
- No ``None`` (unsqueeze), slice, or ellipsis before an indexing tensor.
118-
- Indexing tensors have rank <= 3.
119112
- The value tensor element count fits in ``int32``.
120113
121114
"""
122115
indices = node.args[1]
123116
for index in indices: # type: ignore[union-attr]
124-
# Usage 2 guard
117+
# Usage 1 guard
125118
if index is None:
126119
self.reporter.report_reject(
127120
node,
@@ -132,17 +125,7 @@ def is_node_tosa_supported(
132125
)
133126
return False
134127

135-
# Usage 1 guard
136-
index = ensure_type(torch.fx.Node, index)
137-
fake_tensor = get_first_fake_tensor(index)
138-
if len(fake_tensor.size()) > 3:
139-
self.reporter.report_reject(
140-
node,
141-
("Indexing tensors of rank >= 4 is not supported."),
142-
)
143-
return False
144-
145-
# Usage 3 guard
128+
# Usage 2 guard
146129
input_node = ensure_type(torch.fx.Node, node.args[0])
147130
total_vals = math.prod(get_first_fake_tensor(input_node).shape)
148131
if total_vals > torch.iinfo(torch.int32).max:

backends/arm/test/ops/test_index_tensor.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,28 @@ class IndexTensor(torch.nn.Module):
318318
torch.randint(2, size=(15,), dtype=torch.int32),
319319
),
320320
),
321+
"test_1d_rank4_index": (
322+
torch.rand(12),
323+
(torch.randint(12, size=(1, 2, 1, 3), dtype=torch.int32),),
324+
),
325+
"test_2d_rank4_broadcastable_indices": (
326+
torch.rand(4, 6),
327+
(
328+
torch.randint(4, size=(1, 2, 1, 1), dtype=torch.int32),
329+
torch.randint(6, size=(1, 1, 3, 1), dtype=torch.int32),
330+
),
331+
),
332+
"test_1d_high_rank_index": (
333+
torch.rand(24),
334+
(torch.randint(24, size=(1, 1, 2, 1, 3), dtype=torch.int32),),
335+
),
336+
"test_2d_high_rank_broadcastable_indices": (
337+
torch.rand(4, 5),
338+
(
339+
torch.randint(4, size=(1, 2, 1, 1, 1), dtype=torch.int32),
340+
torch.randint(5, size=(1, 1, 3, 1, 1), dtype=torch.int32),
341+
),
342+
),
321343
}
322344
test_data_bf16: dict[input_params] = {
323345
"test_2d_1_idx_bf16": (

0 commit comments

Comments
 (0)