forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdecompose_acos.py
More file actions
75 lines (59 loc) · 2.35 KB
/
decompose_acos.py
File metadata and controls
75 lines (59 loc) · 2.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ExportPass, PassResult
from .utils import copy_meta, get_const_node
class DecomposeAcos(ExportPass):
"""
Decompose acos using the identity: acos(x) = π/2 - asin(x).
"""
def __init__(self):
super(DecomposeAcos, self).__init__()
self.acos_targets = {
torch.ops.aten.acos.default,
exir_ops.edge.aten.acos.default,
}
def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
acos_nodes = [
n
for n in graph.nodes
if n.op == "call_function" and n.target in self.acos_targets
]
if not acos_nodes:
return PassResult(graph_module, False)
pi_half = torch.pi / 2.0
pi_half_node = None
for node in acos_nodes:
input_node = node.args[0]
is_edge = isinstance(node.target, EdgeOpOverload)
asin_op = (
exir_ops.edge.aten.asin.default
if is_edge
else torch.ops.aten.asin.default
)
sub_op = (
exir_ops.edge.aten.sub.Tensor if is_edge else torch.ops.aten.sub.Tensor
)
if is_edge and pi_half_node is None:
pi_half_node = get_const_node(
graph, graph_module, "_pi_half_constant", pi_half, node
)
sub_arg = pi_half_node if is_edge else pi_half
with graph.inserting_before(node):
asin_node = graph.create_node("call_function", asin_op, (input_node,))
asin_node.meta = copy_meta(node.meta)
sub_node = graph.create_node(
"call_function", sub_op, (sub_arg, asin_node)
)
sub_node.meta = copy_meta(node.meta)
for user in node.users.copy():
user.replace_input_with(node, sub_node)
graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)