Skip to content

Commit 45124e6

Browse files
committed
Qualcomm AI Engine Direct - Adding QNN backend support for randn core ATen op
1 parent 74c7c91 commit 45124e6

9 files changed

Lines changed: 147 additions & 7 deletions

File tree

backends/qualcomm/_passes/layout_transform.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ class LayoutTransform(ExportPass):
113113
exir_ops.edge.aten.neg.default,
114114
exir_ops.edge.aten.pow.Tensor_Scalar,
115115
exir_ops.edge.aten.prelu.default,
116+
exir_ops.edge.aten.rand.default,
117+
exir_ops.edge.aten.randn.default,
116118
exir_ops.edge.aten.reflection_pad1d.default,
117119
exir_ops.edge.aten.reflection_pad2d.default,
118120
exir_ops.edge.aten.repeat.default,

backends/qualcomm/builders/README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ Please help update following table if you are contributing new operators:
368368
+ 🚫 = Deprecated, supported with other QNN Ops
369369

370370

371-
| Operators | HTP - 98/119 Enabled |
371+
| Operators | HTP - 100/120 Enabled |
372372
|-----------|---------|
373373
| Argmax | ✓ |
374374
| Argmin | ✓ |
@@ -457,7 +457,8 @@ Please help update following table if you are contributing new operators:
457457
| PoolMax2d | ✓ |
458458
| Prelu | ✓ |
459459
| Quantize | ✓ |
460-
| Rand | ✓ |
460+
| RandomUniformLike | ✓ |
461+
| RandomNormalLike | ✓ |
461462
| ReduceMax | ✓ |
462463
| ReduceMean | ✓ |
463464
| ReduceMin | ✓ |
@@ -472,7 +473,7 @@ Please help update following table if you are contributing new operators:
472473
| ResizeNearestNeighbor | ✓ |
473474
| RoiAlign | ✗ |
474475
| RmsNorm | ✓ |
475-
| ScatterElements | ✗ |
476+
| ScatterElements | ✓ |
476477
| ScatterNd | ✓ |
477478
| Sigmoid | ✓ |
478479
| Softmax | ✓ |

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
op_prelu,
8282
op_quantize,
8383
op_rand,
84+
op_randn,
8485
op_relu,
8586
op_repeat,
8687
op_reshape,
@@ -194,6 +195,7 @@
194195
op_prelu,
195196
op_quantize,
196197
op_rand,
198+
op_randn,
197199
op_relu,
198200
op_repeat,
199201
op_reshape,
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager
9+
10+
import numpy as np
11+
import torch
12+
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
13+
14+
from .node_visitor import NodeVisitor
15+
from .node_visitor_manager import register_node_visitor
16+
from .qnn_constants import OpRandomNormalLike, QNN_OP_PACKAGE_NAME_QTI_AISW
17+
18+
19+
@register_node_visitor
20+
class Randn(NodeVisitor):
21+
target = ["aten.randn.default", "aten.randn_like.default"]
22+
23+
def __init__(self, *args) -> None:
24+
super().__init__(*args)
25+
26+
def define_node(
27+
self,
28+
node: torch.fx.Node,
29+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper],
30+
) -> PyQnnManager.PyQnnOpWrapper:
31+
output_tensor = node.meta["val"]
32+
output_shape = list(output_tensor.shape)
33+
34+
shape_data = np.array(output_shape, dtype=np.uint32)
35+
shape_dims = [len(output_shape)]
36+
37+
shape_tensor_wrapper = PyQnnManager.TensorWrapper(
38+
f"{node.name}_shape",
39+
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
40+
PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
41+
PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED,
42+
{},
43+
len(shape_dims),
44+
shape_dims,
45+
[],
46+
shape_data,
47+
True,
48+
)
49+
50+
output_tensor_wrapper = self.define_tensor(
51+
node,
52+
node,
53+
output_tensor,
54+
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
55+
nodes_to_wrappers,
56+
)
57+
58+
randn_op = PyQnnManager.PyQnnOpWrapper(
59+
node.name,
60+
QNN_OP_PACKAGE_NAME_QTI_AISW,
61+
OpRandomNormalLike.op_name,
62+
)
63+
64+
randn_op.AddInputTensors([shape_tensor_wrapper])
65+
randn_op.AddOutputTensors([output_tensor_wrapper])
66+
67+
randn_op.AddScalarParam(
68+
OpRandomNormalLike.param_mean,
69+
PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
70+
{QCOM_DATA: np.float32(0.0)},
71+
)
72+
73+
randn_op.AddScalarParam(
74+
OpRandomNormalLike.param_scale,
75+
PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
76+
{QCOM_DATA: np.float32(1.0)},
77+
)
78+
79+
return randn_op

backends/qualcomm/builders/qnn_constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,13 @@ class OpQuantize:
504504
op_name: str = "Quantize"
505505

