Skip to content

Commit 2c1369e

Browse files
committed
XNNPACK: accept positive spatial dims in mean.dim via normalization
Normalize mean.dim axis arguments before validation and partitioning so equivalent forms such as [2, 3] and [-2, -1] are accepted. Previously, only negative dim forms were allowed, causing valid 4D spatial reductions with positive dims to be rejected. This preserves existing XNNPACK constraints (4D input, keepdim=True, reduction over the innermost two dimensions) while removing dependence on axis spelling. Add a test covering positive spatial dims. This reduces EdgeTAM end-to-end runtime by ~6–7% on SME2 enabled devices. Signed-off-by: Måns Nilsson <mans.nilsson@arm.com> Change-Id: I8a56568dbe79fa3bb8327b83c453b240795491ab
1 parent eb19f24 commit 2c1369e

4 files changed

Lines changed: 58 additions & 23 deletions

File tree

backends/xnnpack/operators/op_mean_dim.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
67

7-
from typing import cast, Dict, List
8+
from typing import cast, Dict
89

910
import torch
1011
from executorch.backends.xnnpack.operators.node_visitor import (
@@ -18,16 +19,17 @@
1819
XNNGraph,
1920
XNode,
2021
)
22+
from executorch.backends.xnnpack.utils.utils import normalize_mean_dims
2123
from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_FLAG_KEEP_DIMS
2224

2325

2426
@register_node_visitor
2527
class MeanDim(NodeVisitor):
2628
"""
27-
XNNPACK only supports a special case of mean dim in which the operation can be written
28-
as Global Average Pooling. In order to be handled by xnnpack the input tensor must be 4d,
29-
the dimensions to reduce must be the two innermost (-1, -2) or (-2, -1). and the flag
30-
for keepdim must be set to True.
29+
XNNPACK only supports the special case of mean.dim that can be lowered
30+
to Global Average Pooling. The input tensor must be 4D, keepdim must be
31+
True, and the reduced dimensions must normalize to [2, 3] (for example
32+
[2, 3], [3, 2], [-1, -2], or [-2, -1]).
3133
"""
3234

