Skip to content

Commit 10f22a7

Browse files
Arm backend: Add pass to create CONST_SHAPEs (#18099)
Add pass that creates CONST_SHAPEs for view_copy and repeat based on their list arguments. cc @digantdesai @freddan80 @per @zingo @mansnils @Sebastian-Larsson @robell Signed-off-by: Oscar Andersson <oscar.andersson@arm.com> Co-authored-by: Per Åstrand <per.astrand@arm.com>
1 parent 6c9a4d6 commit 10f22a7

5 files changed

Lines changed: 68 additions & 30 deletions

File tree

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@
110110
from .fuse_equal_placeholders_pass import FuseEqualPlaceholdersPass # noqa
111111
from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa
112112
from .fuse_view_copy_transform_pass import FuseViewCopyTransformPass # noqa
113+
from .insert_const_shapes import InsertConstShapesPass # noqa
113114
from .insert_int32_casts_after_int64_placeholders import ( # noqa
114115
InsertInt32CastsAfterInt64PlaceholdersPass,
115116
)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@
102102
FuseEqualPlaceholdersPass,
103103
FuseQuantizedActivationPass,
104104
FuseViewCopyTransformPass,
105+
InsertConstShapesPass,
105106
InsertControlFlowRescalesPass,
106107
InsertInt32CastsAfterInt64PlaceholdersPass,
107108
InsertRescaleInt32Pass,
@@ -380,6 +381,7 @@ def _tosa_pipeline(
380381
RewriteMatmulPass(),
381382
RewritePadPass(),
382383
RewriteSlicePass(),
384+
InsertConstShapesPass(),
383385
]
384386
)
385387

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Any, Optional
7+
8+
from executorch.backends.arm._passes import ArmPass
9+
from executorch.backends.arm.tosa.dialect.shape import meta_has_shape_mark
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
12+
13+
class InsertConstShapesPass(ArmPass):
14+
"""Materialize literal shape arguments as CONST_SHAPE nodes.
15+
16+
This pass targets ops such as `aten.view_copy` and `aten.repeat` whose shape
17+
arguments might otherwise remain raw Python lists/tuples. Replacing them
18+
with explicit CONST_SHAPE nodes simplifies the serialization of these ops
19+
the serialization of their arguments is handled by the CONST_SHAPE node visitor.
20+
21+
"""
22+
23+
_passes_required_after = set()
24+
targeted_ops = {
25+
exir_ops.edge.aten.view_copy.default,
26+
exir_ops.edge.aten.repeat.default,
27+
}
28+
29+
@staticmethod
30+
def _is_shape_arg(arg: Any) -> bool:
31+
"""Return True when `arg` looks like a literal shape list/tuple."""
32+
is_shape_op = meta_has_shape_mark(arg.meta) if hasattr(arg, "meta") else False
33+
return (
34+
not is_shape_op
35+
and isinstance(arg, (list, tuple))
36+
and all(isinstance(x, int) for x in arg)
37+
)
38+
39+
def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False):
40+
if op not in self.targeted_ops:
41+
return super().call_operator(op, args, kwargs, meta, updated)
42+
if any(InsertConstShapesPass._is_shape_arg(arg) for arg in args):
43+
new_args = []
44+
for arg in args:
45+
if InsertConstShapesPass._is_shape_arg(arg):
46+
# Insert a const node for the shape argument
47+
if op == exir_ops.edge.aten.view_copy.default:
48+
arg = meta.data["val"].shape
49+
const_node = super().call_shape_operator(
50+
exir_ops.backend.tosa.CONST_SHAPE.default,
51+
(arg,),
52+
{},
53+
meta,
54+
True,
55+
)
56+
new_args.append(const_node)
57+
updated = True
58+
else:
59+
new_args.append(arg)
60+
61+
return super().call_operator(op, tuple(new_args), kwargs, meta, updated)
62+
63+
return super().call_operator(op, args, kwargs, meta, updated)

backends/arm/operators/op_repeat.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
validate_valid_dtype,
2020
)
2121
from executorch.backends.arm.tosa.mapping import TosaArg
22-
from executorch.backends.arm.tosa.utils import tosa_shape
2322

2423

2524
@register_node_visitor
@@ -53,25 +52,13 @@ def define_node(
5352
self.tosa_spec,
5453
)
5554

56-
multiples = inputs[1].special
57-
58-
if len(multiples) == 0:
59-
raise ValueError(f"Length of multiples argument is 0: {inputs[1]}!")
60-
61-
multiple_shapes = tosa_graph.addConst(
62-
(len(multiples),),
63-
ts.DType.SHAPE,
64-
list(tosa_shape(multiples, output.dim_order)),
65-
name=output.name + "_multiples",
66-
)
67-
6855
attr = ts.TosaSerializerAttribute()
6956
attr.TileAttribute()
7057
self._serialize_operator(
7158
node,
7259
tosa_graph,
7360
ts.Op.TILE,
74-
[inputs[0].name, multiple_shapes.name],
61+
[inputs[0].name, inputs[1].name],
7562
[output.name],
7663
attr,
7764
)

backends/arm/operators/op_view.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
validate_valid_dtype,
2020
)
2121
from executorch.backends.arm.tosa.mapping import TosaArg
22-
from executorch.backends.arm.tosa.utils import tosa_shape
2322

2423

2524
@register_node_visitor
@@ -55,27 +54,13 @@ def define_node(
5554

5655
tosa_graph = cast(ts.TosaSerializer, tosa_graph)
5756

58-
if len(output.shape) != 0:
59-
shape_len = [len(output.shape)]
60-
shape_data = list(tosa_shape(output.shape, output.dim_order))
61-
else:
62-
shape_len = []
63-
shape_data = []
64-
65-
shape = tosa_graph.addConst(
66-
shape_len,
67-
ts.DType.SHAPE,
68-
shape_data,
69-
name=output.name + "_shape",
70-
)
71-
7257
attr = ts.TosaSerializerAttribute()
7358
attr.ReshapeAttribute()
7459
self._serialize_operator(
7560
node,
7661
tosa_graph,
7762
ts.Op.RESHAPE,
78-
[inputs[0].name, shape.name],
63+
[inputs[0].name, inputs[1].name],
7964
[output.name],
8065
attr,
8166
)

0 commit comments

Comments
 (0)