Skip to content

Commit 1154d34

Browse files
authored
Arm backend: Add TOSA AVG_POOL2D op (#18972)
### Summary Adds new TOSA dialect op for AVG_POOL2D. aten.AvgPool2d nodes are replaced by tosa.AVG_POOL2D in RewriteAvgPool2dPass. op_avg_pool2d node visitor is replaced by a simpler node vistior for tosa.avg_pool2d. Signed-off-by: Saoirse Stewart <saoirse.stewart@arm.com>
1 parent 04955b2 commit 1154d34

11 files changed

Lines changed: 351 additions & 235 deletions

File tree

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@
143143
from .replace_scalar_with_tensor_pass import ( # noqa
144144
ReplaceScalarWithTensorByProfilePass,
145145
)
146+
from .rewrite_avg_pool2d_pass import RewriteAvgPool2dPass # noqa
146147
from .rewrite_bool_bitwise_to_logical_pass import ( # noqa
147148
RewriteBoolBitwiseToLogicalPass,
148149
)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from collections.abc import Sequence
1111
from dataclasses import dataclass, field
1212

13-
import executorch.backends.arm.tosa.dialect # noqa: unused
1413
from executorch.backends.arm._passes import (
1514
AccumulateIndexPutPass,
1615
AnnotateOutputDimOrderPass,
@@ -126,6 +125,7 @@
126125
RemoveNoopPass,
127126
ReplaceInfAndLimitValuesPass,
128127
ReplaceScalarWithTensorByProfilePass,
128+
RewriteAvgPool2dPass,
129129
RewriteBoolBitwiseToLogicalPass,
130130
RewriteBoolToFp32CastViaInt8Pass,
131131
RewriteConvPass,
@@ -144,7 +144,6 @@
144144
UnsqueezeBeforeRepeatPass,
145145
UnsqueezeScalarPlaceholdersPass,
146146
)
147-
148147
from executorch.backends.arm._passes.arm_pass import ArmPass
149148
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
150149
from executorch.backends.arm.common.pipeline_config import (
@@ -463,6 +462,8 @@ def _tosa_pipeline(
463462
DecomposeSliceScatterPass(),
464463
AccumulateIndexPutPass(),
465464
DecomposeIndexTensorToGatherPass(),
465+
DecomposeAdaptiveAvgPool2dPass(),
466+
DecomposeAvgPool2dPass(),
466467
Conv1dUnsqueezePass(),
467468
]
468469
)
@@ -499,17 +500,16 @@ def _tosa_pipeline(
499500
DecomposeSoftmaxPass(),
500501
ConvertMinMaxPass(),
501502
DecomposeAnyPass(),
502-
DecomposeAdaptiveAvgPool2dPass(),
503-
DecomposeAvgPool2dPass(),
504503
DecorateFp32toInt32CastingPass(),
505-
ComputeConstantOpsAOTPass(exported_program),
506-
FuseConstantArgsPass(exported_program),
507504
ConvertExpandCopyToRepeatPass(),
508505
UnsqueezeBeforeRepeatPass(),
509506
DecomposeCumsumPass(exported_program),
510507
DecomposeAsStridedCopyPass(),
511508
DecomposeMaxPool2dPass(),
512509
SizeAdjustInputPass(),
510+
RewriteAvgPool2dPass(),
511+
ComputeConstantOpsAOTPass(exported_program),
512+
FuseConstantArgsPass(exported_program),
513513
DecomposeSelectPass(),
514514
ConvertSqueezesToViewPass(),
515515
CastToInt32Pass(),
@@ -605,6 +605,8 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
605605
DecomposeLayerNormPass(tfa_pass=True),
606606
DecomposeVarPass(tfa_pass=True),
607607
DecomposeMeanDimPass(graph_module, self.tosa_spec, tfa_pass=True),
608+
DecomposeAdaptiveAvgPool2dPass(tfa_pass=True),
609+
DecomposeAvgPool2dPass(tfa_pass=True),
608610
]
609611
)
610612

@@ -630,8 +632,6 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
630632
DecomposeDivPass(tfa_pass=True),
631633
DecomposeLinalgVectorNormPass(tfa_pass=True),
632634
DecomposeSqrtPass(tfa_pass=True),
633-
DecomposeAdaptiveAvgPool2dPass(tfa_pass=True),
634-
DecomposeAvgPool2dPass(tfa_pass=True),
635635
DecomposeSoftmaxPass(
636636
tfa_pass=True,
637637
),

backends/arm/_passes/decompose_avg_pool2d_pass.py

Lines changed: 99 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66

7-
from typing import Set, Type
7+
from typing import Any, Set, Type
88

