diff --git a/backends/xnnpack/operators/op_mean_dim.py b/backends/xnnpack/operators/op_mean_dim.py index 663606a8880..527748f32a0 100644 --- a/backends/xnnpack/operators/op_mean_dim.py +++ b/backends/xnnpack/operators/op_mean_dim.py @@ -1,10 +1,11 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import cast, Dict, List +from typing import cast, Dict import torch from executorch.backends.xnnpack.operators.node_visitor import ( @@ -18,16 +19,17 @@ XNNGraph, XNode, ) +from executorch.backends.xnnpack.utils.utils import normalize_mean_dims from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_FLAG_KEEP_DIMS @register_node_visitor class MeanDim(NodeVisitor): """ - XNNPACK only supports a special case of mean dim in which the operation can be written - as Global Average Pooling. In order to be handled by xnnpack the input tensor must be 4d, - the dimensions to reduce must be the two innermost (-1, -2) or (-2, -1). and the flag - for keepdim must be set to True. + XNNPACK only supports the special case of mean.dim that can be lowered + to Global Average Pooling. The input tensor must be 4D, keepdim must be + True, and the reduced dimensions must normalize to [2, 3] (for example + [2, 3], [3, 2], [-1, -2], or [-2, -1]). """ target = "aten.mean.dim" @@ -51,10 +53,15 @@ def define_node( # output output_id = vals_to_ids[node] + input_shape = get_tensor_value(xnn_graph.xvalues[input_id]).dims + check_or_raise( + len(input_shape) == 4, "Require input to mean.dim be 4 dimensional" + ) + # mean dims - mean_dims = cast(List[int], node.args[1]) + mean_dims = normalize_mean_dims(node.args[1], len(input_shape)) check_or_raise( - mean_dims == [-1, -2] or mean_dims == [-2, -1], + sorted(mean_dims) == [2, 3], "XNNPACK only supports mean.dim across the innermost dimensions", ) @@ -64,11 +71,6 @@ def define_node( "XNNPACK only supports mean.dim that keeps dims", ) - input_shape = get_tensor_value(xnn_graph.xvalues[input_id]).dims - check_or_raise( - len(input_shape) == 4, "Require input to mean.dim be 4 dimensional" - ) - ser_node = XNode( xnode_union=XNNGlobalAvgPooling2d( input_id=input_id, output_id=output_id, flags=XNN_FLAG_KEEP_DIMS diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index 0e588af66cb..988271e1383 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -20,7 +21,7 @@ is_quant, tag_as_implicit_q_dq, ) -from executorch.backends.xnnpack.utils.utils import get_input_node +from executorch.backends.xnnpack.utils.utils import get_input_node, normalize_mean_dims from executorch.exir.backend.canonical_partitioners.config_partitioner import ( format_target_name, ) @@ -515,22 +516,28 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: if not self.check_common_constraints(node, ep): return False - dims = node.args[1] - output_dims = node.meta["val"].dim() + input_rank = get_input_node(node, 0).meta["val"].dim() + keepdim = len(node.args) >= 3 and bool(node.args[2]) + dims = normalize_mean_dims(node.args[1], input_rank) - if dims not in ([-2, -1], [-1, -2]): + if sorted(dims) != [2, 3]: why( node, reason="mean.dim only supports averaging 4D tensors across the innermost dimensions", ) return False - if output_dims != 4: + if input_rank != 4: why( node, - reason=f"mean.dim only supports averaging 4D tensors, got tensor of rank {output_dims}", + reason=f"mean.dim only supports averaging 4D tensors, got tensor of rank {input_rank}", ) return False + + if not keepdim: + why(node, reason="mean.dim only supports keepdim=True") + return False + return True def supported_precision_types(self) -> List[ConfigPrecisionType]: diff --git a/backends/xnnpack/test/ops/test_mean_dim.py b/backends/xnnpack/test/ops/test_mean_dim.py index 81a93c3e97e..da9dd79a907 100644 --- a/backends/xnnpack/test/ops/test_mean_dim.py +++ b/backends/xnnpack/test/ops/test_mean_dim.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -15,18 +16,19 @@ def setUp(self): torch._dynamo.reset() class MeanDim(torch.nn.Module): - def __init__(self, dims): + def __init__(self, dims, keepdim=True): super().__init__() self.dims = dims + self.keepdim = keepdim def forward(self, x): y = x + x - z = torch.mean(y, self.dims, keepdim=True) + z = torch.mean(y, self.dims, keepdim=self.keepdim) return z - def _test_mean_dim(self, inputs): + def _test_mean_dim(self, inputs, dims=(-1, -2)): ( - Tester(self.MeanDim((-1, -2)), inputs) + Tester(self.MeanDim(dims), inputs) .export() .check_count({"torch.ops.aten.mean.dim": 1}) .to_edge_transform_and_lower() @@ -45,6 +47,10 @@ def test_fp32_mean_dim(self): inputs = (torch.randn(1, 5, 4, 4),) self._test_mean_dim(inputs) + def test_fp32_mean_dim_positive_dims(self): + inputs = (torch.randn(1, 5, 4, 4),) + self._test_mean_dim(inputs, dims=(2, 3)) + def test_fp32_mean_dim_unsupported(self): """ XNNPack mean.dim implementation only supports innermost two dimensions. As such, @@ -72,6 +78,16 @@ def test_fp32_mean_dim_unsupported_3d(self): .check_count({"executorch_exir_dialects_edge__ops_aten_mean_dim": 1}) ) + def test_fp32_mean_dim_unsupported_keepdim_false(self): + inputs = (torch.randn(1, 5, 4, 4),) + ( + Tester(self.MeanDim((-1, -2), keepdim=False), inputs) + .export() + .check_count({"torch.ops.aten.mean.dim": 1}) + .to_edge_transform_and_lower() + .check_count({"executorch_exir_dialects_edge__ops_aten_mean_dim": 1}) + ) + def test_qs8_mean_dim(self): inputs = (torch.randn(1, 5, 4, 4),) ( diff --git a/backends/xnnpack/utils/utils.py b/backends/xnnpack/utils/utils.py index a41d5bc634a..19e0832e62b 100644 --- a/backends/xnnpack/utils/utils.py +++ b/backends/xnnpack/utils/utils.py @@ -1,10 +1,11 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, cast, Optional, Tuple +from typing import Any, cast, List, Optional, Sequence, Tuple import torch @@ -57,6 +58,15 @@ def get_input_node(node: torch.fx.Node, input_index: int) -> torch.fx.Node: return cast(torch.fx.Node, node.args[input_index]) +def normalize_mean_dims(mean_dims: Sequence[int] | int | None, rank: int) -> List[int]: + """Return mean dims as non-negative indices for the given rank.""" + if mean_dims is None: + return list(range(rank)) + if isinstance(mean_dims, int): + mean_dims = [mean_dims] + return [dim % rank for dim in mean_dims] + + def get_relu_fused_node(node: torch.fx.Node) -> Optional[torch.fx.Node]: """ Checks if the current node is only consumed by a relu node and can be fused,