XNNPACK: accept positive spatial dims in mean.dim via normalization#19206
XNNPACK: accept positive spatial dims in mean.dim via normalization#19206mansnils wants to merge 1 commit intopytorch:mainfrom
Conversation
Normalize mean.dim axis arguments before validation and partitioning so equivalent forms such as [2, 3] and [-2, -1] are accepted. Previously, only negative dim forms were allowed, causing valid 4D spatial reductions with positive dims to be rejected. This preserves existing XNNPACK constraints (4D input, keepdim=True, reduction over the innermost two dimensions) while removing dependence on axis spelling. Add a test covering positive spatial dims. This reduces EdgeTAM end-to-end runtime by ~6–7% on SME2 enabled devices. Signed-off-by: Måns Nilsson <mans.nilsson@arm.com> Change-Id: I8a56568dbe79fa3bb8327b83c453b240795491ab
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19206
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 Awaiting Approval, 4 New Failures, 1 Cancelled Job, 4 Unrelated FailuresAs of commit 2c1369e with merge base eb19f24 ( NEW FAILURES - The following jobs have failed:
CANCELLED JOB - The following job was cancelled. Please retry:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
Normalize mean.dim axis arguments so XNNPACK accepts equivalent positive/negative spatial dim specifications (e.g., (2, 3) and (-2, -1)) while preserving existing lowering constraints.
Changes:
- Added
normalize_mean_dims(...)helper to canonicalize mean reduction dims to non-negative indices. - Updated XNNPACK partitioning + lowering checks to validate normalized dims and
keepdim=True. - Added test coverage for positive spatial dims and
keepdim=Falseunsupported behavior.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
backends/xnnpack/utils/utils.py |
Introduces dim-normalization helper for mean.dim axis arguments. |
backends/xnnpack/partition/config/generic_node_configs.py |
Uses normalized dims in MeanDimConfig constraint checks; adds keepdim=True constraint. |
backends/xnnpack/operators/op_mean_dim.py |
Uses normalized dims during serialization/lowering to XNNPACK GlobalAvgPooling. |
backends/xnnpack/test/ops/test_mean_dim.py |
Adds tests for positive dims support and keepdim=False unsupported case. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| input_rank = get_input_node(node, 0).meta["val"].dim() | ||
| keepdim = len(node.args) >= 3 and bool(node.args[2]) | ||
| dims = normalize_mean_dims(node.args[1], input_rank) | ||
|
|
There was a problem hiding this comment.
mean.dim has a dtype kwarg (see operator signature). The current constraints ignore node.kwargs['dtype'], but lowering to XNNPACK GlobalAvgPooling cannot honor a requested dtype change. Please explicitly reject partitioning when dtype is set (non-None) to avoid silent dtype mismatches.
| input_rank = get_input_node(node, 0).meta["val"].dim() | ||
| keepdim = len(node.args) >= 3 and bool(node.args[2]) | ||
| dims = normalize_mean_dims(node.args[1], input_rank) | ||
|
|
||
| if dims not in ([-2, -1], [-1, -2]): | ||
| if sorted(dims) != [2, 3]: | ||
| why( | ||
| node, | ||
| reason="mean.dim only supports averaging 4D tensors across the innermost dimensions", | ||
| ) | ||
| return False | ||
|
|
||
| if output_dims != 4: | ||
| if input_rank != 4: | ||
| why( | ||
| node, | ||
| reason=f"mean.dim only supports averaging 4D tensors, got tensor of rank {output_dims}", | ||
| reason=f"mean.dim only supports averaging 4D tensors, got tensor of rank {input_rank}", | ||
| ) | ||
| return False |
There was a problem hiding this comment.
The constraint checks sorted(dims) != [2, 3] before verifying input_rank == 4. For non-4D inputs this will always fail the dims check first, producing a less accurate why(...) reason and making the rank-specific message unreachable. Consider checking input_rank != 4 (and returning) before normalizing/validating dims.
| @@ -64,11 +71,6 @@ def define_node( | |||
| "XNNPACK only supports mean.dim that keeps dims", | |||
| ) | |||
|
|
|||
There was a problem hiding this comment.
aten.mean.dim supports a dtype kwarg, but this lowering ignores it and always emits XNNGlobalAvgPooling2d, which will preserve the input dtype. Please add a guard to ensure node.kwargs.get('dtype') is None (or otherwise reject) so models requesting a different dtype don't get incorrect results.
| # dtype | |
| check_or_raise( | |
| node.kwargs.get("dtype") is None, | |
| "XNNPACK only supports mean.dim without a dtype override", | |
| ) |
| if mean_dims is None: | ||
| return list(range(rank)) | ||
| if isinstance(mean_dims, int): | ||
| mean_dims = [mean_dims] | ||
| return [dim % rank for dim in mean_dims] | ||
|
|
||
|
|
There was a problem hiding this comment.
normalize_mean_dims normalizes via dim % rank, which can silently turn out-of-range dims (e.g., dim == rank) into a different valid axis. This can make constraint checks accept an axis specification that PyTorch would treat as invalid. Consider validating each dim is within [-rank, rank-1] (and that rank > 0) before normalizing, and raising/handling invalid dims explicitly.
| if mean_dims is None: | |
| return list(range(rank)) | |
| if isinstance(mean_dims, int): | |
| mean_dims = [mean_dims] | |
| return [dim % rank for dim in mean_dims] | |
| check_or_raise(rank > 0, f"Expected rank > 0, got rank={rank}") | |
| if mean_dims is None: | |
| return list(range(rank)) | |
| if isinstance(mean_dims, int): | |
| mean_dims = [mean_dims] | |
| normalized_dims = [] | |
| for dim in mean_dims: | |
| check_or_raise( | |
| -rank <= dim < rank, | |
| f"Dimension out of range (expected to be in range of [{-rank}, {rank - 1}], got {dim})", | |
| ) | |
| normalized_dims.append(dim % rank) | |
| return normalized_dims |
Normalize mean.dim axis arguments before validation and partitioning so equivalent forms such as [2, 3] and [-2, -1] are accepted. Previously, only negative dim forms were allowed, causing valid 4D spatial reductions with positive dims to be rejected.
This preserves existing XNNPACK constraints (4D input, keepdim=True, reduction over the innermost two dimensions) while removing dependence on axis spelling.
Add a test covering positive spatial dims.
This reduces EdgeTAM end-to-end runtime by ~6–7% on SME2 enabled devices.
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @Sebastian-Larsson @robell