Skip to content

Commit 15e8bf7

Browse files
authored
Qualcomm AI Engine Direct - Adding QNN backend support for trunc core ATen Op (pytorch#18543)
1 parent 25545f7 commit 15e8bf7

7 files changed

Lines changed: 119 additions & 1 deletion

File tree

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from .decompose_silu import DecomposeSilu
3131
from .decompose_threshold import DecomposeThreshold
3232
from .decompose_triu import DecomposeTriu
33+
from .decompose_trunc import DecomposeTrunc
3334
from .decompose_wrap_with_autocast import DecomposeWrapWithAutocast
3435
from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape
3536
from .fixed_linear_keep_dim import FixedLinearKeepDim
@@ -81,6 +82,7 @@
8182
DecomposeSilu,
8283
DecomposeThreshold,
8384
DecomposeTriu,
85+
DecomposeTrunc,
8486
DecomposeWrapWithAutocast,
8587
ExpandBroadcastTensorShape,
8688
FixedLinearKeepDim,
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
10+
from executorch.exir.pass_base import ExportPass, PassResult
11+
12+
from .utils import copy_meta
13+
14+
15+
class DecomposeTrunc(ExportPass):
16+
"""
17+
Decompose trunc via the identity: trunc(x) = sign(x) * floor(abs(x)).
18+
"""
19+
20+
def __init__(self):
21+
super(DecomposeTrunc, self).__init__()
22+
self.trunc_targets = {
23+
torch.ops.aten.trunc.default,
24+
exir_ops.edge.aten.trunc.default,
25+
}
26+
27+
def call(self, graph_module: torch.fx.GraphModule):
28+
graph = graph_module.graph
29+
for node in graph.nodes:
30+
if node.op == "call_function" and node.target in self.trunc_targets:
31+
trunc_node = node
32+
input_node = node.args[0]
33+
34+
is_edge = isinstance(node.target, EdgeOpOverload)
35+
sign_op = (
36+
exir_ops.edge.aten.sign.default
37+
if is_edge
38+
else torch.ops.aten.sign.default
39+
)
40+
abs_op = (
41+
exir_ops.edge.aten.abs.default
42+
if is_edge
43+
else torch.ops.aten.abs.default
44+
)
45+
floor_op = (
46+
exir_ops.edge.aten.floor.default
47+
if is_edge
48+
else torch.ops.aten.floor.default
49+
)
50+
mul_op = (
51+
exir_ops.edge.aten.mul.Tensor
52+
if is_edge
53+
else torch.ops.aten.mul.Tensor
54+
)
55+
56+
with graph_module.graph.inserting_after(input_node):
57+
sign_node = graph.create_node(
58+
"call_function",
59+
sign_op,
60+
(input_node,),
61+
)
62+
sign_node.meta = copy_meta(trunc_node.meta)
63+
64+
with graph_module.graph.inserting_after(sign_node):
65+
abs_node = graph.create_node(
66+
"call_function",
67+
abs_op,
68+
(input_node,),
69+
)
70+
abs_node.meta = copy_meta(trunc_node.meta)
71+
72+
with graph_module.graph.inserting_after(abs_node):
73+
floor_node = graph.create_node(
74+
"call_function",
75+
floor_op,
76+
(abs_node,),
77+
)
78+
floor_node.meta = copy_meta(trunc_node.meta)
79+
80+
with graph_module.graph.inserting_after(floor_node):
81+
mul_node = graph.create_node(
82+
"call_function",
83+
mul_op,
84+
(sign_node, floor_node),
85+
)
86+
mul_node.meta = copy_meta(trunc_node.meta)
87+
88+
for user in trunc_node.users.copy():
89+
user.replace_input_with(trunc_node, mul_node)
90+
91+
graph.eliminate_dead_code()
92+
graph_module.recompile()
93+
return PassResult(graph_module, True)

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
DecomposeSilu,
3636
DecomposeThreshold,
3737
DecomposeTriu,
38+
DecomposeTrunc,
3839
DecomposeWrapWithAutocast,
3940
ExpandBroadcastTensorShape,
4041
FixedLinearKeepDim,
@@ -100,6 +101,7 @@ def get_capture_program_passes():
100101
(DecomposeLogVariants, True),
101102
(DecomposeMaxPool3d, True),
102103
(DecomposeMinMaxDim, True),
104+
(DecomposeTrunc, True),
103105
(ExpandBroadcastTensorShape, True),
104106
(FixedLinearKeepDim, True),
105107
(FoldQDQ, True),
@@ -219,6 +221,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
219221
self.add_pass(DecomposeSilu())
220222
self.add_pass(DecomposeThreshold())
221223
self.add_pass(DecomposeTriu())
224+
self.add_pass(DecomposeTrunc())
222225
self.add_pass(DecomposeWrapWithAutocast())
223226
self.add_pass(DecomposeEinsum())
224227
self.add_pass(DecomposeExpM1())

backends/qualcomm/_passes/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def get_passes_dependency_for_capture_program():
7171
DecomposeLinalgVectorNorm,
7272
DecomposeLogVariants,
7373
DecomposeMaxPool3d,
74+
DecomposeTrunc,
7475
ExpandBroadcastTensorShape,
7576
FixedLinearKeepDim,
7677
FoldQDQ,
@@ -99,6 +100,7 @@ def get_passes_dependency_for_capture_program():
99100
DecomposeLinalgVectorNorm: [RemoveRedundancy],
100101
DecomposeLogVariants: [RemoveRedundancy],
101102
DecomposeMaxPool3d: [RemoveRedundancy],
103+
DecomposeTrunc: [RemoveRedundancy],
102104
ExpandBroadcastTensorShape: [FoldQDQ],
103105
FixedLinearKeepDim: [FoldQDQ],
104106
FoldQDQ: [AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind],

backends/qualcomm/partition/common_defs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
exir_ops.edge.aten.median.dim,
2727
exir_ops.edge.aten.round.decimals,
2828
exir_ops.edge.aten.le.Scalar,
29-
exir_ops.edge.aten.trunc.default,
3029
]
3130

