Skip to content

Commit 524f037

Browse files
Arm backend: Add TOSA.AVG_POOL2D_ADAPTIVE op (#19150)
Adds adaptive avg_pool2d TOSA dialect op to be used for TOSA-1.1. Also addresses a bug in tosa.avg_pool2d fake tracing. cc @digantdesai @freddan80 @per @zingo @mansnils @Sebastian-Larsson @robell Signed-off-by: Oscar Andersson <oscar.andersson@arm.com>
1 parent d6f1625 commit 524f037

5 files changed

Lines changed: 516 additions & 54 deletions

File tree

backends/arm/_passes/rewrite_avg_pool2d_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def call_operator(self, op, args, kwargs, meta, updated=False):
4242

4343
pad_h, pad_w = to_2tuple(args[3]) if len(args) > 3 else (0, 0)
4444
# Make sure pad corresponds to TOSA
45-
pad = [pad_h, pad_w, pad_h, pad_w]
45+
pad = [pad_h, pad_h, pad_w, pad_w]
4646

4747
ceil_mode = args[4] if len(args) > 4 else False
4848

Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import executorch.backends.arm.tosa.dialect # noqa: F401
8+
import pytest
9+
import torch
10+
from executorch.backends.arm.tosa.dialect.lib import TosaValueError
11+
from executorch.backends.arm.tosa.dialect.ops.avg_pool2d import (
12+
validate_avg_pool2d_dtype,
13+
)
14+
from executorch.backends.arm.tosa.specification import (
15+
TosaLoweringContext,
16+
TosaSpecification,
17+
)
18+
from executorch.exir.dialects._ops import ops as exir_ops
19+
from torch._subclasses.fake_tensor import FakeTensorMode
20+
21+
22+
def test_avg_pool2d_adaptive_tosa_INT():
23+
sample_inputs = [
24+
(
25+
(
26+
torch.randint(-128, 127, (1, 20, 20, 8), dtype=torch.int8),
27+
torch.zeros((1,), dtype=torch.int8),
28+
torch.zeros((1,), dtype=torch.int8),
29+
[3, 3],
30+
[2, 2],
31+
[1, 1, 1, 1],
32+
torch.int32,
33+
),
34+
(1, 10, 10, 8),
35+
torch.int8,
36+
),
37+
(
38+
(
39+
torch.randint(-32768, 32767, (1, 9, 13, 4), dtype=torch.int16),
40+
torch.zeros((1,), dtype=torch.int16),
41+
torch.zeros((1,), dtype=torch.int16),
42+
[2, 4],
43+
[1, 3],
44+
[0, 0, 1, 1],
45+
torch.int32,
46+
),
47+
(1, 8, 4, 4),
48+
torch.int16,
49+
),
50+
]
51+
52+
with TosaLoweringContext(
53+
TosaSpecification.create_from_string("TOSA-1.1+INT+int16")
54+
), FakeTensorMode() as mode:
55+
for sample_input, expected_output_shape, expected_output_type in sample_inputs:
56+
output = exir_ops.backend.tosa.AVG_POOL2D_ADAPTIVE.default(
57+
*tuple(
58+
[
59+
mode.from_tensor(i) if isinstance(i, torch.Tensor) else i
60+
for i in sample_input
61+
]
62+
)
63+
)
64+
assert output.dtype == expected_output_type
65+
assert tuple(output.shape) == expected_output_shape
66+
67+
68+
def test_avg_pool2d_adaptive_tosa_FP():
69+
sample_inputs = [
70+
(
71+
(
72+
torch.randn((1, 20, 20, 8), dtype=torch.float32),
73+
torch.zeros((1,), dtype=torch.float32),
74+
torch.zeros((1,), dtype=torch.float32),
75+
[3, 3],
76+
[2, 2],
77+
[1, 1, 1, 1],
78+
torch.float32,
79+
),
80+
(1, 10, 10, 8),
81+
torch.float32,
82+
),
83+
(
84+
(
85+
torch.randn((1, 9, 13, 4), dtype=torch.bfloat16),
86+
torch.zeros((1,), dtype=torch.bfloat16),
87+
torch.zeros((1,), dtype=torch.bfloat16),
88+
[2, 4],
89+
[1, 3],
90+
[0, 0, 1, 1],
91+
torch.float32,
92+
),
93+
(1, 8, 4, 4),
94+
torch.bfloat16,
95+
),
96+
]
97+
98+
with TosaLoweringContext(
99+
TosaSpecification.create_from_string("TOSA-1.1+FP+bf16")
100+
), FakeTensorMode() as mode:
101+
for sample_input, expected_output_shape, expected_output_type in sample_inputs:
102+
output = exir_ops.backend.tosa.AVG_POOL2D_ADAPTIVE.default(
103+
*tuple(
104+
[
105+
mode.from_tensor(i) if isinstance(i, torch.Tensor) else i
106+
for i in sample_input
107+
]
108+
)
109+
)
110+
assert output.dtype == expected_output_type
111+
assert tuple(output.shape) == expected_output_shape
112+
113+
114+
def test_avg_pool2d_adaptive_accepts_remainder_one_mapping():
115+
with TosaLoweringContext(
116+
TosaSpecification.create_from_string("TOSA-1.1+FP")
117+
), FakeTensorMode() as mode:
118+
x = mode.from_tensor(torch.randn((1, 5, 5, 4), dtype=torch.float32))
119+
input_zp = mode.from_tensor(torch.zeros((1,), dtype=torch.float32))
120+
output_zp = mode.from_tensor(torch.zeros((1,), dtype=torch.float32))
121+
122+
output = exir_ops.backend.tosa.AVG_POOL2D_ADAPTIVE.default(
123+
x,
124+
input_zp,
125+
output_zp,
126+
[3, 3],
127+
[2, 2],
128+
[0, 0, 0, 0],
129+
torch.float32,
130+
)
131+
132+
assert tuple(output.shape) == (1, 2, 2, 4)
133+
134+
135+
def test_avg_pool2d_adaptive_rejects_irregular_single_op_mapping():
136+
with TosaLoweringContext(
137+
TosaSpecification.create_from_string("TOSA-1.1+FP")
138+
), FakeTensorMode() as mode:
139+
x = mode.from_tensor(torch.randn((1, 8, 8, 4), dtype=torch.float32))
140+
input_zp = mode.from_tensor(torch.zeros((1,), dtype=torch.float32))
141+
output_zp = mode.from_tensor(torch.zeros((1,), dtype=torch.float32))
142+
143+
with pytest.raises(
144+
TosaValueError, match=r"input_size % output_size in \{0, 1\}"
145+
):
146+
exir_ops.backend.tosa.AVG_POOL2D_ADAPTIVE.default(
147+
x,
148+
input_zp,
149+
output_zp,
150+
[3, 3],
151+
[2, 2],
152+
[0, 0, 0, 0],
153+
torch.float32,
154+
)
155+
156+
157+
@pytest.mark.parametrize(
158+
"spec_str,input_dtype,zero_point_dtype,acc_type",
159+
[
160+
("TOSA-1.0+INT", torch.int8, torch.int8, torch.int32),
161+
("TOSA-1.1+INT+int16", torch.int16, torch.int16, torch.int32),
162+
("TOSA-1.0+FP", torch.float16, torch.float16, torch.float16),
163+
("TOSA-1.0+FP", torch.float16, torch.float16, torch.float32),
164+
("TOSA-1.0+FP", torch.float32, torch.float32, torch.float32),
165+
("TOSA-1.1+FP+bf16", torch.bfloat16, torch.bfloat16, torch.float32),
166+
],
167+
)
168+
def test_validate_avg_pool2d_dtype_accepts_spec_supported_combinations(
169+
spec_str: str,
170+
input_dtype: torch.dtype,
171+
zero_point_dtype: torch.dtype,
172+
acc_type: torch.dtype,
173+
):
174+
spec = TosaSpecification.create_from_string(spec_str)
175+
x = torch.zeros((1, 2, 8, 8), dtype=input_dtype)
176+
input_zp = torch.zeros((1,), dtype=zero_point_dtype)
177+
output_zp = torch.zeros((1,), dtype=zero_point_dtype)
178+
179+
validate_avg_pool2d_dtype(spec, x, input_zp, output_zp, acc_type, op="AVG_POOL2D")
180+
181+
182+
@pytest.mark.parametrize(
183+
"spec_str,input_dtype,zero_point_dtype,acc_type,match",
184+
[
185+
(
186+
"TOSA-1.0+FP",
187+
torch.float32,
188+
torch.int32,
189+
torch.float32,
190+
"input zero-point dtype",
191+
),
192+
(
193+
"TOSA-1.0+FP",
194+
torch.float32,
195+
torch.float32,
196+
torch.int32,
197+
"accumulator type must be one of",
198+
),
199+
(
200+
"TOSA-1.0+INT",
201+
torch.int16,
202+
torch.int16,
203+
torch.int32,
204+
"Unsupported input dtype",
205+
),
206+
(
207+
"TOSA-1.0+INT",
208+
torch.uint8,
209+
torch.uint8,
210+
torch.int32,
211+
"Unsupported input dtype",
212+
),
213+
],
214+
)
215+
def test_validate_avg_pool2d_dtype_rejects_invalid_combinations(
216+
spec_str: str,
217+
input_dtype: torch.dtype,
218+
zero_point_dtype: torch.dtype,
219+
acc_type: torch.dtype,
220+
match: str,
221+
):
222+
spec = TosaSpecification.create_from_string(spec_str)
223+
x = torch.zeros((1, 2, 8, 8), dtype=input_dtype)
224+
input_zp = torch.zeros((1,), dtype=zero_point_dtype)
225+
output_zp = torch.zeros((1,), dtype=zero_point_dtype)
226+
227+
with pytest.raises(TosaValueError, match=match):
228+
validate_avg_pool2d_dtype(
229+
spec,
230+
x,
231+
input_zp,
232+
output_zp,
233+
acc_type,
234+
op="AVG_POOL2D",
235+
)
236+
237+
238+
@pytest.mark.parametrize(
239+
"op_target",
240+
[
241+
exir_ops.backend.tosa.AVG_POOL2D.default,
242+
exir_ops.backend.tosa.AVG_POOL2D_ADAPTIVE.default,
243+
],
244+
)
245+
def test_avg_pool2d_ops_reject_invalid_parameter_lengths(op_target):
246+
with TosaLoweringContext(
247+
TosaSpecification.create_from_string("TOSA-1.1+FP+shape")
248+
), FakeTensorMode() as mode:
249+
x = mode.from_tensor(torch.randn((1, 8, 8, 4), dtype=torch.float32))
250+
input_zp = mode.from_tensor(torch.zeros((1,), dtype=torch.float32))
251+
output_zp = mode.from_tensor(torch.zeros((1,), dtype=torch.float32))
252+
253+
with pytest.raises(TosaValueError, match="expects kernel of length 2"):
254+
op_target(
255+
x,
256+
input_zp,
257+
output_zp,
258+
[2],
259+
[2, 2],
260+
[0, 0, 0, 0],
261+
torch.float32,
262+
)
263+
264+
with pytest.raises(TosaValueError, match="stride of length 2"):
265+
op_target(
266+
x,
267+
input_zp,
268+
output_zp,
269+
[2, 2],
270+
[2],
271+
[0, 0, 0, 0],
272+
torch.float32,
273+
)
274+
275+
with pytest.raises(TosaValueError, match="pad of length 4"):
276+
op_target(
277+
x,
278+
input_zp,
279+
output_zp,
280+
[2, 2],
281+
[2, 2],
282+
[0, 0, 0],
283+
torch.float32,
284+
)
285+
286+
287+
def test_avg_pool2d_adaptive_no_target_requires_tosa_1_1():
288+
with TosaLoweringContext(
289+
TosaSpecification.create_from_string("TOSA-1.0+FP")
290+
), FakeTensorMode() as mode:
291+
x = mode.from_tensor(torch.randn((1, 8, 8, 4), dtype=torch.float32))
292+
input_zp = mode.from_tensor(torch.zeros((1,), dtype=torch.float32))
293+
output_zp = mode.from_tensor(torch.zeros((1,), dtype=torch.float32))
294+
with pytest.raises(TosaValueError, match="support AVG_POOL2D_ADAPTIVE"):
295+
exir_ops.backend.tosa.AVG_POOL2D_ADAPTIVE.default(
296+
x,
297+
input_zp,
298+
output_zp,
299+
[2, 2],
300+
[2, 2],
301+
[0, 0, 0, 0],
302+
torch.float32,
303+
)

backends/arm/tosa/dialect/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from executorch.backends.arm.tosa.dialect.ops import ( # noqa F401
77
avg_pool2d,
8+
avg_pool2d_adaptive,
89
conv2d,
910
conv3d,
1011
custom,

0 commit comments

Comments
 (0)