Skip to content

Commit 4c67d96

Browse files
authored
Qualcomm AI Engine Direct - Adding QNN backend support for acos core ATen Op (#18743)
1 parent 037cbeb commit 4c67d96

7 files changed

Lines changed: 169 additions & 21 deletions

File tree

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .convert_linear_to_conv2d import ConvertLinearToConv2d
1414
from .convert_mha_to_sha import ConvertMhaToSha
1515
from .convert_square_to_pow import ConvertSquareToPow
16+
from .decompose_acos import DecomposeAcos
1617
from .decompose_any import DecomposeAny
1718
from .decompose_binary_alpha import DecomposeBinaryAlpha
1819
from .decompose_cdist import DecomposeCDist
@@ -65,6 +66,7 @@
6566
ConvertLinearToConv2d,
6667
ConvertMhaToSha,
6768
ConvertSquareToPow,
69+
DecomposeAcos,
6870
DecomposeAny,
6971
DecomposeBinaryAlpha,
7072
DecomposeCDist,
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
10+
from executorch.exir.pass_base import ExportPass, PassResult
11+
12+
from .utils import copy_meta, get_const_node
13+
14+
15+
class DecomposeAcos(ExportPass):
16+
"""
17+
Decompose acos using the identity: acos(x) = π/2 - asin(x).
18+
"""
19+
20+
def __init__(self):
21+
super(DecomposeAcos, self).__init__()
22+
self.acos_targets = {
23+
torch.ops.aten.acos.default,
24+
exir_ops.edge.aten.acos.default,
25+
}
26+
27+
def call(self, graph_module: torch.fx.GraphModule):
28+
graph = graph_module.graph
29+
30+
acos_nodes = [
31+
n
32+
for n in graph.nodes
33+
if n.op == "call_function" and n.target in self.acos_targets
34+
]
35+
if not acos_nodes:
36+
return PassResult(graph_module, False)
37+
38+
pi_half = torch.pi / 2.0
39+
pi_half_node = None
40+
41+
for node in acos_nodes:
42+
input_node = node.args[0]
43+
is_edge = isinstance(node.target, EdgeOpOverload)
44+
45+
asin_op = (
46+
exir_ops.edge.aten.asin.default
47+
if is_edge
48+
else torch.ops.aten.asin.default
49+
)
50+
sub_op = (
51+
exir_ops.edge.aten.sub.Tensor if is_edge else torch.ops.aten.sub.Tensor
52+
)
53+
54+
if is_edge and pi_half_node is None:
55+
pi_half_node = get_const_node(
56+
graph, graph_module, "_pi_half_constant", pi_half, node
57+
)
58+
59+
sub_arg = pi_half_node if is_edge else pi_half
60+
61+
with graph.inserting_before(node):
62+
asin_node = graph.create_node("call_function", asin_op, (input_node,))
63+
asin_node.meta = copy_meta(node.meta)
64+
65+
sub_node = graph.create_node(
66+
"call_function", sub_op, (sub_arg, asin_node)
67+
)
68+
sub_node.meta = copy_meta(node.meta)
69+
70+
for user in node.users.copy():
71+
user.replace_input_with(node, sub_node)
72+
73+
graph.eliminate_dead_code()
74+
graph_module.recompile()
75+
return PassResult(graph_module, True)

backends/qualcomm/_passes/decompose_log_variants.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,31 +31,29 @@ class DecomposeLogVariants(ExportPass):
3131
def __init__(self) -> None:
3232
super().__init__()
3333
self._dispatcher = {
34-
# Edge dialect (post-to_edge) - FP
34+
# Edge dialect (post-to_edge)
3535
exir_ops.edge.aten.log10.default: partial(self._decompose_log_n, n=10),
3636
exir_ops.edge.aten.log2.default: partial(self._decompose_log_n, n=2),
3737
exir_ops.edge.aten.log1p.default: partial(self._decompose_log_p, p=1),
38-
# ATen dialect (pre-to_edge) - Quantized
38+
# ATen dialect (pre-to_edge)
3939
torch.ops.aten.log10.default: partial(self._decompose_log_n, n=10),
4040
torch.ops.aten.log2.default: partial(self._decompose_log_n, n=2),
4141
torch.ops.aten.log1p.default: partial(self._decompose_log_p, p=1),
4242
}
4343

44-
def _decompose_log_n(self, node, graph, graph_module, n):
44+
def _decompose_log_n(self, node, graph, graph_module, const_cache, n):
4545
input_node = node.args[0]
4646
is_edge = node.target in self._EDGE_OPS
4747

4848
if is_edge:
4949
log_op = exir_ops.edge.aten.log.default
5050
div_op = exir_ops.edge.aten.div.Tensor
51-
div_arg = get_const_node(
52-
graph,
53-
graph_module,
54-
f"_log_base_{n}_constant",
55-
math.log(n),
56-
node,
57-
)
58-
51+
attr_name = f"_log_base_{n}_constant"
52+
if attr_name not in const_cache:
53+
const_cache[attr_name] = get_const_node(
54+
graph, graph_module, attr_name, math.log(n), node
55+
)
56+
div_arg = const_cache[attr_name]
5957
else:
6058
log_op = torch.ops.aten.log.default
6159
div_op = torch.ops.aten.div.Tensor
@@ -74,21 +72,19 @@ def _decompose_log_n(self, node, graph, graph_module, n):
7472
for user in node.users.copy():
7573
user.replace_input_with(node, div_node)
7674

77-
def _decompose_log_p(self, node, graph, graph_module, p):
75+
def _decompose_log_p(self, node, graph, graph_module, const_cache, p):
7876
input_node = node.args[0]
7977
is_edge = node.target in self._EDGE_OPS
8078

8179
if is_edge:
8280
add_op = exir_ops.edge.aten.add.Tensor
8381
log_op = exir_ops.edge.aten.log.default
84-
add_arg = get_const_node(
85-
graph,
86-
graph_module,
87-
f"_log1p_addend_{p}_constant",
88-
p,
89-
node,
90-
)
91-
82+
attr_name = f"_log1p_addend_{p}_constant"
83+
if attr_name not in const_cache:
84+
const_cache[attr_name] = get_const_node(
85+
graph, graph_module, attr_name, p, node
86+
)
87+
add_arg = const_cache[attr_name]
9288
else:
9389
add_op = torch.ops.aten.add.Tensor
9490
log_op = torch.ops.aten.log.default
@@ -107,10 +103,11 @@ def _decompose_log_p(self, node, graph, graph_module, p):
107103

108104
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
109105
graph = graph_module.graph
106+
const_cache = {}
110107

111108
for node in list(graph.nodes):
112109
if node.target in self._dispatcher:
113-
self._dispatcher[node.target](node, graph, graph_module)
110+
self._dispatcher[node.target](node, graph, graph_module, const_cache)
114111

115112
graph.eliminate_dead_code()
116113
graph_module.recompile()

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ConvertLinearToConv2d,
1919
ConvertMhaToSha,
2020
ConvertSquareToPow,
21+
DecomposeAcos,
2122
DecomposeAny,
2223
DecomposeBinaryAlpha,
2324
DecomposeCDist,
@@ -96,6 +97,7 @@ def get_capture_program_passes():
9697
(AnnotateStack, True),
9798
(AnnotateUnbind, True),
9899
(ConvertBmmToMatmul, False),
100+
(DecomposeAcos, True),
99101
(DecomposeAny, True),
100102
(DecomposeColIm, True),
101103
(DecomposeLogVariants, True),
@@ -213,6 +215,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
213215
self.add_pass(RecomposePixelUnshuffle(quantization_capture=True))
214216
self.add_pass(RecomposeRmsNorm(quantization_capture=True))
215217
self.add_pass(ReplaceArangeArgs())
218+
self.add_pass(DecomposeAcos())
216219
self.add_pass(DecomposeBinaryAlpha())
217220
self.add_pass(DecomposeCDist())
218221
self.add_pass(DecomposeMaxPool3d(quantization_capture=True))

backends/qualcomm/_passes/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def get_passes_dependency_for_capture_program():
6666
AnnotateUnbind,
6767
CanonicalizeConv,
6868
ConvertBmmToMatmul,
69+
DecomposeAcos,
6970
DecomposeAny,
7071
DecomposeColIm,
7172
DecomposeLinalgVectorNorm,
@@ -95,6 +96,7 @@ def get_passes_dependency_for_capture_program():
9596
AnnotateStack: [RemoveRedundancy],
9697
AnnotateUnbind: [RemoveRedundancy],
9798
ConvertBmmToMatmul: [RecomposePixelUnshuffle],
99+
DecomposeAcos: [RemoveRedundancy],
98100
DecomposeAny: [RemoveRedundancy],
99101
DecomposeColIm: [FoldQDQ],
100102
DecomposeLinalgVectorNorm: [RemoveRedundancy],

backends/qualcomm/tests/models.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,22 @@ def forward(self, x):
4141
return torch.abs(x)
4242

4343

44+
class Acos(torch.nn.Module):
45+
def __init__(self):
46+
super().__init__()
47+
48+
def forward(self, x):
49+
return torch.acos(x)
50+
51+
52+
class AcosMultiNode(torch.nn.Module):
53+
def __init__(self):
54+
super().__init__()
55+
56+
def forward(self, x, y):
57+
return torch.acos(x), torch.acos(y)
58+
59+
4460
class AdaptiveMaxPool2D(torch.nn.Module):
4561
def __init__(self, output_size, return_indices=False):
4662
super().__init__()
@@ -1498,6 +1514,14 @@ def forward(self, x):
14981514
return torch.nn.functional.log_softmax(x, dim=-1)
14991515

15001516

1517+
class LogVariantsMultiNode(torch.nn.Module):
1518+
def __init__(self):
1519+
super().__init__()
1520+
1521+
def forward(self, x, y):
1522+
return torch.log10(x), torch.log10(y), torch.log2(x), torch.log1p(x)
1523+
1524+
15011525
class Log10(torch.nn.Module):
15021526
def __init__(self):
15031527
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,19 @@ def test_qnn_backend_abs(self):
125125
sample_input = (torch.randn(1, 2, 3, 4),)
126126
self.lower_module_and_test_output(module, sample_input)
127127

128+
def test_qnn_backend_acos(self):
129+
module = Acos() # noqa: F405
130+
sample_input = (torch.rand(3, 4) * 2 - 1,)
131+
self.lower_module_and_test_output(module, sample_input)
132+
133+
def test_qnn_backend_acos_multi_node(self):
134+
module = AcosMultiNode() # noqa: F405
135+
sample_input = (
136+
torch.tensor([0.0, 0.5, -0.5, 1.0, -1.0]),
137+
torch.tensor([0.1, -0.1, 0.9, -0.9, 0.0]),
138+
)
139+
self.lower_module_and_test_output(module, sample_input)
140+
128141
def test_qnn_backend_adaptive_avg_pool1d(self):
129142
module = AdaptiveAvgPool1D() # noqa: F405
130143
sample_input = (torch.randn(1, 512, 7),)
@@ -1455,6 +1468,14 @@ def test_qnn_backend_log_softmax(self):
14551468
sample_input = (torch.randn([1, 4, 8, 8]),)
14561469
self.lower_module_and_test_output(module, sample_input)
14571470

1471+
def test_qnn_backend_log_variants_multi_node(self):
1472+
module = LogVariantsMultiNode() # noqa: F405
1473+
sample_input = (
1474+
torch.abs(torch.rand(2, 3, 4)) + 0.1,
1475+
torch.abs(torch.rand(2, 3, 4)) + 0.1,
1476+
)
1477+
self.lower_module_and_test_output(module, sample_input)
1478+
14581479
def test_qnn_backend_log10(self):
14591480
module = Log10() # noqa: F405
14601481
sample_input = (torch.abs(torch.rand(2, 5, 1, 3) + 0.1),)
@@ -2362,6 +2383,21 @@ def test_qnn_backend_abs(self):
23622383
module = self.get_qdq_module(module, sample_input)
23632384
self.lower_module_and_test_output(module, sample_input)
23642385

2386+
def test_qnn_backend_acos(self):
2387+
module = Acos() # noqa: F405
2388+
sample_input = (torch.rand(3, 4) * 2 - 1,)
2389+
module = self.get_qdq_module(module, sample_input)
2390+
self.lower_module_and_test_output(module, sample_input)
2391+
2392+
def test_qnn_backend_acos_multi_node(self):
2393+
module = AcosMultiNode() # noqa: F405
2394+
sample_input = (
2395+
torch.tensor([0.0, 0.5, -0.5, 1.0, -1.0]),
2396+
torch.tensor([0.1, -0.1, 0.9, -0.9, 0.0]),
2397+
)
2398+
module = self.get_qdq_module(module, sample_input)
2399+
self.lower_module_and_test_output(module, sample_input)
2400+
23652401
def test_qnn_backend_adaptive_avg_pool1d(self):
23662402
module = AdaptiveAvgPool1D() # noqa: F405
23672403
sample_input = (torch.randn(1, 512, 7),)
@@ -3813,6 +3849,15 @@ def test_qnn_backend_log_softmax(self):
38133849
module = self.get_qdq_module(module, sample_input)
38143850
self.lower_module_and_test_output(module, sample_input)
38153851

3852+
def test_qnn_backend_log_variants_multi_node(self):
3853+
module = LogVariantsMultiNode() # noqa: F405
3854+
sample_input = (
3855+
torch.abs(torch.rand(2, 3, 4)) + 0.1,
3856+
torch.abs(torch.rand(2, 3, 4)) + 0.1,
3857+
)
3858+
module = self.get_qdq_module(module, sample_input)
3859+
self.lower_module_and_test_output(module, sample_input)
3860+
38163861
def test_qnn_backend_log10(self):
38173862
module = Log10() # noqa: F405
38183863
sample_input = (torch.abs(torch.rand(2, 5, 1, 3) + 0.1),)

0 commit comments

Comments
 (0)