Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .convert_linear_to_conv2d import ConvertLinearToConv2d
from .convert_mha_to_sha import ConvertMhaToSha
from .convert_square_to_pow import ConvertSquareToPow
from .decompose_acos import DecomposeAcos
from .decompose_any import DecomposeAny
from .decompose_binary_alpha import DecomposeBinaryAlpha
from .decompose_cdist import DecomposeCDist
Expand Down Expand Up @@ -65,6 +66,7 @@
ConvertLinearToConv2d,
ConvertMhaToSha,
ConvertSquareToPow,
DecomposeAcos,
DecomposeAny,
DecomposeBinaryAlpha,
DecomposeCDist,
Expand Down
75 changes: 75 additions & 0 deletions backends/qualcomm/_passes/decompose_acos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ExportPass, PassResult

from .utils import copy_meta, get_const_node


class DecomposeAcos(ExportPass):
"""
Decompose acos using the identity: acos(x) = π/2 - asin(x).
"""

def __init__(self):
super(DecomposeAcos, self).__init__()
self.acos_targets = {
torch.ops.aten.acos.default,
exir_ops.edge.aten.acos.default,
}

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph

acos_nodes = [
n
for n in graph.nodes
if n.op == "call_function" and n.target in self.acos_targets
]
if not acos_nodes:
return PassResult(graph_module, False)

pi_half = torch.pi / 2.0
pi_half_node = None

for node in acos_nodes:
input_node = node.args[0]
is_edge = isinstance(node.target, EdgeOpOverload)

asin_op = (
exir_ops.edge.aten.asin.default
if is_edge
else torch.ops.aten.asin.default
)
sub_op = (
exir_ops.edge.aten.sub.Tensor if is_edge else torch.ops.aten.sub.Tensor
)

if is_edge and pi_half_node is None:
pi_half_node = get_const_node(
graph, graph_module, "_pi_half_constant", pi_half, node
)

sub_arg = pi_half_node if is_edge else pi_half

with graph.inserting_before(node):
asin_node = graph.create_node("call_function", asin_op, (input_node,))
asin_node.meta = copy_meta(node.meta)

sub_node = graph.create_node(
"call_function", sub_op, (sub_arg, asin_node)
)
sub_node.meta = copy_meta(node.meta)

for user in node.users.copy():
user.replace_input_with(node, sub_node)

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
39 changes: 18 additions & 21 deletions backends/qualcomm/_passes/decompose_log_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,31 +31,29 @@ class DecomposeLogVariants(ExportPass):
def __init__(self) -> None:
super().__init__()
self._dispatcher = {
# Edge dialect (post-to_edge) - FP
# Edge dialect (post-to_edge)
exir_ops.edge.aten.log10.default: partial(self._decompose_log_n, n=10),
exir_ops.edge.aten.log2.default: partial(self._decompose_log_n, n=2),
exir_ops.edge.aten.log1p.default: partial(self._decompose_log_p, p=1),
# ATen dialect (pre-to_edge) - Quantized
# ATen dialect (pre-to_edge)
torch.ops.aten.log10.default: partial(self._decompose_log_n, n=10),
torch.ops.aten.log2.default: partial(self._decompose_log_n, n=2),
torch.ops.aten.log1p.default: partial(self._decompose_log_p, p=1),
}

def _decompose_log_n(self, node, graph, graph_module, n):
def _decompose_log_n(self, node, graph, graph_module, const_cache, n):
input_node = node.args[0]
is_edge = node.target in self._EDGE_OPS

if is_edge:
log_op = exir_ops.edge.aten.log.default
div_op = exir_ops.edge.aten.div.Tensor
div_arg = get_const_node(
graph,
graph_module,
f"_log_base_{n}_constant",
math.log(n),
node,
)

attr_name = f"_log_base_{n}_constant"
if attr_name not in const_cache:
const_cache[attr_name] = get_const_node(
graph, graph_module, attr_name, math.log(n), node
)
div_arg = const_cache[attr_name]
else:
log_op = torch.ops.aten.log.default
div_op = torch.ops.aten.div.Tensor
Expand All @@ -74,21 +72,19 @@ def _decompose_log_n(self, node, graph, graph_module, n):
for user in node.users.copy():
user.replace_input_with(node, div_node)

def _decompose_log_p(self, node, graph, graph_module, p):
def _decompose_log_p(self, node, graph, graph_module, const_cache, p):
input_node = node.args[0]
is_edge = node.target in self._EDGE_OPS

if is_edge:
add_op = exir_ops.edge.aten.add.Tensor
log_op = exir_ops.edge.aten.log.default
add_arg = get_const_node(
graph,
graph_module,
f"_log1p_addend_{p}_constant",
p,
node,
)

attr_name = f"_log1p_addend_{p}_constant"
if attr_name not in const_cache:
const_cache[attr_name] = get_const_node(
graph, graph_module, attr_name, p, node
)
add_arg = const_cache[attr_name]
else:
add_op = torch.ops.aten.add.Tensor
log_op = torch.ops.aten.log.default
Expand All @@ -107,10 +103,11 @@ def _decompose_log_p(self, node, graph, graph_module, p):

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph = graph_module.graph
const_cache = {}

