Skip to content

Commit 8e8e957

Browse files
authored
XNNPACK: accept positive spatial dims in mean.dim via normalization (#19206)
Normalize mean.dim axis arguments before validation and partitioning so equivalent forms such as [2, 3] and [-2, -1] are accepted. Previously, only negative dim forms were allowed, causing valid 4D spatial reductions with positive dims to be rejected. This preserves existing XNNPACK constraints (4D input, keepdim=True, reduction over the innermost two dimensions) while removing dependence on axis spelling. Add a test covering positive spatial dims. This reduces EdgeTAM end-to-end runtime by ~6–7% on SME2 enabled devices. cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @Sebastian-Larsson @robell --------- Signed-off-by: Måns Nilsson <mans.nilsson@arm.com>
1 parent a1ebbcc commit 8e8e957

4 files changed

Lines changed: 98 additions & 25 deletions

File tree

backends/xnnpack/operators/op_mean_dim.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
67

7-
from typing import cast, Dict, List
8+
from typing import cast, Dict
89

910
import torch
1011
from executorch.backends.xnnpack.operators.node_visitor import (
@@ -18,16 +19,17 @@
1819
XNNGraph,
1920
XNode,
2021
)
22+
from executorch.backends.xnnpack.utils.utils import normalize_mean_dims
2123
from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_FLAG_KEEP_DIMS
2224

2325

2426
@register_node_visitor
2527
class MeanDim(NodeVisitor):
2628
"""
27-
XNNPACK only supports a special case of mean dim in which the operation can be written
28-
as Global Average Pooling. In order to be handled by xnnpack the input tensor must be 4d,
29-
the dimensions to reduce must be the two innermost (-1, -2) or (-2, -1). and the flag
30-
for keepdim must be set to True.
29+
XNNPACK only supports the special case of mean.dim that can be lowered
30+
to Global Average Pooling. The input tensor must be 4D, keepdim must be
31+
True, and the reduced dimensions must normalize to [2, 3] (for example
32+
[2, 3], [3, 2], [-1, -2], or [-2, -1]).
3133
"""
3234

3335
target = "aten.mean.dim"
@@ -51,10 +53,22 @@ def define_node(
5153
# output
5254
output_id = vals_to_ids[node]
5355

56+
input_shape = get_tensor_value(xnn_graph.xvalues[input_id]).dims
57+
check_or_raise(
58+
len(input_shape) == 4, "Require input to mean.dim be 4 dimensional"
59+
)
60+
61+
# This visitor serializes mean.dim as Global Average Pooling, which has
62+
# no field for an explicit dtype override.
63+
check_or_raise(
64+
node.kwargs.get("dtype") is None,
65+
"XNNPACK does not support mean.dim with dtype",
66+
)
67+
5468
# mean dims
55-
mean_dims = cast(List[int], node.args[1])
69+
mean_dims = normalize_mean_dims(node.args[1], len(input_shape))
5670
check_or_raise(
57-
mean_dims == [-1, -2] or mean_dims == [-2, -1],
71+
sorted(mean_dims) == [2, 3],
5872
"XNNPACK only supports mean.dim across the innermost dimensions",
5973
)
6074

@@ -64,11 +78,6 @@ def define_node(
6478
"XNNPACK only supports mean.dim that keeps dims",
6579
)
6680

67-
input_shape = get_tensor_value(xnn_graph.xvalues[input_id]).dims
68-
check_or_raise(
69-
len(input_shape) == 4, "Require input to mean.dim be 4 dimensional"
70-
)
71-
7281
ser_node = XNode(
7382
xnode_union=XNNGlobalAvgPooling2d(
7483
input_id=input_id, output_id=output_id, flags=XNN_FLAG_KEEP_DIMS

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -20,7 +21,7 @@
2021
is_quant,
2122
tag_as_implicit_q_dq,
2223
)
23-
from executorch.backends.xnnpack.utils.utils import get_input_node
24+
from executorch.backends.xnnpack.utils.utils import get_input_node, normalize_mean_dims
2425
from executorch.exir.backend.canonical_partitioners.config_partitioner import (
2526
format_target_name,
2627
)
@@ -515,22 +516,38 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
515516
if not self.check_common_constraints(node, ep):
516517
return False
517518

518-
dims = node.args[1]
519-
output_dims = node.meta["val"].dim()
520-
521-
if dims not in ([-2, -1], [-1, -2]):
519+
input_rank = get_input_node(node, 0).meta["val"].dim()
520+
if input_rank != 4:
522521
why(
523522
node,
524-
reason="mean.dim only supports averaging 4D tensors across the innermost dimensions",
523+
reason=f"mean.dim only supports averaging 4D tensors, got tensor of rank {input_rank}",
525524
)
526525
return False
527526

528-
if output_dims != 4:
527+
# This path lowers mean.dim to XNNPACK Global Average Pooling, which
528+
# cannot encode an explicit dtype override.
529+
if node.kwargs.get("dtype") is not None:
530+
why(node, reason="mean.dim does not support dtype")
531+
return False
532+
533+
keepdim = len(node.args) >= 3 and bool(node.args[2])
534+
try:
535+
dims = normalize_mean_dims(node.args[1], input_rank)
536+
except ValueError as error:
537+
why(node, reason=f"mean.dim has invalid dims: {error}")
538+
return False
539+
540+
if sorted(dims) != [2, 3]:
529541
why(
530542
node,
531-
reason=f"mean.dim only supports averaging 4D tensors, got tensor of rank {output_dims}",
543+
reason="mean.dim only supports averaging 4D tensors across the innermost dimensions",
532544
)
533545
return False
546+
547+
if not keepdim:
548+
why(node, reason="mean.dim only supports keepdim=True")
549+
return False
550+
534551
return True
535552

536553
def supported_precision_types(self) -> List[ConfigPrecisionType]:

backends/xnnpack/test/ops/test_mean_dim.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -15,18 +16,23 @@ def setUp(self):
1516
torch._dynamo.reset()
1617

1718
class MeanDim(torch.nn.Module):
18-
def __init__(self, dims):
19+
def __init__(self, dims, keepdim=True, dtype=None):
1920
super().__init__()
2021
self.dims = dims
22+
self.keepdim = keepdim
23+
self.dtype = dtype
2124

2225
def forward(self, x):
2326
y = x + x
24-
z = torch.mean(y, self.dims, keepdim=True)
27+
if self.dtype is None:
28+
z = torch.mean(y, self.dims, keepdim=self.keepdim)
29+
else:
30+
z = torch.mean(y, self.dims, keepdim=self.keepdim, dtype=self.dtype)
2531
return z
2632

27-
def _test_mean_dim(self, inputs):
33+
def _test_mean_dim(self, inputs, dims=(-1, -2)):
2834
(
29-
Tester(self.MeanDim((-1, -2)), inputs)
35+
Tester(self.MeanDim(dims), inputs)
3036
.export()
3137
.check_count({"torch.ops.aten.mean.dim": 1})
3238
.to_edge_transform_and_lower()
@@ -45,6 +51,10 @@ def test_fp32_mean_dim(self):
4551
inputs = (torch.randn(1, 5, 4, 4),)
4652
self._test_mean_dim(inputs)
4753

54+
def test_fp32_mean_dim_positive_dims(self):
55+
inputs = (torch.randn(1, 5, 4, 4),)
56+
self._test_mean_dim(inputs, dims=(2, 3))
57+
4858
def test_fp32_mean_dim_unsupported(self):
4959
"""
5060
XNNPack mean.dim implementation only supports innermost two dimensions. As such,
@@ -72,6 +82,26 @@ def test_fp32_mean_dim_unsupported_3d(self):
7282
.check_count({"executorch_exir_dialects_edge__ops_aten_mean_dim": 1})
7383
)
7484

85+
def test_fp32_mean_dim_unsupported_keepdim_false(self):
86+
inputs = (torch.randn(1, 5, 4, 4),)
87+
(
88+
Tester(self.MeanDim((-1, -2), keepdim=False), inputs)
89+
.export()
90+
.check_count({"torch.ops.aten.mean.dim": 1})
91+
.to_edge_transform_and_lower()
92+
.check_count({"executorch_exir_dialects_edge__ops_aten_mean_dim": 1})
93+
)
94+
95+
def test_fp32_mean_dim_unsupported_dtype(self):
96+
inputs = (torch.randn(1, 5, 4, 4),)
97+
(
98+
Tester(self.MeanDim((-1, -2), dtype=torch.float64), inputs)
99+
.export()
100+
.check_count({"torch.ops.aten.mean.dim": 1})
101+
.to_edge_transform_and_lower()
102+
.check_count({"executorch_exir_dialects_edge__ops_aten_mean_dim": 1})
103+
)
104+
75105
def test_qs8_mean_dim(self):
76106
inputs = (torch.randn(1, 5, 4, 4),)
77107
(

backends/xnnpack/utils/utils.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
67

7-
from typing import Any, cast, Optional, Tuple
8+
from typing import Any, cast, List, Optional, Sequence, Tuple
89

910
import torch
1011

@@ -57,6 +58,22 @@ def get_input_node(node: torch.fx.Node, input_index: int) -> torch.fx.Node:
5758
return cast(torch.fx.Node, node.args[input_index])
5859

5960

61+
def normalize_mean_dims(mean_dims: Sequence[int] | int | None, rank: int) -> List[int]:
62+
"""Return mean dims as non-negative indices for the given rank."""
63+
if rank <= 0:
64+
raise ValueError(f"Expected rank > 0, got {rank}")
65+
if mean_dims is None:
66+
return list(range(rank))
67+
if isinstance(mean_dims, int):
68+
mean_dims = [mean_dims]
69+
normalized_dims = []
70+
for dim in mean_dims:
71+
if dim < -rank or dim >= rank:
72+
raise ValueError(f"Dimension out of range: {dim} for rank {rank}")
73+
normalized_dims.append(dim % rank)
74+
return normalized_dims
75+
76+
6077
def get_relu_fused_node(node: torch.fx.Node) -> Optional[torch.fx.Node]:
6178
"""
6279
Checks if the current node is only consumed by a relu node and can be fused,

0 commit comments

Comments
 (0)