Skip to content

Commit 47b71d8

Browse files
authored
Arm backend: Add MAX_POOL2D_ADAPTIVE lowering support (pytorch#19801)
Adds TOSA-1.1 backend-op support for MAX_POOL2D_ADAPTIVE and decomposition of irregular symbolic cases produced by dynamic max_pool2d lowering. Signed-off-by: Oscar Andersson <oscar.andersson@arm.com> Co-authored-by: Saoirse Stewart <saoirse.stewart@arm.com>
1 parent 37effad commit 47b71d8

14 files changed

Lines changed: 801 additions & 58 deletions

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .convert_to_clamp_pass import ConvertToClampPass # noqa
2828
from .decompose_acosh_pass import DecomposeAcoshPass # noqa
2929
from .decompose_adaptive_avg_pool2d_pass import DecomposeAdaptiveAvgPool2dPass # noqa
30+
from .decompose_adaptive_max_pool2d_pass import DecomposeAdaptiveMaxPool2dPass # noqa
3031
from .decompose_add_sub_alpha_pass import DecomposeAddSubAlphaPass # noqa
3132
from .decompose_addmm_pass import DecomposeAddmmPass # noqa
3233
from .decompose_any_pass import DecomposeAnyPass # noqa
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Set, Type
7+
8+
import torch
9+
from executorch.backends.arm._passes import ArmPass
10+
from executorch.backends.arm.constants import NHWC_INVERSE_ORDER, NHWC_ORDER
11+
from executorch.backends.arm.tosa.dialect.ops.max_pool2d import (
12+
compute_max_pool2d_output_shape,
13+
)
14+
from executorch.backends.arm.tosa.specification import get_context_shape_env
15+
from executorch.exir.dialects._ops import ops as exir_ops
16+
from executorch.exir.pass_base import ExportPass, NodeMetadata
17+
18+
19+
class DecomposeAdaptiveMaxPool2dPass(ArmPass):
20+
"""Decompose irregular TOSA MAX_POOL2D_ADAPTIVE into per-bin slices.
21+
22+
For dynamic-shape cases where ``MAX_POOL2D_ADAPTIVE`` cannot directly map
23+
pooling regions (input_size % output_size not in {0, 1}), materialize
24+
adaptive bins via ``tosa.SLICE`` and pool each bin to 1x1 with
25+
``MAX_POOL2D_ADAPTIVE``.
26+
27+
"""
28+
29+
_passes_required_after: Set[Type[ExportPass]] = set()
30+
31+
@staticmethod
32+
def _is_static_dim(dim) -> bool:
33+
return not isinstance(dim, torch.SymInt)
34+
35+
def _symbolic_bin_bounds(self, input_size, output_size: int, out_idx: int, meta):
36+
# Compute symbolic slice bounds directly via Python arithmetic
37+
start = (input_size * out_idx) // output_size
38+
end = (input_size * (out_idx + 1) + (output_size - 1)) // output_size
39+
size = end - start
40+
return start, size
41+
42+
def _emit_tosa_slice(self, x, start_h, size_h, start_w, size_w, meta):
43+
# Slice the transposed NHWC tensor along its spatial axes.
44+
batch = x.data.shape[0]
45+
channel = x.data.shape[3]
46+
start = [0, start_h, start_w, 0]
47+
size = [batch, size_h, size_w, channel]
48+
return super().call_operator(
49+
exir_ops.backend.tosa.SLICE.default,
50+
(x, start, size),
51+
{},
52+
meta,
53+
True,
54+
)
55+
56+
def _emit_adaptive_max_pool(self, x_slice, size_h, size_w, meta):
57+
# Use direct lists for kernel, stride, and pad
58+
kernel = [size_h, size_w]
59+
stride = [1, 1]
60+
pad = [0, 0, 0, 0]
61+
pad = super().call_shape_operator(
62+
exir_ops.backend.tosa.CONST_SHAPE.default,
63+
(pad,),
64+
{},
65+
meta,
66+
)
67+
kernel = [size_h, size_w]
68+
if all(isinstance(k, int) for k in kernel):
69+
kernel = super().call_shape_operator(
70+
exir_ops.backend.tosa.CONST_SHAPE.default,
71+
(kernel,),
72+
{},
73+
meta,
74+
)
75+
if all(isinstance(s, int) for s in stride):
76+
stride = super().call_shape_operator(
77+
exir_ops.backend.tosa.CONST_SHAPE.default,
78+
(stride,),
79+
{},
80+
meta,
81+
)
82+
return super().call_operator(
83+
exir_ops.backend.tosa.MAX_POOL2D_ADAPTIVE.default,
84+
(x_slice, kernel, stride, pad),
85+
{},
86+
meta,
87+
True,
88+
)
89+
90+
def _is_directly_representable(self, input_size, output_size) -> bool:
91+
if isinstance(output_size, torch.SymInt):
92+
return False
93+
if self._is_static_dim(input_size):
94+
return input_size % output_size in (0, 1)
95+
96+
try:
97+
remainder_range = get_context_shape_env().bound_sympy(
98+
(input_size % output_size).node.expr
99+
)
100+
except Exception:
101+
return False
102+
return remainder_range.is_singleton() and remainder_range.upper in (0, 1)
103+
104+
def _decompose_irregular(self, x, output_size_h: int, output_size_w: int, meta):
105+
metadata_dict = dict(meta.data)
106+
metadata_dict["input_qparams"] = {}
107+
metadata_dict["output_qparams"] = {}
108+
meta_with_no_qparams = NodeMetadata(metadata_dict)
109+
110+
x_nhwc = super().call_operator(
111+
exir_ops.edge.aten.permute_copy.default,
112+
(x, list(NHWC_ORDER)),
113+
{},
114+
meta,
115+
True,
116+
)
117+
input_h_shape = x_nhwc.data.shape[1]
118+
input_w_shape = x_nhwc.data.shape[2]
119+
120+
rows = []
121+
for out_i in range(output_size_h):
122+
cols = []
123+
start_h, size_h = self._symbolic_bin_bounds(
124+
input_h_shape, output_size_h, out_i, meta_with_no_qparams
125+
)
126+
for out_j in range(output_size_w):
127+
start_w, size_w = self._symbolic_bin_bounds(
128+
input_w_shape, output_size_w, out_j, meta_with_no_qparams
129+
)
130+
x_slice = self._emit_tosa_slice(
131+
x_nhwc, start_h, size_h, start_w, size_w, meta_with_no_qparams
132+
)
133+
cols.append(
134+
self._emit_adaptive_max_pool(
135+
x_slice, size_h, size_w, meta_with_no_qparams
136+
)
137+
)
138+
139+
rows.append(
140+
super().call_operator(
141+
exir_ops.edge.aten.cat.default,
142+
(cols, 2),
143+
{},
144+
meta_with_no_qparams,
145+
True,
146+
)
147+
if len(cols) > 1
148+
else cols[0]
149+
)
150+
151+
out_nhwc = (
152+
super().call_operator(
153+
exir_ops.edge.aten.cat.default,
154+
(rows, 1),
155+
{},
156+
meta_with_no_qparams,
157+
True,
158+
)
159+
if len(rows) > 1
160+
else rows[0]
161+
)
162+
return super().call_operator(
163+
exir_ops.edge.aten.permute_copy.default,
164+
(out_nhwc, list(NHWC_INVERSE_ORDER)),
165+
{},
166+
meta,
167+
True,
168+
)
169+
170+
def call_operator(self, op, args, kwargs, meta, updated=False):
171+
if op != exir_ops.backend.tosa.MAX_POOL2D_ADAPTIVE.default:
172+
return super().call_operator(op, args, kwargs, meta, updated)
173+
174+
x, kernel, stride, pad = args
175+
output_shape = compute_max_pool2d_output_shape(
176+
x.data.permute(0, 2, 3, 1),
177+
kernel,
178+
stride,
179+
pad,
180+
op="MAX_POOL2D_ADAPTIVE",
181+
)
182+
output_size_h = output_shape[1]
183+
output_size_w = output_shape[2]
184+
185+
if isinstance(output_size_h, torch.SymInt) or isinstance(
186+
output_size_w, torch.SymInt
187+
):
188+
return super().call_operator(op, args, kwargs, meta, updated)
189+
190+
if output_size_h <= 1 and output_size_w <= 1:
191+
return super().call_operator(op, args, kwargs, meta, updated)
192+
193+
input_size_h, input_size_w = x.data.shape[2], x.data.shape[3]
194+
# If both spatial dimensions satisfy the direct-representability criterion
195+
# (input_size % output_size is 0 or 1 for static sizes, or symbolically
196+
# guaranteed in [0,1]), we can invoke the TOSA MAX_POOL2D_ADAPTIVE operator
197+
# directly instead of decomposing into individual bins.
198+
if self._is_directly_representable(
199+
input_size_h, output_size_h
200+
) and self._is_directly_representable(input_size_w, output_size_w):
201+
return super().call_operator(op, args, kwargs, meta, updated)
202+
203+
return self._decompose_irregular(x, output_size_h, output_size_w, meta)

backends/arm/_passes/decompose_avg_pool2d_pass.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66

7-
from typing import Any, Set, Type
7+
from typing import Set, Type
88

99
import torch
1010
from executorch.backends.arm._passes.arm_pass import ArmOpTargetedPass
@@ -38,13 +38,13 @@ def get_decomposition(op) -> tuple:
3838

3939

4040
def _compute_post_pad(
41-
size: int,
41+
size: int | torch.SymInt,
4242
kernel: int,
4343
stride: int,
44-
pad: int,
44+
pad: int | torch.SymInt,
4545
ceil_mode: bool,
4646
divisor_override,
47-
) -> int:
47+
) -> int | torch.SymInt:
4848

