Skip to content

Commit b772b64

Browse files
committed
Arm backend: Fix remove tosa_dim_order review comments
- Fix [] stride in avg/max_pool2d + add tests - Fix meta-data of rescale in rewrite_upsample - Merge conv weight permutes into singe help function - Nits: Remove dead code, stale comments, TEMP path Signed-off-by: Adrian Lundell <adrian.lundell@arm.com> Change-Id: I6aa9221467a575e1c42a40cc5ca7237a810f782d
1 parent 5ac1d3c commit b772b64

10 files changed

Lines changed: 234 additions & 254 deletions

backends/arm/_passes/arm_pass_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,3 +397,12 @@ def get_cond_while_submodules_nested(
397397
}
398398
# collect cond/while submodules (using mapping indices)
399399
return _get_control_flow_submodules(graph_module, mapping)
400+
401+
402+
def to_2tuple(value):
403+
"""Normalizes scalars, and 1-element sequences to a tuple of length 2."""
404+
if isinstance(value, int):
405+
return (value, value)
406+
if len(value) == 1:
407+
return (value[0], value[0])
408+
return tuple(value)

backends/arm/_passes/rewrite_avg_pool2d_pass.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import torch
99
from executorch.backends.arm._passes import ArmPass
10+
from executorch.backends.arm._passes.arm_pass_utils import to_2tuple
1011
from executorch.backends.arm.constants import NHWC_INVERSE_ORDER, NHWC_ORDER
1112
from executorch.backends.arm.operators.operator_validation_utils import (
1213
adjust_pooling_pad_if_needed,
@@ -33,19 +34,25 @@ def call_operator(self, op, args, kwargs, meta, updated=False):
3334
return super().call_operator(op, args, kwargs, meta, updated)
3435

3536
x = args[0]
36-
pad_h, pad_w = args[3] if len(args) > 3 else (0, 0)
37+
kernel = to_2tuple(args[1])
38+
39+
stride = to_2tuple(args[2]) if len(args) > 2 else ()
40+
if not stride:
41+
stride = kernel # default to kernel_size
42+
43+
pad_h, pad_w = to_2tuple(args[3]) if len(args) > 3 else (0, 0)
3744
# Make sure pad corresponds to TOSA
3845
pad = [pad_h, pad_w, pad_h, pad_w]
3946

40-
_, _, h, w = x.data.shape
41-
kernel_h, kernel_w = args[1]
42-
stride_h, stride_w = args[2] if len(args) > 2 else (1, 1)
43-
4447
ceil_mode = args[4] if len(args) > 4 else False
4548

4649
# Adjust padding if necessary
47-
pad[1] = adjust_pooling_pad_if_needed(h, kernel_h, stride_h, pad[1], ceil_mode)
48-
pad[3] = adjust_pooling_pad_if_needed(w, kernel_w, stride_w, pad[3], ceil_mode)
50+
pad[1] = adjust_pooling_pad_if_needed(
51+
x.data.shape[2], kernel[0], stride[0], pad[1], ceil_mode
52+
)
53+
pad[3] = adjust_pooling_pad_if_needed(
54+
x.data.shape[3], kernel[1], stride[1], pad[3], ceil_mode
55+
)
4956

5057
# Materialize zero-point constants
5158
in_qparams = meta.data.get("input_qparams", {})
@@ -76,8 +83,8 @@ def call_operator(self, op, args, kwargs, meta, updated=False):
7683
pre_permute,
7784
input_zp,
7885
output_zp,
79-
[kernel_h, kernel_w],
80-
[stride_h, stride_w],
86+
list(kernel),
87+
list(stride),
8188
pad,
8289
acc_type,
8390
)

backends/arm/_passes/rewrite_conv_pass.py

Lines changed: 79 additions & 221 deletions
Large diffs are not rendered by default.

backends/arm/_passes/rewrite_max_pool2d_pass.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Set, Type
77