99
import torch
1010
from executorch.backends.arm._passes.arm_pass import ArmPass
@@ -17,138 +17,162 @@
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818
from executorch.exir.pass_base import ExportPass
1919

20-
edge_div_ops = (exir_ops.edge.aten.avg_pool2d.default,)
21-
aten_div_ops = (torch.ops.aten.avg_pool2d.default,)
20+
edge_avg_pool2d = (exir_ops.edge.aten.avg_pool2d.default,)
21+
aten_avg_pool2d = (torch.ops.aten.avg_pool2d.default,)
2222

2323

2424
def get_decomposition(op) -> tuple:
25-
if op in edge_div_ops:
25+
if op in edge_avg_pool2d:
2626
return (
27-
exir_ops.edge.aten.full.default,
28-
exir_ops.edge.aten.cat.default,
27+
exir_ops.edge.aten.constant_pad_nd.default,
2928
exir_ops.edge.aten.avg_pool2d.default,
3029
exir_ops.edge.aten.mul.Tensor,
3130
)
32-
if op in aten_div_ops:
31+
if op in aten_avg_pool2d:
3332
return (
34-
torch.ops.aten.full.default,
35-
torch.ops.aten.cat.default,
33+
torch.ops.aten.pad.default,
3634
torch.ops.aten.avg_pool2d.default,
3735
torch.ops.aten.mul.Tensor,
3836
)
3937
raise RuntimeError(f"Can't get avg_pool2d decomposition for op {op}")
4038

4139

40+
def _compute_post_pad(
41+
size: int,
42+
kernel: int,
43+
stride: int,
44+
pad: int,
45+
ceil_mode: bool,
46+
divisor_override,
47+
) -> int:
48+
49+
if pad == 0:
50+
return pad
51+
if ceil_mode and divisor_override is None:
52+
return pad
53+
54+
pad_adjust = adjust_pooling_pad_if_needed(size, kernel, stride, pad, ceil_mode)
55+
56+
# Padding must always be above 0, the above adjustment may return -1
57+
if pad_adjust > 0:
58+
return pad_adjust
59+
return pad
60+
61+
62+
def _get_avgpool_post_pad(
63+
h,
64+
w,
65+
kernel: tuple,
66+
stride_h,
67+
stride_w,
68+
pad_h,
69+
pad_w,
70+
ceil_mode,
71+
count_include_pad,
72+
divisor_override,
73+
) -> tuple[list[Any], list[int]]:
74+
"""Compute the post-padding configuration for avg_pool2d when pre-
75+
materializing explicit zero padding ahead of the pooling operation.
76+
77+
Given the original spatial dimensions (h, w), pooling kernel size, stride,
78+
and explicit pre-padding amounts (pad_h, pad_w), this function returns the
79+
additional padding to apply on the right and bottom edges so that avg_pool2d
80+
with count_include_pad and/or divisor_override produces the equivalent
81+
result without built-in padding.
82+
83+
"""
84+
85+
k_h, k_w = kernel
86+
post_h, post_w = (0, 0)
87+
new_pad_h, new_pad_w = pad_h, pad_w
88+
89+
if not count_include_pad:
90+
return [new_pad_h, new_pad_w], [new_pad_h, new_pad_w]
91+
92+
post_h = _compute_post_pad(h, k_h, stride_h, pad_h, ceil_mode, divisor_override)
93+
post_w = _compute_post_pad(w, k_w, stride_w, pad_w, ceil_mode, divisor_override)
94+
95+
# Return our pre-padding calculation. Turn off built-in padding.
96+
return [pad_w, post_w, pad_h, post_h], [0, 0]
97+
98+
4299
class DecomposeAvgPool2dPass(ArmPass):
43100
_passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass}
44101

45102
def call_operator(self, op, args, kwargs, meta):
46-
if op not in (edge_div_ops + aten_div_ops) or not self.allowed_to_transform(
47-
meta
48-
):
103+
if op not in (
104+
edge_avg_pool2d + aten_avg_pool2d
105+
) or not self.allowed_to_transform(meta):
49106
return super().call_operator(op, args, kwargs, meta)
50107

51-
full_op, cat_op, avgpool_op, mul_op = get_decomposition(op)
108+
pad_op, avgpool_op, mul_op = get_decomposition(op)
52109

53110
x = args[0]
54-
full_kwargs = {"device": x.data.device, "dtype": x.data.dtype}
55111
kernel_h, kernel_w = args[1]
56112
kernel_size = kernel_h * kernel_w
113+
57114
if len(args) > 2 and args[2] is not None:
58115
stride_h, stride_w = args[2]
59116
else:
60117
stride_h, stride_w = kernel_h, kernel_w
61-
pad_h, pad_w = new_pad_h, new_pad_w = args[3] if len(args) > 3 else (0, 0)
118+
pad_h, pad_w = args[3] if len(args) > 3 else (0, 0)
62119
ceil_mode = args[4] if len(args) > 4 else False
63120
count_include_pad = args[5] if len(args) > 5 else True
64121
divisor_override = args[6] if len(args) > 6 else None
65122