4949
if pad == 0:
5050
return pad
@@ -70,7 +70,7 @@ def _get_avgpool_post_pad(
7070
ceil_mode,
7171
count_include_pad,
7272
divisor_override,
73-
) -> tuple[list[Any], list[int]]:
73+
) -> tuple[list[int | torch.SymInt], list[int | torch.SymInt]]:
7474
"""Compute the post-padding configuration for avg_pool2d when pre-
7575
materializing explicit zero padding ahead of the pooling operation.
7676

backends/arm/_passes/insert_dynamic_padding.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class InsertDynamicPaddingPass(ArmOpTargetedPass):
3030
target_ops = (
3131
exir_ops.backend.tosa.CONV2D.default,
3232
exir_ops.backend.tosa.DEPTHWISE_CONV2D.default,
33+
exir_ops.backend.tosa.MAX_POOL2D.default,
3334
)
3435

3536
def _is_dynamic_padding(
@@ -45,23 +46,29 @@ def _is_dynamic_padding(
4546
def call_operator(self, op, args, kwargs, meta, updated=False) -> ProxyValue:
4647
if op not in self.target_ops:
4748
return super().call_operator(op, args, kwargs, meta, updated)
48-
padding = args[4]
49+
if op == exir_ops.backend.tosa.MAX_POOL2D.default:
50+
padding_index = 3
51+
else:
52+
padding_index = 4
53+
padding = args[padding_index]
4954
if not self._is_dynamic_padding(padding):
5055
return super().call_operator(op, args, kwargs, meta, updated)
5156

5257
# Create a pad op before conv2d
5358
input_tensor = args[0]
5459

55-
zero_padding = [0, 0, 0, 0]
56-
NC_padding = super().call_shape_operator(
60+
zero_padding_pair = [0, 0]
61+
zero_spatial_padding = [0, 0, 0, 0]
62+
N_padding = super().call_shape_operator(
5763
exir_ops.backend.tosa.CONST_SHAPE.default,
58-
(zero_padding,),
64+
(zero_padding_pair,),
5965
{},
6066
meta,
6167
True,
6268
)
69+
C_padding = N_padding
6370

64-
padding_shape_args = [NC_padding, padding]
71+
padding_shape_args = [N_padding, padding, C_padding]
6572

6673
padding_shape = super().call_shape_operator(
6774
exir_ops.backend.tosa.CONCAT_SHAPE.default,
@@ -85,5 +92,5 @@ def call_operator(self, op, args, kwargs, meta, updated=False) -> ProxyValue:
8592
)
8693
new_conv2d_args = list(args)
8794
new_conv2d_args[0] = pad_res
88-
new_conv2d_args[4] = zero_padding
95+
new_conv2d_args[padding_index] = zero_spatial_padding
8996
return super().call_operator(op, tuple(new_conv2d_args), kwargs, meta, updated)

0 commit comments

Comments
 (0)