88
from executorch.backends.arm._passes import ArmPass
9+
from executorch.backends.arm._passes.arm_pass_utils import to_2tuple
910
from executorch.backends.arm.constants import NHWC_INVERSE_ORDER, NHWC_ORDER
1011
from executorch.backends.arm.operators.operator_validation_utils import (
1112
adjust_pooling_pad_if_needed,
@@ -16,14 +17,6 @@
1617
edge_max_pool2d_ops = (exir_ops.edge.aten.max_pool2d.default,)
1718

1819

19-
def _to_2tuple(value):
20-
if isinstance(value, int):
21-
return (value, value)
22-
if len(value) == 1:
23-
return (value[0], value[0])
24-
return tuple(value)
25-
26-
2720
class RewriteMaxPool2dPass(ArmPass):
2821
"""Rewrite max_pool2d ops to TOSA MAX_POOL2D."""
2922

@@ -34,19 +27,23 @@ def call_operator(self, op, args, kwargs, meta):
3427
return super().call_operator(op, args, kwargs, meta)
3528

3629
x = args[0]
37-
kernel = _to_2tuple(args[1])
38-
39-
if len(args) > 2 and args[2] is not None:
40-
stride = _to_2tuple(args[2])
41-
else:
42-
stride = kernel
30+
kernel = args[1]
31+
stride = to_2tuple(args[2]) if len(args) > 2 else ()
32+
if not stride:
33+
stride = kernel # default to kernel_size
4334

44-
padding = _to_2tuple(args[3]) if len(args) > 3 else (0, 0)
45-
dilation = _to_2tuple(args[4]) if len(args) > 4 else (1, 1)
35+
padding = to_2tuple(args[3]) if len(args) > 3 else (0, 0)
36+
dilation = to_2tuple(args[4]) if len(args) > 4 else (1, 1)
4637
ceil_mode = args[5] if len(args) > 5 else False
4738

48-
if dilation != (1, 1):
49-
return super().call_operator(op, args, kwargs, meta)
39+
if not dilation == (1, 1):
40+
from executorch.backends.arm._passes.decompose_maxpool2d_with_dilation_pass import (
41+
DecomposeMaxPool2dPass,
42+
)
43+
44+
raise RuntimeError(
45+
f"Dilation > 1 is not supported for tosa.MAX_POOL2D, has {DecomposeMaxPool2dPass.__name__} run?"
46+
)
5047

5148
# TOSA MAX_POOL2D pad order is [top, bottom, left, right]
5249
pad = [padding[0], padding[0], padding[1], padding[1]]

backends/arm/_passes/rewrite_upsample.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,10 @@ def call(self, graph_module):
227227
rescale_node = create_node(
228228
graph_module.graph,
229229
exir_ops.backend.tosa.RESCALE.default,
230+
from_node=node,
230231
)
232+
rescale_node.meta["val"] = node_replacement_fake
233+
231234
if input_dtype == torch.int16:
232235
tosa_resize_node.meta[TosaSpecialDtype.meta_key()] = (
233236
TosaSpecialDtype.INT48

backends/arm/constants.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -41,6 +41,8 @@
4141
NNCHW_ORDER: Final = (0, 1, 2, 3, 4)
4242
NNNCHW_ORDER: Final = (0, 1, 2, 3, 4, 5)
4343

44+
OHWI_ORDER: Final = (1, 2, 3, 0)
45+
ODHWI_ORDER: Final = (0, 2, 3, 4, 1)
4446
HWCM_ORDER: Final = (2, 3, 0, 1)
4547

4648
MAX_RANK: Final = 6

backends/arm/test/misc/test_transpose_counts.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def forward(self, x: torch.Tensor, dim: int):
169169
return torch.cumsum(x, dim)
170170

171171

172-
class ConvMaxPoolResidualLinear(torch.nn.Module):
172+
class Model1ConvMaxPoolResidualLinear(torch.nn.Module):
173173
def __init__(self):
174174
super().__init__()
175175
self.conv = torch.nn.Conv1d(8, 8, kernel_size=3, padding=1)
@@ -427,7 +427,7 @@ def forward(self, x):
427427
0,
428428
),
429429
"model_1_conv_maxpool_residual_linear": TransposeCountCase(
430-
ConvMaxPoolResidualLinear(), (torch.randn(2, 8, 64),), 5
430+
Model1ConvMaxPoolResidualLinear(), (torch.randn(2, 8, 64),), 5
431431
),
432432
"model_2_conv_mha_linear_layernorm": TransposeCountCase(
433433
Model2ConvMhaLinearLayerNorm(), (torch.randn(2, 8, 32),), 11
@@ -486,7 +486,7 @@ def forward(self, x):
486486
torch.randn(2, 2, 2, 3).to(memory_format=torch.channels_last),
487487
torch.randn(2, 2, 3, 4).to(memory_format=torch.channels_last),
488488
),
489-
2, # The test crashes before reaching the transpose count
489+
2,
490490
),
491491
"pixel_shuffle_channels_last": TransposeCountCase(
492492
PixelShuffleModule(),
@@ -526,7 +526,7 @@ def forward(self, x):
526526
"cumsum_rank4_dim3_channels_last": TransposeCountCase(
527527
CumsumModule(),
528528
(torch.randn(1, 2, 3, 4).to(memory_format=torch.channels_last), 3),
529-
1, # The test crashes before reaching the transpose count
529+
1,
530530
),
531531
}
532532

