@@ -171,6 +171,7 @@ def _quantize_fused_conv_bias(
171171 set_param ,
172172 get_weight_scale_tensor ,
173173 default_zero_bias = False ,
174+ use_symmetric_quantization = False ,
174175):
175176 """Core logic for quantizing biases introduced by BatchNorm fusion/QAT.
176177
@@ -188,6 +189,7 @@ def _quantize_fused_conv_bias(
188189 set_param: Callable(node_or_name, tensor, insert_before=None) -> Node.
189190 get_weight_scale_tensor: Callable(node) -> Tensor.
190191 default_zero_bias: If True, create zero bias for conv nodes without bias.
192+ use_symmetric_quantization: If True, uses symmetric quantization range.
191193
192194 Returns:
193195 True if any modifications were made.
@@ -236,6 +238,7 @@ def _quantize_fused_conv_bias(
236238 else torch .empty (bias .shape , dtype = torch .float32 )
237239 )
238240
241+ quant_min = - (2 ** 31 ) + 1 if use_symmetric_quantization else - (2 ** 31 )
239242 if isinstance (weight_dequant .args [1 ], torch .fx .node .Node ):
240243 weight_scale = get_weight_scale_tensor (weight_dequant .args [1 ])
241244 bias_scale = input_dequant .args [1 ] * weight_scale
@@ -246,7 +249,7 @@ def _quantize_fused_conv_bias(
246249 bias_scale ,
247250 bias_zp ,
248251 0 ,
249- - ( 2 ** 31 ) ,
252+ quant_min ,
250253 2 ** 31 - 1 ,
251254 torch .int32 ,
252255 )
@@ -267,7 +270,7 @@ def _quantize_fused_conv_bias(
267270 scale_node ,
268271 zp_node ,
269272 0 ,
270- - ( 2 ** 31 ) ,
273+ quant_min ,
271274 2 ** 31 - 1 ,
272275 torch .int32 ,
273276 ),
@@ -279,14 +282,14 @@ def _quantize_fused_conv_bias(
279282 bias_scale = input_dequant .args [1 ] * weight_scale
280283
281284 qbias = torch .ops .quantized_decomposed .quantize_per_tensor .default (
282- bias , bias_scale , 0 , - ( 2 ** 31 ) , 2 ** 31 - 1 , torch .int32
285+ bias , bias_scale , 0 , quant_min , 2 ** 31 - 1 , torch .int32
283286 )
284287 set_param (bias_node , qbias )
285288
286289 with graph_module .graph .inserting_before (node ):
287290 bias_dequant = graph_module .graph .call_function (
288291 dq_per_tensor ,
289- (bias_node , bias_scale , 0 , - ( 2 ** 31 ) , 2 ** 31 - 1 , torch .int32 ),
292+ (bias_node , bias_scale , 0 , quant_min , 2 ** 31 - 1 , torch .int32 ),
290293 )
291294 bias_dequant .meta ["val" ] = dequant_val
292295 node .replace_input_with (bias_node , bias_dequant )
@@ -306,9 +309,12 @@ class QuantizeFusedConvBnBiasAtenPass(PassBase):
306309 exported_program can be omitted.
307310 """
308311
309- def __init__ (self , exported_program = None , default_zero_bias = False ) -> None :
312+ def __init__ (
313+ self , exported_program = None , default_zero_bias = False , symmetric_quant = False
314+ ) -> None :
310315 self .exported_program = exported_program
311316 self .default_zero_bias = default_zero_bias
317+ self .symmetric_quantization = symmetric_quant
312318
313319 def call (self , graph_module : fx .GraphModule ) -> PassResult :
314320 ep = self .exported_program
@@ -351,5 +357,6 @@ def get_scale(node):
351357 set_param = set_param ,
352358 get_weight_scale_tensor = get_scale ,
353359 default_zero_bias = self .default_zero_bias ,
360+ use_symmetric_quantization = self .symmetric_quantization ,
354361 )
355362 return PassResult (graph_module , modified )
0 commit comments