Skip to content

Commit 137dedc

Browse files
authored
Qualcomm AI Engine Direct - Adding QNN backend support for avg_pool_1d core ATen op (#18733)
1 parent ddb7762 commit 137dedc

7 files changed

Lines changed: 43 additions & 13 deletions

File tree

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from .annotate_adaptive_avg_pool1d import AnnotateAdaptiveAvgPool1D
7+
from .annotate_avg_pool1d import AnnotateAvgPool1D
88
from .annotate_quant_attrs import AnnotateQuantAttrs
99
from .annotate_stack import AnnotateStack
1010
from .annotate_unbind import AnnotateUnbind
@@ -57,7 +57,7 @@
5757
from .tag_quant_io import TagQuantIO
5858

5959
__all__ = [
60-
AnnotateAdaptiveAvgPool1D,
60+
AnnotateAvgPool1D,
6161
AnnotateQuantAttrs,
6262
AnnotateStack,
6363
AnnotateUnbind,

backends/qualcomm/_passes/annotate_adaptive_avg_pool1d.py renamed to backends/qualcomm/_passes/annotate_avg_pool1d.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
67
import torch
78
from executorch.backends.qualcomm.builders.node_visitor import dq_ops, q_ops
89
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
@@ -13,21 +14,29 @@
1314
from .utils import get_quant_attrs
1415

1516

16-
class AnnotateAdaptiveAvgPool1D(ExportPass):
17+
class AnnotateAvgPool1D(ExportPass):
1718
"""
1819
Add "quant_attrs" to graph nodes' meta from the QDQ information
1920
generated after quantization process.
20-
adaptive_avg_pool1d got decomposed to unsqueeze -> adaptive_avg_pool2d -> squeeze
21+
avg_pool1d and adaptive_avg_pool1d get decomposed to:
22+
unsqueeze -> avg_pool2d/adaptive_avg_pool2d -> squeeze
2123
"""
2224

25+
_SOURCE_OPS = [
26+
torch.ops.aten.avg_pool1d.default,
27+
torch.avg_pool1d,
28+
torch.ops.aten.adaptive_avg_pool1d.default,
29+
torch.adaptive_avg_pool1d,
30+
]
31+
2332
def __init__(self, edge_program: torch.export.ExportedProgram):
24-
super(AnnotateAdaptiveAvgPool1D, self).__init__()
33+
super(AnnotateAvgPool1D, self).__init__()
2534
self.edge_program = edge_program
2635

27-
def _annotate_adaptive_avg_pool1d(self, graph_module: torch.fx.GraphModule):
36+
def _annotate(self, graph_module: torch.fx.GraphModule):
2837
partitions = get_source_partitions(
2938
graph_module.graph,
30-
[torch.ops.aten.adaptive_avg_pool1d.default, torch.adaptive_avg_pool1d],
39+
self._SOURCE_OPS,
3140
)
3241
for src_partitions in partitions.values():
3342
for src_partition in src_partitions:
@@ -44,11 +53,11 @@ def _annotate_adaptive_avg_pool1d(self, graph_module: torch.fx.GraphModule):
4453
self.edge_program, list(output.users)[0]
4554
)
4655
for n in src_partition.nodes:
47-
# For adaptive_avg_pool2d and squeeze
56+
# For avg_pool2d/adaptive_avg_pool2d and squeeze
4857
if n.target != exir_ops.edge.aten.unsqueeze_copy.default:
4958
n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy()
5059

5160
def call(self, graph_module: torch.fx.GraphModule):
52-
self._annotate_adaptive_avg_pool1d(graph_module)
61+
self._annotate(graph_module)
5362
graph_module.recompile()
5463
return PassResult(graph_module, True)

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import Dict
1010

1111
from executorch.backends.qualcomm._passes import (
12-
AnnotateAdaptiveAvgPool1D,
12+
AnnotateAvgPool1D,
1313
AnnotateQuantAttrs,
1414
AnnotateStack,
1515
AnnotateUnbind,
@@ -95,7 +95,7 @@ def get_capture_program_passes():
9595
# The second value in each tuple in `default_passes_and_setting` indicates whether the corresponding pass is activated by default.
9696
# If a pass is activated, it will be executed by default.
9797
default_passes_and_setting = [
98-
(AnnotateAdaptiveAvgPool1D, True),
98+
(AnnotateAvgPool1D, True),
9999
(AnnotateQuantAttrs, True),
100100
(AnnotateStack, True),
101101
(AnnotateUnbind, True),

backends/qualcomm/_passes/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def get_passes_dependency_for_capture_program():
6060
dict: A dictionary mapping each pass to its corresponding list of dependencies.
6161
"""
6262
from executorch.backends.qualcomm._passes import (
63-
AnnotateAdaptiveAvgPool1D,
63+
AnnotateAvgPool1D,
6464
AnnotateQuantAttrs,
6565
AnnotateStack,
6666
AnnotateUnbind,
@@ -86,7 +86,7 @@ def get_passes_dependency_for_capture_program():
8686
)
8787

8888
return {
89-
AnnotateAdaptiveAvgPool1D: [RemoveRedundancy],
89+
AnnotateAvgPool1D: [RemoveRedundancy],
9090
AnnotateQuantAttrs: [
9191
ConvertBmmToMatmul,
9292
RecomposePixelUnshuffle,

backends/qualcomm/quantizer/annotators/htp_rules.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ class Atan(GeneralOpDef):
163163
[
164164
torch.ops.aten.adaptive_avg_pool1d.default,
165165
torch.ops.aten.adaptive_avg_pool2d.default,
166+
torch.ops.aten.avg_pool1d.default,
166167
torch.ops.aten.avg_pool2d.default,
167168
],
168169
QnnConstants.OpPoolAvg2d.op_name,

backends/qualcomm/tests/models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,15 @@ def forward(self, x):
263263
return torch.atan(x)
264264

265265

266+
class AvgPool1D(torch.nn.Module):
267+
def __init__(self):
268+
super().__init__()
269+
self.pool = torch.nn.AvgPool1d(kernel_size=3, stride=2, padding=1)
270+
271+
def forward(self, x):
272+
return self.pool(x)
273+
274+
266275
class AvgPool3d(torch.nn.Module):
267276
def __init__(self, kernel_size, stride, padding, ceil_mode, count_include_pad):
268277
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,11 @@ def test_qnn_backend_atan(self):
301301
module = Atan() # noqa: F405
302302
self.lower_module_and_test_output(module, sample_input)
303303

304+
def test_qnn_backend_avg_pool1d(self):
305+
module = AvgPool1D() # noqa: F405
306+
sample_input = (torch.randn(1, 512, 7),)
307+
self.lower_module_and_test_output(module, sample_input)
308+
304309
def test_qnn_backend_avg_pool2d(self):
305310
modules = [
306311
AvgPoolModule((2, 2), (1, 1), (1, 1), False), # noqa: F405
@@ -2585,6 +2590,12 @@ def test_qnn_backend_atan(self):
25852590
module = self.get_qdq_module(module, sample_input)
25862591
self.lower_module_and_test_output(module, sample_input)
25872592

2593+
def test_qnn_backend_avg_pool1d(self):
2594+
module = AvgPool1D() # noqa: F405
2595+
sample_input = (torch.randn(1, 512, 7),)
2596+
module = self.get_qdq_module(module, sample_input)
2597+
self.lower_module_and_test_output(module, sample_input)
2598+
25882599
def test_qnn_backend_avg_pool2d(self):
25892600
modules = [
25902601
AvgPoolModule((2, 2), (1, 1), (1, 1), False), # noqa: F405

0 commit comments

Comments
 (0)