Skip to content

XNNPACK: accept positive spatial dims in mean.dim via normalization#19206

Open
mansnils wants to merge 1 commit intopytorch:mainfrom
mansnils:xnnpack
Open

XNNPACK: accept positive spatial dims in mean.dim via normalization#19206
mansnils wants to merge 1 commit intopytorch:mainfrom
mansnils:xnnpack

Conversation

@mansnils
Copy link
Copy Markdown
Collaborator

@mansnils mansnils commented Apr 29, 2026

Normalize mean.dim axis arguments before validation and partitioning so equivalent forms such as [2, 3] and [-2, -1] are accepted. Previously, only negative dim forms were allowed, causing valid 4D spatial reductions with positive dims to be rejected.

This preserves existing XNNPACK constraints (4D input, keepdim=True, reduction over the innermost two dimensions) while removing dependence on axis spelling.

Add a test covering positive spatial dims.

This reduces EdgeTAM end-to-end runtime by ~6–7% on SME2 enabled devices.

cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @Sebastian-Larsson @robell

Normalize mean.dim axis arguments before validation and partitioning so
equivalent forms such as [2, 3] and [-2, -1] are accepted. Previously,
only negative dim forms were allowed, causing valid 4D spatial
reductions with positive dims to be rejected.

This preserves existing XNNPACK constraints (4D input, keepdim=True,
reduction over the innermost two dimensions) while removing dependence
on axis spelling.

Add a test covering positive spatial dims.

This reduces EdgeTAM end-to-end runtime by ~6–7% on SME2 enabled
devices.

Signed-off-by: Måns Nilsson <mans.nilsson@arm.com>
Change-Id: I8a56568dbe79fa3bb8327b83c453b240795491ab
Copilot AI review requested due to automatic review settings April 29, 2026 14:39
@mansnils mansnils requested a review from digantdesai as a code owner April 29, 2026 14:39
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 29, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19206

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 Awaiting Approval, 4 New Failures, 1 Cancelled Job, 4 Unrelated Failures

As of commit 2c1369e with merge base eb19f24 (image):

AWAITING APPROVAL - The following workflow needs approval before CI can run:

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 29, 2026
@mansnils mansnils added partner: arm For backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm and removed CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. labels Apr 29, 2026
@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Normalize mean.dim axis arguments so XNNPACK accepts equivalent positive/negative spatial dim specifications (e.g., (2, 3) and (-2, -1)) while preserving existing lowering constraints.

Changes:

  • Added normalize_mean_dims(...) helper to canonicalize mean reduction dims to non-negative indices.
  • Updated XNNPACK partitioning + lowering checks to validate normalized dims and keepdim=True.
  • Added test coverage for positive spatial dims and keepdim=False unsupported behavior.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.

File Description
backends/xnnpack/utils/utils.py Introduces dim-normalization helper for mean.dim axis arguments.
backends/xnnpack/partition/config/generic_node_configs.py Uses normalized dims in MeanDimConfig constraint checks; adds keepdim=True constraint.
backends/xnnpack/operators/op_mean_dim.py Uses normalized dims during serialization/lowering to XNNPACK GlobalAvgPooling.
backends/xnnpack/test/ops/test_mean_dim.py Adds tests for positive dims support and keepdim=False unsupported case.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +519 to 522
input_rank = get_input_node(node, 0).meta["val"].dim()
keepdim = len(node.args) >= 3 and bool(node.args[2])
dims = normalize_mean_dims(node.args[1], input_rank)

Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mean.dim has a dtype kwarg (see operator signature). The current constraints ignore node.kwargs['dtype'], but lowering to XNNPACK GlobalAvgPooling cannot honor a requested dtype change. Please explicitly reject partitioning when dtype is set (non-None) to avoid silent dtype mismatches.

Copilot uses AI. Check for mistakes.
Comment on lines +519 to 535
input_rank = get_input_node(node, 0).meta["val"].dim()
keepdim = len(node.args) >= 3 and bool(node.args[2])
dims = normalize_mean_dims(node.args[1], input_rank)

if dims not in ([-2, -1], [-1, -2]):
if sorted(dims) != [2, 3]:
why(
node,
reason="mean.dim only supports averaging 4D tensors across the innermost dimensions",
)
return False

if output_dims != 4:
if input_rank != 4:
why(
node,
reason=f"mean.dim only supports averaging 4D tensors, got tensor of rank {output_dims}",
reason=f"mean.dim only supports averaging 4D tensors, got tensor of rank {input_rank}",
)
return False
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The constraint checks sorted(dims) != [2, 3] before verifying input_rank == 4. For non-4D inputs this will always fail the dims check first, producing a less accurate why(...) reason and making the rank-specific message unreachable. Consider checking input_rank != 4 (and returning) before normalizing/validating dims.

Copilot uses AI. Check for mistakes.
@@ -64,11 +71,6 @@ def define_node(
"XNNPACK only supports mean.dim that keeps dims",
)

Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aten.mean.dim supports a dtype kwarg, but this lowering ignores it and always emits XNNGlobalAvgPooling2d, which will preserve the input dtype. Please add a guard to ensure node.kwargs.get('dtype') is None (or otherwise reject) so models requesting a different dtype don't get incorrect results.

Suggested change
# dtype
check_or_raise(
node.kwargs.get("dtype") is None,
"XNNPACK only supports mean.dim without a dtype override",
)

Copilot uses AI. Check for mistakes.
Comment on lines +63 to +69
if mean_dims is None:
return list(range(rank))
if isinstance(mean_dims, int):
mean_dims = [mean_dims]
return [dim % rank for dim in mean_dims]


Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

normalize_mean_dims normalizes via dim % rank, which can silently turn out-of-range dims (e.g., dim == rank) into a different valid axis. This can make constraint checks accept an axis specification that PyTorch would treat as invalid. Consider validating each dim is within [-rank, rank-1] (and that rank > 0) before normalizing, and raising/handling invalid dims explicitly.

Suggested change
if mean_dims is None:
return list(range(rank))
if isinstance(mean_dims, int):
mean_dims = [mean_dims]
return [dim % rank for dim in mean_dims]
check_or_raise(rank > 0, f"Expected rank > 0, got rank={rank}")
if mean_dims is None:
return list(range(rank))
if isinstance(mean_dims, int):
mean_dims = [mean_dims]
normalized_dims = []
for dim in mean_dims:
check_or_raise(
-rank <= dim < rank,
f"Dimension out of range (expected to be in range of [{-rank}, {rank - 1}], got {dim})",
)
normalized_dims.append(dim % rank)
return normalized_dims

Copilot uses AI. Check for mistakes.
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 29, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. partner: arm For backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants