99
1010import numpy as np
1111import torch
12- from executorch .backends .qualcomm .utils .constants import QCOM_DATA
12+ from executorch .backends .qualcomm .utils .constants import (
13+ QCOM_DATA ,
14+ QCOM_DTYPE ,
15+ QCOM_ENCODING ,
16+ QCOM_QUANT_ATTRS ,
17+ QCOM_QUANT_MAX ,
18+ QCOM_QUANT_MIN ,
19+ QCOM_SCALE ,
20+ QCOM_SCALES ,
21+ QCOM_ZERO_POINT ,
22+ QCOM_ZERO_POINTS ,
23+ )
1324
14- from .node_visitor import NodeVisitor
25+ from .node_visitor import NodeVisitor , PER_CHANNEL_ENCODING , QNN_QUANT_TYPE_MAP
1526from .node_visitor_manager import register_node_visitor
16- from .qnn_constants import OpGather , QNN_OP_PACKAGE_NAME_QTI_AISW
27+ from .qnn_constants import OpConvert , OpGather , QNN_OP_PACKAGE_NAME_QTI_AISW
1728from .utils import get_parameter
1829
1930
@@ -30,6 +41,9 @@ def define_node(
3041 nodes_to_wrappers : Dict [torch .fx .Node , PyQnnManager .TensorWrapper ],
3142 ) -> PyQnnManager .PyQnnOpWrapper :
3243 weight_node = self .get_node (node .args [0 ])
44+ is_pcq_embedding = QCOM_QUANT_ATTRS in weight_node .meta and weight_node .meta [
45+ QCOM_QUANT_ATTRS
46+ ][QCOM_ENCODING ] in (PER_CHANNEL_ENCODING )
3347 weight_tensor = get_parameter (weight_node , self .edge_program )
3448 weight_tensor_wrapper = self .define_tensor (
3549 weight_node ,
@@ -52,17 +66,41 @@ def define_node(
5266 gather_input_tensors = [weight_tensor_wrapper , indices_tensor_wrapper ]
5367
5468 output_tensor = self .get_tensor (node , node )
69+ node_name = node .name
70+ if is_pcq_embedding :
71+ node_quant_attrs = node .meta [QCOM_QUANT_ATTRS ].copy ()
72+ intermediate_quant_attrs = node .meta [QCOM_QUANT_ATTRS ].copy ()
73+ # Based on QNN HTP quantization constraints,
74+ # we should set the scale to max of scales and per-tensor quantization for embedding op
75+ intermediate_quant_attrs [QCOM_SCALE ] = (
76+ weight_node .meta [QCOM_QUANT_ATTRS ][QCOM_SCALES ].max ().item ()
77+ )
78+ intermediate_quant_attrs [QCOM_ZERO_POINT ] = (
79+ weight_node .meta [QCOM_QUANT_ATTRS ][QCOM_ZERO_POINTS ].max ().item ()
80+ )
81+ intermediate_quant_attrs [QCOM_DTYPE ] = weight_node .meta [QCOM_QUANT_ATTRS ][
82+ QCOM_DTYPE
83+ ]
84+ intermediate_quant_attrs [QCOM_QUANT_MAX ] = weight_node .meta [
85+ QCOM_QUANT_ATTRS
86+ ][QCOM_QUANT_MAX ]
87+ intermediate_quant_attrs [QCOM_QUANT_MIN ] = weight_node .meta [
88+ QCOM_QUANT_ATTRS
89+ ][QCOM_QUANT_MIN ]
90+ node .meta [QCOM_QUANT_ATTRS ] = intermediate_quant_attrs
91+ node_name += "_intermediate"
5592 output_tensor_wrapper = self .define_tensor (
5693 node ,
5794 node ,
5895 output_tensor ,
5996 PyQnnManager .Qnn_TensorType_t .QNN_TENSOR_TYPE_NATIVE ,
6097 nodes_to_wrappers ,
98+ node_name = node_name ,
6199 )
62100 gather_output_tensors = [output_tensor_wrapper ]
63101
64102 gather_op = PyQnnManager .PyQnnOpWrapper (
65- node . name ,
103+ node_name ,
66104 QNN_OP_PACKAGE_NAME_QTI_AISW ,
67105 OpGather .op_name ,
68106 )
@@ -76,4 +114,36 @@ def define_node(
76114 {QCOM_DATA : np .int32 (0 )},
77115 )
78116
79- return gather_op
117+ op_wrapper_list = [gather_op ]
118+
119+ if is_pcq_embedding :
120+ node .meta [QCOM_QUANT_ATTRS ] = node_quant_attrs
121+ act_quant_encoding , act_quant_configs = self .get_quant_encoding_conf (
122+ node , node
123+ )
124+ act_dtype = (
125+ torch .uint16
126+ if act_quant_configs [QCOM_DTYPE ] == torch .int32
127+ else act_quant_configs [QCOM_DTYPE ]
128+ )
129+ convert_tensor_wrapper = self .define_custom_tensor_wrapper (
130+ node_name = node .name ,
131+ tensor_type = PyQnnManager .Qnn_TensorType_t .QNN_TENSOR_TYPE_NATIVE ,
132+ dtype = QNN_QUANT_TYPE_MAP [act_dtype ],
133+ quant_encoding = act_quant_encoding ,
134+ quant_configs = act_quant_configs ,
135+ dims = output_tensor .size (),
136+ tensor = output_tensor ,
137+ is_fake_tensor = True ,
138+ nodes_to_wrappers = nodes_to_wrappers ,
139+ )
140+ convert_op = PyQnnManager .PyQnnOpWrapper (
141+ node .name + "_convert" ,
142+ QNN_OP_PACKAGE_NAME_QTI_AISW ,
143+ OpConvert .op_name ,
144+ )
145+ convert_op .AddInputTensors (gather_output_tensors )
146+ convert_op .AddOutputTensors ([convert_tensor_wrapper ])
147+ op_wrapper_list .append (convert_op )
148+
149+ return op_wrapper_list
0 commit comments