Skip to content

Commit 3a13e41

Browse files
Arm backend: Reject Squeeze no-op partition (pytorch#19662)
Reject delegates of squeeze op that are no-op in the partitioner cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani Signed-off-by: Christoffer J.L <christoffer.johanssonlundqvist@arm.com>
1 parent 5eb8492 commit 3a13e41

2 files changed

Lines changed: 38 additions & 1 deletion

File tree

backends/arm/test/ops/test_squeeze.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024-2025 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -16,6 +16,7 @@
1616
from executorch.backends.arm.test.tester.test_pipeline import (
1717
EthosU55PipelineINT,
1818
EthosU85PipelineINT,
19+
OpNotSupportedPipeline,
1920
TosaPipelineFP,
2021
TosaPipelineINT,
2122
VgfPipeline,
@@ -60,6 +61,12 @@ def forward(self, x: torch.Tensor):
6061
return x.squeeze()
6162

6263

64+
unsupported_cases = {
65+
"squeeze_dim_no_effect": lambda: (torch.randn(3, 4, 5), 1),
66+
"squeeze_no_effect": lambda: (torch.randn(3, 4, 5),),
67+
}
68+
69+
6370
##############
6471
## Squeeze ###
6572
##############
@@ -137,6 +144,16 @@ def test_squeeze_dim_vgf_quant(test_data: Tuple):
137144
pipeline.run()
138145

139146

147+
def test_squeeze_no_target_not_delegated() -> None:
148+
pipeline = OpNotSupportedPipeline[input_t1](
149+
Squeeze(),
150+
unsupported_cases["squeeze_no_effect"](),
151+
{"executorch_exir_dialects_edge__ops_aten_squeeze_copy_dims": 1},
152+
n_expected_delegates=0,
153+
)
154+
pipeline.run()
155+
156+
140157
#################
141158
## SqueezeDim ###
142159
#################
@@ -214,6 +231,16 @@ def test_squeeze_dim_vgf_quant_2(test_data: Tuple):
214231
pipeline.run()
215232

216233

234+
def test_squeeze_dim_no_target_not_delegated() -> None:
235+
pipeline = OpNotSupportedPipeline[Tuple[torch.Tensor, int]](
236+
SqueezeDim(),
237+
unsupported_cases["squeeze_dim_no_effect"](),
238+
{"executorch_exir_dialects_edge__ops_aten_squeeze_copy_dims": 1},
239+
n_expected_delegates=0,
240+
)
241+
pipeline.run()
242+
243+
217244
##################
218245
## SqueezeDims ###
219246
##################

backends/arm/tosa/partitioner.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,15 @@ def _is_noop_expand(node: torch.fx.node.Node) -> bool:
104104
return all(m == 1 for m in multiples) and not changes_rank
105105

106106

107+
def _is_noop_squeeze(node: torch.fx.Node) -> bool:
108+
if node.target != exir_ops.edge.aten.squeeze_copy.dims:
109+
return False
110+
else:
111+
input_tensor = get_first_fake_tensor(ensure_type(torch.fx.Node, node.args[0]))
112+
output_tensor = get_first_fake_tensor(node)
113+
return input_tensor.shape == output_tensor.shape
114+
115+
107116
def _is_view_copy(node: torch.fx.node.Node) -> bool:
108117
return node.target == exir_ops.edge.aten.view_copy.default
109118

@@ -388,6 +397,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
388397
or _is_noop_expand(node)
389398
or _is_noop_detach_copy(node)
390399
or _is_noop_to_dim_order_copy(node)
400+
or _is_noop_squeeze(node)
391401
or _is_view_copy(node)
392402
or _is_noop_as_strided_copy(node)
393403
or node.target in Q_OPS

0 commit comments

Comments
 (0)