Skip to content

Commit c26dcf3

Browse files
Arm backend: Test avgpool non-square kernels
Add a regression test for AVG_POOL2D output shape with a non-square kernel and height-only padding. Change-Id: Ib2f0c15720aa7ba5c15a7406bafbc2d37aa4fa5a Signed-off-by: Sebastian Larsson <sebastian.larsson@arm.com>
1 parent 7fc33ab commit c26dcf3

2 files changed

Lines changed: 22 additions & 1 deletion

File tree

backends/arm/test/misc/test_tosa_dialect_avg_pool2d_adaptive.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,27 @@
1919
from torch._subclasses.fake_tensor import FakeTensorMode
2020

2121

22+
def test_avg_pool2d_tosa_non_square_kernel_output_shape():
23+
with TosaLoweringContext(
24+
TosaSpecification.create_from_string("TOSA-1.0+FP")
25+
), FakeTensorMode() as mode:
26+
x = mode.from_tensor(torch.randn((1, 20, 20, 8), dtype=torch.float32))
27+
input_zp = mode.from_tensor(torch.zeros((1,), dtype=torch.float32))
28+
output_zp = mode.from_tensor(torch.zeros((1,), dtype=torch.float32))
29+
30+
output = exir_ops.backend.tosa.AVG_POOL2D.default(
31+
x,
32+
input_zp,
33+
output_zp,
34+
[2, 3],
35+
[2, 1],
36+
[1, 1, 0, 0],
37+
torch.float32,
38+
)
39+
40+
assert tuple(output.shape) == (1, 11, 18, 8)
41+
42+
2243
def test_avg_pool2d_adaptive_tosa_INT():
2344
sample_inputs = [
2445
(

backends/arm/tosa/dialect/ops/avg_pool2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def compute_avg_pool2d_output_shape(
136136
pad: List[IntLikeType] | List[int],
137137
op: str = "AVG_POOL2D",
138138
) -> List[IntLikeType]:
139-
"""Compute the output shape for NCHW avg-pool."""
139+
"""Compute the output shape for NHWC avg-pool."""
140140

141141
if x.dim() != 4:
142142
raise TosaValueError(f"{op} requires a 4D tensor, got {x.dim()}D", op=op)

0 commit comments

Comments
 (0)