Skip to content

Commit 837c266

Browse files
author
Baris Demir
committed
Arm backend: Decompose avg_pool2d count_include_pad via pad
Signed-off-by: Baris Demir <baris.demir@arm.com> Change-Id: I9d4e1cdafe1200c325bd474a7d3685ba39b55228
1 parent ffc4705 commit 837c266

5 files changed

Lines changed: 36 additions & 39 deletions

File tree

backends/cortex_m/ops/op_quantized_avg_pool2d.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ Tensor& quantized_avg_pool2d_out(
3131

3232
const int64_t dilation_values[2] = {1, 1};
3333
const Int64ArrayRef dilation(dilation_values, 2);
34-
const bool ceil_mode = false;
35-
3634
CmsisPool2DConfig pool_config;
3735
if (!prepare_cmsis_pool2d_config(
3836
context,
@@ -43,7 +41,7 @@ Tensor& quantized_avg_pool2d_out(
4341
stride,
4442
padding,
4543
dilation,
46-
ceil_mode,
44+
false,
4745
activation_min,
4846
activation_max,
4947
pool_config)) {
@@ -57,7 +55,7 @@ Tensor& quantized_avg_pool2d_out(
5755
const int8_t* input_data = input.const_data_ptr<int8_t>();
5856
int8_t* output_data = out.mutable_data_ptr<int8_t>();
5957

60-
arm_cmsis_nn_status status = arm_avgpool_s8(
58+
const arm_cmsis_nn_status status = arm_avgpool_s8(
6159
&cmsis_ctx,
6260
&pool_config.pool_params,
6361
&pool_config.input_dims,

backends/cortex_m/ops/operators.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,16 +1110,16 @@ def quantized_avg_pool2d_meta(
11101110
multiplier: int,
11111111
shift: int,
11121112
) -> torch.Tensor:
1113-
kernel = _ensure_tuple2(kernel_size)
1114-
stride_vals = _ensure_tuple2(stride)
1115-
padding_vals = _ensure_tuple2(padding)
1116-
dilation_vals = (1, 1)
1117-
1118-
output_shape = _compute_max_pool2d_output_shape(
1119-
input.shape, kernel, stride_vals, padding_vals, dilation_vals
1113+
output = F.avg_pool2d(
1114+
input.to(torch.float),
1115+
kernel_size,
1116+
stride=stride,
1117+
padding=padding,
1118+
ceil_mode=False,
1119+
count_include_pad=False,
11201120
)
11211121
return torch.empty(
1122-
output_shape,
1122+
output.shape,
11231123
dtype=torch.int8,
11241124
device=input.device,
11251125
memory_format=torch.channels_last,
@@ -1136,21 +1136,20 @@ def quantized_avg_pool2d_impl(
11361136
multiplier: int,
11371137
shift: int,
11381138
) -> torch.Tensor:
1139-
11401139
dequant_input = dequantize_per_tensor_cmsis(input, zero_point, multiplier, shift)
11411140

11421141
kernel = _ensure_tuple2(kernel_size)
11431142
stride_vals = _ensure_tuple2(stride)
11441143
padding_vals = _ensure_tuple2(padding)
11451144

1146-
# TODO: implement count_include_pad=True, ceil_mode=True, dilation != 1.
1145+
# TODO: implement dilation != 1.
11471146
result = F.avg_pool2d(
11481147
dequant_input,
11491148
kernel,
11501149
stride=stride_vals,
11511150
padding=padding_vals,
1152-
count_include_pad=False,
11531151
ceil_mode=False,
1152+
count_include_pad=False,
11541153
)
11551154
result = quantize_per_tensor_cmsis(result, zero_point, multiplier, shift)
11561155
output = torch.clamp(result, -128, 127)

backends/cortex_m/passes/quantized_op_fusion_pass.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -360,14 +360,29 @@ def _get_avg_pool2d_replacement(self, args, meta):
360360
divisor_override = args[6] if len(args) > 6 else None
361361
divisor_override_val = self._unwrap_argument(divisor_override)
362362

363-
if ceil_mode or count_include_pad or divisor_override_val is not None:
363+
if ceil_mode or divisor_override_val is not None:
364364
return exir_ops.edge.aten.avg_pool2d.default, args
365365

366+
input_arg = args[0]
367+
avg_padding = padding
368+
if count_include_pad:
369+
# Decompose count_include_pad=True into explicit input padding.
370+
pad_h, pad_w = padding
371+
pre_pad = [0, 0, pad_h, pad_w]
372+
post_pad = [0, 0, pad_h, pad_w]
373+
input_arg = super().call_operator(
374+
exir_ops.edge.cortex_m.pad.default,
375+
(input_arg, pre_pad, post_pad, int(zero_point)),
376+
{},
377+
NodeMetadata({}),
378+
)
379+
avg_padding = [0, 0]
380+
366381
args = (
367-
args[0],
382+
input_arg,
368383
kernel_size,
369384
stride,
370-
padding,
385+
avg_padding,
371386
zero_point,
372387
output_mult,
373388
output_shift,

backends/cortex_m/quantizer/pattern_checkers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,7 @@ def check_pattern(cls, pattern):
291291
return False
292292
node = pattern[0]
293293
ceil_mode = cast(bool, node.args[4]) if len(node.args) > 4 else False
294-
count_include_pad = cast(bool, node.args[5]) if len(node.args) > 5 else True
295-
return not (ceil_mode or count_include_pad)
294+
return not ceil_mode
296295

297296
@classmethod
298297
def check_quantization_config(

backends/cortex_m/test/ops/test_avg_pool2d.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -59,13 +59,6 @@ def forward(self, x): # noqa: D102
5959
CortexMAvgPool2d(kernel_size=3, stride=2, padding=1),
6060
(ramp_tensor(0, 15, (1, 1, 4, 4)),),
6161
),
62-
}
63-
64-
test_cases_fp = {
65-
"avgpool_3x3_s2_pad1_ceil": McuTestCase(
66-
CortexMAvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True),
67-
(ramp_tensor(0, 15, (1, 1, 4, 4)),),
68-
),
6962
"avgpool_3x3_s2_pad1_countinc": McuTestCase(
7063
CortexMAvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=True),
7164
(ramp_tensor(0, 15, (1, 1, 4, 4)),),
@@ -76,19 +69,12 @@ def forward(self, x): # noqa: D102
7669
@parametrize("test_case", test_cases)
7770
def test_dialect_avg_pool2d(test_case):
7871
tester = CortexMTester(test_case.model, test_case.example_inputs)
72+
ops_after = dict(test_case.model.ops_after_transforms)
73+
if test_case.model.pool.count_include_pad:
74+
ops_after["executorch_exir_dialects_edge__ops_cortex_m_pad_default"] = 1
7975
tester.test_dialect(
8076
test_case.model.ops_before_transforms,
81-
test_case.model.ops_after_transforms,
82-
qtol=1,
83-
)
84-
85-
86-
@parametrize("test_case", test_cases_fp)
87-
def test_dialect_avg_pool2d_fp(test_case):
88-
tester = CortexMTester(test_case.model, test_case.example_inputs)
89-
tester.test_dialect(
90-
{"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1},
91-
{"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1},
77+
ops_after,
9278
qtol=1,
9379
)
9480

0 commit comments

Comments
 (0)