Skip to content

Commit d523daa

Browse files
Arm backend: Add FuseIdenticalInputTransformsPass (#20514)
Sinks identical data movement ops on all inputs of a binary op, e.g. add(permute(x), permute(y)) becomes: permute(add(x, y)) Adds support for -1 indexing to view-map using the normalize_view_shape helper, and collapsing of known SymInts using _simplify_dim. Signed-off-by: Adrian Lundell <adrian.lundell@arm.com>
1 parent bb28350 commit d523daa

8 files changed

Lines changed: 885 additions & 36 deletions

backends/arm/_passes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@
119119
)
120120
from .fuse_duplicate_users_pass import FuseDuplicateUsersPass # noqa
121121
from .fuse_equal_placeholders_pass import FuseEqualPlaceholdersPass # noqa
122+
from .fuse_identical_input_transforms_pass import ( # noqa
123+
FuseIdenticalInputTransformsPass,
124+
)
122125
from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa
123126
from .fuse_view_copy_transform_pass import FuseViewCopyTransformPass # noqa
124127
from .insert_const_shapes import InsertConstShapesPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@
112112
FuseConstantArgsPass,
113113
FuseDuplicateUsersPass,
114114
FuseEqualPlaceholdersPass,
115+
FuseIdenticalInputTransformsPass,
115116
FuseQuantizedActivationPass,
116117
FuseViewCopyTransformPass,
117118
InsertConstShapesPass,
@@ -500,6 +501,7 @@ def _tosa_pipeline(
500501
# TODO: DecomposeLinearPass should run after InsertRescaleInt32Pass or
501502
# before FoldAndAnnotateQParamsPass but is unable to at the moment.
502503
# Ticket: MLETORCH-1539
504+
FuseIdenticalInputTransformsPass(),
503505
DecomposeLinearPass(),
504506
InsertRescaleInt32Pass(),
505507
FuseConsecutiveRescalesPass(),

backends/arm/_passes/arm_pass_utils.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,18 @@
1313

1414
import torch
1515
import torch.fx
16+
17+
from executorch.backends.arm._passes.dim_maps import (
18+
_normalize_dims,
19+
normalize_view_shape,
20+
)
1621
from executorch.backends.arm.common.debug import get_node_debug_info
1722
from executorch.backends.arm.common.type import ensure_type
1823
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
1924
from executorch.exir import ExportedProgram
2025
from executorch.exir.dialects._ops import ops as exir_ops
2126
from executorch.exir.dialects.edge._ops import EdgeOpOverload
2227
from executorch.exir.pass_base import NodeMetadata
23-
2428
from torch._export.utils import (
2529
get_buffer,
2630
get_lifted_tensor_constant,
@@ -33,6 +37,8 @@
3337
from torch._subclasses.fake_tensor import FakeTensor
3438
from torch.export.graph_signature import InputKind
3539

40+
_Dim = int | torch.SymInt
41+
3642

3743
def is_submodule_node(node: torch.fx.Node):
3844
if node.op not in ("get_attr", "placeholder"):
@@ -243,6 +249,43 @@ def meta_without_qparams(meta: NodeMetadata) -> NodeMetadata:
243249
return NodeMetadata(plain_meta_dict)
244250

245251

252+
def refresh_permute_view_meta(node: torch.fx.Node) -> None:
253+
"""Compute new meta-vals, specifically preserving SymInts for view/permute
254+
nodes.
255+
"""
256+
input_node = node.all_input_nodes[0]
257+
input_val = input_node.meta.get("val")
258+
if input_val is None or node.target not in {
259+
exir_ops.edge.aten.view_copy.default,
260+
exir_ops.edge.aten.permute_copy.default,
261+
}:
262+
return
263+
264+
if not isinstance(input_val, torch.Tensor):
265+
node.meta["val"] = node.target(input_val, *node.args[1:]) # type: ignore[operator]
266+
return
267+
268+
# Compute new meta shapes to preserve SymInts.
269+
match node.target:
270+
case exir_ops.edge.aten.view_copy.default:
271+
node.meta["val"] = input_val.new_empty(
272+
tuple(
273+
normalize_view_shape(
274+
input_val.shape, cast(Sequence[_Dim], node.args[1])
275+
)
276+
)
277+
)
278+
case exir_ops.edge.aten.permute_copy.default:
279+
dims = _normalize_dims(
280+
cast(Sequence[int], node.args[1]), len(input_val.shape)
281+
)
282+
node.meta["val"] = input_val.new_empty(
283+
tuple(input_val.shape[dim] for dim in dims)
284+
)
285+
case _:
286+
node.meta["val"] = node.target(input_val, *node.args[1:]) # type: ignore[operator]
287+
288+
246289
def insert_scalar(
247290
graph: torch.fx.Graph,
248291
value: int | float,

backends/arm/_passes/canonicalize_view_copy_permute_pass.py

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
from typing import cast, Sequence, Set, Type
99

1010
import torch
11+
1112
from executorch.backends.arm._passes.arm_pass import ArmPass
13+
from executorch.backends.arm._passes.arm_pass_utils import refresh_permute_view_meta
1214
from executorch.backends.arm._passes.dim_maps import (
1315
_dim_equals,
1416
_is_permutation,
@@ -18,7 +20,6 @@
1820
)
1921
from executorch.exir.dialects._ops import ops as exir_ops
2022
from executorch.exir.pass_base import ExportPass
21-
2223
from torch.fx import GraphModule, Node
2324
from torch.fx.node import Target
2425
from torch.fx.passes.infra.pass_base import PassResult
@@ -360,39 +361,12 @@ def _set_node_op(
360361
) -> None:
361362
node.target = target
362363
node.args = (input_node, list(arg))
363-
self._refresh_meta(node)
364+
refresh_permute_view_meta(node)
364365

365366
def _permute_dims(self, node: Node) -> list[int]:
366367
assert node.target == self._PERMUTE_TARGET, "Expected permute node"
367368
return list(cast(Sequence[int], node.args[1]))
368369

369-
@classmethod
370-
def _refresh_meta(cls, node: Node) -> None:
371-
input_node = node.args[0]
372-
assert isinstance(input_node, Node)
373-
input_val = input_node.meta.get("val")
374-
if input_val is None or node.target not in cls._TARGETS:
375-
return
376-
377-
# Compute new meta shapes to preserve SymInts.
378-
if isinstance(input_val, torch.Tensor):
379-
if node.target == cls._VIEW_TARGET:
380-
node.meta["val"] = input_val.new_empty(
381-
tuple(cast(Sequence[_Dim], node.args[1]))
382-
)
383-
return
384-
385-
if node.target == cls._PERMUTE_TARGET:
386-
dims = _normalize_dims(
387-
cast(Sequence[int], node.args[1]), len(input_val.shape)
388-
)
389-
node.meta["val"] = input_val.new_empty(
390-
tuple(input_val.shape[dim] for dim in dims)
391-
)
392-
return
393-
394-
node.meta["val"] = node.target(input_val, *node.args[1:]) # type: ignore[operator]
395-
396370
@staticmethod
397371
def _shape(node: Node) -> list[_Dim]:
398372
return cast(list[_Dim], list(node.meta["val"].shape))

backends/arm/_passes/dim_maps.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,21 @@ def _dim_expr(dim: _Dim) -> sympy.Basic:
9191
return sympy.Integer(dim) if isinstance(dim, int) else dim.node.expr
9292

9393

94+
def _simplify_dim(dim: _Dim) -> _Dim:
95+
if isinstance(dim, int):
96+
return dim
97+
98+
maybe_int = dim.node.maybe_as_int()
99+
return maybe_int if maybe_int is not None else dim
100+
101+
102+
def _simplify_shape(shape: Iterable[_Dim]) -> list[_Dim]:
103+
return [_simplify_dim(dim) for dim in shape]
104+
105+
94106
def _dim_equals(lhs: _Dim, rhs: _Dim) -> bool:
107+
lhs = _simplify_dim(lhs)
108+
rhs = _simplify_dim(rhs)
95109
if isinstance(lhs, int) and isinstance(rhs, int):
96110
return lhs == rhs
97111
return sympy.simplify(_dim_expr(lhs) - _dim_expr(rhs)) == 0
@@ -113,6 +127,7 @@ def _factor_int(dim: int) -> list[_FactorKey] | None:
113127

114128

115129
def _factor_dim(dim: _Dim) -> list[_FactorKey] | None:
130+
dim = _simplify_dim(dim)
116131
if _dim_equals(dim, 1):
117132
return []
118133
if isinstance(dim, int):
@@ -143,14 +158,39 @@ def _dedupe(items: Iterable[int]) -> list[int]:
143158
def numel(shape: Iterable[_Dim]) -> _Dim:
144159
numel: _Dim = 1
145160
for dim in shape:
146-
numel *= dim
161+
numel = _simplify_dim(numel * _simplify_dim(dim))
147162
return numel
148163

149164

150165
def same_numel(first_shape: Iterable[_Dim], second_shape: Iterable[_Dim]) -> bool:
151166
return _dim_equals(numel(first_shape), numel(second_shape))
152167

153168

169+
def normalize_view_shape(
170+
source_shape: Sequence[_Dim], target_shape: Sequence[_Dim]
171+
) -> list[_Dim]:
172+
"""Normalize a view shape with <=1 unknown dim, indicated by -1."""
173+
source_shape = _simplify_shape(source_shape)
174+
normalized_shape = _simplify_shape(target_shape)
175+
inferred_dims = [
176+
index for index, dim in enumerate(normalized_shape) if _dim_equals(dim, -1)
177+
]
178+
if not inferred_dims:
179+
return normalized_shape
180+
181+
assert len(inferred_dims) == 1, f"Invalid view shape {target_shape}"
182+
inferred_dim = inferred_dims[0]
183+
known_shape = [
184+
dim for index, dim in enumerate(normalized_shape) if index != inferred_dim
185+
]
186+
source_numel = numel(source_shape)
187+
known_numel = numel(known_shape)
188+
normalized_shape[inferred_dim] = _simplify_dim(
189+
source_numel if _dim_equals(known_numel, 1) else source_numel // known_numel
190+
)
191+
return normalized_shape
192+
193+
154194
class _UnionFind:
155195
def __init__(self, size: int) -> None:
156196
self.parents = list(range(size))
@@ -210,8 +250,10 @@ def __init__(self, view_node: Node) -> None:
210250
input_val = input_node.meta["val"]
211251
assert isinstance(input_val, torch.Tensor)
212252

213-
self.source_shape = cast(list[_Dim], list(input_val.shape))
214-
self.target_shape = list(cast(Sequence[_Dim], view_node.args[1]))
253+
self.source_shape = _simplify_shape(cast(Sequence[_Dim], input_val.shape))
254+
self.target_shape = normalize_view_shape(
255+
self.source_shape, cast(Sequence[_Dim], view_node.args[1])
256+
)
215257
self._groups = self._build_groups(self.source_shape, self.target_shape)
216258

217259
@classmethod
@@ -220,8 +262,10 @@ def from_shapes(
220262
) -> ViewMap:
221263
"""Build a view map directly from source and target shapes."""
222264
view_map = cls.__new__(cls)
223-
view_map.source_shape = list(source_shape)
224-
view_map.target_shape = list(target_shape)
265+
view_map.source_shape = _simplify_shape(source_shape)
266+
view_map.target_shape = normalize_view_shape(
267+
view_map.source_shape, target_shape
268+
)
225269
view_map._groups = cls._build_groups(
226270
view_map.source_shape, view_map.target_shape
227271
)

0 commit comments

Comments
 (0)