diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index bbf5ea52741..6b9af389a12 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -13,6 +13,7 @@ from .convert_linear_to_conv2d import ConvertLinearToConv2d from .convert_mha_to_sha import ConvertMhaToSha from .convert_square_to_pow import ConvertSquareToPow +from .decompose_acos import DecomposeAcos from .decompose_any import DecomposeAny from .decompose_binary_alpha import DecomposeBinaryAlpha from .decompose_cdist import DecomposeCDist @@ -65,6 +66,7 @@ ConvertLinearToConv2d, ConvertMhaToSha, ConvertSquareToPow, + DecomposeAcos, DecomposeAny, DecomposeBinaryAlpha, DecomposeCDist, diff --git a/backends/qualcomm/_passes/decompose_acos.py b/backends/qualcomm/_passes/decompose_acos.py new file mode 100644 index 00000000000..f83b18f11fc --- /dev/null +++ b/backends/qualcomm/_passes/decompose_acos.py @@ -0,0 +1,75 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, PassResult + +from .utils import copy_meta, get_const_node + + +class DecomposeAcos(ExportPass): + """ + Decompose acos using the identity: acos(x) = π/2 - asin(x). + """ + + def __init__(self): + super(DecomposeAcos, self).__init__() + self.acos_targets = { + torch.ops.aten.acos.default, + exir_ops.edge.aten.acos.default, + } + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + + acos_nodes = [ + n + for n in graph.nodes + if n.op == "call_function" and n.target in self.acos_targets + ] + if not acos_nodes: + return PassResult(graph_module, False) + + pi_half = torch.pi / 2.0 + pi_half_node = None + + for node in acos_nodes: + input_node = node.args[0] + is_edge = isinstance(node.target, EdgeOpOverload) + + asin_op = ( + exir_ops.edge.aten.asin.default + if is_edge + else torch.ops.aten.asin.default + ) + sub_op = ( + exir_ops.edge.aten.sub.Tensor if is_edge else torch.ops.aten.sub.Tensor + ) + + if is_edge and pi_half_node is None: + pi_half_node = get_const_node( + graph, graph_module, "_pi_half_constant", pi_half, node + ) + + sub_arg = pi_half_node if is_edge else pi_half + + with graph.inserting_before(node): + asin_node = graph.create_node("call_function", asin_op, (input_node,)) + asin_node.meta = copy_meta(node.meta) + + sub_node = graph.create_node( + "call_function", sub_op, (sub_arg, asin_node) + ) + sub_node.meta = copy_meta(node.meta) + + for user in node.users.copy(): + user.replace_input_with(node, sub_node) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/decompose_log_variants.py b/backends/qualcomm/_passes/decompose_log_variants.py index f8569122958..2b394806b68 100644 --- a/backends/qualcomm/_passes/decompose_log_variants.py +++ b/backends/qualcomm/_passes/decompose_log_variants.py @@ -31,31 +31,29 @@ class DecomposeLogVariants(ExportPass): def __init__(self) -> None: super().__init__() self._dispatcher = { - # Edge dialect (post-to_edge) - FP + # Edge dialect (post-to_edge) exir_ops.edge.aten.log10.default: partial(self._decompose_log_n, n=10), exir_ops.edge.aten.log2.default: partial(self._decompose_log_n, n=2), exir_ops.edge.aten.log1p.default: partial(self._decompose_log_p, p=1), - # ATen dialect (pre-to_edge) - Quantized + # ATen dialect (pre-to_edge) torch.ops.aten.log10.default: partial(self._decompose_log_n, n=10), torch.ops.aten.log2.default: partial(self._decompose_log_n, n=2), torch.ops.aten.log1p.default: partial(self._decompose_log_p, p=1), } - def _decompose_log_n(self, node, graph, graph_module, n): + def _decompose_log_n(self, node, graph, graph_module, const_cache, n): input_node = node.args[0] is_edge = node.target in self._EDGE_OPS if is_edge: log_op = exir_ops.edge.aten.log.default div_op = exir_ops.edge.aten.div.Tensor - div_arg = get_const_node( - graph, - graph_module, - f"_log_base_{n}_constant", - math.log(n), - node, - ) - + attr_name = f"_log_base_{n}_constant" + if attr_name not in const_cache: + const_cache[attr_name] = get_const_node( + graph, graph_module, attr_name, math.log(n), node + ) + div_arg = const_cache[attr_name] else: log_op = torch.ops.aten.log.default div_op = torch.ops.aten.div.Tensor @@ -74,21 +72,19 @@ def _decompose_log_n(self, node, graph, graph_module, n): for user in node.users.copy(): user.replace_input_with(node, div_node) - def _decompose_log_p(self, node, graph, graph_module, p): + def _decompose_log_p(self, node, graph, graph_module, const_cache, p): input_node = node.args[0] is_edge = node.target in self._EDGE_OPS if is_edge: add_op = exir_ops.edge.aten.add.Tensor log_op = exir_ops.edge.aten.log.default - add_arg = get_const_node( - graph, - graph_module, - f"_log1p_addend_{p}_constant", - p, - node, - ) - + attr_name = f"_log1p_addend_{p}_constant" + if attr_name not in const_cache: + const_cache[attr_name] = get_const_node( + graph, graph_module, attr_name, p, node + ) + add_arg = const_cache[attr_name] else: add_op = torch.ops.aten.add.Tensor log_op = torch.ops.aten.log.default @@ -107,10 +103,11 @@ def _decompose_log_p(self, node, graph, graph_module, p): def call(self, graph_module: torch.fx.GraphModule) -> PassResult: graph = graph_module.graph + const_cache = {} for node in list(graph.nodes): if node.target in self._dispatcher: - self._dispatcher[node.target](node, graph, graph_module) + self._dispatcher[node.target](node, graph, graph_module, const_cache) graph.eliminate_dead_code() graph_module.recompile() diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index 19d855be9d9..07f3ff8e9fc 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -18,6 +18,7 @@ ConvertLinearToConv2d, ConvertMhaToSha, ConvertSquareToPow, + DecomposeAcos, DecomposeAny, DecomposeBinaryAlpha, DecomposeCDist, @@ -96,6 +97,7 @@ def get_capture_program_passes(): (AnnotateStack, True), (AnnotateUnbind, True), (ConvertBmmToMatmul, False), + (DecomposeAcos, True), (DecomposeAny, True), (DecomposeColIm, True), (DecomposeLogVariants, True), @@ -213,6 +215,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(RecomposePixelUnshuffle(quantization_capture=True)) self.add_pass(RecomposeRmsNorm(quantization_capture=True)) self.add_pass(ReplaceArangeArgs()) + self.add_pass(DecomposeAcos()) self.add_pass(DecomposeBinaryAlpha()) self.add_pass(DecomposeCDist()) self.add_pass(DecomposeMaxPool3d(quantization_capture=True)) diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index 07be2b597f1..93d4f6d7992 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -66,6 +66,7 @@ def get_passes_dependency_for_capture_program(): AnnotateUnbind, CanonicalizeConv, ConvertBmmToMatmul, + DecomposeAcos, DecomposeAny, DecomposeColIm, DecomposeLinalgVectorNorm, @@ -95,6 +96,7 @@ def get_passes_dependency_for_capture_program(): AnnotateStack: [RemoveRedundancy], AnnotateUnbind: [RemoveRedundancy], ConvertBmmToMatmul: [RecomposePixelUnshuffle], + DecomposeAcos: [RemoveRedundancy], DecomposeAny: [RemoveRedundancy], DecomposeColIm: [FoldQDQ], DecomposeLinalgVectorNorm: [RemoveRedundancy], diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index da6b4bec66c..53c3a33825d 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -41,6 +41,22 @@ def forward(self, x): return torch.abs(x) +class Acos(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.acos(x) + + +class AcosMultiNode(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.acos(x), torch.acos(y) + + class AdaptiveMaxPool2D(torch.nn.Module): def __init__(self, output_size, return_indices=False): super().__init__() @@ -1498,6 +1514,14 @@ def forward(self, x): return torch.nn.functional.log_softmax(x, dim=-1) +class LogVariantsMultiNode(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.log10(x), torch.log10(y), torch.log2(x), torch.log1p(x) + + class Log10(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 4e23f43c2ea..be4ba665823 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -125,6 +125,19 @@ def test_qnn_backend_abs(self): sample_input = (torch.randn(1, 2, 3, 4),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_acos(self): + module = Acos() # noqa: F405 + sample_input = (torch.rand(3, 4) * 2 - 1,) + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_acos_multi_node(self): + module = AcosMultiNode() # noqa: F405 + sample_input = ( + torch.tensor([0.0, 0.5, -0.5, 1.0, -1.0]), + torch.tensor([0.1, -0.1, 0.9, -0.9, 0.0]), + ) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_adaptive_avg_pool1d(self): module = AdaptiveAvgPool1D() # noqa: F405 sample_input = (torch.randn(1, 512, 7),) @@ -1455,6 +1468,14 @@ def test_qnn_backend_log_softmax(self): sample_input = (torch.randn([1, 4, 8, 8]),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_log_variants_multi_node(self): + module = LogVariantsMultiNode() # noqa: F405 + sample_input = ( + torch.abs(torch.rand(2, 3, 4)) + 0.1, + torch.abs(torch.rand(2, 3, 4)) + 0.1, + ) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_log10(self): module = Log10() # noqa: F405 sample_input = (torch.abs(torch.rand(2, 5, 1, 3) + 0.1),) @@ -2362,6 +2383,21 @@ def test_qnn_backend_abs(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_acos(self): + module = Acos() # noqa: F405 + sample_input = (torch.rand(3, 4) * 2 - 1,) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_acos_multi_node(self): + module = AcosMultiNode() # noqa: F405 + sample_input = ( + torch.tensor([0.0, 0.5, -0.5, 1.0, -1.0]), + torch.tensor([0.1, -0.1, 0.9, -0.9, 0.0]), + ) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_adaptive_avg_pool1d(self): module = AdaptiveAvgPool1D() # noqa: F405 sample_input = (torch.randn(1, 512, 7),) @@ -3813,6 +3849,15 @@ def test_qnn_backend_log_softmax(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_log_variants_multi_node(self): + module = LogVariantsMultiNode() # noqa: F405 + sample_input = ( + torch.abs(torch.rand(2, 3, 4)) + 0.1, + torch.abs(torch.rand(2, 3, 4)) + 0.1, + ) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_log10(self): module = Log10() # noqa: F405 sample_input = (torch.abs(torch.rand(2, 5, 1, 3) + 0.1),)