-
Notifications
You must be signed in to change notification settings - Fork 966
XNNPACK: accept positive spatial dims in mean.dim via normalization #19206
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # Copyright 2026 Arm Limited and/or its affiliates. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
@@ -20,7 +21,7 @@ | |
| is_quant, | ||
| tag_as_implicit_q_dq, | ||
| ) | ||
| from executorch.backends.xnnpack.utils.utils import get_input_node | ||
| from executorch.backends.xnnpack.utils.utils import get_input_node, normalize_mean_dims | ||
| from executorch.exir.backend.canonical_partitioners.config_partitioner import ( | ||
| format_target_name, | ||
| ) | ||
|
|
@@ -515,22 +516,28 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: | |
| if not self.check_common_constraints(node, ep): | ||
| return False | ||
|
|
||
| dims = node.args[1] | ||
| output_dims = node.meta["val"].dim() | ||
| input_rank = get_input_node(node, 0).meta["val"].dim() | ||
| keepdim = len(node.args) >= 3 and bool(node.args[2]) | ||
| dims = normalize_mean_dims(node.args[1], input_rank) | ||
|
|
||
|
Comment on lines
+519
to
522
|
||
| if dims not in ([-2, -1], [-1, -2]): | ||
| if sorted(dims) != [2, 3]: | ||
| why( | ||
| node, | ||
| reason="mean.dim only supports averaging 4D tensors across the innermost dimensions", | ||
| ) | ||
| return False | ||
|
|
||
| if output_dims != 4: | ||
| if input_rank != 4: | ||
| why( | ||
| node, | ||
| reason=f"mean.dim only supports averaging 4D tensors, got tensor of rank {output_dims}", | ||
| reason=f"mean.dim only supports averaging 4D tensors, got tensor of rank {input_rank}", | ||
| ) | ||
| return False | ||
|
Comment on lines
+519
to
535
|
||
|
|
||
| if not keepdim: | ||
| why(node, reason="mean.dim only supports keepdim=True") | ||
| return False | ||
|
|
||
| return True | ||
|
|
||
| def supported_precision_types(self) -> List[ConfigPrecisionType]: | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,10 +1,11 @@ | ||||||||||||||||||||||||||||||||||||||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||||||||||||||||||||||||||||||||||||
| # All rights reserved. | ||||||||||||||||||||||||||||||||||||||||
| # Copyright 2026 Arm Limited and/or its affiliates. | ||||||||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||||||||
| # This source code is licensed under the BSD-style license found in the | ||||||||||||||||||||||||||||||||||||||||
| # LICENSE file in the root directory of this source tree. | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| from typing import Any, cast, Optional, Tuple | ||||||||||||||||||||||||||||||||||||||||
| from typing import Any, cast, List, Optional, Sequence, Tuple | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
@@ -57,6 +58,15 @@ def get_input_node(node: torch.fx.Node, input_index: int) -> torch.fx.Node: | |||||||||||||||||||||||||||||||||||||||
| return cast(torch.fx.Node, node.args[input_index]) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def normalize_mean_dims(mean_dims: Sequence[int] | int | None, rank: int) -> List[int]: | ||||||||||||||||||||||||||||||||||||||||
| """Return mean dims as non-negative indices for the given rank.""" | ||||||||||||||||||||||||||||||||||||||||
| if mean_dims is None: | ||||||||||||||||||||||||||||||||||||||||
| return list(range(rank)) | ||||||||||||||||||||||||||||||||||||||||
| if isinstance(mean_dims, int): | ||||||||||||||||||||||||||||||||||||||||
| mean_dims = [mean_dims] | ||||||||||||||||||||||||||||||||||||||||
| return [dim % rank for dim in mean_dims] | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+63
to
+69
|
||||||||||||||||||||||||||||||||||||||||
| if mean_dims is None: | |
| return list(range(rank)) | |
| if isinstance(mean_dims, int): | |
| mean_dims = [mean_dims] | |
| return [dim % rank for dim in mean_dims] | |
| check_or_raise(rank > 0, f"Expected rank > 0, got rank={rank}") | |
| if mean_dims is None: | |
| return list(range(rank)) | |
| if isinstance(mean_dims, int): | |
| mean_dims = [mean_dims] | |
| normalized_dims = [] | |
| for dim in mean_dims: | |
| check_or_raise( | |
| -rank <= dim < rank, | |
| f"Dimension out of range (expected to be in range of [{-rank}, {rank - 1}], got {dim})", | |
| ) | |
| normalized_dims.append(dim % rank) | |
| return normalized_dims |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aten.mean.dimsupports adtypekwarg, but this lowering ignores it and always emitsXNNGlobalAvgPooling2d, which will preserve the input dtype. Please add a guard to ensurenode.kwargs.get('dtype') is None(or otherwise reject) so models requesting a different dtype don't get incorrect results.