Skip to content

Commit 1bf23bf

Browse files
Arm backend: Refactor aten.index.tensor suppport
- Decompose edge.aten.index.Tensor via backend tosa.GATHER + shape ops - Remove the index.Tensor node visitor - Reorder AccumulateIndexPutPass (generates index.Tensor) and DecomposeSliceScatterPass (may generate index_put) before DecomposeIndexTensorToGatherPass - Correct index_put tests to not test int inputs under FP-only profile Change-Id: I5cfc7c110f0074463043ef1cb61165cc784a4db2 Signed-off-by: Yufeng Shi <yufeng.shi@arm.com> Co-authored-by: Erik Lundell <erik.lundell@arm.com>
1 parent 1619308 commit 1bf23bf

9 files changed

Lines changed: 472 additions & 320 deletions

backends/arm/_passes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@
5454
from .decompose_index_select_to_gather_pass import ( # noqa
5555
DecomposeIndexSelectToGatherPass,
5656
)
57+
from .decompose_index_tensor_to_gather_pass import ( # noqa
58+
DecomposeIndexTensorToGatherPass,
59+
)
5760
from .decompose_int16_activation_conv_pass import ( # noqa
5861
DecomposeConvWithInt16ActivationPass,
5962
)

backends/arm/_passes/accumulate_index_put_pass.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
import torch
88

99
from executorch.backends.arm._passes import ArmPass
10+
from executorch.backends.arm._passes.decompose_index_tensor_to_gather_pass import (
11+
DecomposeIndexTensorToGatherPass,
12+
)
13+
from executorch.backends.arm._passes.rewrite_index_put_pass import RewriteIndexPutPass
1014
from executorch.exir.dialects._ops import ops as exir_ops
1115
from executorch.exir.pass_base import ExportPass
1216

@@ -33,7 +37,10 @@ class AccumulateIndexPutPass(ArmPass):
3337
for the index_put op.
3438
"""
3539

36-
_passes_required_after: Set[Type[ExportPass]] = set()
40+
_passes_required_after: Set[Type[ExportPass]] = {
41+
DecomposeIndexTensorToGatherPass,
42+
RewriteIndexPutPass,
43+
}
3744

3845
def call_operator(self, op, args, kwargs, meta):
3946
if op not in (aten_ops + edge_ops) or not self.allowed_to_transform(meta):

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
DecomposeGroupedConvPass,
6060
DecomposeGroupNormPass,
6161
DecomposeIndexSelectToGatherPass,
62+
DecomposeIndexTensorToGatherPass,
6263
DecomposeIntPowPass,
6364
DecomposeLayerNormPass,
6465
DecomposeLeakyReLUPass,
@@ -303,6 +304,9 @@ def _tosa_pipeline(
303304
DecomposeEmbeddingPass(),
304305
DecomposeIndexSelectToGatherPass(),
305306
DecomposeStridedSliceCopyPass(),
307+
DecomposeSliceScatterPass(),
308+
AccumulateIndexPutPass(),
309+
DecomposeIndexTensorToGatherPass(),
306310
Conv1dUnsqueezePass(),
307311
]
308312
)
@@ -325,8 +329,6 @@ def _tosa_pipeline(
325329
# Node transformation passes (post scalar-removal)
326330
self.add_passes(
327331
[
328-
DecomposeSliceScatterPass(),
329-
AccumulateIndexPutPass(),
330332
RewriteIndexPutPass(),
331333
RewriteBoolBitwiseToLogicalPass(),
332334
DecomposeRemainderPass(),

backends/arm/_passes/arm_pass_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from executorch.exir import ExportedProgram
1818
from executorch.exir.dialects._ops import ops as exir_ops
1919
from executorch.exir.dialects.edge._ops import EdgeOpOverload
20+
from executorch.exir.pass_base import NodeMetadata
2021

2122
from torch._export.utils import (
2223
get_buffer,
@@ -197,6 +198,14 @@ def insert_q_dq_pair(
197198
return dq
198199

199200

201+
def meta_without_qparams(meta: NodeMetadata) -> NodeMetadata:
202+
"""Return a copy of NodeMetadata with input/output qparams cleared."""
203+
plain_meta_dict = dict(meta.data)
204+
plain_meta_dict["input_qparams"] = {}
205+
plain_meta_dict["output_qparams"] = {}
206+
return NodeMetadata(plain_meta_dict)
207+
208+
200209
def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor:
201210
"""Returns a FakeTensor from the meta field of 'node'.
202211

0 commit comments

Comments
 (0)