66123
n, c, h, w = x.data.shape
67-
post_pad_w, post_pad_h = (0, 0)
68124

69125
# Count_include_pad == False means that we use a different divisor for edge elements
70126
# When divisor_override is set, this will be overriden anyways.
71127
# It is easier to replace a constant divisor, so set count_include_pad == True
72128
if divisor_override is not None:
73129
count_include_pad = True
74130

75-
# Add width padding manually if count_include_pad
76-
if count_include_pad and pad_w > 0:
77-
pre_pad_shape = [n, c, h, pad_w]
78-
pre_pad = super().call_operator(
79-
full_op, (pre_pad_shape, 0.0), full_kwargs, meta, updated=True
80-
)
81-
82-
if ceil_mode and divisor_override is None:
83-
post_pad_w = pad_w
84-
else:
85-
post_pad_w = adjust_pooling_pad_if_needed(
86-
w, kernel_w, stride_w, pad_w, ceil_mode
87-
)
88-
89-
if post_pad_w > 0:
90-
post_pad_shape = [n, c, h, post_pad_w]
91-
post_pad = super().call_operator(
92-
full_op, (post_pad_shape, 0.0), full_kwargs, meta, updated=True
93-
)
94-
cat_nodes = [pre_pad, x, post_pad]
95-
else:
96-
cat_nodes = [pre_pad, x]
97-
98-
x = super().call_operator(
99-
cat_op, (cat_nodes, 3), kwargs, meta, updated=True
100-
)
101-
new_pad_w = 0
102-
103-
# Add height padding manually if count_include_pad
104-
if count_include_pad and pad_h > 0:
105-
pre_pad_shape = [n, c, pad_h, w + pad_w + post_pad_w]
106-
pre_pad = super().call_operator(
107-
full_op, (pre_pad_shape, 0.0), full_kwargs, meta, updated=True
108-
)
131+
pad, new_pad = _get_avgpool_post_pad(
132+
h,
133+
w,
134+
args[1],
135+
stride_h,
136+
stride_w,
137+
pad_h,
138+
pad_w,
139+
ceil_mode,
140+
count_include_pad,
141+
divisor_override,
142+
)
109143

110-
if ceil_mode and divisor_override is None:
111-
post_pad_h = pad_h
112-
else:
113-
post_pad_h = adjust_pooling_pad_if_needed(
114-
h, kernel_h, stride_h, pad_h, ceil_mode
115-
)
116-
117-
if post_pad_h > 0:
118-
post_pad_shape = [n, c, post_pad_h, w + pad_w + post_pad_w]
119-
post_pad = super().call_operator(
120-
full_op, (post_pad_shape, 0.0), full_kwargs, meta, updated=True
121-
)
122-
cat_nodes = [pre_pad, x, post_pad]
144+
if count_include_pad and (pad_h > 0 or pad_w > 0):
145+
if op in aten_avg_pool2d:
146+
pad_args = (x, pad, "constant", 0.0)
123147
else:
124-
cat_nodes = [pre_pad, x]
148+
pad_args = (x, pad, 0.0)
125149

126150
x = super().call_operator(
127-
cat_op, (cat_nodes, 2), kwargs, meta, updated=True
151+
pad_op,
152+
pad_args,
153+
{},
154+
meta,
155+
updated=True,
128156
)
129-
new_pad_h = 0
130157

131158
avgpool_args = (
132159
x,
133160
args[1],
134161
[stride_h, stride_w],
135-
[new_pad_h, new_pad_w],
162+
new_pad,
136163
ceil_mode,
137164
False,
138165
)
166+
139167
x = super().call_operator(avgpool_op, avgpool_args, kwargs, meta, updated=True)
140168

141-
# Multiply by factor (kernel_size / divisor_override) if divisor_override
142169
if divisor_override is not None and divisor_override != kernel_size:
143-
override_multiplier = super().call_operator(
144-
full_op,
145-
([1, 1, 1, 1], kernel_size / divisor_override),
146-
full_kwargs,
170+
x = super().call_operator(
171+
mul_op,
172+
(x, super().call_scalar(kernel_size / divisor_override, meta)),
173+
{},
147174
meta,
148175
updated=True,
149176
)
150-
x = super().call_operator(
151-
mul_op, (x, override_multiplier), kwargs, meta, updated=True
152-
)
153177

154178
return x

0 commit comments

Comments
 (0)