Skip to content

Commit 5298e29

Browse files
qti-horodnicrascani
authored andcommitted
Qualcomm AI Engine Direct - Adding QNN backend support for log2, log10, log1p core ATen ops (pytorch#18542)
1 parent 0d9e792 commit 5298e29

7 files changed

Lines changed: 203 additions & 3 deletions

File tree

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .decompose_floor_divide import DecomposeFloorDivide
2323
from .decompose_glu import DecomposeGlu
2424
from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm
25+
from .decompose_log_variants import DecomposeLogVariants
2526
from .decompose_maxpool3d import DecomposeMaxPool3d
2627
from .decompose_minmaxdim import DecomposeMinMaxDim
2728
from .decompose_reciprocal import DecomposeReciprocal
@@ -72,6 +73,7 @@
7273
DecomposeFloorDivide,
7374
DecomposeGlu,
7475
DecomposeLinalgVectorNorm,
76+
DecomposeLogVariants,
7577
DecomposeMaxPool3d,
7678
DecomposeMinMaxDim,
7779
DecomposeReciprocal,
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import math
8+
from functools import partial
9+
10+
import torch
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass, PassResult
13+
14+
from .utils import copy_meta, get_const_node
15+
16+
17+
class DecomposeLogVariants(ExportPass):
18+
"""
19+
Decompose log variants [log10, log2, log1p] operations using the identities:
20+
log10(x) = log(x) / log(10)
21+
log2(x) = log(x) / log(2)
22+
log1p(x) = log(1 + x)
23+
"""
24+
25+
_EDGE_OPS = {
26+
exir_ops.edge.aten.log10.default,
27+
exir_ops.edge.aten.log2.default,
28+
exir_ops.edge.aten.log1p.default,
29+
}
30+
31+
def __init__(self) -> None:
32+
super().__init__()
33+
self._dispatcher = {
34+
# Edge dialect (post-to_edge) - FP
35+
exir_ops.edge.aten.log10.default: partial(self._decompose_log_n, n=10),
36+
exir_ops.edge.aten.log2.default: partial(self._decompose_log_n, n=2),
37+
exir_ops.edge.aten.log1p.default: partial(self._decompose_log_p, p=1),
38+
# ATen dialect (pre-to_edge) - Quantized
39+
torch.ops.aten.log10.default: partial(self._decompose_log_n, n=10),
40+
torch.ops.aten.log2.default: partial(self._decompose_log_n, n=2),
41+
torch.ops.aten.log1p.default: partial(self._decompose_log_p, p=1),
42+
}
43+
44+
def _decompose_log_n(self, node, graph, graph_module, n):
45+
input_node = node.args[0]
46+
is_edge = node.target in self._EDGE_OPS
47+
48+
if is_edge:
49+
log_op = exir_ops.edge.aten.log.default
50+
div_op = exir_ops.edge.aten.div.Tensor
51+
div_arg = get_const_node(
52+
graph,
53+
graph_module,
54+
f"_log_base_{n}_constant",
55+
math.log(n),
56+
node,
57+
)
58+
59+
else:
60+
log_op = torch.ops.aten.log.default
61+
div_op = torch.ops.aten.div.Tensor
62+
div_arg = math.log(n)
63+
64+
with graph.inserting_after(input_node):
65+
log_node = graph.create_node("call_function", log_op, (input_node,))
66+
log_node.meta = copy_meta(node.meta)
67+
68+
with graph.inserting_after(log_node):
69+
div_node = graph.create_node(
70+
"call_function", div_op, (log_node, div_arg)
71+
)
72+
div_node.meta = copy_meta(node.meta)
73+
74+
for user in node.users.copy():
75+
user.replace_input_with(node, div_node)
76+
77+
def _decompose_log_p(self, node, graph, graph_module, p):
78+
input_node = node.args[0]
79+
is_edge = node.target in self._EDGE_OPS
80+
81+
if is_edge:
82+
add_op = exir_ops.edge.aten.add.Tensor
83+
log_op = exir_ops.edge.aten.log.default
84+
add_arg = get_const_node(
85+
graph,
86+
graph_module,
87+
f"_log1p_addend_{p}_constant",
88+
p,
89+
node,
90+
)
91+
92+
else:
93+
add_op = torch.ops.aten.add.Tensor
94+
log_op = torch.ops.aten.log.default
95+
add_arg = p
96+
97+
with graph.inserting_after(input_node):
98+
add_node = graph.create_node("call_function", add_op, (input_node, add_arg))
99+
add_node.meta = copy_meta(node.meta)
100+
101+
with graph.inserting_after(add_node):
102+
log_node = graph.create_node("call_function", log_op, (add_node,))
103+
log_node.meta = copy_meta(node.meta)
104+
105+
for user in node.users.copy():
106+
user.replace_input_with(node, log_node)
107+
108+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
109+
graph = graph_module.graph
110+
111+
for node in list(graph.nodes):
112+
if node.target in self._dispatcher:
113+
self._dispatcher[node.target](node, graph, graph_module)
114+
115+
graph.eliminate_dead_code()
116+
graph_module.recompile()
117+
return PassResult(graph_module, True)

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
DecomposeFloorDivide,
2828
DecomposeGlu,
2929
DecomposeLinalgVectorNorm,
30+
DecomposeLogVariants,
3031
DecomposeMaxPool3d,
3132
DecomposeMinMaxDim,
3233
DecomposeReciprocal,
@@ -96,6 +97,7 @@ def get_capture_program_passes():
9697
(ConvertBmmToMatmul, False),
9798
(DecomposeAny, True),
9899
(DecomposeColIm, True),
100+
(DecomposeLogVariants, True),
99101
(DecomposeMaxPool3d, True),
100102
(DecomposeMinMaxDim, True),
101103
(ExpandBroadcastTensorShape, True),
@@ -226,6 +228,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
226228
# TODO: Skip this pass for CPU backend (Dependency: Backend-aware passes manager)
227229
self.add_pass(DecomposeReciprocal())
228230
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
231+
self.add_pass(DecomposeLogVariants())
229232
self.add_pass(ReplaceInfValues())
230233
self.add_pass(LiftConstantScalarOperands())
231234
self.add_pass(InsertReshapeForReduceOps())

backends/qualcomm/_passes/utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def get_passes_dependency_for_capture_program():
6969
DecomposeAny,
7070
DecomposeColIm,
7171
DecomposeLinalgVectorNorm,
72+
DecomposeLogVariants,
7273
DecomposeMaxPool3d,
7374
ExpandBroadcastTensorShape,
7475
FixedLinearKeepDim,
@@ -96,6 +97,7 @@ def get_passes_dependency_for_capture_program():
9697
DecomposeAny: [RemoveRedundancy],
9798
DecomposeColIm: [FoldQDQ],
9899
DecomposeLinalgVectorNorm: [RemoveRedundancy],
100+
DecomposeLogVariants: [RemoveRedundancy],
99101
DecomposeMaxPool3d: [RemoveRedundancy],
100102
ExpandBroadcastTensorShape: [FoldQDQ],
101103
FixedLinearKeepDim: [FoldQDQ],
@@ -285,3 +287,25 @@ def append_qdq(
285287
dq_node = graph_module.graph.create_node("call_function", dq_op, dq_args)
286288
dq_node.meta = copy_meta(node.meta)
287289
return dq_node
290+
291+
292+
def get_const_node(
293+
graph: torch.fx.Graph,
294+
graph_module: torch.fx.GraphModule,
295+
attr_name: str,
296+
value,
297+
source_node: torch.fx.Node,
298+
) -> torch.fx.Node:
299+
"""
300+
Register a scalar constant as a named buffer on the graph module and return a get_attr node referencing it.
301+
Used in edge dialect op decomposition passes where raw scalar arguments are not accepted by QNN op builders which need the inputs to be graph nodes.
302+
"""
303+
dtype = source_node.meta["val"].dtype
304+
tensor = torch.tensor(value, dtype=dtype)
305+
graph_module.register_buffer(attr_name, tensor)
306+
307+
fake_mode = source_node.meta["val"].fake_mode
308+
with graph.inserting_before(next(iter(graph.nodes))):
309+
const_node = graph.get_attr(attr_name)
310+
const_node.meta["val"] = fake_mode.from_tensor(tensor)
311+
return const_node

backends/qualcomm/partition/common_defs.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@
2121
to_be_implemented_operator = [
2222
exir_ops.edge.aten.adaptive_max_pool3d.default,
2323
exir_ops.edge.aten.div.Tensor_mode,
24-
exir_ops.edge.aten.log10.default,
25-
exir_ops.edge.aten.log1p.default,
26-
exir_ops.edge.aten.log2.default,
2724
exir_ops.edge.aten.max_pool3d_with_indices.default,
2825
exir_ops.edge.aten.median.default,
2926
exir_ops.edge.aten.median.dim,

backends/qualcomm/tests/models.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,6 +1498,30 @@ def forward(self, x):
14981498
return torch.nn.functional.log_softmax(x, dim=-1)
14991499

15001500

1501+
class Log10(torch.nn.Module):
1502+
def __init__(self):
1503+
super().__init__()
1504+
1505+
def forward(self, x):
1506+
return torch.log10(x)
1507+
1508+
1509+
class Log1p(torch.nn.Module):
1510+
def __init__(self):
1511+
super().__init__()
1512+
1513+
def forward(self, x):
1514+
return torch.log1p(x)
1515+
1516+
1517+
class Log2(torch.nn.Module):
1518+
def __init__(self):
1519+
super().__init__()
1520+
1521+
def forward(self, x):
1522+
return torch.log2(x)
1523+
1524+
15011525
class MaxPool2d(torch.nn.Module):
15021526
def __init__(self, kernel_size=3, stride=1, padding=1, ceil_mode=True):
15031527
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1455,6 +1455,21 @@ def test_qnn_backend_log_softmax(self):
14551455
sample_input = (torch.randn([1, 4, 8, 8]),)
14561456
self.lower_module_and_test_output(module, sample_input)
14571457

1458+
def test_qnn_backend_log10(self):
1459+
module = Log10() # noqa: F405
1460+
sample_input = (torch.abs(torch.rand(2, 5, 1, 3) + 0.1),)
1461+
self.lower_module_and_test_output(module, sample_input)
1462+
1463+
def test_qnn_backend_log1p(self):
1464+
module = Log1p() # noqa: F405
1465+
sample_input = (torch.abs(torch.rand(2, 5, 1, 3) + 0.1),)
1466+
self.lower_module_and_test_output(module, sample_input)
1467+
1468+
def test_qnn_backend_log2(self):
1469+
module = Log2() # noqa: F405
1470+
sample_input = (torch.abs(torch.rand(2, 5, 1, 3) + 0.1),)
1471+
self.lower_module_and_test_output(module, sample_input)
1472+
14581473
def test_qnn_backend_maximum(self):
14591474
module = Maximum() # noqa: F405
14601475
sample_input = (torch.randn(1, 2, 3, 4), torch.randn(2, 3, 4))
@@ -3780,6 +3795,24 @@ def test_qnn_backend_log_softmax(self):
37803795
module = self.get_qdq_module(module, sample_input)
37813796
self.lower_module_and_test_output(module, sample_input)
37823797

3798+
def test_qnn_backend_log10(self):
3799+
module = Log10() # noqa: F405
3800+
sample_input = (torch.abs(torch.rand(2, 5, 1, 3) + 0.1),)
3801+
module = self.get_qdq_module(module, sample_input)
3802+
self.lower_module_and_test_output(module, sample_input)
3803+
3804+
def test_qnn_backend_log1p(self):
3805+
module = Log1p() # noqa: F405
3806+
sample_input = (torch.abs(torch.rand(2, 5, 1, 3) + 0.1),)
3807+
module = self.get_qdq_module(module, sample_input)
3808+
self.lower_module_and_test_output(module, sample_input)
3809+
3810+
def test_qnn_backend_log2(self):
3811+
module = Log2() # noqa: F405
3812+
sample_input = (torch.abs(torch.rand(2, 5, 1, 3) + 0.1),)
3813+
module = self.get_qdq_module(module, sample_input)
3814+
self.lower_module_and_test_output(module, sample_input)
3815+
37833816
def test_qnn_backend_maximum(self):
37843817
module = Maximum() # noqa: F405
37853818
sample_input = (torch.randn(1, 2, 3, 4), torch.randn(2, 3, 4))

0 commit comments

Comments
 (0)