Skip to content

Commit 11f363a

Browse files
Arm backend: Generalize RemovePermutesAroundElementwiseTosaOps (#20238)
- Use is_param_node for finding constant placeholders - Ensure ops are not modified by multiple subgraphs --------- Signed-off-by: Adrian Lundell <adrian.lundell@arm.com>
1 parent cdb9413 commit 11f363a

5 files changed

Lines changed: 179 additions & 9 deletions

File tree

backends/arm/_passes/arm_pass_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,7 @@ def _tosa_pipeline(
618618
RewriteMatmulPass(),
619619
RewritePadPass(),
620620
FuseViewCopyTransformPass(),
621-
RemovePermutesAroundElementwiseTosaOps(),
621+
RemovePermutesAroundElementwiseTosaOps(exported_program),
622622
CanonicalizeViewCopyPermutePass(),
623623
FuseCascadedTransposeOrPermuteOps(),
624624
RewriteHighRankSingletonPermutePass(),

backends/arm/_passes/remove_permutes_around_elementwise_tosa_ops.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,19 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import torch
7+
8+
from executorch.backends.arm._passes.arm_pass_utils import is_param_node
69
from executorch.backends.arm._passes.insert_table_ops import TableOps
710
from executorch.backends.transforms.remove_permutes_around_elementwise_ops import (
811
RemovePermutesAroundElementwiseOps,
912
)
13+
from executorch.exir import ExportedProgram
1014
from executorch.exir.dialects._ops import ops as exir_ops
1115

1216

1317
class RemovePermutesAroundElementwiseTosaOps(RemovePermutesAroundElementwiseOps):
14-
def __init__(self) -> None:
18+
def __init__(self, exported_program: ExportedProgram) -> None:
1519
super().__init__(
1620
extra_permutable_ops={
1721
*TableOps.unary_table_ops.keys(),
@@ -20,16 +24,19 @@ def __init__(self) -> None:
2024
exir_ops.backend.tosa.TABLE.default,
2125
}
2226
)
27+
self.exported_program = exported_program
28+
29+
def _is_constant(self, node: torch.fx.Node) -> bool:
30+
# Override fragile string match check with exported program check
31+
return super()._is_constant(node) or is_param_node(self.exported_program, node)
2332

2433
def permute_subgraph(self, subgraph) -> bool:
25-
# Original function will always permute constant nodes which is wrong for table ops
26-
# Remove constant tosa.TABLE edges before running full function
34+
# TABLE lookup inputs are already tied to the table layout.
2735
new_constant_edges_in = set()
2836
for const_node, user_node in subgraph.constant_edges_in:
2937
if user_node.target == exir_ops.backend.tosa.TABLE.default:
3038
continue
31-
else:
32-
new_constant_edges_in.add((const_node, user_node))
39+
new_constant_edges_in.add((const_node, user_node))
3340

3441
subgraph.constant_edges_in = new_constant_edges_in
3542
return super().permute_subgraph(subgraph)

backends/arm/test/misc/test_transpose_counts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def forward(self, x: torch.Tensor):
453453
Model4ConvLstmLinearLayerNorm(), (torch.randn(2, 8, 32),), 3
454454
),
455455
"model_5_dwconv_gelu_layernorm_avgpool": TransposeCountCase(
456-
Model5DwConvGeluLayerNormAvgPool(), (torch.randn(1, 8, 16, 16),), 4
456+
Model5DwConvGeluLayerNormAvgPool(), (torch.randn(1, 8, 16, 16),), 2
457457
),
458458
"model_6_gru_linear": TransposeCountCase(
459459
Model6GruLinear(), (torch.randn(2, 16, 8),), 2

backends/arm/test/passes/test_remove_permutes_around_elementwise_tosa_ops.py

Lines changed: 139 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from types import SimpleNamespace
7+
from typing import cast
8+
69
import torch
710
from executorch.backends.arm._passes.remove_permutes_around_elementwise_tosa_ops import (
811
RemovePermutesAroundElementwiseTosaOps,
@@ -11,12 +14,29 @@
1114
TosaLoweringContext,
1215
TosaSpecification,
1316
)
17+
from executorch.exir import ExportedProgram
1418
from executorch.exir.dialects._ops import ops as exir_ops
1519

1620
TOSA_INT_SPEC = TosaSpecification.create_from_string("TOSA-1.0+INT")
21+
TOSA_FP_SPEC = TosaSpecification.create_from_string("TOSA-1.0+FP")
1722
PERMUTE_TARGET = exir_ops.edge.aten.permute_copy.default
1823
RESCALE_TARGET = exir_ops.backend.tosa.RESCALE.default
19-
TABLE_TARGET = exir_ops.backend.tosa.TABLE.default
24+
MUL_TARGET = exir_ops.edge.aten.mul.Tensor
25+
ADD_TARGET = exir_ops.edge.aten.add.Tensor
26+
ERF_TARGET = exir_ops.edge.aten.erf.default
27+
28+
29+
def _fake_exported_program() -> ExportedProgram:
30+
return cast(
31+
ExportedProgram,
32+
SimpleNamespace(
33+
graph_signature=SimpleNamespace(
34+
inputs_to_buffers={},
35+
inputs_to_lifted_tensor_constants={},
36+
inputs_to_parameters={},
37+
)
38+
),
39+
)
2040

2141

2242
def _count_nodes(graph_module: torch.fx.GraphModule, target) -> int:
@@ -52,8 +72,125 @@ def test_remove_permutes_around_rescale_tosa_INT() -> None:
5272
graph_module = torch.fx.GraphModule({}, graph)
5373

5474
with TosaLoweringContext(TOSA_INT_SPEC):
55-
result = RemovePermutesAroundElementwiseTosaOps().call(graph_module)
75+
result = RemovePermutesAroundElementwiseTosaOps(_fake_exported_program()).call(
76+
graph_module
77+
)
5678

5779
assert result.modified
5880
assert _count_nodes(result.graph_module, PERMUTE_TARGET) == 0
5981
assert _count_nodes(result.graph_module, RESCALE_TARGET) == 1
82+
83+
84+
def test_remove_permutes_around_gelu_with_folded_scalar_constants_tosa_FP() -> None:
85+
graph = torch.fx.Graph()
86+
x = graph.placeholder("x")
87+
x.meta["val"] = torch.randn(1, 2, 3, 4)
88+
89+
scalar_constants = []
90+
for i in range(3):
91+
const = graph.placeholder(f"c_scalar_{i}")
92+
const.meta["val"] = torch.randn(1, 1, 1, 1)
93+
scalar_constants.append(const)
94+
95+
permute_in = graph.create_node(
96+
"call_function",
97+
PERMUTE_TARGET,
98+
args=(x, [0, 2, 3, 1]),
99+
)
100+
permute_in.meta["val"] = torch.randn(1, 3, 4, 2)
101+
mul_0 = graph.create_node(
102+
"call_function",
103+
MUL_TARGET,
104+
args=(permute_in, scalar_constants[0]),
105+
)
106+
mul_0.meta["val"] = torch.randn(1, 3, 4, 2)
107+
erf = graph.create_node("call_function", ERF_TARGET, args=(mul_0,))
108+
erf.meta["val"] = torch.randn(1, 3, 4, 2)
109+
add = graph.create_node(
110+
"call_function",
111+
ADD_TARGET,
112+
args=(erf, scalar_constants[1]),
113+
)
114+
add.meta["val"] = torch.randn(1, 3, 4, 2)
115+
mul_1 = graph.create_node(
116+
"call_function",
117+
MUL_TARGET,
118+
args=(add, scalar_constants[2]),
119+
)
120+
mul_1.meta["val"] = torch.randn(1, 3, 4, 2)
121+
mul_2 = graph.create_node(
122+
"call_function",
123+
MUL_TARGET,
124+
args=(permute_in, mul_1),
125+
)
126+
mul_2.meta["val"] = torch.randn(1, 3, 4, 2)
127+
permute_out = graph.create_node(
128+
"call_function",
129+
PERMUTE_TARGET,
130+
args=(mul_2, [0, 3, 1, 2]),
131+
)
132+
permute_out.meta["val"] = torch.randn(1, 2, 3, 4)
133+
graph.output(permute_out)
134+
135+
graph_module = torch.fx.GraphModule({}, graph)
136+
137+
with TosaLoweringContext(TOSA_FP_SPEC):
138+
result = RemovePermutesAroundElementwiseTosaOps(_fake_exported_program()).call(
139+
graph_module
140+
)
141+
142+
assert result.modified
143+
assert _count_nodes(result.graph_module, PERMUTE_TARGET) == 3
144+
assert _count_nodes(result.graph_module, ERF_TARGET) == 1
145+
146+
147+
def test_remove_permutes_skips_stale_shared_boundary_subgraph_tosa_FP() -> None:
148+
graph = torch.fx.Graph()
149+
x = graph.placeholder("x")
150+
x.meta["val"] = torch.randn(1, 16, 16, 8)
151+
152+
channel_const = graph.placeholder("p_layer_norm_weight")
153+
channel_const.meta["val"] = torch.randn(1, 1, 1, 8)
154+
155+
permute_in = graph.create_node(
156+
"call_function",
157+
PERMUTE_TARGET,
158+
args=(x, [0, 3, 1, 2]),
159+
)
160+
permute_in.meta["val"] = torch.randn(1, 8, 16, 16)
161+
first_mul = graph.create_node(
162+
"call_function",
163+
MUL_TARGET,
164+
args=(permute_in, permute_in),
165+
)
166+
first_mul.meta["val"] = torch.randn(1, 8, 16, 16)
167+
shared_permute = graph.create_node(
168+
"call_function",
169+
PERMUTE_TARGET,
170+
args=(first_mul, [0, 2, 3, 1]),
171+
)
172+
shared_permute.meta["val"] = torch.randn(1, 16, 16, 8)
173+
second_mul = graph.create_node(
174+
"call_function",
175+
MUL_TARGET,
176+
args=(shared_permute, channel_const),
177+
)
178+
second_mul.meta["val"] = torch.randn(1, 16, 16, 8)
179+
permute_out = graph.create_node(
180+
"call_function",
181+
PERMUTE_TARGET,
182+
args=(second_mul, [0, 3, 1, 2]),
183+
)
184+
permute_out.meta["val"] = torch.randn(1, 8, 16, 16)
185+
graph.output(permute_out)
186+
187+
graph_module = torch.fx.GraphModule({}, graph)
188+
189+
with TosaLoweringContext(TOSA_FP_SPEC):
190+
result = RemovePermutesAroundElementwiseTosaOps(_fake_exported_program()).call(
191+
graph_module
192+
)
193+
194+
assert result.modified
195+
assert _count_nodes(result.graph_module, PERMUTE_TARGET) == 1
196+
assert second_mul.args[1] is channel_const

backends/transforms/remove_permutes_around_elementwise_ops.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,10 @@ def is_node_permutable(self, node: torch.fx.Node) -> bool:
400400
return self._is_pointwise(node.target)
401401

402402
def permute_subgraph(self, subgraph: Subgraph) -> bool: # noqa: C901
403+
# Ensure that the subgraph's edges have not been modified by an earlier rewrite before applying changes.
404+
if not self._subgraph_edges_are_current(subgraph):
405+
return False
406+
403407
# Validate: every view_copy node's permutation rank must match its
404408
# input tensor rank. A mismatch can occur when a squeeze/unsqueeze
405409
# view is reached via upstream traversal with a permutation that was
@@ -495,6 +499,28 @@ def permute_subgraph(self, subgraph: Subgraph) -> bool: # noqa: C901
495499

496500
return True
497501

502+
def _subgraph_edges_are_current(self, subgraph: Subgraph) -> bool:
503+
"""Return false if an earlier rewrite invalidated this candidate."""
504+
for inp, out in subgraph.edges_in:
505+
if (
506+
inp.target != exir_ops.edge.aten.permute_copy.default
507+
or inp not in out.all_input_nodes
508+
):
509+
return False
510+
511+
for inp, out in subgraph.edges_out:
512+
if (
513+
out.target != exir_ops.edge.aten.permute_copy.default
514+
or out not in inp.users
515+
):
516+
return False
517+
518+
for const_node, user_node in subgraph.constant_edges_in:
519+
if const_node not in user_node.all_input_nodes:
520+
return False
521+
522+
return True
523+
498524
def update_cat(self, node: torch.fx.Node, start_permute: list[int]) -> None:
499525
dim = get_arg(node, "dim", int)
500526
set_arg(node, "dim", start_permute[dim])

0 commit comments

Comments
 (0)