for node in list(graph.nodes):
if node.target in self._dispatcher:
self._dispatcher[node.target](node, graph, graph_module)
self._dispatcher[node.target](node, graph, graph_module, const_cache)

graph.eliminate_dead_code()
graph_module.recompile()
Expand Down
3 changes: 3 additions & 0 deletions backends/qualcomm/_passes/qnn_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ConvertLinearToConv2d,
ConvertMhaToSha,
ConvertSquareToPow,
DecomposeAcos,
DecomposeAny,
DecomposeBinaryAlpha,
DecomposeCDist,
Expand Down Expand Up @@ -96,6 +97,7 @@ def get_capture_program_passes():
(AnnotateStack, True),
(AnnotateUnbind, True),
(ConvertBmmToMatmul, False),
(DecomposeAcos, True),
(DecomposeAny, True),
(DecomposeColIm, True),
(DecomposeLogVariants, True),
Expand Down Expand Up @@ -213,6 +215,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(RecomposePixelUnshuffle(quantization_capture=True))
self.add_pass(RecomposeRmsNorm(quantization_capture=True))
self.add_pass(ReplaceArangeArgs())
self.add_pass(DecomposeAcos())
self.add_pass(DecomposeBinaryAlpha())
self.add_pass(DecomposeCDist())
self.add_pass(DecomposeMaxPool3d(quantization_capture=True))
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def get_passes_dependency_for_capture_program():
AnnotateUnbind,
CanonicalizeConv,
ConvertBmmToMatmul,
DecomposeAcos,
DecomposeAny,
DecomposeColIm,
DecomposeLinalgVectorNorm,
Expand Down Expand Up @@ -95,6 +96,7 @@ def get_passes_dependency_for_capture_program():
AnnotateStack: [RemoveRedundancy],
AnnotateUnbind: [RemoveRedundancy],
ConvertBmmToMatmul: [RecomposePixelUnshuffle],
DecomposeAcos: [RemoveRedundancy],
DecomposeAny: [RemoveRedundancy],
DecomposeColIm: [FoldQDQ],
DecomposeLinalgVectorNorm: [RemoveRedundancy],
Expand Down
24 changes: 24 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,22 @@ def forward(self, x):
return torch.abs(x)


class Acos(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.acos(x)


class AcosMultiNode(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.acos(x), torch.acos(y)


class AdaptiveMaxPool2D(torch.nn.Module):
def __init__(self, output_size, return_indices=False):
super().__init__()
Expand Down Expand Up @@ -1498,6 +1514,14 @@ def forward(self, x):
return torch.nn.functional.log_softmax(x, dim=-1)


class LogVariantsMultiNode(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.log10(x), torch.log10(y), torch.log2(x), torch.log1p(x)


class Log10(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
45 changes: 45 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,19 @@ def test_qnn_backend_abs(self):
sample_input = (torch.randn(1, 2, 3, 4),)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_acos(self):
module = Acos() # noqa: F405
sample_input = (torch.rand(3, 4) * 2 - 1,)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_acos_multi_node(self):
module = AcosMultiNode() # noqa: F405
sample_input = (
torch.tensor([0.0, 0.5, -0.5, 1.0, -1.0]),
torch.tensor([0.1, -0.1, 0.9, -0.9, 0.0]),
)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_adaptive_avg_pool1d(self):
module = AdaptiveAvgPool1D() # noqa: F405
sample_input = (torch.randn(1, 512, 7),)
Expand Down Expand Up @@ -1455,6 +1468,14 @@ def test_qnn_backend_log_softmax(self):
sample_input = (torch.randn([1, 4, 8, 8]),)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_log_variants_multi_node(self):
module = LogVariantsMultiNode() # noqa: F405
sample_input = (
torch.abs(torch.rand(2, 3, 4)) + 0.1,
torch.abs(torch.rand(2, 3, 4)) + 0.1,
)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_log10(self):
module = Log10() # noqa: F405
sample_input = (torch.abs(torch.rand(2, 5, 1, 3) + 0.1),)
Expand Down Expand Up @@ -2362,6 +2383,21 @@ def test_qnn_backend_abs(self):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_acos(self):
module = Acos() # noqa: F405
sample_input = (torch.rand(3, 4) * 2 - 1,)
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_acos_multi_node(self):
module = AcosMultiNode() # noqa: F405
sample_input = (
torch.tensor([0.0, 0.5, -0.5, 1.0, -1.0]),
torch.tensor([0.1, -0.1, 0.9, -0.9, 0.0]),
)
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_adaptive_avg_pool1d(self):
module = AdaptiveAvgPool1D() # noqa: F405
sample_input = (torch.randn(1, 512, 7),)
Expand Down Expand Up @@ -3813,6 +3849,15 @@ def test_qnn_backend_log_softmax(self):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_log_variants_multi_node(self):
module = LogVariantsMultiNode() # noqa: F405
sample_input = (
torch.abs(torch.rand(2, 3, 4)) + 0.1,
torch.abs(torch.rand(2, 3, 4)) + 0.1,
)
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_log10(self):
module = Log10() # noqa: F405
sample_input = (torch.abs(torch.rand(2, 5, 1, 3) + 0.1),)
Expand Down
Loading