Skip to content

Commit dc3f012

Browse files
committed
Qualcomm AI Engine Direct - Adding QNN backend support for atan2 core ATen Op
1 parent 069a793 commit dc3f012

6 files changed

Lines changed: 312 additions & 0 deletions

File tree

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .convert_square_to_pow import ConvertSquareToPow
1616
from .decompose_acos import DecomposeAcos
1717
from .decompose_any import DecomposeAny
18+
from .decompose_atan2 import DecomposeAtan2
1819
from .decompose_binary_alpha import DecomposeBinaryAlpha
1920
from .decompose_cdist import DecomposeCDist
2021
from .decompose_col_im import DecomposeColIm
@@ -70,6 +71,7 @@
7071
ConvertSquareToPow,
7172
DecomposeAcos,
7273
DecomposeAny,
74+
DecomposeAtan2,
7375
DecomposeBinaryAlpha,
7476
DecomposeCDist,
7577
DecomposeColIm,
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
10+
from executorch.exir.pass_base import ExportPass, PassResult
11+
12+
from .utils import copy_meta, create_node, get_const_node
13+
14+
15+
class DecomposeAtan2(ExportPass):
16+
"""
17+
Decompose atan2(y, x) with full piecewise definition:
18+
atan2(y, x) =
19+
atan(y/x) if x > 0
20+
atan(y/x) + π if x < 0, y >= 0
21+
atan(y/x) - π if x < 0, y < 0
22+
+π/2 if x = 0, y > 0
23+
-π/2 if x = 0, y < 0
24+
0 if x = 0, y = 0
25+
"""
26+
27+
_OPS = {
28+
"eq": (exir_ops.edge.aten.eq.Tensor, torch.ops.aten.eq.Tensor),
29+
"lt": (exir_ops.edge.aten.lt.Tensor, torch.ops.aten.lt.Tensor),
30+
"gt": (exir_ops.edge.aten.gt.Tensor, torch.ops.aten.gt.Tensor),
31+
"ge": (exir_ops.edge.aten.ge.Tensor, torch.ops.aten.ge.Tensor),
32+
"where": (exir_ops.edge.aten.where.self, torch.ops.aten.where.self),
33+
"div": (exir_ops.edge.aten.div.Tensor, torch.ops.aten.div.Tensor),
34+
"atan": (exir_ops.edge.aten.atan.default, torch.ops.aten.atan.default),
35+
"add": (exir_ops.edge.aten.add.Tensor, torch.ops.aten.add.Tensor),
36+
}
37+
38+
_TO_FLOAT_OP = (
39+
exir_ops.edge.aten._to_copy.default,
40+
torch.ops.aten._to_copy.default,
41+
)
42+
43+
def __init__(self):
44+
super(DecomposeAtan2, self).__init__()
45+
self.atan2_targets = {
46+
torch.ops.aten.atan2.default,
47+
torch.ops.aten.atan2.out,
48+
exir_ops.edge.aten.atan2.default,
49+
}
50+
51+
def _get_op(self, name, is_edge):
52+
return self._OPS[name][0] if is_edge else self._OPS[name][1]
53+
54+
def _cast_to_float(self, graph, node, meta, is_edge):
55+
"""Insert a cast from integer to float if the input is not floating-point."""
56+
node_val = node.meta.get("val")
57+
if node_val is not None and not node_val.is_floating_point():
58+
to_float_op = self._TO_FLOAT_OP[0] if is_edge else self._TO_FLOAT_OP[1]
59+
cast_node = graph.create_node(
60+
"call_function", to_float_op, (node,), {"dtype": torch.float32}
61+
)
62+
cast_node.meta = copy_meta(meta)
63+
return cast_node
64+
return node
65+
66+
def _get_constants(self, graph, graph_module, node, is_edge, const_cache):
67+
if is_edge:
68+
69+
def make_const(name, val):
70+
if name not in const_cache:
71+
const_cache[name] = get_const_node(
72+
graph, graph_module, name, val, node
73+
)
74+
return const_cache[name]
75+
76+
return {
77+
"zero": make_const("_atan2_zero", 0.0),
78+
"one": make_const("_atan2_one", 1.0),
79+
"pi": make_const("_atan2_pi", torch.pi),
80+
"neg_pi": make_const("_atan2_neg_pi", -torch.pi),
81+
"pi_half": make_const("_atan2_pi_half", torch.pi / 2),
82+
"neg_pi_half": make_const("_atan2_neg_pi_half", -torch.pi / 2),
83+
}
84+
return {
85+
"zero": 0.0,
86+
"one": 1.0,
87+
"pi": torch.pi,
88+
"neg_pi": -torch.pi,
89+
"pi_half": torch.pi / 2,
90+
"neg_pi_half": -torch.pi / 2,
91+
}
92+
93+
def call(self, graph_module: torch.fx.GraphModule):
94+
graph = graph_module.graph
95+
const_cache = {}
96+
for node in list(graph.nodes):
97+
if node.op == "call_function" and node.target in self.atan2_targets:
98+
y_node, x_node = node.args[0], node.args[1]
99+
is_edge = isinstance(node.target, EdgeOpOverload)
100+
meta = node.meta
101+
102+
with graph.inserting_before(node):
103+
y_node = self._cast_to_float(graph, y_node, meta, is_edge)
104+
x_node = self._cast_to_float(graph, x_node, meta, is_edge)
105+
106+
consts = self._get_constants(
107+
graph, graph_module, node, is_edge, const_cache
108+
)
109+
110+
x_eq_zero = create_node(
111+
graph,
112+
self._get_op("eq", is_edge),
113+
(x_node, consts["zero"]),
114+
meta,
115+
callback=lambda m: {**m, "val": m["val"].to(torch.bool)},
116+
)
117+
safe_x = create_node(
118+
graph,
119+
self._get_op("where", is_edge),
120+
(x_eq_zero, consts["one"], x_node),
121+
meta,
122+
)
123+
ratio = create_node(
124+
graph,
125+
self._get_op("div", is_edge),
126+
(y_node, safe_x),
127+
meta,
128+
)
129+
130+
base = create_node(
131+
graph,
132+
self._get_op("atan", is_edge),
133+
(ratio,),
134+
meta,
135+
)
136+
137+
x_lt_zero = create_node(
138+
graph,
139+
self._get_op("lt", is_edge),
140+
(x_node, consts["zero"]),
141+
meta,
142+
callback=lambda m: {**m, "val": m["val"].to(torch.bool)},
143+
)
144+
y_ge_zero = create_node(
145+
graph,
146+
self._get_op("ge", is_edge),
147+
(y_node, consts["zero"]),
148+
meta,
149+
callback=lambda m: {**m, "val": m["val"].to(torch.bool)},
150+
)
151+
y_sign_pi = create_node(
152+
graph,
153+
self._get_op("where", is_edge),
154+
(y_ge_zero, consts["pi"], consts["neg_pi"]),
155+
meta,
156+
)
157+
adjustment = create_node(
158+
graph,
159+
self._get_op("where", is_edge),
160+
(x_lt_zero, y_sign_pi, consts["zero"]),
161+
meta,
162+
)
163+
adjusted = create_node(
164+
graph,
165+
self._get_op("add", is_edge),
166+
(base, adjustment),
167+
meta,
168+
)
169+
170+
y_gt_zero = create_node(
171+
graph,
172+
self._get_op("gt", is_edge),
173+
(y_node, consts["zero"]),
174+
meta,
175+
callback=lambda m: {**m, "val": m["val"].to(torch.bool)},
176+
)
177+
x_zero_result = create_node(
178+
graph,
179+
self._get_op("where", is_edge),
180+
(y_gt_zero, consts["pi_half"], consts["neg_pi_half"]),
181+
meta,
182+
)
183+
184+
y_eq_zero = create_node(
185+
graph,
186+
self._get_op("eq", is_edge),
187+
(y_node, consts["zero"]),
188+
meta,
189+
callback=lambda m: {**m, "val": m["val"].to(torch.bool)},
190+
)
191+
x_zero_final = create_node(
192+
graph,
193+
self._get_op("where", is_edge),
194+
(y_eq_zero, consts["zero"], x_zero_result),
195+
meta,
196+
)
197+
198+
result = create_node(
199+
graph,
200+
self._get_op("where", is_edge),
201+
(x_eq_zero, x_zero_final, adjusted),
202+
meta,
203+
)
204+
205+
for user in node.users.copy():
206+
user.replace_input_with(node, result)
207+
208+
graph.eliminate_dead_code()
209+
graph_module.recompile()
210+
return PassResult(graph_module, True)

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
ConvertSquareToPow,
2121
DecomposeAcos,
2222
DecomposeAny,
23+
DecomposeAtan2,
2324
DecomposeBinaryAlpha,
2425
DecomposeCDist,
2526
DecomposeColIm,
@@ -104,6 +105,7 @@ def get_capture_program_passes():
104105
(ConvertBmmToMatmul, False),
105106
(DecomposeAcos, True),
106107
(DecomposeAny, True),
108+
(DecomposeAtan2, True),
107109
(DecomposeColIm, True),
108110
(DecomposeLogVariants, True),
109111
(DecomposeMaxPool3d, True),
@@ -226,6 +228,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
226228
self.add_pass(RecomposeRmsNorm(quantization_capture=True))
227229
self.add_pass(ReplaceArangeArgs())
228230
self.add_pass(DecomposeAcos())
231+
self.add_pass(DecomposeAtan2())
229232
self.add_pass(DecomposeBinaryAlpha())
230233
self.add_pass(DecomposeCDist())
231234
self.add_pass(DecomposeMaxPool3d(quantization_capture=True))

