Skip to content

Commit ec8b798

Browse files
committed
Arm backend: Use channels-first
Signed-off-by: Adrian Lundell <adrian.lundell@arm.com> Change-Id: I7584e519cd876d63f1cafe98710f1c7fb0378581
1 parent 75ba558 commit ec8b798

24 files changed

Lines changed: 1438 additions & 73 deletions

backends/arm/_passes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
QuantizeClampArgumentsPass,
106106
)
107107
from .fuse_batch_norm2d_pass import FuseBatchNorm2dPass # noqa
108+
from .fuse_canceling_transposes_pass import FuseCancelingTransposesPass # noqa
108109
from .fuse_consecutive_concat_shapes import FuseConsecutiveConcatShapesPass # noqa
109110
from .fuse_consecutive_rescales_pass import FuseConsecutiveRescalesPass # noqa
110111
from .fuse_constant_ops_pass import ( # noqa
@@ -129,6 +130,7 @@
129130
from .match_arg_dtype_pass import MatchArgDtypePass # noqa
130131
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
131132
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
133+
from .normalize_delegate_io_layout_pass import NormalizeDelegateIOLayoutPass # noqa
132134
from .normalize_index_put_bool_index_tensor_pass import ( # noqa
133135
NormalizeIndexPutBoolIndexTensorPass,
134136
)
@@ -158,6 +160,7 @@
158160
from .rewrite_le_lt_to_ge_gt_pass import RewriteLeLtToGeGtPass # noqa
159161
from .rewrite_matmul import RewriteMatmulPass # noqa
160162
from .rewrite_pad import RewritePadPass # noqa
163+
from .rewrite_pool_pass import RewritePoolPass # noqa
161164
from .rewrite_slice import RewriteSlicePass # noqa
162165
from .rewrite_upsample import RewriteUpsamplePass # noqa
163166
from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
DecorateFp32toInt32CastingPass,
102102
FoldAndAnnotateQParamsPass,
103103
FuseBatchNorm2dPass,
104+
FuseCancelingTransposesPass,
104105
FuseConsecutiveConcatShapesPass,
105106
FuseConsecutiveRescalesPass,
106107
FuseConstantArgsPass,
@@ -116,6 +117,7 @@
116117
InsertTableOpsPass,
117118
MatchArgDtypePass,
118119
MatchArgRanksPass,
120+
NormalizeDelegateIOLayoutPass,
119121
NormalizeIndexPutBoolIndexTensorPass,
120122
NormalizeIndexPutNoneIndicesPass,
121123
NormalizeWhileInitialArgsPass,
@@ -135,11 +137,11 @@
135137
RewriteLeLtToGeGtPass,
136138
RewriteMatmulPass,
137139
RewritePadPass,
140+
RewritePoolPass,
138141
RewriteSlicePass,
139142
RewriteUpsamplePass,
140143
ScalarsToAttributePass,
141144
SizeAdjustInputPass,
142-
ToTosaMemoryFormatPass,
143145
UnsqueezeBeforeRepeatPass,
144146
UnsqueezeScalarPlaceholdersPass,
145147
)
@@ -497,9 +499,12 @@ def _tosa_pipeline(
497499
[
498500
RewriteUpsamplePass(),
499501
RewriteConvPass(exported_program),
502+
RewritePoolPass(),
500503
RewriteMatmulPass(),
501504
RewritePadPass(),
502505
RewriteSlicePass(),
506+
NormalizeDelegateIOLayoutPass(exported_program),
507+
FuseCancelingTransposesPass(),
503508
InsertConstShapesPass(),
504509
]
505510
)
@@ -510,7 +515,6 @@ def _tosa_pipeline(
510515
CastInt64BuffersToInt32Pass(exported_program),
511516
FuseEqualPlaceholdersPass(exported_program),
512517
FuseConsecutiveConcatShapesPass(),
513-
ToTosaMemoryFormatPass(exported_program),
514518
RemoveNoopPass(),
515519
InsertRescalePass(),
516520
]

backends/arm/_passes/decompose_int16_activation_conv_pass.py

Lines changed: 67 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,66 @@ def __init__(self) -> None:
3030
super().__init__()
3131

3232
_passes_required_after: Set[Type[ExportPass]] = set()
33+
_NHWC_ORDER = [0, 2, 3, 1]
34+
_NHWC_INVERSE_ORDER = [0, 3, 1, 2]
35+
_NDHWC_ORDER = [0, 2, 3, 4, 1]
36+
_NDHWC_INVERSE_ORDER = [0, 4, 1, 2, 3]
3337

3438
def bias_view_shape(
3539
self, bias: torch.Tensor, activation_rank: int
3640
) -> Sequence[int]:
3741
# reshape bias to match convolution output rank so addition broadcasts over channels
3842
return [1, bias.shape[0], *([1] * (activation_rank - 2))]
3943

