Skip to content

Commit 405d85f

Browse files
committed
NXP backend: Add option to select symmetric quantization range
1 parent dddf85a commit 405d85f

2 files changed

Lines changed: 15 additions & 6 deletions

File tree

backends/nxp/quantizer/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,8 @@ def calibrate_and_quantize(
219219

220220
m = convert_pt2e(m)
221221

222-
m = QuantizeFusedConvBnBiasAtenPass(default_zero_bias=True)(m).graph_module
222+
m = QuantizeFusedConvBnBiasAtenPass(default_zero_bias=False, symmetric_quant=True)(
223+
m
224+
).graph_module
223225

224226
return m

backends/transforms/quantize_fused_convbn_bias_pass.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def _quantize_fused_conv_bias(
171171
set_param,
172172
get_weight_scale_tensor,
173173
default_zero_bias=False,
174+
use_symmetric_quantization=False,
174175
):
175176
"""Core logic for quantizing biases introduced by BatchNorm fusion/QAT.
176177
@@ -188,6 +189,7 @@ def _quantize_fused_conv_bias(
188189
set_param: Callable(node_or_name, tensor, insert_before=None) -> Node.
189190
get_weight_scale_tensor: Callable(node) -> Tensor.
190191
default_zero_bias: If True, create zero bias for conv nodes without bias.
192+
use_symmetric_quantization: If True, uses symmetric quantization range.
191193
192194
Returns:
193195
True if any modifications were made.
@@ -236,6 +238,7 @@ def _quantize_fused_conv_bias(
236238
else torch.empty(bias.shape, dtype=torch.float32)
237239
)
238240

241+
quant_min = -(2**31) + 1 if use_symmetric_quantization else -(2**31)
239242
if isinstance(weight_dequant.args[1], torch.fx.node.Node):
240243
weight_scale = get_weight_scale_tensor(weight_dequant.args[1])
241244
bias_scale = input_dequant.args[1] * weight_scale
@@ -246,7 +249,7 @@ def _quantize_fused_conv_bias(
246249
bias_scale,
247250
bias_zp,
248251
0,
249-
-(2**31),
252+
quant_min,
250253
2**31 - 1,
251254
torch.int32,
252255
)
@@ -267,7 +270,7 @@ def _quantize_fused_conv_bias(
267270
scale_node,
268271
zp_node,
269272
0,
270-
-(2**31),
273+
quant_min,
271274
2**31 - 1,
272275
torch.int32,
273276
),
@@ -279,14 +282,14 @@ def _quantize_fused_conv_bias(
279282
bias_scale = input_dequant.args[1] * weight_scale
280283

281284
qbias = torch.ops.quantized_decomposed.quantize_per_tensor.default(
282-
bias, bias_scale, 0, -(2**31), 2**31 - 1, torch.int32
285+
bias, bias_scale, 0, quant_min, 2**31 - 1, torch.int32
283286
)
284287
set_param(bias_node, qbias)
285288

286289
with graph_module.graph.inserting_before(node):
287290
bias_dequant = graph_module.graph.call_function(
288291
dq_per_tensor,
289-
(bias_node, bias_scale, 0, -(2**31), 2**31 - 1, torch.int32),
292+
(bias_node, bias_scale, 0, quant_min, 2**31 - 1, torch.int32),
290293
)
291294
bias_dequant.meta["val"] = dequant_val
292295
node.replace_input_with(bias_node, bias_dequant)
@@ -306,9 +309,12 @@ class QuantizeFusedConvBnBiasAtenPass(PassBase):
306309
exported_program can be omitted.
307310
"""
308311

309-
def __init__(self, exported_program=None, default_zero_bias=False) -> None:
312+
def __init__(
313+
self, exported_program=None, default_zero_bias=False, symmetric_quant=False
314+
) -> None:
310315
self.exported_program = exported_program
311316
self.default_zero_bias = default_zero_bias
317+
self.symmetric_quantization = symmetric_quant
312318

313319
def call(self, graph_module: fx.GraphModule) -> PassResult:
314320
ep = self.exported_program
@@ -351,5 +357,6 @@ def get_scale(node):
351357
set_param=set_param,
352358
get_weight_scale_tensor=get_scale,
353359
default_zero_bias=self.default_zero_bias,
360+
use_symmetric_quantization=self.symmetric_quantization,
354361
)
355362
return PassResult(graph_module, modified)

0 commit comments

Comments
 (0)