backends/arm/test/ops/test_conv2d.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,5 @@ def test_convolution_2d_u85_INT_a16w8(test_data: input_t):
758758
a16w8_quantization=True,
759759
use_to_edge_transform_and_lower=True,
760760
per_channel_quantization=per_channel_quantization,
761-
custom_path="TEMP",
762761
)
763762
pipeline.run()

backends/arm/test/passes/test_rewrite_avg_pool2d_pass.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from executorch.backends.arm._passes.rewrite_avg_pool2d_pass import RewriteAvgPool2dPass
1010
from executorch.backends.arm.test import common
1111
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
12+
from executorch.backends.test.harness.stages import StageType
13+
from executorch.exir.dialects._ops import ops as exir_ops
1214

1315
input_t = Tuple[torch.Tensor]
1416

@@ -41,6 +43,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4143
return torch.nn.functional.avg_pool2d(x, kernel_size=[2, 3])
4244

4345

46+
class AvgPool2dScalarPadding(torch.nn.Module):
47+
def get_inputs(self) -> input_t:
48+
return (torch.rand(1, 3, 8, 8),)
49+
50+
def forward(self, x: torch.Tensor) -> torch.Tensor:
51+
return torch.nn.functional.avg_pool2d(x, kernel_size=3, stride=2, padding=1)
52+
53+
54+
class AvgPool2dWithEmptyStride(torch.nn.Module):
55+
def get_inputs(self) -> input_t:
56+
return (torch.rand(1, 3, 8, 8),)
57+
58+
def forward(self, x: torch.Tensor) -> torch.Tensor:
59+
return torch.nn.functional.avg_pool2d(x, kernel_size=[2, 3], stride=[])
60+
61+
4462
modules: Dict[str, ModuleWithInputs] = {
4563
"avg_pool2d_with_stride": AvgPool2dWithStride(),
4664
"avg_pool2d_without_stride": AvgPool2dWithoutStride(),
@@ -67,3 +85,42 @@ def test_rewrite_avg_pool2d_tosa(module: ModuleWithInputs) -> None:
6785
"run_method_and_compare_outputs"
6886
) # Cannot run aten graph with tosa dialect ops
6987
pipeline.run()
88+
89+
90+
def _get_tosa_avg_pool2d_node(
91+
pipeline: PassPipeline[input_t],
92+
) -> torch.fx.Node:
93+
exported_program = pipeline.tester.get_artifact(
94+
StageType.RUN_PASSES
95+
).exported_program()
96+
graph_module = exported_program.graph_module
97+
98+
tosa_nodes = [
99+
node
100+
for node in graph_module.graph.nodes
101+
if node.op == "call_function"
102+
and node.target == exir_ops.backend.tosa.AVG_POOL2D.default
103+
]
104+
assert len(tosa_nodes) == 1
105+
return tosa_nodes[0]
106+
107+
108+
def test_rewrite_avg_pool2d_tosa_empty_stride_uses_kernel_size() -> None:
109+
module = AvgPool2dWithEmptyStride()
110+
pipeline = PassPipeline[input_t](
111+
module,
112+
module.get_inputs(),
113+
ops_before_pass={
114+
"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1,
115+
},
116+
ops_after_pass={
117+
"executorch_exir_dialects_backend__ops_tosa_AVG_POOL2D_default": 1,
118+
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 2,
119+
},
120+
pass_list=[RewriteAvgPool2dPass],
121+
)
122+
pipeline.pop_stage("run_method_and_compare_outputs")
123+
pipeline.run()
124+
125+
tosa_node = _get_tosa_avg_pool2d_node(pipeline)
126+
assert tosa_node.args[4] == [2, 3]