44+
def _insert_rescale_with_optional_transpose(
45+
self,
46+
input_node,
47+
output_dtype,
48+
scales: list[float],
49+
in_zp: int,
50+
out_zp: int,
51+
meta,
52+
activation_rank: int,
53+
):
54+
per_channel = len(scales) > 1
55+
if not per_channel:
56+
return super().call_operator(
57+
exir_ops.backend.tosa.RESCALE.default,
58+
(input_node, output_dtype, scales, in_zp, out_zp),
59+
{},
60+
meta,
61+
)
62+
63+
if activation_rank == 4:
64+
pre_permute = self._NHWC_ORDER
65+
post_permute = self._NHWC_INVERSE_ORDER
66+
elif activation_rank == 5:
67+
pre_permute = self._NDHWC_ORDER
68+
post_permute = self._NDHWC_INVERSE_ORDER
69+
else:
70+
raise RuntimeError(
71+
f"Unsupported rank {activation_rank} for per-channel rescale"
72+
)
73+
74+
channel_last = super().call_operator(
75+
exir_ops.edge.aten.permute_copy.default,
76+
(input_node, pre_permute),
77+
{},
78+
meta,
79+
)
80+
rescaled = super().call_operator(
81+
exir_ops.backend.tosa.RESCALE.default,
82+
(channel_last, output_dtype, scales, in_zp, out_zp),
83+
{},
84+
meta,
85+
)
86+
return super().call_operator(
87+
exir_ops.edge.aten.permute_copy.default,
88+
(rescaled, post_permute),
89+
{},
90+
meta,
91+
)
92+
4093
def call_operator(self, op, args, kwargs, meta):
4194
if op != exir_ops.edge.aten.convolution.default:
4295
return super().call_operator(op, args, kwargs, meta)
@@ -112,11 +165,14 @@ def call_operator(self, op, args, kwargs, meta):
112165
conv_rescale_factors = [1.0] * len(bias_scale)
113166
final_output_scale = [b / conv_output_scale for b in bias_scale]
114167

115-
conv_output = super().call_operator(
116-
exir_ops.backend.tosa.RESCALE.default,
117-
(convolution, torch.int32, conv_rescale_factors, 0, 0),
118-
{},
168+
conv_output = self._insert_rescale_with_optional_transpose(
169+
convolution,
170+
torch.int32,
171+
conv_rescale_factors,
172+
0,
173+
0,
119174
new_meta,
175+
activation_rank,
120176
)
121177

122178
add = super().call_operator(
@@ -126,17 +182,14 @@ def call_operator(self, op, args, kwargs, meta):
126182
new_meta,
127183
)
128184

129-
res_rescale = super().call_operator(
130-
exir_ops.backend.tosa.RESCALE.default,
131-
(
132-
add,
133-
output_dtype,
134-
final_output_scale,
135-
0,
136-
0,
137-
),
138-
{},
185+
res_rescale = self._insert_rescale_with_optional_transpose(
186+
add,
187+
output_dtype,
188+
final_output_scale,
189+
0,
190+
0,
139191
new_meta,
192+
activation_rank,
140193
)
141194

142195
else:
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Set, Type
7+
8+
import torch
9+
from executorch.backends.arm._passes.arm_pass import ArmPass
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
from executorch.exir.pass_base import ExportPass, PassResult
12+
13+
14+
class FuseCancelingTransposesPass(ArmPass):
15+
"""Collapse consecutive ``aten.permute_copy`` chains.
16+
17+
For a chain ``permute(permute(...permute(x, p0)..., pN-1), pN)``, compose
18+
the permutations into a single permutation on ``x`` and:
19+
- keep one ``permute_copy`` when the composition is non-identity
20+
- remove the chain entirely when the composition is identity
21+
22+
"""
23+
24+
_passes_required_after: Set[Type[ExportPass]] = set()
25+
26+
@staticmethod
27+
def _compose_permutations(
28+
first: list[int] | tuple[int, ...],
29+
second: list[int] | tuple[int, ...],
30+
) -> list[int]:
31+
return [first[i] for i in second]
32+
33+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
34+
graph = graph_module.graph
35+
modified = False
36+
37+
for node in list(graph.nodes):
38+
if (
39+
node.op != "call_function"
40+
or node.target != exir_ops.edge.aten.permute_copy.default
41+
):
42+
continue
43+
44+
current_input = node.args[0]
45+
node_perm = node.args[1]
46+
if not isinstance(node_perm, (list, tuple)):
47+
continue
48+
composed = list(node_perm)
49+
folded_any = False
50+
51+
while (
52+
isinstance(current_input, torch.fx.Node)
53+
and current_input.op == "call_function"
54+
and current_input.target == exir_ops.edge.aten.permute_copy.default
55+
):
56+
current_input_perm = current_input.args[1]
57+
if not isinstance(current_input_perm, (list, tuple)):
58+
break
59+
current_perm = list(current_input_perm)
60+
if len(current_perm) != len(composed):
61+
break
62+
composed = self._compose_permutations(current_perm, composed)
63+
current_input = current_input.args[0]
64+
folded_any = True
65+
66+
if not folded_any:
67+
continue
68+
69+
if composed == list(range(len(composed))):
70+
node.replace_all_uses_with(current_input)
71+
else:
72+
node.update_arg(0, current_input)
73+
node.update_arg(1, composed)
74+
modified = True
75+
76+
if modified:
77+
graph.eliminate_dead_code()
78+
graph.lint()
79+
graph_module.recompile()
80+
graph_module = super().call(graph_module).graph_module
81+
82+
return PassResult(graph_module, modified)

0 commit comments

Comments
 (0)