Skip to content

Commit 7f02bd7

Browse files
authored
Qualcomm AI Engine Direct - Fix for documentation and reuse dead_code_elimination_pass (#18644)
1 parent 4c67d96 commit 7f02bd7

40 files changed

Lines changed: 81 additions & 80 deletions

backends/qualcomm/_passes/canonicalize_conv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter
1212
from executorch.exir.pass_base import ExportPass, PassResult
13+
from executorch.exir.passes import dead_code_elimination_pass
1314
from torch._guards import detect_fake_mode
1415

1516
from .utils import append_qdq, copy_meta
@@ -199,6 +200,5 @@ def call(self, graph_module: torch.fx.GraphModule):
199200
for user in node.users.copy():
200201
user.replace_input_with(node, squeeze_node)
201202

202-
graph.eliminate_dead_code()
203-
graph_module.recompile()
203+
dead_code_elimination_pass(graph_module)
204204
return PassResult(graph_module, True)

backends/qualcomm/_passes/convert_bmm_to_matmul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
from executorch.exir.dialects._ops import ops as exir_ops
1212
from executorch.exir.pass_base import ExportPass, PassResult
13+
from executorch.exir.passes import dead_code_elimination_pass
1314
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1415

1516

@@ -78,6 +79,5 @@ def call(self, graph_module: torch.fx.GraphModule):
7879
for user in output.users.copy():
7980
user.replace_input_with(output, matmul_node)
8081

81-
graph.eliminate_dead_code()
82-
graph_module.recompile()
82+
dead_code_elimination_pass(graph_module)
8383
return PassResult(graph_module, True)

backends/qualcomm/_passes/convert_linear_to_conv2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from executorch.backends.qualcomm._passes.utils import append_qdq, copy_meta
99
from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter
1010
from executorch.exir.pass_base import ExportPass, PassResult
11+
from executorch.exir.passes import dead_code_elimination_pass
1112
from torch.fx import GraphModule
1213
from torchao.quantization.pt2e.utils import get_new_attr_name_with_prefix
1314

@@ -227,6 +228,5 @@ def call(self, graph_module: GraphModule):
227228
node.replace_all_uses_with(y)
228229
graph.erase_node(node)
229230

230-
graph.eliminate_dead_code()
231-
graph_module.recompile()
231+
dead_code_elimination_pass(graph_module)
232232
return PassResult(graph_module, True)

backends/qualcomm/_passes/convert_square_to_pow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66
import torch
77
from executorch.exir.pass_base import ExportPass, PassResult
8+
from executorch.exir.passes import dead_code_elimination_pass
89

910
from .utils import copy_meta
1011

@@ -33,6 +34,5 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
3334
for user in node.users.copy():
3435
user.replace_input_with(node, pow_node)
3536

36-
graph.eliminate_dead_code()
37-
graph_module.recompile()
37+
dead_code_elimination_pass(graph_module)
3838
return PassResult(graph_module, True)

backends/qualcomm/_passes/decompose_any.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
from executorch.exir import to_edge
99
from executorch.exir.pass_base import ExportPass, PassResult
10+
from executorch.exir.passes import dead_code_elimination_pass
1011

1112
from .utils import merge_decomposed_graph
1213

@@ -59,6 +60,5 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
5960
)
6061
graph.erase_node(node)
6162

62-
graph.eliminate_dead_code()
63-
graph_module.recompile()
63+
dead_code_elimination_pass(graph_module)
6464
return PassResult(graph_module, True)

backends/qualcomm/_passes/decompose_binary_alpha.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import torch
88
from executorch.exir.pass_base import ExportPass, PassResult
9+
from executorch.exir.passes import dead_code_elimination_pass
910

1011
from .utils import copy_meta
1112

@@ -56,6 +57,5 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
5657
mul_node,
5758
)
5859

59-
graph.eliminate_dead_code()
60-
graph_module.recompile()
60+
dead_code_elimination_pass(graph_module)
6161
return PassResult(graph_module, True)

backends/qualcomm/_passes/decompose_cdist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import torch
88
from executorch.exir.pass_base import ExportPass, PassResult
9+
from executorch.exir.passes import dead_code_elimination_pass
910

1011
from .utils import merge_decomposed_graph
1112

@@ -64,6 +65,5 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
6465
)
6566
graph.erase_node(node)
6667

67-
graph.eliminate_dead_code()
68-
graph_module.recompile()
68+
dead_code_elimination_pass(graph_module)
6969
return PassResult(graph_module, True)

backends/qualcomm/_passes/decompose_col_im.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
from executorch.exir.dialects._ops import ops as exir_ops
88
from executorch.exir.pass_base import ExportPass, PassResult
9+
from executorch.exir.passes import dead_code_elimination_pass
910

1011
from .utils import copy_meta
1112

@@ -117,5 +118,5 @@ def _decompose_col2im(self, graph_module: torch.fx.GraphModule):
117118
def call(self, graph_module: torch.fx.GraphModule):
118119
self._decompose_im2col(graph_module)
119120
self._decompose_col2im(graph_module)
120-
graph_module.recompile()
121+
dead_code_elimination_pass(graph_module)
121122
return PassResult(graph_module, True)

backends/qualcomm/_passes/decompose_einsum.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import torch
88
from executorch.exir.pass_base import ExportPass, PassResult
9+
from executorch.exir.passes import dead_code_elimination_pass
910
from torch.fx.experimental.proxy_tensor import make_fx
1011

1112
from .utils import merge_decomposed_graph
@@ -46,6 +47,5 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
4647
)
4748
graph.erase_node(node)
4849

49-
graph.eliminate_dead_code()
50-
graph_module.recompile()
50+
dead_code_elimination_pass(graph_module)
5151
return PassResult(graph_module, True)

backends/qualcomm/_passes/decompose_expm1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import torch
88
from executorch.exir.pass_base import ExportPass, PassResult
9+
from executorch.exir.passes import dead_code_elimination_pass
910

1011
from .utils import copy_meta
1112

@@ -41,6 +42,5 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
4142
for user in node.users.copy():
4243
user.replace_input_with(node, sub_node)
4344

44-
graph.eliminate_dead_code()
45-
graph_module.recompile()
45+
dead_code_elimination_pass(graph_module)
4646
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)