backends/arm/test/passes/test_rewrite_max_pool2d_pass.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from executorch.backends.arm._passes.rewrite_max_pool2d_pass import RewriteMaxPool2dPass
1111
from executorch.backends.arm.test import common
1212
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
13+
from executorch.backends.test.harness.stages import StageType
14+
from executorch.exir.dialects._ops import ops as exir_ops
1315

1416
input_t = Tuple[torch.Tensor]
1517

@@ -42,6 +44,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4244
return torch.nn.functional.max_pool2d(x, kernel_size=[2, 3])
4345

4446

47+
class MaxPool2dWithEmptyStride(torch.nn.Module):
48+
def get_inputs(self) -> input_t:
49+
return (torch.rand(1, 3, 8, 8),)
50+
51+
def forward(self, x: torch.Tensor) -> torch.Tensor:
52+
return torch.nn.functional.max_pool2d(x, kernel_size=[2, 3], stride=[])
53+
54+
4555
modules: Dict[str, ModuleWithInputs] = {
4656
"max_pool2d_with_stride": MaxPool2dWithStride(),
4757
"max_pool2d_without_stride": MaxPool2dWithoutStride(),
@@ -67,3 +77,41 @@ def test_rewrite_max_pool2d_tosa(module: ModuleWithInputs) -> None:
6777
"run_method_and_compare_outputs"
6878
) # Cannnot run aten graph with tosa dialect ops
6979
pipeline.run()
80+
81+
82+
def _get_tosa_max_pool2d_node(
83+
pipeline: PassPipeline[input_t],
84+
) -> torch.fx.Node:
85+
exported_program = pipeline.tester.get_artifact(
86+
StageType.RUN_PASSES
87+
).exported_program()
88+
graph_module = exported_program.graph_module
89+
90+
tosa_nodes = [
91+
node
92+
for node in graph_module.graph.nodes
93+
if node.op == "call_function"
94+
and node.target == exir_ops.backend.tosa.MAX_POOL2D.default
95+
]
96+
assert len(tosa_nodes) == 1
97+
return tosa_nodes[0]
98+
99+
100+
def test_rewrite_max_pool2d_tosa_empty_stride_uses_kernel_size() -> None:
101+
module = MaxPool2dWithEmptyStride()
102+
pipeline = PassPipeline[input_t](
103+
module,
104+
module.get_inputs(),
105+
ops_before_pass={
106+
"executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1,
107+
},
108+
ops_after_pass={
109+
"executorch_exir_dialects_backend__ops_tosa_MAX_POOL2D_default": 1,
110+
},
111+
pass_list=[RemoveGetItemPass, RewriteMaxPool2dPass],
112+
)
113+
pipeline.pop_stage("run_method_and_compare_outputs")
114+
pipeline.run()
115+
116+
tosa_node = _get_tosa_max_pool2d_node(pipeline)
117+
assert tosa_node.args[2] == [2, 3]

0 commit comments

Comments
 (0)