506506

507+
@dataclass(init=False, frozen=True)
508+
class OpRandomNormalLike:
509+
op_name: str = "RandomNormalLike"
510+
param_mean: str = "mean"
511+
param_scale: str = "scale"
512+
513+
507514
@dataclass(init=False, frozen=True)
508515
class OpRandomUniformLike:
509516
op_name: str = "RandomUniformLike"

backends/qualcomm/quantizer/annotators/htp_rules.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,8 @@ class ColIm(GeneralOpDef):
343343
torch.ops.aten.zeros_like.default,
344344
torch.ops.aten.ones.default,
345345
torch.ops.aten.ones_like.default,
346+
torch.ops.aten.rand.default,
347+
torch.ops.aten.randn.default,
346348
],
347349
qnn_op=None,
348350
)

backends/qualcomm/quantizer/annotators/lpai_rules.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,8 @@ class ColIm(GeneralOpDef):
271271
torch.ops.aten.zeros_like.default,
272272
torch.ops.aten.ones.default,
273273
torch.ops.aten.ones_like.default,
274+
torch.ops.aten.rand.default,
275+
torch.ops.aten.randn.default,
274276
],
275277
qnn_op=None,
276278
)

backends/qualcomm/tests/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1929,6 +1929,14 @@ def forward(self, x):
19291929
return torch.rand_like(x) + x
19301930

19311931

1932+
class Randn(torch.nn.Module):
1933+
def __init__(self):
1934+
super().__init__()
1935+
1936+
def forward(self, x):
1937+
return torch.randn_like(x) + x
1938+
1939+
19321940
class Reciprocal(torch.nn.Module):
19331941
def __init__(self):
19341942
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1803,6 +1803,25 @@ def test_qnn_backend_prelu(self):
18031803
self.lower_module_and_test_output(module, sample_input)
18041804

18051805
def test_qnn_backend_rand(self):
1806+
module = Rand() # noqa: F405
1807+
sample_inputs = [
1808+
(torch.randn(3, 4, 5),),
1809+
(torch.randn(2, 8),),
1810+
(
1811+
torch.randn(
1812+
10,
1813+
),
1814+
),
1815+
(torch.randn(1, 3, 32, 32),),
1816+
]
1817+
for i, sample_input in enumerate(sample_inputs):
1818+
with self.subTest(i=i):
1819+
self.lower_module_and_test_output(
1820+
module, sample_input, assert_output_equal=False
1821+
)
1822+
1823+
def test_qnn_backend_randn(self):
1824+
module = Randn() # noqa: F405
18061825
sample_inputs = [
18071826
(torch.randn(3, 4, 5),),
18081827
(torch.randn(2, 8),),
@@ -1815,7 +1834,6 @@ def test_qnn_backend_rand(self):
18151834
]
18161835
for i, sample_input in enumerate(sample_inputs):
18171836
with self.subTest(i=i):
1818-
module = Rand() # noqa: F405
18191837
self.lower_module_and_test_output(
18201838
module, sample_input, assert_output_equal=False
18211839
)
@@ -4380,6 +4398,7 @@ def test_qnn_backend_prelu(self):
43804398
self.lower_module_and_test_output(module, sample_input)
43814399

43824400
def test_qnn_backend_rand(self):
4401+
module = Rand() # noqa: F405
43834402
sample_inputs = [
43844403
(torch.randn(3, 4, 5),),
43854404
(torch.randn(2, 8),),
@@ -4392,10 +4411,28 @@ def test_qnn_backend_rand(self):
43924411
]
43934412
for i, sample_input in enumerate(sample_inputs):
43944413
with self.subTest(i=i):
4395-
module = Rand() # noqa: F405
4396-
module = self.get_qdq_module(module, sample_input)
4414+
qdq_module = self.get_qdq_module(module, sample_input)
43974415
self.lower_module_and_test_output(
4398-
module, sample_input, assert_output_equal=False
4416+
qdq_module, sample_input, assert_output_equal=False
4417+
)
4418+
4419+
def test_qnn_backend_randn(self):
4420+
module = Randn() # noqa: F405
4421+
sample_inputs = [
4422+
(torch.randn(3, 4, 5),),
4423+
(torch.randn(2, 8),),
4424+
(
4425+
torch.randn(
4426+
10,
4427+
),
4428+
),
4429+
(torch.randn(1, 3, 32, 32),),
4430+
]
4431+
for i, sample_input in enumerate(sample_inputs):
4432+
with self.subTest(i=i):
4433+
qdq_module = self.get_qdq_module(module, sample_input)
4434+
self.lower_module_and_test_output(
4435+
qdq_module, sample_input, assert_output_equal=False
43994436
)
44004437

44014438
def test_qnn_backend_reciprocal(self):

0 commit comments

Comments
 (0)