Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions backends/xnnpack/operators/op_mean_dim.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast, Dict, List
from typing import cast, Dict

import torch
from executorch.backends.xnnpack.operators.node_visitor import (
Expand All @@ -18,16 +19,17 @@
XNNGraph,
XNode,
)
from executorch.backends.xnnpack.utils.utils import normalize_mean_dims
from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_FLAG_KEEP_DIMS


@register_node_visitor
class MeanDim(NodeVisitor):
"""
XNNPACK only supports a special case of mean dim in which the operation can be written
as Global Average Pooling. In order to be handled by xnnpack the input tensor must be 4d,
the dimensions to reduce must be the two innermost (-1, -2) or (-2, -1). and the flag
for keepdim must be set to True.
XNNPACK only supports the special case of mean.dim that can be lowered
to Global Average Pooling. The input tensor must be 4D, keepdim must be
True, and the reduced dimensions must normalize to [2, 3] (for example
[2, 3], [3, 2], [-1, -2], or [-2, -1]).
"""

target = "aten.mean.dim"
Expand All @@ -51,10 +53,15 @@ def define_node(
# output
output_id = vals_to_ids[node]

input_shape = get_tensor_value(xnn_graph.xvalues[input_id]).dims
check_or_raise(
len(input_shape) == 4, "Require input to mean.dim be 4 dimensional"
)

# mean dims
mean_dims = cast(List[int], node.args[1])
mean_dims = normalize_mean_dims(node.args[1], len(input_shape))
check_or_raise(
mean_dims == [-1, -2] or mean_dims == [-2, -1],
sorted(mean_dims) == [2, 3],
"XNNPACK only supports mean.dim across the innermost dimensions",
)

Expand All @@ -64,11 +71,6 @@ def define_node(
"XNNPACK only supports mean.dim that keeps dims",
)

Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aten.mean.dim supports a dtype kwarg, but this lowering ignores it and always emits XNNGlobalAvgPooling2d, which will preserve the input dtype. Please add a guard to ensure node.kwargs.get('dtype') is None (or otherwise reject) so models requesting a different dtype don't get incorrect results.

Suggested change
# dtype
check_or_raise(
node.kwargs.get("dtype") is None,
"XNNPACK only supports mean.dim without a dtype override",
)

Copilot uses AI. Check for mistakes.
input_shape = get_tensor_value(xnn_graph.xvalues[input_id]).dims
check_or_raise(
len(input_shape) == 4, "Require input to mean.dim be 4 dimensional"
)

ser_node = XNode(
xnode_union=XNNGlobalAvgPooling2d(
input_id=input_id, output_id=output_id, flags=XNN_FLAG_KEEP_DIMS
Expand Down
19 changes: 13 additions & 6 deletions backends/xnnpack/partition/config/generic_node_configs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -20,7 +21,7 @@
is_quant,
tag_as_implicit_q_dq,
)
from executorch.backends.xnnpack.utils.utils import get_input_node
from executorch.backends.xnnpack.utils.utils import get_input_node, normalize_mean_dims
from executorch.exir.backend.canonical_partitioners.config_partitioner import (
format_target_name,
)
Expand Down Expand Up @@ -515,22 +516,28 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
if not self.check_common_constraints(node, ep):
return False

dims = node.args[1]
output_dims = node.meta["val"].dim()
input_rank = get_input_node(node, 0).meta["val"].dim()
keepdim = len(node.args) >= 3 and bool(node.args[2])
dims = normalize_mean_dims(node.args[1], input_rank)

Comment on lines +519 to 522
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mean.dim has a dtype kwarg (see operator signature). The current constraints ignore node.kwargs['dtype'], but lowering to XNNPACK GlobalAvgPooling cannot honor a requested dtype change. Please explicitly reject partitioning when dtype is set (non-None) to avoid silent dtype mismatches.

Copilot uses AI. Check for mistakes.
if dims not in ([-2, -1], [-1, -2]):
if sorted(dims) != [2, 3]:
why(
node,
reason="mean.dim only supports averaging 4D tensors across the innermost dimensions",
)
return False

if output_dims != 4:
if input_rank != 4:
why(
node,
reason=f"mean.dim only supports averaging 4D tensors, got tensor of rank {output_dims}",
reason=f"mean.dim only supports averaging 4D tensors, got tensor of rank {input_rank}",
)
return False
Comment on lines +519 to 535
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The constraint checks sorted(dims) != [2, 3] before verifying input_rank == 4. For non-4D inputs this will always fail the dims check first, producing a less accurate why(...) reason and making the rank-specific message unreachable. Consider checking input_rank != 4 (and returning) before normalizing/validating dims.

Copilot uses AI. Check for mistakes.

if not keepdim:
why(node, reason="mean.dim only supports keepdim=True")
return False

return True

def supported_precision_types(self) -> List[ConfigPrecisionType]:
Expand Down
24 changes: 20 additions & 4 deletions backends/xnnpack/test/ops/test_mean_dim.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -15,18 +16,19 @@ def setUp(self):
torch._dynamo.reset()

class MeanDim(torch.nn.Module):
def __init__(self, dims):
def __init__(self, dims, keepdim=True):
super().__init__()
self.dims = dims
self.keepdim = keepdim

def forward(self, x):
y = x + x
z = torch.mean(y, self.dims, keepdim=True)
z = torch.mean(y, self.dims, keepdim=self.keepdim)
return z

def _test_mean_dim(self, inputs):
def _test_mean_dim(self, inputs, dims=(-1, -2)):
(
Tester(self.MeanDim((-1, -2)), inputs)
Tester(self.MeanDim(dims), inputs)
.export()
.check_count({"torch.ops.aten.mean.dim": 1})
.to_edge_transform_and_lower()
Expand All @@ -45,6 +47,10 @@ def test_fp32_mean_dim(self):
inputs = (torch.randn(1, 5, 4, 4),)
self._test_mean_dim(inputs)

def test_fp32_mean_dim_positive_dims(self):
inputs = (torch.randn(1, 5, 4, 4),)
self._test_mean_dim(inputs, dims=(2, 3))

def test_fp32_mean_dim_unsupported(self):
"""
XNNPack mean.dim implementation only supports innermost two dimensions. As such,
Expand Down Expand Up @@ -72,6 +78,16 @@ def test_fp32_mean_dim_unsupported_3d(self):
.check_count({"executorch_exir_dialects_edge__ops_aten_mean_dim": 1})
)

def test_fp32_mean_dim_unsupported_keepdim_false(self):
inputs = (torch.randn(1, 5, 4, 4),)
(
Tester(self.MeanDim((-1, -2), keepdim=False), inputs)
.export()
.check_count({"torch.ops.aten.mean.dim": 1})
.to_edge_transform_and_lower()
.check_count({"executorch_exir_dialects_edge__ops_aten_mean_dim": 1})
)

def test_qs8_mean_dim(self):
inputs = (torch.randn(1, 5, 4, 4),)
(
Expand Down
12 changes: 11 additions & 1 deletion backends/xnnpack/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, cast, Optional, Tuple
from typing import Any, cast, List, Optional, Sequence, Tuple

import torch

Expand Down Expand Up @@ -57,6 +58,15 @@ def get_input_node(node: torch.fx.Node, input_index: int) -> torch.fx.Node:
return cast(torch.fx.Node, node.args[input_index])


def normalize_mean_dims(mean_dims: Sequence[int] | int | None, rank: int) -> List[int]:
"""Return mean dims as non-negative indices for the given rank."""
if mean_dims is None:
return list(range(rank))
if isinstance(mean_dims, int):
mean_dims = [mean_dims]
return [dim % rank for dim in mean_dims]


Comment on lines +63 to +69
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

normalize_mean_dims normalizes via dim % rank, which can silently turn out-of-range dims (e.g., dim == rank) into a different valid axis. This can make constraint checks accept an axis specification that PyTorch would treat as invalid. Consider validating each dim is within [-rank, rank-1] (and that rank > 0) before normalizing, and raising/handling invalid dims explicitly.

Suggested change
if mean_dims is None:
return list(range(rank))
if isinstance(mean_dims, int):
mean_dims = [mean_dims]
return [dim % rank for dim in mean_dims]
check_or_raise(rank > 0, f"Expected rank > 0, got rank={rank}")
if mean_dims is None:
return list(range(rank))
if isinstance(mean_dims, int):
mean_dims = [mean_dims]
normalized_dims = []
for dim in mean_dims:
check_or_raise(
-rank <= dim < rank,
f"Dimension out of range (expected to be in range of [{-rank}, {rank - 1}], got {dim})",
)
normalized_dims.append(dim % rank)
return normalized_dims

Copilot uses AI. Check for mistakes.
def get_relu_fused_node(node: torch.fx.Node) -> Optional[torch.fx.Node]:
"""
Checks if the current node is only consumed by a relu node and can be fused,
Expand Down
Loading