Skip to content

Commit baa9888

Browse files
authored
Qualcomm AI Engine Direct - Adding QNN backend support for the isNan core ATen op (#17941)
1 parent b97cd7b commit baa9888

6 files changed

Lines changed: 120 additions & 0 deletions

File tree

backends/qualcomm/builders/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ Please help update following table if you are contributing new operators:
440440
| HardSwish | ✓ |
441441
| InstanceNorm | ✓ |
442442
| IsInf | ✓ |
443+
| IsNan | ✓ |
443444
| L2Norm | ✗ |
444445
| LayerNorm | ✓ |
445446
| LogSoftmax | ✓ |

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
op_index_select,
5757
op_instance_norm,
5858
op_is_inf,
59+
op_is_nan,
5960
op_layer_norm,
6061
op_le,
6162
op_linear,
@@ -166,6 +167,7 @@
166167
op_index_select,
167168
op_instance_norm,
168169
op_is_inf,
170+
op_is_nan,
169171
op_layer_norm,
170172
op_le,
171173
op_linear,
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import warnings
7+
from typing import Dict
8+
9+
import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager
10+
11+
import torch
12+
13+
from .node_visitor import NodeVisitor
14+
from .node_visitor_manager import register_node_visitor
15+
from .qnn_constants import OpIsNan, QNN_OP_PACKAGE_NAME_QTI_AISW
16+
17+
18+
@register_node_visitor
19+
class IsNan(NodeVisitor):
20+
target = ["aten.isnan.default"]
21+
22+
def __init__(self, *args) -> None:
23+
super().__init__(*args)
24+
25+
def define_node(
26+
self,
27+
node: torch.fx.Node,
28+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper],
29+
) -> PyQnnManager.PyQnnOpWrapper:
30+
input_node = self.get_node(node.args[0])
31+
input_tensor = self.get_tensor(input_node, node)
32+
33+
if input_tensor.dtype not in [torch.float32, torch.float16]:
34+
warnings.warn(
35+
"[QNN Delegate Op Builder]: QNN IsNan only supports FP32 or FP16 inputs.",
36+
stacklevel=1,
37+
)
38+
return None
39+
40+
input_tensor_wrapper = self.define_tensor(
41+
input_node,
42+
node,
43+
self.get_tensor(input_node, node),
44+
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
45+
nodes_to_wrappers,
46+
)
47+
input_tensors = [input_tensor_wrapper]
48+
49+
out_tensor = self.get_tensor(node, node)
50+
output_tensor_wrapper = self.define_tensor(
51+
node,
52+
node,
53+
out_tensor,
54+
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
55+
nodes_to_wrappers,
56+
)
57+
output_tensors = [output_tensor_wrapper]
58+
isnan_op = PyQnnManager.PyQnnOpWrapper(
59+
node.name,
60+
QNN_OP_PACKAGE_NAME_QTI_AISW,
61+
OpIsNan.op_name,
62+
)
63+
isnan_op.AddInputTensors(input_tensors)
64+
isnan_op.AddOutputTensors(output_tensors)
65+
66+
return isnan_op

backends/qualcomm/builders/qnn_constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,11 @@ class OpIsInf:
396396
param_detect_positive = "detect_positive"
397397

398398

399+
@dataclass(init=False, frozen=True)
400+
class OpIsNan:
401+
op_name: str = "IsNan"
402+
403+
399404
@dataclass(init=False, frozen=True)
400405
class OpLayerNorm:
401406
op_name: str = "LayerNorm"

backends/qualcomm/tests/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,6 +1318,14 @@ def forward(self, x):
13181318
return torch.isinf(x)
13191319

13201320

1321+
class IsNan(torch.nn.Module):
1322+
def __init__(self):
1323+
super().__init__()
1324+
1325+
def forward(self, x):
1326+
return torch.isnan(x)
1327+
1328+
13211329
class LargeTensorLinear(torch.nn.Module):
13221330
def __init__(self):
13231331
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,6 +1270,44 @@ def test_qnn_backend_is_inf(self):
12701270
)
12711271
self.lower_module_and_test_output(module, sample_input)
12721272

1273+
def test_qnn_backend_is_nan(self):
1274+
module = IsNan() # noqa: F405
1275+
sample_inputs = [
1276+
(
1277+
torch.tensor(
1278+
[
1279+
-2.0,
1280+
float("nan"),
1281+
-float("nan"),
1282+
0.2,
1283+
float("inf"),
1284+
3.2,
1285+
float("nan"),
1286+
-float("inf"),
1287+
],
1288+
dtype=torch.float32,
1289+
),
1290+
),
1291+
(
1292+
torch.tensor(
1293+
[
1294+
-0.234,
1295+
-float("nan"),
1296+
float("nan"),
1297+
-float("inf"),
1298+
3.2,
1299+
float("nan"),
1300+
1.26,
1301+
float("inf"),
1302+
],
1303+
dtype=torch.float16,
1304+
),
1305+
),
1306+
]
1307+
1308+
for sample_input in sample_inputs:
1309+
self.lower_module_and_test_output(module, sample_input)
1310+
12731311
def test_qnn_backend_interpolate_bicubic(self):
12741312
modules = [
12751313
ResizeBicubic([2, 2], None, False), # noqa: F405

0 commit comments

Comments
 (0)