Skip to content

Commit 1bb039f

Browse files
Arm backend: Remove use of tosa_dim_order (#18948)
This patch removes the use of the ToTosaMemoryFormatPass and tosa_dim_order in favor of serializing directly to the contiguous stride shape. To allow this, modify all channels-last tosa dialect ops to be explicitly channels last and lower them with permutes, e.g. (1,2,3,3) -> aten conv -> (1,4,3,3) lowers to (1, 2, 3,3) -> (1,3,3,2) -> tosa conv -> (1,3,3,4) -> (1,4,3,3) To handle channels-last input/output, NormalizeDelegateIOLayoutPass permutes the shape of such inputs and outputs and inserts a permute to force it to be contiguous. This permute will then typically cancel out the top convolution permute, enabling zero-transpose graphs in many cases. Additionally, - Conv1D input is unsqueezed (N,C,L)->(N,C,1,L) instead of (N,C,L,1) to match avg_pool1d and max_pool2d shapes better - Bias to int16 convolutions are handled by lowering bias to a int48 tensor rather than using a seperate add operator. - Fix remove_permutes_around_elementwise_ops for permutes with negative indices. Signed-off-by: Adrian Lundell <adrian.lundell@arm.com> Co-authored-by: mcremon-meta <134334895+mcremon-meta@users.noreply.github.com>
1 parent c0c079b commit 1bb039f

34 files changed

Lines changed: 1071 additions & 315 deletions

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@
131131
from .match_arg_dtype_pass import MatchArgDtypePass # noqa
132132
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
133133
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
134+
from .normalize_delegate_io_layout_pass import NormalizeDelegateIOLayoutPass # noqa
134135
from .normalize_index_put_bool_index_tensor_pass import ( # noqa
135136
NormalizeIndexPutBoolIndexTensorPass,
136137
)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
from executorch.backends.arm._passes import (
1414
AccumulateIndexPutPass,
15-
AnnotateOutputDimOrderPass,
1615
BroadcastArgsPass,
1716
CanonicalizeGatherPass,
1817
CastInt64BuffersToInt32Pass,
@@ -44,7 +43,6 @@
4443
DecomposeAtanPass,
4544
DecomposeAvgPool2dPass,
4645
DecomposeBatchNormNoStatsPass,
47-
DecomposeConvWithInt16ActivationPass,
4846
DecomposeCoshPass,
4947
DecomposeCosineSimilarityPass,
5048
DecomposeCumsumPass,
@@ -117,6 +115,7 @@
117115
InsertTableOpsPass,
118116
MatchArgDtypePass,
119117
MatchArgRanksPass,
118+
NormalizeDelegateIOLayoutPass,
120119
NormalizeIndexPutBoolIndexTensorPass,
121120
NormalizeIndexPutNoneIndicesPass,
122121
NormalizeWhileInitialArgsPass,
@@ -142,7 +141,6 @@
142141
RewriteUpsamplePass,
143142
ScalarsToAttributePass,
144143
SizeAdjustInputPass,
145-
ToTosaMemoryFormatPass,
146144
UnsqueezeBeforeRepeatPass,
147145
UnsqueezeScalarPlaceholdersPass,
148146
)
@@ -158,6 +156,16 @@
158156
TosaLoweringContext,
159157
TosaSpecification,
160158
)
159+
from executorch.backends.transforms.fuse_cascaded_transpose_or_permute_ops import (
160+
FuseCascadedTransposeOrPermuteOps,
161+
)
162+
from executorch.backends.transforms.postpone_permute_below_squeeze_view import (
163+
PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView,
164+
)
165+
166+
from executorch.backends.transforms.remove_permutes_around_elementwise_ops import (
167+
RemovePermutesAroundElementwiseOps,
168+
)
161169
from executorch.exir import ExportedProgram
162170
from executorch.exir.pass_base import ExportPass
163171
from executorch.exir.pass_manager import PassManager
@@ -386,12 +394,10 @@ def _tosa_pipeline(
386394
# Allow subclasses to configure pass insertions before building pipeline
387395
self._configure_pass_insertions(exported_program)
388396

389-
# Preprocessing passes
390-
self.add_pass(AnnotateOutputDimOrderPass())
391-
392397
# Node transformation passes (pre q/dq folding)
393398
self.add_passes(
394399
[
400+
NormalizeDelegateIOLayoutPass(exported_program),
395401
FuseQuantizedActivationPass(),
396402
RewriteBoolToFp32CastViaInt8Pass(),
397403
CanonicalizeGatherPass(),
@@ -516,12 +522,9 @@ def _tosa_pipeline(
516522
ConvertSqueezesToViewPass(),
517523
CastToInt32Pass(),
518524
BroadcastArgsPass(),
519-
ConvertPermuteSingletonToViewPass(),
520-
RewriteHighRankSingletonPermutePass(),
521-
FuseViewCopyTransformPass(),
522-
DecomposeConvWithInt16ActivationPass(),
523525
DecomposeSumPass(),
524526
InsertTableOpsPass(exported_program),
527+
RemoveNoopPass(),
525528
]
526529
)
527530

@@ -534,6 +537,12 @@ def _tosa_pipeline(
534537
RewriteMatmulPass(),
535538
RewritePadPass(),
536539
RewriteSlicePass(),
540+
FuseViewCopyTransformPass(),
541+
RemovePermutesAroundElementwiseOps(),
542+
PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(),
543+
FuseCascadedTransposeOrPermuteOps(),
544+
ConvertPermuteSingletonToViewPass(),
545+
RewriteHighRankSingletonPermutePass(),
537546
InsertConstShapesPass(),
538547
]
539548
)
@@ -544,7 +553,6 @@ def _tosa_pipeline(
544553
CastInt64BuffersToInt32Pass(exported_program),
545554
FuseEqualPlaceholdersPass(exported_program),
546555
FuseConsecutiveConcatShapesPass(),
547-
ToTosaMemoryFormatPass(exported_program),
548556
EnsureUniqueOutputNodesPass(),
549557
RemoveNoopPass(),
550558
InsertRescalePass(),

backends/arm/_passes/arm_pass_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,3 +397,12 @@ def get_cond_while_submodules_nested(
397397
}
398398
# collect cond/while submodules (using mapping indices)
399399
return _get_control_flow_submodules(graph_module, mapping)
400+
401+
402+
def to_2tuple(value):
403+
"""Normalizes scalars, and 1-element sequences to a tuple of length 2."""
404+
if isinstance(value, int):
405+
return (value, value)
406+
if len(value) == 1:
407+
return (value[0], value[0])
408+
return tuple(value)

backends/arm/_passes/conv1d_unsqueeze_pass.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def call_operator(self, op, args, kwargs, meta):
4747
x_meta.data["output_qparams"] = {}
4848

4949
x = args[0]
50-
x_unsqueezed_shape = list(x.data.shape) + [1]
50+
x_unsqueezed_shape = list(x.data.shape[:-1]) + [1] + [x.data.shape[-1]]
5151
x = super().call_operator(
5252
exir_ops.edge.aten.view_copy.default,
5353
(x, x_unsqueezed_shape),
@@ -61,7 +61,7 @@ def call_operator(self, op, args, kwargs, meta):
6161
w_meta.data["output_qparams"] = {}
6262

6363
w = args[1]
64-
w_unsqueezed_shape = list(w.data.shape) + [1]
64+
w_unsqueezed_shape = list(w.data.shape[:-1]) + [1] + [w.data.shape[-1]]
6565
w = super().call_operator(
6666
exir_ops.edge.aten.view_copy.default,
6767
(w, w_unsqueezed_shape),
@@ -74,11 +74,11 @@ def call_operator(self, op, args, kwargs, meta):
7474
x,
7575
w,
7676
args[2],
77-
args[3] + [1], # stride
78-
args[4] + [0], # padding
79-
args[5] + [1], # dilation
77+
[1] + args[3], # stride
78+
[0] + args[4], # padding
79+
[1] + args[5], # dilation
8080
args[6],
81-
args[7] + [0],
81+
[0] + args[7],
8282
args[8],
8383
)
8484
x = super().call_operator(
@@ -88,7 +88,7 @@ def call_operator(self, op, args, kwargs, meta):
8888
x_squeezed_meta = meta.copy()
8989
x_squeezed_meta.data["input_qparams"] = {}
9090
x_squeezed_meta.data["output_qparams"] = {}
91-
x_squeezed_shape = list(x.data.shape)[:-1]
91+
x_squeezed_shape = list(x.data.shape[:-2]) + [x.data.shape[-1]]
9292
x = super().call_operator(
9393
exir_ops.edge.aten.view_copy.default,
9494
(x, x_squeezed_shape),
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Any, Set, Type
7+
8+
import torch
9+
from executorch.backends.arm._passes import ArmPass
10+
from executorch.backends.arm._passes.arm_pass_utils import (
11+
create_node,
12+
get_first_fake_tensor,
13+
is_param_node,
14+
)
15+
from executorch.exir import ExportedProgram
16+
from executorch.exir.dialects._ops import ops as exir_ops
17+
from executorch.exir.pass_base import ExportPass, PassResult
18+
19+
20+
class NormalizeDelegateIOLayoutPass(ArmPass):
21+
"""Adjust delegated boundary tensor shapes and insert permutes at I/O."""
22+
23+
_passes_required_after: Set[Type[ExportPass]] = set()
24+
25+
def __init__(self, exported_program: ExportedProgram, *args, **kwargs) -> None:
26+
super().__init__(*args, **kwargs)
27+
self.exported_program = exported_program
28+
29+
@staticmethod
30+
def _inverse_permutation(perm: tuple[int, ...]) -> tuple[int, ...]:
31+
inverse = [0] * len(perm)
32+
for idx, axis in enumerate(perm):
33+
inverse[axis] = idx
34+
return tuple(inverse)
35+
36+
@staticmethod
37+
def _permute_shape(shape: torch.Size, perm: tuple[int, ...]) -> tuple[int, ...]:
38+
return tuple(shape[axis] for axis in perm)
39+
40+
@staticmethod
41+
def _is_identity_dim_order(dim_order: tuple[int, ...]) -> bool:
42+
return dim_order == tuple(range(len(dim_order)))
43+
44+
def _normalize_input_layout(self, graph_module: torch.fx.GraphModule) -> bool:
45+
modified = False
46+
for node in graph_module.graph.nodes:
47+
if node.op != "placeholder" or is_param_node(self.exported_program, node):
48+
continue
49+
50+
input_fake = get_first_fake_tensor(node)
51+
dim_order = input_fake.dim_order()
52+
if self._is_identity_dim_order(dim_order):
53+
continue
54+
55+
boundary_shape = self._permute_shape(input_fake.shape, dim_order)
56+
node.meta["val"] = input_fake.reshape(boundary_shape)
57+
58+
transpose_perm = self._inverse_permutation(dim_order)
59+
with graph_module.graph.inserting_after(node):
60+
permute_node = create_node(
61+
graph_module.graph,
62+
exir_ops.edge.aten.permute_copy.default,
63+
args=(node, list(transpose_perm)),
64+
from_node=node,
65+
)
66+
permute_node.meta["val"] = exir_ops.edge.aten.permute_copy.default(
67+
node.meta["val"], list(transpose_perm)
68+
)
69+
70+
users = [user for user in node.users if user != permute_node]
71+
for user in users:
72+
user.replace_input_with(node, permute_node)
73+
74+
modified = True
75+
76+
return modified
77+
78+
def _rewrite_output_arg(
79+
self, arg: Any, graph_module: torch.fx.GraphModule
80+
) -> tuple[Any, bool]:
81+
if isinstance(arg, torch.fx.Node):
82+
output_fake = get_first_fake_tensor(arg)
83+
dim_order = output_fake.dim_order()
84+
if self._is_identity_dim_order(dim_order):
85+
return arg, False
86+
87+
with graph_module.graph.inserting_after(arg):
88+
permute_node = create_node(
89+
graph_module.graph,
90+
exir_ops.edge.aten.permute_copy.default,
91+
args=(arg, list(dim_order)),
92+
from_node=arg,
93+
)
94+
permute_node.meta["val"] = exir_ops.edge.aten.permute_copy.default(
95+
output_fake, list(dim_order)
96+
)
97+
98+
return permute_node, True
99+
100+
if isinstance(arg, tuple):
101+
modified = False
102+
rewritten = []
103+
for item in arg:
104+
new_item, item_modified = self._rewrite_output_arg(item, graph_module)
105+
rewritten.append(new_item)
106+
modified = modified or item_modified
107+
return tuple(rewritten), modified
108+
109+
if isinstance(arg, list):
110+
modified = False
111+
rewritten = []
112+
for item in arg:
113+
new_item, item_modified = self._rewrite_output_arg(item, graph_module)
114+
rewritten.append(new_item)
115+
modified = modified or item_modified
116+
return rewritten, modified
117+
118+
return arg, False
119+
120+
def _normalize_output_layout(self, graph_module: torch.fx.GraphModule) -> bool:
121+
output_node = graph_module.graph.output_node()
122+
rewritten_outputs, modified = self._rewrite_output_arg(
123+
output_node.args[0], graph_module
124+
)
125+
if modified:
126+
output_node.args = (rewritten_outputs,)
127+
return modified
128+
129+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
130+
modified = self._normalize_input_layout(graph_module)
131+
modified = self._normalize_output_layout(graph_module) or modified
132+
133+
if modified:
134+
graph_module.recompile()
135+
graph_module = super().call(graph_module).graph_module
136+
137+
return PassResult(graph_module, modified)

backends/arm/_passes/rewrite_avg_pool2d_pass.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
import torch
99
from executorch.backends.arm._passes import ArmPass
10+
from executorch.backends.arm._passes.arm_pass_utils import to_2tuple
11+
from executorch.backends.arm.constants import NHWC_INVERSE_ORDER, NHWC_ORDER
1012
from executorch.backends.arm.operators.operator_validation_utils import (
1113
adjust_pooling_pad_if_needed,
1214
)
@@ -32,19 +34,25 @@ def call_operator(self, op, args, kwargs, meta, updated=False):
3234
return super().call_operator(op, args, kwargs, meta, updated)
3335

3436
x = args[0]
35-
pad_h, pad_w = args[3]
37+
kernel = to_2tuple(args[1])
38+
39+
stride = to_2tuple(args[2]) if len(args) > 2 else ()
40+
if not stride:
41+
stride = kernel # default to kernel_size
42+
43+
pad_h, pad_w = to_2tuple(args[3]) if len(args) > 3 else (0, 0)
3644
# Make sure pad corresponds to TOSA
3745
pad = [pad_h, pad_w, pad_h, pad_w]
3846

39-
_, _, h, w = x.data.shape
40-
kernel_h, kernel_w = args[1]
41-
stride_h, stride_w = args[2]
42-
4347
ceil_mode = args[4] if len(args) > 4 else False
4448

4549
# Adjust padding if necessary
46-
pad[1] = adjust_pooling_pad_if_needed(h, kernel_h, stride_h, pad[1], ceil_mode)
47-
pad[3] = adjust_pooling_pad_if_needed(w, kernel_w, stride_w, pad[3], ceil_mode)
50+
pad[1] = adjust_pooling_pad_if_needed(
51+
x.data.shape[2], kernel[0], stride[0], pad[1], ceil_mode
52+
)
53+
pad[3] = adjust_pooling_pad_if_needed(
54+
x.data.shape[3], kernel[1], stride[1], pad[3], ceil_mode
55+
)
4856

4957
# Materialize zero-point constants
5058
in_qparams = meta.data.get("input_qparams", {})
@@ -63,13 +71,36 @@ def call_operator(self, op, args, kwargs, meta, updated=False):
6371
else:
6472
acc_type = torch.float32
6573

66-
tosa_args = (args[0], input_zp, output_zp, *args[1:3], pad, acc_type)
74+
pre_permute = super().call_operator(
75+
exir_ops.edge.aten.permute_copy.default,
76+
(x, list(NHWC_ORDER)),
77+
{},
78+
meta,
79+
updated=True,
80+
)
81+
82+
tosa_args = (
83+
pre_permute,
84+
input_zp,
85+
output_zp,
86+
list(kernel),
87+
list(stride),
88+
pad,
89+
acc_type,
90+
)
6791

6892
# Emit TOSA AVG_POOL2D with normalized args
69-
return super().call_operator(
93+
tosa_avg_pool = super().call_operator(
7094
exir_ops.backend.tosa.AVG_POOL2D.default,
7195
tosa_args,
7296
{},
7397
meta,
7498
True,
7599
)
100+
return super().call_operator(
101+
exir_ops.edge.aten.permute_copy.default,
102+
(tosa_avg_pool, list(NHWC_INVERSE_ORDER)),
103+
{},
104+
meta,
105+
updated=True,
106+
)

0 commit comments

Comments
 (0)