Skip to content

Commit 968fff9

Browse files
authored
Arm backend: Add avg_pool2d_adaptive rewrite pass (pytorch#20027)
Adds pass to replace aten.adaptive_avg_pool2d with tosa.avg_pool2d_adaptive. Signed-off-by: Oscar Andersson <oscar.andersson@arm.com> Co-authored-by: Saoirse Stewart <saoirse.stewart@arm.com>
1 parent aca0b1a commit 968fff9

4 files changed

Lines changed: 501 additions & 0 deletions

File tree

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@
149149
from .replace_scalar_with_tensor_pass import ( # noqa
150150
ReplaceScalarWithTensorByProfilePass,
151151
)
152+
from .rewrite_adaptive_avg_pool2d import RewriteAdaptiveAvgPool2dPass # noqa
152153
from .rewrite_avg_pool2d_pass import RewriteAvgPool2dPass # noqa
153154
from .rewrite_bool_bitwise_to_logical_pass import ( # noqa
154155
RewriteBoolBitwiseToLogicalPass,

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@
131131
RemovePermutesAroundElementwiseTosaOps,
132132
ReplaceInfAndLimitValuesPass,
133133
ReplaceScalarWithTensorByProfilePass,
134+
RewriteAdaptiveAvgPool2dPass,
134135
RewriteAvgPool2dPass,
135136
RewriteBoolBitwiseToLogicalPass,
136137
RewriteBoolToFp32CastViaInt8Pass,
@@ -504,6 +505,7 @@ def _tosa_pipeline(
504505
DecomposeAsStridedCopyPass(),
505506
DecomposeMaxPool2dPass(),
506507
SizeAdjustInputPass(),
508+
RewriteAdaptiveAvgPool2dPass(),
507509
RewriteAvgPool2dPass(),
508510
ComputeConstantOpsAOTPass(exported_program),
509511
FuseConstantArgsPass(exported_program),
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Set, Type
7+
8+
import torch
9+
from executorch.backends.arm._passes import ArmPass
10+
11+
from executorch.backends.arm._passes.fuse_constant_ops_pass import (
12+
ComputeConstantOpsAOTPass,
13+
)
14+
from executorch.backends.arm.constants import NHWC_INVERSE_ORDER, NHWC_ORDER
15+
from executorch.backends.arm.tosa.specification import (
16+
get_context_shape_env,
17+
get_context_spec,
18+
)
19+
from executorch.exir.dialects._ops import ops as exir_ops
20+
from executorch.exir.pass_base import ExportPass
21+
22+
23+
class RewriteAdaptiveAvgPool2dPass(ArmPass):
24+
"""Rewrite dynamic adaptive average pooling to tosa.avg_pool2d_adaptive when
25+
possible.
26+
27+
The condition for rewriting is that symbolic input dimensions have a known
28+
remainder of 0 or 1 when divided by the static output dimensions. This
29+
preserves the adaptive pooling regions without materializing slice/cat
30+
decomposition.
31+
32+
"""
33+
34+
targeted_ops = {exir_ops.edge.aten._adaptive_avg_pool2d.default}
35+
_passes_required_after: Set[Type[ExportPass]] = {
36+
ComputeConstantOpsAOTPass,
37+
}
38+
39+
@staticmethod
40+
def _is_symbolic_dim(dim) -> bool:
41+
return isinstance(dim, torch.SymInt)
42+
43+
@staticmethod
44+
def _supports_dynamic_tosa_adaptive() -> bool:
45+
try:
46+
tosa_spec = get_context_spec()
47+
except Exception:
48+
return False
49+
return (
50+
tosa_spec.version.major == 1
51+
and tosa_spec.version.minor >= 1
52+
and tosa_spec.support_extension("shape")
53+
)
54+
55+
@classmethod
56+
def _get_pool_params(cls, input_size, output_size: int):
57+
if isinstance(output_size, torch.SymInt) or not isinstance(output_size, int):
58+
return None
59+
60+
remainder = input_size % output_size
61+
if cls._is_symbolic_dim(remainder):
62+
shape_env = get_context_shape_env()
63+
try:
64+
remainder_range = shape_env.bound_sympy(remainder.node.expr)
65+
except Exception:
66+
return None
67+
68+
if not remainder_range.is_singleton() or int(remainder_range.upper) not in (
69+
0,
70+
1,
71+
):
72+
return None
73+
74+
stride = input_size // output_size
75+
return stride + int(remainder_range.upper), stride
76+
77+
if remainder not in (0, 1):
78+
return None
79+
80+
stride = input_size // output_size
81+
return stride + remainder, stride
82+
83+
def call_operator(self, op, args, kwargs, meta, updated=False):
84+
if op not in self.targeted_ops:
85+
return super().call_operator(op, args, kwargs, meta, updated)
86+
87+
x = args[0]
88+
_, _, input_h, input_w = x.data.shape
89+
if not (self._is_symbolic_dim(input_h) or self._is_symbolic_dim(input_w)):
90+
return super().call_operator(op, args, kwargs, meta, updated)
91+
92+
# Dynamic adaptive lowering requires shape-aware TOSA support.
93+
if not self._supports_dynamic_tosa_adaptive():
94+
raise RuntimeError(
95+
"Dynamic adaptive_avg_pool2d rewrite requires TOSA-1.1 with the shape extension."
96+
)
97+
98+
output_h, output_w = args[1]
99+
h_params = self._get_pool_params(input_h, output_h)
100+
w_params = self._get_pool_params(input_w, output_w)
101+
# Fall back when either spatial dimension cannot be expressed as one TOSA adaptive pool.
102+
if h_params is None or w_params is None:
103+
return super().call_operator(op, args, kwargs, meta, updated)
104+
105+
kernel = [h_params[0], w_params[0]]
106+
stride = [h_params[1], w_params[1]]
107+
pad = [0, 0, 0, 0]
108+
pad = super().call_shape_operator(
109+
exir_ops.backend.tosa.CONST_SHAPE.default,
110+
(pad,),
111+
{},
112+
meta,
113+
)
114+
if all(isinstance(k, int) for k in kernel):
115+
kernel = super().call_shape_operator(
116+
exir_ops.backend.tosa.CONST_SHAPE.default,
117+
(kernel,),
118+
{},
119+
meta,
120+
)
121+
if all(isinstance(s, int) for s in stride):
122+
stride = super().call_shape_operator(
123+
exir_ops.backend.tosa.CONST_SHAPE.default,
124+
(stride,),
125+
{},
126+
meta,
127+
)
128+
129+
in_qparams = meta.data.get("input_qparams", {})
130+
in_zp_val = in_qparams[0].get_zp_per_tensor() if 0 in in_qparams else 0
131+
input_zp = self.call_scalar(in_zp_val, meta)
132+
133+
out_qparams = meta.data.get("output_qparams", {})
134+
out_zp_val = out_qparams[0].get_zp_per_tensor() if 0 in out_qparams else 0
135+
output_zp = self.call_scalar(out_zp_val, meta)
136+
137+
acc_type = (
138+
torch.int32 if x.data.dtype in (torch.int8, torch.int16) else torch.float32
139+
)
140+
pre_permute = super().call_operator(
141+
exir_ops.edge.aten.permute_copy.default,
142+
(x, list(NHWC_ORDER)),
143+
{},
144+
meta,
145+
True,
146+
)
147+
tosa_args = (
148+
pre_permute,
149+
input_zp,
150+
output_zp,
151+
kernel,
152+
stride,
153+
pad,
154+
acc_type,
155+
)
156+
157+
tosa_avg_pool = super().call_operator(
158+
exir_ops.backend.tosa.AVG_POOL2D_ADAPTIVE.default,
159+
tosa_args,
160+
{},
161+
meta,
162+
True,
163+
)
164+
return super().call_operator(
165+
exir_ops.edge.aten.permute_copy.default,
166+
(tosa_avg_pool, list(NHWC_INVERSE_ORDER)),
167+
{},
168+
meta,
169+
True,
170+
)

0 commit comments

Comments
 (0)