Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .annotate_adaptive_avg_pool1d import AnnotateAdaptiveAvgPool1D
from .annotate_avg_pool1d import AnnotateAvgPool1D
from .annotate_quant_attrs import AnnotateQuantAttrs
from .annotate_stack import AnnotateStack
from .annotate_unbind import AnnotateUnbind
Expand Down Expand Up @@ -56,7 +56,7 @@
from .tag_quant_io import TagQuantIO

__all__ = [
AnnotateAdaptiveAvgPool1D,
AnnotateAvgPool1D,
AnnotateQuantAttrs,
AnnotateStack,
AnnotateUnbind,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.qualcomm.builders.node_visitor import dq_ops, q_ops
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
Expand All @@ -13,21 +14,29 @@
from .utils import get_quant_attrs


class AnnotateAdaptiveAvgPool1D(ExportPass):
class AnnotateAvgPool1D(ExportPass):
"""
Add "quant_attrs" to graph nodes' meta from the QDQ information
generated after quantization process.
adaptive_avg_pool1d got decomposed to unsqueeze -> adaptive_avg_pool2d -> squeeze
avg_pool1d and adaptive_avg_pool1d get decomposed to:
unsqueeze -> avg_pool2d/adaptive_avg_pool2d -> squeeze
"""

_SOURCE_OPS = [
torch.ops.aten.avg_pool1d.default,
torch.avg_pool1d,
torch.ops.aten.adaptive_avg_pool1d.default,
torch.adaptive_avg_pool1d,
]

def __init__(self, edge_program: torch.export.ExportedProgram):
super(AnnotateAdaptiveAvgPool1D, self).__init__()
super(AnnotateAvgPool1D, self).__init__()
self.edge_program = edge_program

def _annotate_adaptive_avg_pool1d(self, graph_module: torch.fx.GraphModule):
def _annotate(self, graph_module: torch.fx.GraphModule):
partitions = get_source_partitions(
graph_module.graph,
[torch.ops.aten.adaptive_avg_pool1d.default, torch.adaptive_avg_pool1d],
self._SOURCE_OPS,
)
for src_partitions in partitions.values():
for src_partition in src_partitions:
Expand All @@ -44,11 +53,11 @@ def _annotate_adaptive_avg_pool1d(self, graph_module: torch.fx.GraphModule):
self.edge_program, list(output.users)[0]
)
for n in src_partition.nodes:
# For adaptive_avg_pool2d and squeeze
# For avg_pool2d/adaptive_avg_pool2d and squeeze
if n.target != exir_ops.edge.aten.unsqueeze_copy.default:
n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy()

def call(self, graph_module: torch.fx.GraphModule):
self._annotate_adaptive_avg_pool1d(graph_module)
self._annotate(graph_module)
graph_module.recompile()
return PassResult(graph_module, True)
4 changes: 2 additions & 2 deletions backends/qualcomm/_passes/qnn_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Dict

from executorch.backends.qualcomm._passes import (
AnnotateAdaptiveAvgPool1D,
AnnotateAvgPool1D,
AnnotateQuantAttrs,
AnnotateStack,
AnnotateUnbind,
Expand Down Expand Up @@ -91,7 +91,7 @@ def get_capture_program_passes():
# The second value in each tuple in `default_passes_and_setting` indicates whether the corresponding pass is activated by default.
# If a pass is activated, it will be executed by default.
default_passes_and_setting = [
(AnnotateAdaptiveAvgPool1D, True),
(AnnotateAvgPool1D, True),
(AnnotateQuantAttrs, True),
(AnnotateStack, True),
(AnnotateUnbind, True),
Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/_passes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_passes_dependency_for_capture_program():
dict: A dictionary mapping each pass to its corresponding list of dependencies.
"""
from executorch.backends.qualcomm._passes import (
AnnotateAdaptiveAvgPool1D,
AnnotateAvgPool1D,
AnnotateQuantAttrs,
AnnotateStack,
AnnotateUnbind,
Expand All @@ -86,7 +86,7 @@ def get_passes_dependency_for_capture_program():
)

return {
AnnotateAdaptiveAvgPool1D: [RemoveRedundancy],
AnnotateAvgPool1D: [RemoveRedundancy],
AnnotateQuantAttrs: [
ConvertBmmToMatmul,
RecomposePixelUnshuffle,
Expand Down
1 change: 1 addition & 0 deletions backends/qualcomm/quantizer/annotators/htp_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ class Atan(GeneralOpDef):
[
torch.ops.aten.adaptive_avg_pool1d.default,
torch.ops.aten.adaptive_avg_pool2d.default,
torch.ops.aten.avg_pool1d.default,
torch.ops.aten.avg_pool2d.default,
],
QnnConstants.OpPoolAvg2d.op_name,
Expand Down
9 changes: 9 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,15 @@ def forward(self, x):
return torch.atan(x)


class AvgPool1D(torch.nn.Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.AvgPool1d(kernel_size=3, stride=2, padding=1)

def forward(self, x):
return self.pool(x)


class AvgPool3d(torch.nn.Module):
def __init__(self, kernel_size, stride, padding, ceil_mode, count_include_pad):
super().__init__()
Expand Down
11 changes: 11 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,11 @@ def test_qnn_backend_atan(self):
module = Atan() # noqa: F405
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_avg_pool1d(self):
module = AvgPool1D() # noqa: F405
sample_input = (torch.randn(1, 512, 7),)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_avg_pool2d(self):
modules = [
AvgPoolModule((2, 2), (1, 1), (1, 1), False), # noqa: F405
Expand Down Expand Up @@ -2536,6 +2541,12 @@ def test_qnn_backend_atan(self):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_avg_pool1d(self):
module = AvgPool1D() # noqa: F405
sample_input = (torch.randn(1, 512, 7),)
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_avg_pool2d(self):
modules = [
AvgPoolModule((2, 2), (1, 1), (1, 1), False), # noqa: F405
Expand Down
Loading