44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7- import copy
87from typing import Any , Dict , Tuple
98
109import executorch .backends .qualcomm .python .PyQnnManagerAdaptor as PyQnnManager
@@ -151,8 +150,12 @@ def _get_tensor(node, index):
151150 def make_qnn_per_block_config (self , node : torch .fx .Node , quant_attrs : Dict ):
152151 import math
153152
154- quant_config = copy .deepcopy (quant_attrs )
155- scales , scale_offset , quantized_scales = quant_attrs [QCOM_SCALE ], [], []
153+ quant_config = {
154+ QCOM_DTYPE : quant_attrs [QCOM_DTYPE ],
155+ QCOM_QUANT_MIN : quant_attrs [QCOM_QUANT_MIN ],
156+ QCOM_QUANT_MAX : quant_attrs [QCOM_QUANT_MAX ],
157+ }
158+ scales = quant_attrs [QCOM_SCALE ]
156159 # channel in observers defaults to zero
157160 num_channels = node .meta ["val" ].shape [0 ]
158161 user_0 = self .get_first_user (node )
@@ -170,17 +173,23 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict):
170173 PyQnnManager .Qnn_BlockwiseExpansionBlockScaleStorageType_t .QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8
171174 )
172175
176+ scale_offset_arr = np .empty (
177+ num_channels , dtype = [("scale" , np .float32 ), ("offset" , np .int32 )]
178+ )
179+ # move channel axis to dim 0 for transpose_conv case
180+ candidates = scales if ch_axis == 0 else scales .transpose (0 , 1 )
181+ candidates = candidates .reshape (num_channels , - 1 )
182+ # find max scale per channel
183+ max_scales = candidates .amax (dim = - 1 ) / num_steps
184+ # quantize scales per channel
185+ q_scales = torch .clamp (
186+ input = torch .round (input = candidates / max_scales .unsqueeze (- 1 )),
187+ min = 1 ,
188+ max = 2 ** bitwidth_of_scale ,
189+ ).to (quant_scales_dtype )
190+ # symmetric quantization is required
173191 for ch in range (num_channels ):
174- candidates = scales [ch ] if ch_axis == 0 else scales [:, ch , ...]
175- max_scale = candidates .reshape (1 , - 1 ).amax (dim = - 1 ) / num_steps
176- q_scales = torch .clamp (
177- input = torch .round (input = candidates / max_scale ),
178- min = 1 ,
179- max = 2 ** bitwidth_of_scale ,
180- ).to (quant_scales_dtype )
181- quantized_scales .append (q_scales )
182- # symmetric quantization is required
183- scale_offset .append (PyQnnManager .Qnn_ScaleOffset_t (max_scale , 0 ))
192+ scale_offset_arr [ch ] = (float (max_scales [ch ]), 0 )
184193
185194 # skip dequantize op, e.g. frozen_param -> dq -> conv2d
186195 user_0 = self .get_first_user (node )
@@ -195,9 +204,9 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict):
195204 else :
196205 raise AttributeError ("undetermined axis for block quantization" )
197206
198- quant_config [QCOM_NUM_BLOCKS_PER_AXIS ] = quantized_scales [ 0 ] .shape . numel ()
199- quant_config [QCOM_BLOCK_SCALE_OFFSET ] = scale_offset
200- quant_config [QCOM_BLOCK_SCALES ] = torch . cat ( quantized_scales ).detach ().numpy ()
207+ quant_config [QCOM_NUM_BLOCKS_PER_AXIS ] = q_scales .shape [ 1 ]
208+ quant_config [QCOM_BLOCK_SCALE_OFFSET ] = scale_offset_arr
209+ quant_config [QCOM_BLOCK_SCALES ] = q_scales . flatten ( ).detach ().numpy ()
201210 # e.g. if use 16 bit for quantized scales, we need to expand 16 - 4 = 12 bits
202211 quant_config [QCOM_BLOCK_SCALE_BITWIDTH ] = (
203212 int (math .log2 (torch .iinfo (quant_scales_dtype ).max + 1 )) - bitwidth_of_scale
@@ -209,20 +218,23 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict):
209218 )
210219
211220 def make_qnn_per_channel_config (self , node : torch .fx .Node , quant_attrs : Dict ):
212- quant_config = copy .deepcopy (quant_attrs )
221+ quant_config = {
222+ QCOM_DTYPE : quant_attrs [QCOM_DTYPE ],
223+ QCOM_QUANT_MAX : quant_attrs [QCOM_QUANT_MAX ],
224+ QCOM_QUANT_MIN : quant_attrs [QCOM_QUANT_MIN ],
225+ }
213226
214227 scales = quant_attrs [QCOM_SCALES ]
215228 zero_points = quant_attrs [QCOM_ZERO_POINTS ]
216229 assert len (scales ) == len (
217230 zero_points
218231 ), f"Per channel encoding of node { node } , has different size for scales { len (scales )} and zero_points { len (zero_points )} "
219232
220- scale_offset = []
233+ scale_offset_arr = np .empty (
234+ len (scales ), dtype = [("scale" , np .float32 ), ("offset" , np .int32 )]
235+ )
221236 for i in range (len (scales )):
222- # check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h
223- scale_offset .append (
224- PyQnnManager .Qnn_ScaleOffset_t (scales [i ], - zero_points [i ])
225- )
237+ scale_offset_arr [i ] = (float (scales [i ]), int (- zero_points [i ]))
226238
227239 # skip dequantize op, e.g. frozen_param -> dq -> conv2d
228240 user_0 = self .get_first_user (node )
@@ -234,7 +246,7 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
234246 else :
235247 quant_config [QCOM_AXIS ] = quant_attrs [QCOM_AXIS ]
236248
237- quant_config [QCOM_SCALE_OFFSET ] = scale_offset
249+ quant_config [QCOM_SCALE_OFFSET ] = scale_offset_arr
238250 # special case for 4 bits
239251 if (
240252 quant_config [QCOM_DTYPE ] == torch .int8
@@ -251,7 +263,12 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
251263 )
252264
253265 def make_qnn_per_tensor_config (self , quant_attrs : Dict ):
254- quant_config = copy .deepcopy (quant_attrs )
266+ quant_config = {
267+ QCOM_DTYPE : quant_attrs [QCOM_DTYPE ],
268+ QCOM_SCALE : quant_attrs [QCOM_SCALE ],
269+ QCOM_QUANT_MAX : quant_attrs [QCOM_QUANT_MAX ],
270+ QCOM_QUANT_MIN : quant_attrs [QCOM_QUANT_MIN ],
271+ }
255272 # check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h
256273 quant_config [QCOM_OFFSET ] = - quant_attrs [QCOM_ZERO_POINT ]
257274 # special case for 4 bits
0 commit comments