3335
target = "aten.mean.dim"
@@ -51,10 +53,15 @@ def define_node(
5153
# output
5254
output_id = vals_to_ids[node]
5355

56+
input_shape = get_tensor_value(xnn_graph.xvalues[input_id]).dims
57+
check_or_raise(
58+
len(input_shape) == 4, "Require input to mean.dim be 4 dimensional"
59+
)
60+
5461
# mean dims
55-
mean_dims = cast(List[int], node.args[1])
62+
mean_dims = normalize_mean_dims(node.args[1], len(input_shape))
5663
check_or_raise(
57-
mean_dims == [-1, -2] or mean_dims == [-2, -1],
64+
sorted(mean_dims) == [2, 3],
5865
"XNNPACK only supports mean.dim across the innermost dimensions",
5966
)
6067

@@ -64,11 +71,6 @@ def define_node(
6471
"XNNPACK only supports mean.dim that keeps dims",
6572
)
6673

67-
input_shape = get_tensor_value(xnn_graph.xvalues[input_id]).dims
68-
check_or_raise(
69-
len(input_shape) == 4, "Require input to mean.dim be 4 dimensional"
70-
)
71-
7274
ser_node = XNode(
7375
xnode_union=XNNGlobalAvgPooling2d(
7476
input_id=input_id, output_id=output_id, flags=XNN_FLAG_KEEP_DIMS

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -20,7 +21,7 @@
2021
is_quant,
2122
tag_as_implicit_q_dq,
2223
)
23-
from executorch.backends.xnnpack.utils.utils import get_input_node
24+
from executorch.backends.xnnpack.utils.utils import get_input_node, normalize_mean_dims
2425
from executorch.exir.backend.canonical_partitioners.config_partitioner import (
2526
format_target_name,
2627
)
@@ -515,22 +516,28 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
515516
if not self.check_common_constraints(node, ep):
516517
return False
517518

518-
dims = node.args[1]
519-
output_dims = node.meta["val"].dim()
519+
input_rank = get_input_node(node, 0).meta["val"].dim()
520+
keepdim = len(node.args) >= 3 and bool(node.args[2])
521+
dims = normalize_mean_dims(node.args[1], input_rank)
520522

521-
if dims not in ([-2, -1], [-1, -2]):
523+
if sorted(dims) != [2, 3]:
522524
why(
523525
node,
524526
reason="mean.dim only supports averaging 4D tensors across the innermost dimensions",
525527
)
526528
return False
527529

528-
if output_dims != 4:
530+
if input_rank != 4:
529531
why(
530532
node,
531-
reason=f"mean.dim only supports averaging 4D tensors, got tensor of rank {output_dims}",
533+
reason=f"mean.dim only supports averaging 4D tensors, got tensor of rank {input_rank}",
532534
)
533535
return False
536+
537+
if not keepdim:
538+
why(node, reason="mean.dim only supports keepdim=True")
539+
return False
540+
534541
return True
535542

536543
def supported_precision_types(self) -> List[ConfigPrecisionType]:

backends/xnnpack/test/ops/test_mean_dim.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -15,18 +16,19 @@ def setUp(self):
1516
torch._dynamo.reset()
1617

1718
class MeanDim(torch.nn.Module):
18-
def __init__(self, dims):
19+
def __init__(self, dims, keepdim=True):
1920
super().__init__()
2021
self.dims = dims
22+
self.keepdim = keepdim
2123

2224
def forward(self, x):
2325
y = x + x
24-
z = torch.mean(y, self.dims, keepdim=True)
26+
z = torch.mean(y, self.dims, keepdim=self.keepdim)
2527
return z
2628

27-
def _test_mean_dim(self, inputs):
29+
def _test_mean_dim(self, inputs, dims=(-1, -2)):
2830
(
29-
Tester(self.MeanDim((-1, -2)), inputs)
31+
Tester(self.MeanDim(dims), inputs)
3032
.export()
3133
.check_count({"torch.ops.aten.mean.dim": 1})
3234
.to_edge_transform_and_lower()
@@ -45,6 +47,10 @@ def test_fp32_mean_dim(self):
4547
inputs = (torch.randn(1, 5, 4, 4),)
4648
self._test_mean_dim(inputs)
4749

50+
def test_fp32_mean_dim_positive_dims(self):
51+
inputs = (torch.randn(1, 5, 4, 4),)
52+
self._test_mean_dim(inputs, dims=(2, 3))
53+
4854
def test_fp32_mean_dim_unsupported(self):
4955
"""
5056
XNNPack mean.dim implementation only supports innermost two dimensions. As such,
@@ -72,6 +78,16 @@ def test_fp32_mean_dim_unsupported_3d(self):
7278
.check_count({"executorch_exir_dialects_edge__ops_aten_mean_dim": 1})
7379
)
7480

81+
def test_fp32_mean_dim_unsupported_keepdim_false(self):
82+
inputs = (torch.randn(1, 5, 4, 4),)
83+
(
84+
Tester(self.MeanDim((-1, -2), keepdim=False), inputs)
85+
.export()
86+
.check_count({"torch.ops.aten.mean.dim": 1})
87+
.to_edge_transform_and_lower()
88+
.check_count({"executorch_exir_dialects_edge__ops_aten_mean_dim": 1})
89+
)
90+
7591
def test_qs8_mean_dim(self):
7692
inputs = (torch.randn(1, 5, 4, 4),)
7793
(

backends/xnnpack/utils/utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
67

7-
from typing import Any, cast, Optional, Tuple
8+
from typing import Any, cast, List, Optional, Sequence, Tuple
89

910
import torch
1011

@@ -57,6 +58,15 @@ def get_input_node(node: torch.fx.Node, input_index: int) -> torch.fx.Node:
5758
return cast(torch.fx.Node, node.args[input_index])
5859

5960

61+
def normalize_mean_dims(mean_dims: Sequence[int] | int | None, rank: int) -> List[int]:
62+
"""Return mean dims as non-negative indices for the given rank."""
63+
if mean_dims is None:
64+
return list(range(rank))
65+
if isinstance(mean_dims, int):
66+
mean_dims = [mean_dims]
67+
return [dim % rank for dim in mean_dims]
68+
69+
6070
def get_relu_fused_node(node: torch.fx.Node) -> Optional[torch.fx.Node]:
6171
"""
6272
Checks if the current node is only consumed by a relu node and can be fused,

0 commit comments

Comments
 (0)