3231
constant_operator = [

backends/qualcomm/tests/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2315,6 +2315,14 @@ def forward(self, x):
23152315
return mask + x
23162316

23172317

2318+
class Trunc(torch.nn.Module):
2319+
def __init__(self):
2320+
super().__init__()
2321+
2322+
def forward(self, x):
2323+
return torch.trunc(x)
2324+
2325+
23182326
class Unbind(torch.nn.Module):
23192327
def __init__(self):
23202328
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1918,6 +1918,11 @@ def test_qnn_backend_triu(self):
19181918
index += 1
19191919
self.lower_module_and_test_output(module, sample_input)
19201920

1921+
def test_qnn_backend_trunc(self):
1922+
module = Trunc() # noqa: F405
1923+
sample_input = (torch.randn(3, 4),)
1924+
self.lower_module_and_test_output(module, sample_input)
1925+
19211926
def test_qnn_backend_unflatten(self):
19221927
module = Unflatten(dim=1, sizes=(2, 3, 4)) # noqa: F405
19231928
sample_input = (torch.randn([1, 24]),)
@@ -4319,6 +4324,12 @@ def test_qnn_backend_triu(self):
43194324
qdq_module = self.get_qdq_module(module, sample_input)
43204325
self.lower_module_and_test_output(qdq_module, sample_input)
43214326

4327+
def test_qnn_backend_trunc(self):
4328+
module = Trunc() # noqa: F405
4329+
sample_input = (torch.randn(3, 4),)
4330+
module = self.get_qdq_module(module, sample_input)
4331+
self.lower_module_and_test_output(module, sample_input)
4332+
43224333
def test_qnn_backend_unflatten(self):
43234334
module = Unflatten(dim=1, sizes=(2, 3, 4)) # noqa: F405
43244335
sample_input = (torch.randn([1, 24]),)

0 commit comments

Comments
 (0)