Skip to content

Commit d83aa08

Browse files
authored
Arm backend: Reuse identical CONST_SHAPE nodes (#19770)
Cache CONST_SHAPE nodes created by InsertConstShapesPass and reuse them when a later view/repeat needs the same shape. This removes duplicate shape constants. This improvement is model dependent. Models with few repeated literal shapes will not see any meaningful change, but some models can benefit from it notably. The table below shows the results of a local test lowering DeiT Tiny to TOSA-FP. The lowering time reduced in this run, likely because passes following InsertConstShapesPass had fewer nodes to iterate over. | Metric | Baseline | Optimized | Delta | | -------------- | -------- | --------- | ---------------- | | Total ops | 2106 | 1736 | -370 (-17.6%) | | CONST_SHAPE | 466 | 96 | -370 (-79.4%) | | TOSA size | 23.82 MB | 23.75 MB | -71.6 KB (-0.3%) | | Execution time | 118.7 s | 78.4 s | -40.3 s (-34.0%) | Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com>
1 parent dd00d42 commit d83aa08

1 file changed

Lines changed: 15 additions & 7 deletions

File tree

backends/arm/_passes/insert_const_shapes.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ class InsertConstShapesPass(ArmPass):
2626
exir_ops.edge.aten.repeat.default,
2727
}
2828

29+
def __init__(self) -> None:
30+
super().__init__()
31+
self._const_shape_cache: dict[tuple[int, ...], Any] = {}
32+
2933
@staticmethod
3034
def _is_shape_arg(arg: Any) -> bool:
3135
"""Return True when `arg` looks like a literal shape list/tuple."""
@@ -46,13 +50,17 @@ def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False)
4650
# Insert a const node for the shape argument
4751
if op == exir_ops.edge.aten.view_copy.default:
4852
arg = meta.data["val"].shape
49-
const_node = super().call_shape_operator(
50-
exir_ops.backend.tosa.CONST_SHAPE.default,
51-
(arg,),
52-
{},
53-
meta,
54-
True,
55-
)
53+
shape = tuple(arg)
54+
const_node = self._const_shape_cache.get(shape)
55+
if const_node is None:
56+
const_node = super().call_shape_operator(
57+
exir_ops.backend.tosa.CONST_SHAPE.default,
58+
(arg,),
59+
{},
60+
meta,
61+
True,
62+
)
63+
self._const_shape_cache[shape] = const_node
5664
new_args.append(const_node)
5765
updated = True
5866
else:

0 commit comments

Comments
 (0)