|
| 1 | +# Copyright (c) Qualcomm Innovation Center, Inc. |
| 2 | +# All rights reserved |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import torch |
| 8 | +from executorch.backends.qualcomm._passes.fold_qdq import FoldQDQ |
| 9 | +from executorch.backends.qualcomm._passes.utils import get_quant_attrs |
| 10 | +from executorch.backends.qualcomm.builders.node_visitor import dq_ops |
| 11 | +from executorch.backends.qualcomm.builders.utils import ( |
| 12 | + is_graph_input, |
| 13 | + is_graph_output, |
| 14 | + is_parameter, |
| 15 | +) |
| 16 | +from executorch.backends.qualcomm.utils.constants import ( |
| 17 | + QCOM_BYPASS_NODE, |
| 18 | + QCOM_FALLBACK_NODE, |
| 19 | + QCOM_QUANT_ATTRS, |
| 20 | + QCOM_QUANTIZED_IO, |
| 21 | +) |
| 22 | + |
| 23 | + |
| 24 | +class LpaiFoldQDQ(FoldQDQ): |
| 25 | + """ |
| 26 | + LPAI-specific extension of FoldQDQ. |
| 27 | +
|
| 28 | + In LPAI backend v6, there is an accuracy drop for the quantize and |
| 29 | + dequantize operations. To address this, keep the quantize/dequantize |
| 30 | + operations at the model's input and output. |
| 31 | +
|
| 32 | + For example: |
| 33 | + input -> q_1 (Fallback) -> dq_1 (Bypass) -> graph -> q_2 (Bypass) -> dq_2 (Fallback) -> output |
| 34 | +
|
| 35 | + Here, q_1 and dq_2 will fallback to CPU, while q_2 and dq_1 will be |
| 36 | + bypassed in qnn_partition and folded in qnn_preprocess. |
| 37 | + """ |
| 38 | + |
| 39 | + def _preserve_qdq(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| 40 | + for n in graph_module.graph.nodes: |
| 41 | + # skip parameters & buffers (base class logic) |
| 42 | + if n.target in dq_ops and is_parameter(n.args[0], self.edge_program): |
| 43 | + self._annotate_bypass(n) |
| 44 | + continue |
| 45 | + |
| 46 | + if ( |
| 47 | + is_graph_input(n, self.edge_program) |
| 48 | + # For tagged quantized I/O, we should not fallback quantize operation. |
| 49 | + and QCOM_QUANTIZED_IO not in n.meta |
| 50 | + ): |
| 51 | + user_list = list(n.users.keys()) |
| 52 | + if len(user_list) > 0: |
| 53 | + q_node = user_list[0] |
| 54 | + q_node.meta[QCOM_FALLBACK_NODE] = True |
| 55 | + # Annotate the q_node since it will serve as the input for the first node during operator validation |
| 56 | + q_node.meta[QCOM_QUANT_ATTRS] = get_quant_attrs( |
| 57 | + self.edge_program, q_node |
| 58 | + ) |
| 59 | + q_node.meta[QCOM_QUANTIZED_IO] = q_node.args[-1] |
| 60 | + dq_node = list(q_node.users.keys())[0] |
| 61 | + # Bypass dequantize op for graph validation by torch |
| 62 | + dq_node.meta[QCOM_BYPASS_NODE] = True |
| 63 | + # Make sure that the quantize operator isn't inserted for input in insert_io_qdq.py |
| 64 | + n.meta[QCOM_QUANTIZED_IO] = q_node.args[-1] |
| 65 | + elif ( |
| 66 | + is_graph_output(n) |
| 67 | + and n.target in dq_ops |
| 68 | + # For tagged quantized I/O, we should not fallback dequantize operation. |
| 69 | + and QCOM_QUANTIZED_IO not in n.args[0].args[0].meta |
| 70 | + ): |
| 71 | + n.meta[QCOM_FALLBACK_NODE] = True |
| 72 | + q_node = n.args[0] |
| 73 | + # Bypass quantize op for graph validation by torch |
| 74 | + q_node.meta[QCOM_BYPASS_NODE] = True |
| 75 | + op_node = q_node.args[0] |
| 76 | + # Make sure that the dequantize operator isn't inserted for output in insert_io_qdq.py |
| 77 | + op_node.meta[QCOM_QUANTIZED_IO] = q_node.args[-1] |
0 commit comments