Skip to content

Commit d9688da

Browse files
authored
Qualcomm AI Engine Direct - Adding QNN backend support for atan2 core ATen ops (#19051)
1 parent d8da621 commit d9688da

7 files changed

Lines changed: 343 additions & 0 deletions

File tree

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .convert_square_to_pow import ConvertSquareToPow
1616
from .decompose_acos import DecomposeAcos
1717
from .decompose_any import DecomposeAny
18+
from .decompose_atan2 import DecomposeAtan2
1819
from .decompose_binary_alpha import DecomposeBinaryAlpha
1920
from .decompose_cdist import DecomposeCDist
2021
from .decompose_col_im import DecomposeColIm
@@ -70,6 +71,7 @@
7071
ConvertSquareToPow,
7172
DecomposeAcos,
7273
DecomposeAny,
74+
DecomposeAtan2,
7375
DecomposeBinaryAlpha,
7476
DecomposeCDist,
7577
DecomposeColIm,
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
10+
from executorch.exir.pass_base import ExportPass, PassResult
11+
12+
from .utils import copy_meta, create_node, get_const_node
13+
14+
15+
class DecomposeAtan2(ExportPass):
16+
"""
17+
Decompose atan2(y, x) with full piecewise definition:
18+
atan2(y, x) =
19+
atan(y/x) if x > 0
20+
atan(y/x) + π if x < 0, y >= 0
21+
atan(y/x) - π if x < 0, y < 0
22+
+π/2 if x = 0, y > 0
23+
-π/2 if x = 0, y < 0
24+
0 if x = 0, y = 0
25+
"""
26+
27+
_OPS = {
28+
"eq": (exir_ops.edge.aten.eq.Tensor, torch.ops.aten.eq.Tensor),
29+
"lt": (exir_ops.edge.aten.lt.Tensor, torch.ops.aten.lt.Tensor),
30+
"gt": (exir_ops.edge.aten.gt.Tensor, torch.ops.aten.gt.Tensor),
31+
"ge": (exir_ops.edge.aten.ge.Tensor, torch.ops.aten.ge.Tensor),
32+
"where": (exir_ops.edge.aten.where.self, torch.ops.aten.where.self),
33+
"div": (exir_ops.edge.aten.div.Tensor, torch.ops.aten.div.Tensor),
34+
"atan": (exir_ops.edge.aten.atan.default, torch.ops.aten.atan.default),
35+
"add": (exir_ops.edge.aten.add.Tensor, torch.ops.aten.add.Tensor),
36+
}
37+
38+
_TO_FLOAT_OP = (
39+
exir_ops.edge.aten._to_copy.default,
40+
torch.ops.aten._to_copy.default,
41+
)
42+
43+
def __init__(self):
44+
super(DecomposeAtan2, self).__init__()
45+
self.atan2_targets = {
46+
torch.ops.aten.atan2.default,
47+
torch.ops.aten.atan2.out,
48+
exir_ops.edge.aten.atan2.default,
49+
}
50+
51+
def _get_op(self, name, is_edge):
52+
return self._OPS[name][0] if is_edge else self._OPS[name][1]
53+
54+
def _cast_to_float(self, graph, node, meta, is_edge):
55+
"""Insert a cast from integer to float if the input is not floating-point."""
56+
node_val = node.meta.get("val")
57+
if node_val is not None and not node_val.is_floating_point():
58+
to_float_op = self._TO_FLOAT_OP[0] if is_edge else self._TO_FLOAT_OP[1]
59+
cast_node = graph.create_node(
60+
"call_function", to_float_op, (node,), {"dtype": torch.float32}
61+
)
62+
cast_node.meta = copy_meta(meta)
63+
return cast_node
64+
return node
65+
66+
def _get_constants(self, graph, graph_module, node, is_edge, const_cache):
67+
if is_edge:
68+
69+
def make_const(name, val):
70+
if name not in const_cache:
71+
const_cache[name] = get_const_node(
72+
graph, graph_module, name, val, node
73+
)
74+
return const_cache[name]
75+
76+
return {
77+
"zero": make_const("_atan2_zero", 0.0),
78+
"one": make_const("_atan2_one", 1.0),
79+
"pi": make_const("_atan2_pi", torch.pi),
80+
"neg_pi": make_const("_atan2_neg_pi", -torch.pi),
81+
"pi_half": make_const("_atan2_pi_half", torch.pi / 2),
82+
"neg_pi_half": make_const("_atan2_neg_pi_half", -torch.pi / 2),
83+
}
84+
return {
85+
"zero": 0.0,
86+
"one": 1.0,
87+
"pi": torch.pi,
88+
"neg_pi": -torch.pi,
89+
"pi_half": torch.pi / 2,
90+
"neg_pi_half": -torch.pi / 2,
91+
}
92+
93+
def call(self, graph_module: torch.fx.GraphModule):
94+
graph = graph_module.graph
95+
const_cache = {}
96+
for node in list(graph.nodes):
97+
if node.op == "call_function" and node.target in self.atan2_targets:
98+
y_node, x_node = node.args[0], node.args[1]
99+
is_edge = isinstance(node.target, EdgeOpOverload)
100+
meta = node.meta
101+
102+
with graph.inserting_before(node):
103+
y_node = self._cast_to_float(graph, y_node, meta, is_edge)
104+
x_node = self._cast_to_float(graph, x_node, meta, is_edge)
105+
106+
consts = self._get_constants(
107+
graph, graph_module, node, is_edge, const_cache
108+
)
109+
110+
x_eq_zero = create_node(
111+
graph,
112+
self._get_op("eq", is_edge),
113+
(x_node, consts["zero"]),
114+
meta,
115+
callback=lambda m: {**m, "val": m["val"].to(torch.bool)},
116+
)
117+
safe_x = create_node(
118+
graph,
119+
self._get_op("where", is_edge),
120+
(x_eq_zero, consts["one"], x_node),
121+
meta,
122+
)
123+
ratio = create_node(
124+
graph,
125+
self._get_op("div", is_edge),
126+
(y_node, safe_x),
127+
meta,
128+
)
129+
130+
base = create_node(
131+
graph,
132+
self._get_op("atan", is_edge),
133+
(ratio,),
134+
meta,
135+
)
136+
137+
x_lt_zero = create_node(
138+
graph,
139+
self._get_op("lt", is_edge),
140+
(x_node, consts["zero"]),
141+
meta,
142+
callback=lambda m: {**m, "val": m["val"].to(torch.bool)},
143+
)
144+
y_ge_zero = create_node(
145+
graph,
146+
self._get_op("ge", is_edge),
147+
(y_node, consts["zero"]),
148+
meta,
149+
callback=lambda m: {**m, "val": m["val"].to(torch.bool)},
150+
)
151+
y_sign_pi = create_node(
152+
graph,
153+
self._get_op("where", is_edge),
154+
(y_ge_zero, consts["pi"], consts["neg_pi"]),
155+
meta,
156+
)
157+
adjustment = create_node(
158+
graph,
159+
self._get_op("where", is_edge),
160+
(x_lt_zero, y_sign_pi, consts["zero"]),
161+
meta,
162+
)
163+
adjusted = create_node(
164+
graph,
165+
self._get_op("add", is_edge),
166+
(base, adjustment),
167+
meta,
168+
)
169+
170+
y_gt_zero = create_node(
171+
graph,
172+
self._get_op("gt", is_edge),
173+
(y_node, consts["zero"]),
174+
meta,
175+
callback=lambda m: {**m, "val": m["val"].to(torch.bool)},
176+
)
177+
x_zero_result = create_node(
178+
graph,
179+
self._get_op("where", is_edge),
180+
(y_gt_zero, consts["pi_half"], consts["neg_pi_half"]),
181+
meta,
182+
)
183+
184+
y_eq_zero = create_node(
185+
graph,
186+
self._get_op("eq", is_edge),
187+
(y_node, consts["zero"]),
188+
meta,
189+
callback=lambda m: {**m, "val": m["val"].to(torch.bool)},
190+
)
191+
x_zero_final = create_node(
192+
graph,
193+
self._get_op("where", is_edge),
194+
(y_eq_zero, consts["zero"], x_zero_result),
195+
meta,
196+
)
197+
198+
result = create_node(
199+
graph,
200+
self._get_op("where", is_edge),
201+
(x_eq_zero, x_zero_final, adjusted),
202+
meta,
203+
)
204+
205+
for user in node.users.copy():
206+
user.replace_input_with(node, result)
207+
208+
graph.eliminate_dead_code()
209+
graph_module.recompile()
210+
return PassResult(graph_module, True)

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
ConvertSquareToPow,
2121
DecomposeAcos,
2222
DecomposeAny,
23+
DecomposeAtan2,
2324
DecomposeBinaryAlpha,
2425
DecomposeCDist,
2526
DecomposeColIm,
@@ -104,6 +105,7 @@ def get_capture_program_passes():
104105
(ConvertBmmToMatmul, False),
105106
(DecomposeAcos, True),
106107
(DecomposeAny, True),
108+
(DecomposeAtan2, True),
107109
(DecomposeColIm, True),
108110
(DecomposeLogVariants, True),
109111
(DecomposeMaxPool3d, True),
@@ -226,6 +228,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
226228
self.add_pass(RecomposeRmsNorm(quantization_capture=True))
227229
self.add_pass(ReplaceArangeArgs())
228230
self.add_pass(DecomposeAcos())
231+
self.add_pass(DecomposeAtan2())
229232
self.add_pass(DecomposeBinaryAlpha())
230233
self.add_pass(DecomposeCDist())
231234
self.add_pass(DecomposeMaxPool3d(quantization_capture=True))

backends/qualcomm/_passes/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def get_passes_dependency_for_capture_program():
6767
ConvertBmmToMatmul,
6868
DecomposeAcos,
6969
DecomposeAny,
70+
DecomposeAtan2,
7071
DecomposeColIm,
7172
DecomposeLinalgVectorNorm,
7273
DecomposeLogVariants,
@@ -99,6 +100,7 @@ def get_passes_dependency_for_capture_program():
99100
ConvertBmmToMatmul: [RecomposePixelUnshuffle],
100101
DecomposeAcos: [RemoveRedundancy],
101102
DecomposeAny: [RemoveRedundancy],
103+
DecomposeAtan2: [RemoveRedundancy],
102104
DecomposeColIm: [FoldQDQ],
103105
DecomposeLinalgVectorNorm: [RemoveRedundancy],
104106
DecomposeLogVariants: [RemoveRedundancy],
@@ -315,3 +317,9 @@ def get_const_node(
315317
const_node = graph.get_attr(attr_name)
316318
const_node.meta["val"] = fake_mode.from_tensor(tensor)
317319
return const_node
320+
321+
322+
def create_node(graph, target, args, meta, callback=None):
323+
node = graph.create_node("call_function", target, args)
324+
node.meta = copy_meta(meta, callback)
325+
return node

backends/qualcomm/builders/README.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,37 @@ Please help update following table if you are contributing new operators:
490490
| TransPoseConv3d | &check; |
491491
| Unpack | &check; |
492492

493+
494+
## Additional Operators Supported via Passes
495+
The following PyTorch operators are supported through decomposition or annotation passes located in `backends/qualcomm/_passes/`. These ops do not have a direct 1:1 mapping to a QNN operator but are handled by transforming them into supported QNN operations.
496+
497+
| PyTorch Op | Decomposition Pass |
498+
|---|---|
499+
| `aten.acos` | `DecomposeAcos` |
500+
| `aten.adaptive_avg_pool1d`, `aten.avg_pool1d` | `AnnotateAvgPool1D` |
501+
| `aten.any` | `DecomposeAny` |
502+
| `aten.atan2.default`, `aten.atan2.out` | `DecomposeAtan2` |
503+
| `aten.add` (with alpha), `aten.sub` (with alpha) | `DecomposeBinaryAlpha` |
504+
| `aten.cdist` | `DecomposeCDist` |
505+
| `aten.im2col`, `aten.col2im` | `DecomposeColIm` |
506+
| `aten.einsum` | `DecomposeEinsum` |
507+
| `aten.special_expm1` | `DecomposeExpM1` |
508+
| `aten.floor_divide` | `DecomposeFloorDivide` |
509+
| `aten.glu` | `DecomposeGlu` |
510+
| `aten.linalg_vector_norm` | `DecomposeLinalgVectorNorm` |
511+
| `aten.log10`, `aten.log2`, `aten.log1p` | `DecomposeLogVariants` |
512+
| `aten.max_pool3d` | `DecomposeMaxPool3d` |
513+
| `aten.min.dim`, `aten.max.dim` | `DecomposeMinMaxDim` |
514+
| `aten.reciprocal` | `DecomposeReciprocal` |
515+
| `aten.reflection_pad1d` | PyTorch built-in decomposition |
516+
| `aten.reflection_pad2d` | `DecomposePad` |
517+
| `aten.remainder.Scalar`, `aten.remainder.Tensor` | `DecomposeRemainder` |
518+
| `aten.roll` | `DecomposeRoll` |
519+
| `aten.silu` | `DecomposeSilu` |
520+
| `aten.threshold` | `DecomposeThreshold` |
521+
| `aten.triu` | `DecomposeTriu` |
522+
| `aten.trunc` | `DecomposeTrunc` |
523+
493524
## Issues
494525
Please refer to the [issue section](../README.md#issues) for more information.
495526

backends/qualcomm/tests/models.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,22 @@ def forward(self, x):
263263
return torch.atan(x)
264264

265265

266+
class Atan2(torch.nn.Module):
267+
def __init__(self):
268+
super().__init__()
269+
270+
def forward(self, x, y):
271+
return torch.atan2(x, y)
272+
273+
274+
class Atan2MultiNode(torch.nn.Module):
275+
def __init__(self):
276+
super().__init__()
277+
278+
def forward(self, x1, y1, x2, y2):
279+
return torch.atan2(x1, y1), torch.atan2(x2, y2)
280+
281+
266282
class AvgPool1D(torch.nn.Module):
267283
def __init__(self):
268284
super().__init__()

0 commit comments

Comments
 (0)