Skip to content

Commit 28783db

Browse files
Arm backend: Support high rank index tensor (#17951)
- Drop the rank>=4 rejection from index.Tensor TOSA support checks - Add rank-4 and rank-5 index tensor test cases - Note: rank>=4 support has been covered by the index.Tensor refactor, removing the need for special handling in ToTosaMemoryFormatPass Change-Id: Ief40942a94040c02e54c7f276eecd660d571e46d Arm backend: Refactor aten.index.tensor suppport - Decompose edge.aten.index.Tensor via backend tosa.GATHER + shape ops - Remove the index.Tensor node visitor - Reorder AccumulateIndexPutPass (generates index.Tensor) and DecomposeSliceScatterPass (may generate index_put) before DecomposeIndexTensorToGatherPass - Correct index_put tests to not test int inputs under FP-only profile Change-Id: I5cfc7c110f0074463043ef1cb61165cc784a4db2 cc @digantdesai @SS-JIA @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell --------- Signed-off-by: Yufeng Shi <yufeng.shi@arm.com> Co-authored-by: Erik Lundell <erik.lundell@arm.com>
1 parent 322857a commit 28783db

11 files changed

Lines changed: 500 additions & 343 deletions

backends/arm/_passes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@
5555
from .decompose_index_select_to_gather_pass import ( # noqa
5656
DecomposeIndexSelectToGatherPass,
5757
)
58+
from .decompose_index_tensor_to_gather_pass import ( # noqa
59+
DecomposeIndexTensorToGatherPass,
60+
)
5861
from .decompose_int16_activation_conv_pass import ( # noqa
5962
DecomposeConvWithInt16ActivationPass,
6063
)

backends/arm/_passes/accumulate_index_put_pass.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
import torch
88

99
from executorch.backends.arm._passes import ArmPass
10+
from executorch.backends.arm._passes.decompose_index_tensor_to_gather_pass import (
11+
DecomposeIndexTensorToGatherPass,
12+
)
13+
from executorch.backends.arm._passes.rewrite_index_put_pass import RewriteIndexPutPass
1014
from executorch.exir.dialects._ops import ops as exir_ops
1115
from executorch.exir.pass_base import ExportPass
1216

@@ -33,7 +37,10 @@ class AccumulateIndexPutPass(ArmPass):
3337
for the index_put op.
3438
"""
3539

36-
_passes_required_after: Set[Type[ExportPass]] = set()
40+
_passes_required_after: Set[Type[ExportPass]] = {
41+
DecomposeIndexTensorToGatherPass,
42+
RewriteIndexPutPass,
43+
}
3744

3845
def call_operator(self, op, args, kwargs, meta):
3946
if op not in (aten_ops + edge_ops) or not self.allowed_to_transform(meta):

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
DecomposeGroupedConvPass,
6161
DecomposeGroupNormPass,
6262
DecomposeIndexSelectToGatherPass,
63+
DecomposeIndexTensorToGatherPass,
6364
DecomposeIntPowPass,
6465
DecomposeLayerNormPass,
6566
DecomposeLeakyReLUPass,
@@ -307,6 +308,9 @@ def _tosa_pipeline(
307308
DecomposeEmbeddingPass(),
308309
DecomposeIndexSelectToGatherPass(),
309310
DecomposeStridedSliceCopyPass(),
311+
DecomposeSliceScatterPass(),
312+
AccumulateIndexPutPass(),
313+
DecomposeIndexTensorToGatherPass(),
310314
Conv1dUnsqueezePass(),
311315
]
312316
)
@@ -329,8 +333,6 @@ def _tosa_pipeline(
329333
# Node transformation passes (post scalar-removal)
330334
self.add_passes(
331335
[
332-
DecomposeSliceScatterPass(),
333-
AccumulateIndexPutPass(),
334336
RewriteIndexPutPass(),
335337
RewriteBoolBitwiseToLogicalPass(),
336338
DecomposeRemainderPass(),

backends/arm/_passes/arm_pass_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
_get_control_flow_submodules,
2222
get_control_flow_submodules,
2323
)
24+
from executorch.exir.pass_base import NodeMetadata
2425

2526
from torch._export.utils import (
2627
get_buffer,
@@ -202,6 +203,14 @@ def insert_q_dq_pair(
202203
return dq
203204

204205

206+
def meta_without_qparams(meta: NodeMetadata) -> NodeMetadata:
207+
"""Return a copy of NodeMetadata with input/output qparams cleared."""
208+
plain_meta_dict = dict(meta.data)
209+
plain_meta_dict["input_qparams"] = {}
210+
plain_meta_dict["output_qparams"] = {}
211+
return NodeMetadata(plain_meta_dict)
212+
213+
205214
def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor:
206215
"""Returns a FakeTensor from the meta field of 'node'.
207216

0 commit comments

Comments
 (0)