backends/qualcomm/_passes/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def get_passes_dependency_for_capture_program():
6767
ConvertBmmToMatmul,
6868
DecomposeAcos,
6969
DecomposeAny,
70+
DecomposeAtan2,
7071
DecomposeColIm,
7172
DecomposeLinalgVectorNorm,
7273
DecomposeLogVariants,
@@ -99,6 +100,7 @@ def get_passes_dependency_for_capture_program():
99100
ConvertBmmToMatmul: [RecomposePixelUnshuffle],
100101
DecomposeAcos: [RemoveRedundancy],
101102
DecomposeAny: [RemoveRedundancy],
103+
DecomposeAtan2: [RemoveRedundancy],
102104
DecomposeColIm: [FoldQDQ],
103105
DecomposeLinalgVectorNorm: [RemoveRedundancy],
104106
DecomposeLogVariants: [RemoveRedundancy],
@@ -315,3 +317,9 @@ def get_const_node(
315317
const_node = graph.get_attr(attr_name)
316318
const_node.meta["val"] = fake_mode.from_tensor(tensor)
317319
return const_node
320+
321+
322+
def create_node(graph, target, args, meta, callback=None):
323+
node = graph.create_node("call_function", target, args)
324+
node.meta = copy_meta(meta, callback)
325+
return node

backends/qualcomm/tests/models.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,22 @@ def forward(self, x):
263263
return torch.atan(x)
264264

265265

266+
class Atan2(torch.nn.Module):
267+
def __init__(self):
268+
super().__init__()
269+
270+
def forward(self, x, y):
271+
return torch.atan2(x, y)
272+
273+
274+
class Atan2MultiNode(torch.nn.Module):
275+
def __init__(self):
276+
super().__init__()
277+
278+
def forward(self, x1, y1, x2, y2):
279+
return torch.atan2(x1, y1), torch.atan2(x2, y2)
280+
281+
266282
class AvgPool1D(torch.nn.Module):
267283
def __init__(self):
268284
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,46 @@ def test_qnn_backend_atan(self):
313313
module = Atan() # noqa: F405
314314
self.lower_module_and_test_output(module, sample_input)
315315

316+
def test_qnn_backend_atan2(self):
317+
test_comb = [
318+
{
319+
QCOM_MODULE: [Atan2()], # noqa: F405
320+
QCOM_SAMPLE_INPUTS: [
321+
(
322+
torch.tensor(
323+
[1.0, 1.0, -1.0, -1.0, 1.0, -1.0, 0.0], dtype=torch.float32
324+
),
325+
torch.tensor(
326+
[1.0, -1.0, -1.0, 1.0, 0.0, 0.0, 0.0], dtype=torch.float32
327+
),
328+
),
329+
(
330+
torch.tensor([1, 1, -1, -1, 1, -1, 0], dtype=torch.int32),
331+
torch.tensor([1, -1, -1, 1, 0, 0, 0], dtype=torch.int32),
332+
),
333+
],
334+
},
335+
{
336+
QCOM_MODULE: [Atan2MultiNode()], # noqa: F405
337+
QCOM_SAMPLE_INPUTS: [
338+
(
339+
torch.tensor([1.0, -1.0, 1.0, -1.0]),
340+
torch.tensor([1.0, -1.0, -1.0, 1.0]),
341+
torch.tensor([1.0, -1.0, 1.0, -1.0]),
342+
torch.tensor([-1.0, 1.0, 0.0, 0.0]),
343+
)
344+
],
345+
},
346+
]
347+
348+
index = 0
349+
for comb in test_comb:
350+
for module in comb[QCOM_MODULE]:
351+
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
352+
with self.subTest(i=index):
353+
index += 1
354+
self.lower_module_and_test_output(module, sample_input)
355+
316356
def test_qnn_backend_avg_pool1d(self):
317357
module = AvgPool1D() # noqa: F405
318358
sample_input = (torch.randn(1, 512, 7),)
@@ -2696,6 +2736,39 @@ def test_qnn_backend_atan(self):
26962736
module = self.get_qdq_module(module, sample_input)
26972737
self.lower_module_and_test_output(module, sample_input)
26982738

2739+
def test_qnn_backend_atan2(self):
2740+
test_comb = [
2741+
{
2742+
QCOM_MODULE: [Atan2()], # noqa: F405
2743+
QCOM_SAMPLE_INPUTS: [
2744+
(
2745+
torch.tensor([1.0, 1.0, -1.0, -1.0, 1.0, -1.0, 0.0]),
2746+
torch.tensor([1.0, -1.0, -1.0, 1.0, 0.0, 0.0, 0.0]),
2747+
)
2748+
],
2749+
},
2750+
{
2751+
QCOM_MODULE: [Atan2MultiNode()], # noqa: F405
2752+
QCOM_SAMPLE_INPUTS: [
2753+
(
2754+
torch.tensor([1.0, -1.0, 1.0, -1.0]),
2755+
torch.tensor([1.0, -1.0, -1.0, 1.0]),
2756+
torch.tensor([1.0, -1.0, 1.0, -1.0]),
2757+
torch.tensor([-1.0, 1.0, 0.0, 0.0]),
2758+
)
2759+
],
2760+
},
2761+
]
2762+
2763+
index = 0
2764+
for comb in test_comb:
2765+
for module in comb[QCOM_MODULE]:
2766+
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
2767+
with self.subTest(i=index):
2768+
index += 1
2769+
qdq_module = self.get_qdq_module(module, sample_input)
2770+
self.lower_module_and_test_output(qdq_module, sample_input)
2771+
26992772
def test_qnn_backend_avg_pool1d(self):
27002773
module = AvgPool1D() # noqa: F405
27012774
sample_input = (torch.randn(1, 512, 7),)

0 commit comments

Comments
 (0)