Skip to content

Commit 04955b2

Browse files
Arm backend: Add MAX_POOL2D tosa dialect op (#18970)
Signed-off-by: Adrian Lundell <adrian.lundell@arm.com>
1 parent 70fd62b commit 04955b2

9 files changed

Lines changed: 226 additions & 46 deletions

File tree

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@
157157
from .rewrite_inplace_arithmetic_pass import RewriteInplaceArithmeticPass # noqa
158158
from .rewrite_le_lt_to_ge_gt_pass import RewriteLeLtToGeGtPass # noqa
159159
from .rewrite_matmul import RewriteMatmulPass # noqa
160+
from .rewrite_max_pool2d_pass import RewriteMaxPool2dPass # noqa
160161
from .rewrite_pad import RewritePadPass # noqa
161162
from .rewrite_slice import RewriteSlicePass # noqa
162163
from .rewrite_upsample import RewriteUpsamplePass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@
134134
RewriteInplaceArithmeticPass,
135135
RewriteLeLtToGeGtPass,
136136
RewriteMatmulPass,
137+
RewriteMaxPool2dPass,
137138
RewritePadPass,
138139
RewriteSlicePass,
139140
RewriteUpsamplePass,
@@ -526,6 +527,7 @@ def _tosa_pipeline(
526527
self.add_passes(
527528
[
528529
RewriteUpsamplePass(),
530+
RewriteMaxPool2dPass(),
529531
RewriteConvPass(exported_program),
530532
RewriteMatmulPass(),
531533
RewritePadPass(),
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Set, Type
7+
8+
from executorch.backends.arm._passes import ArmPass
9+
from executorch.backends.arm.operators.operator_validation_utils import (
10+
adjust_pooling_pad_if_needed,
11+
)
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.pass_base import ExportPass
14+
15+
edge_max_pool2d_ops = (exir_ops.edge.aten.max_pool2d.default,)
16+
17+
18+
def _to_2tuple(value):
19+
if isinstance(value, int):
20+
return (value, value)
21+
if len(value) == 1:
22+
return (value[0], value[0])
23+
return tuple(value)
24+
25+
26+
class RewriteMaxPool2dPass(ArmPass):
27+
"""Rewrite max_pool2d ops to TOSA MAX_POOL2D."""
28+
29+
_passes_required_after: Set[Type[ExportPass]] = set()
30+
31+
def call_operator(self, op, args, kwargs, meta):
32+
if op not in edge_max_pool2d_ops:
33+
return super().call_operator(op, args, kwargs, meta)
34+
35+
x = args[0]
36+
kernel = _to_2tuple(args[1])
37+
38+
if len(args) > 2 and args[2] is not None and len(args[2]) > 0:
39+
stride = _to_2tuple(args[2])
40+
else:
41+
stride = kernel
42+
43+
padding = _to_2tuple(args[3]) if len(args) > 3 else (0, 0)
44+
dilation = _to_2tuple(args[4]) if len(args) > 4 else (1, 1)
45+
ceil_mode = args[5] if len(args) > 5 else False
46+
47+
if dilation != (1, 1):
48+
return super().call_operator(op, args, kwargs, meta)
49+
50+
# TOSA MAX_POOL2D pad order is [top, bottom, left, right]
51+
pad = [padding[0], padding[0], padding[1], padding[1]]
52+
pad[1] = adjust_pooling_pad_if_needed(
53+
x.data.shape[2], kernel[0], stride[0], pad[1], ceil_mode
54+
)
55+
pad[3] = adjust_pooling_pad_if_needed(
56+
x.data.shape[3], kernel[1], stride[1], pad[3], ceil_mode
57+
)
58+
59+
return super().call_operator(
60+
exir_ops.backend.tosa.MAX_POOL2D.default,
61+
(x, list(kernel), list(stride), pad),
62+
{},
63+
meta,
64+
updated=True,
65+
)

backends/arm/_passes/size_adjust_input_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
expand_around_channel,
1212
)
1313
from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass
14+
from executorch.backends.arm._passes.rewrite_max_pool2d_pass import RewriteMaxPool2dPass
1415
from executorch.backends.arm.tosa.specification import get_context_shape_env
1516
from executorch.exir.dialects._ops import ops as exir_ops
1617
from executorch.exir.pass_base import ExportPass, PassResult
@@ -201,6 +202,7 @@ class SizeAdjustInputPass(ArmPass):
201202

202203
_passes_required_after: Set[Type[ExportPass]] = {
203204
RewriteConvPass,
205+
RewriteMaxPool2dPass,
204206
}
205207

206208
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:

backends/arm/operators/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
op_gt,
3333
op_log,
3434
op_logical_not,
35-
op_max_pool2d,
3635
op_maximum,
3736
op_minimum,
3837
op_mul,
@@ -55,6 +54,7 @@
5554
op_tosa_depthwise_conv2d,
5655
op_tosa_gather,
5756
op_tosa_matmul,
57+
op_tosa_max_pool2d,
5858
op_tosa_pad,
5959
op_tosa_rescale,
6060
op_tosa_resize,

backends/arm/operators/op_max_pool2d.py renamed to backends/arm/operators/op_tosa_max_pool2d.py

Lines changed: 10 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
register_node_visitor,
1515
)
1616
from executorch.backends.arm.operators.operator_validation_utils import (
17-
adjust_pooling_pad_if_needed,
1817
validate_num_inputs,
1918
validate_same_dtype,
2019
validate_valid_dtype,
@@ -24,10 +23,9 @@
2423

2524
@register_node_visitor
2625
class MaxPool2dVisitor(NodeVisitor):
27-
target = "aten.max_pool2d.default"
26+
"""Visitor for lowering TOSA MAX_POOL2D operator."""
2827

29-
def __init__(self, *args):
30-
super().__init__(*args)
28+
target = "tosa.MAX_POOL2D.default"
3129

3230
def define_node(
3331
self,
@@ -36,59 +34,26 @@ def define_node(
3634
inputs: List[TosaArg],
3735
output: TosaArg,
3836
) -> None:
39-
validate_num_inputs(self.target, inputs, [3, 4, 5, 6])
37+
validate_num_inputs(self.target, inputs, [4])
4038
validate_same_dtype(self.target, [inputs[0], output], ts)
39+
40+
input_tensor, kernel, stride, pad = inputs
41+
4142
supported_dtypes = [ts.DType.INT8, ts.DType.FP16, ts.DType.FP32, ts.DType.BF16]
4243
if self.tosa_spec.support_extension("int16"):
4344
supported_dtypes.append(ts.DType.INT16)
4445
validate_valid_dtype(
4546
self.target,
46-
[inputs[0], output],
47+
[input_tensor, output],
4748
supported_dtypes,
4849
self.tosa_spec,
4950
)
5051

51-
input_tensor = inputs[0]
52-
kernel_size = inputs[1].special
53-
stride = inputs[2].special
54-
55-
if len(inputs) == 6:
56-
ceil_mode = bool(inputs[5].number)
57-
else:
58-
ceil_mode = False
59-
60-
try:
61-
pad_size_list = inputs[3].special
62-
pad_size_list = [
63-
pad_size_list[0],
64-
pad_size_list[0],
65-
pad_size_list[1],
66-
pad_size_list[1],
67-
]
68-
except (IndexError, AttributeError):
69-
pad_size_list = [0, 0, 0, 0]
70-
71-
# Adjust the padding as necessary
72-
pad_size_list[1] = adjust_pooling_pad_if_needed(
73-
input_tensor.shape[2],
74-
kernel_size[0],
75-
stride[0],
76-
pad_size_list[1],
77-
ceil_mode,
78-
)
79-
pad_size_list[3] = adjust_pooling_pad_if_needed(
80-
input_tensor.shape[3],
81-
kernel_size[1],
82-
stride[1],
83-
pad_size_list[3],
84-
ceil_mode,
85-
)
86-
8752
attr = ts.TosaSerializerAttribute()
8853
attr.MaxPool2dAttribute(
89-
kernel=kernel_size,
90-
stride=stride,
91-
pad=pad_size_list,
54+
kernel=kernel.special,
55+
stride=stride.special,
56+
pad=pad.special,
9257
nan_mode=ts.NanPropagationMode.PROPAGATE,
9358
)
9459

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import cast, Dict, Protocol, Tuple
7+
8+
import torch
9+
from executorch.backends.arm._passes.remove_getitem_pass import RemoveGetItemPass
10+
from executorch.backends.arm._passes.rewrite_max_pool2d_pass import RewriteMaxPool2dPass
11+
from executorch.backends.arm.test import common
12+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
13+
14+
input_t = Tuple[torch.Tensor]
15+
16+
17+
class ModuleWithInputs(Protocol):
18+
def get_inputs(self) -> input_t: ...
19+
20+
21+
class MaxPool2dWithStride(torch.nn.Module):
22+
def get_inputs(self) -> input_t:
23+
return (torch.rand(1, 3, 8, 8),)
24+
25+
def forward(self, x: torch.Tensor) -> torch.Tensor:
26+
return torch.nn.functional.max_pool2d(x, kernel_size=2, stride=2)
27+
28+
29+
class MaxPool2dWithoutStride(torch.nn.Module):
30+
def get_inputs(self) -> input_t:
31+
return (torch.rand(1, 3, 8, 8),)
32+
33+
def forward(self, x: torch.Tensor) -> torch.Tensor:
34+
return torch.nn.functional.max_pool2d(x, kernel_size=3)
35+
36+
37+
class MaxPool2dListKernel(torch.nn.Module):
38+
def get_inputs(self) -> input_t:
39+
return (torch.rand(1, 3, 8, 8),)
40+
41+
def forward(self, x: torch.Tensor) -> torch.Tensor:
42+
return torch.nn.functional.max_pool2d(x, kernel_size=[2, 3])
43+
44+
45+
modules: Dict[str, ModuleWithInputs] = {
46+
"max_pool2d_with_stride": MaxPool2dWithStride(),
47+
"max_pool2d_without_stride": MaxPool2dWithoutStride(),
48+
"max_pool2d_list_kernel": MaxPool2dListKernel(),
49+
}
50+
51+
52+
@common.parametrize("module", modules)
53+
def test_rewrite_max_pool2d_tosa(module: ModuleWithInputs) -> None:
54+
nn_module = cast(torch.nn.Module, module)
55+
pipeline = PassPipeline[input_t](
56+
nn_module,
57+
module.get_inputs(),
58+
ops_before_pass={
59+
"executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1,
60+
},
61+
ops_after_pass={
62+
"executorch_exir_dialects_backend__ops_tosa_MAX_POOL2D_default": 1,
63+
},
64+
pass_list=[RemoveGetItemPass, RewriteMaxPool2dPass],
65+
)
66+
pipeline.pop_stage(
67+
"run_method_and_compare_outputs"
68+
) # Cannnot run aten graph with tosa dialect ops
69+
pipeline.run()

backends/arm/tosa/dialect/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
depthwise_conv2d,
1111
gather,
1212
matmul,
13+
max_pool2d,
1314
pad,
1415
rescale,
1516
resize,
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import List, Union
7+
8+
import torch
9+
from executorch.backends.arm.tosa.dialect.lib import TosaValueError
10+
from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op
11+
from executorch.backends.arm.tosa.specification import (
12+
get_context_spec,
13+
TosaSpecification,
14+
)
15+
16+
17+
@register_fake_tosa_op(
18+
"MAX_POOL2D(Tensor input, int[2] kernel, int[2] stride, SymInt[4] pad) -> Tensor",
19+
TosaSpecification.all_versions_and_profiles(),
20+
)
21+
def MAX_POOL2D(
22+
x: torch.Tensor,
23+
kernel: List[int],
24+
stride: List[int],
25+
pad: List[Union[int, torch.SymInt]],
26+
) -> torch.Tensor:
27+
"""Compute output meta for a TOSA MAX_POOL2D operation."""
28+
tosa_spec = get_context_spec()
29+
30+
supported_int_types = [torch.int8]
31+
supported_float_types = [
32+
torch.float16,
33+
torch.float32,
34+
]
35+
if tosa_spec.support_extension("bf16"):
36+
supported_float_types.append(torch.bfloat16)
37+
if tosa_spec.support_extension("int16"):
38+
supported_int_types.append(torch.int16)
39+
40+
if x.dtype in supported_int_types:
41+
if not tosa_spec.support_integer():
42+
raise TosaValueError(
43+
f"TOSA spec {tosa_spec} doesn't support integer pools", op="MAX_POOL2D"
44+
)
45+
elif x.dtype in supported_float_types:
46+
if not tosa_spec.support_float():
47+
raise TosaValueError(
48+
f"TOSA spec {tosa_spec} doesn't support float pools", op="MAX_POOL2D"
49+
)
50+
else:
51+
raise TosaValueError(
52+
f"Unsupported input dtype {x.dtype} for TOSA MAX_POOL2D", op="MAX_POOL2D"
53+
)
54+
55+
if x.dim() != 4:
56+
raise TosaValueError(
57+
f"MAX_POOL2D requires a 4D tensor, got {x.dim()}D", op="MAX_POOL2D"
58+
)
59+
60+
if len(kernel) != 2 or len(stride) != 2 or len(pad) != 4:
61+
raise TosaValueError(
62+
f"MAX_POOL2D expects kernel of length 2, stride of length 2, pad of "
63+
f"length 4; got kernel={kernel}, stride={stride}, pad={pad}",
64+
op="MAX_POOL2D",
65+
)
66+
67+
n, c, h, w = x.shape
68+
k_h, k_w = kernel
69+
s_h, s_w = stride
70+
# TOSA MAX_POOL2D pad order is [top, bottom, left, right]
71+
p_top, p_bot, p_left, p_right = pad
72+
73+
h_out = (h + p_top + p_bot - k_h) // s_h + 1
74+
w_out = (w + p_left + p_right - k_w) // s_w + 1
75+
return torch.empty(size=[n, c, h_out, w_out], dtype=x.dtype)

0 commit comments

Comments
 (0)