diff --git a/backends/nxp/quantizer/patterns.py b/backends/nxp/quantizer/patterns.py index 60afa6bf4d2..8be1ad345da 100644 --- a/backends/nxp/quantizer/patterns.py +++ b/backends/nxp/quantizer/patterns.py @@ -543,10 +543,17 @@ def get_anchors( # If the following node is a fusable activation, quantize together with activation output = [(conv_node,)] - if len( - conv_node.users - ) == 1 and self.neutron_target_info.is_supported_fused_activation__aten( - activation := next(iter(conv_node.users)) + if len(conv_node.users) == 1 and ( + self.neutron_target_info.is_supported_fused_activation__aten( + activation := next(iter(conv_node.users)) + ) + or ( + self.is_qat + and _is_batch_norm(activation) + and self.neutron_target_info.is_supported_fused_activation__aten( + activation := next(iter(activation.users)) + ) + ) ): activation_quantizer = self.neutron_quantizer.op_to_quantizer[ activation.target @@ -555,6 +562,14 @@ def get_anchors( output = [] activation.meta["quantization_annotation"].input_qspec_map = {} + if isinstance(bn := next(iter(conv_node.users)), Node) and _is_batch_norm( + bn + ): + bn_quantizer = self.neutron_quantizer.op_to_quantizer[bn.target] + bn_quantizer.annotate(gm) + bn.meta["quantization_annotation"].input_qspec_map = {} + bn.meta["quantization_annotation"].output_qspec = None + # In order for QAT to be numerically correct, there should be no quantization between # convolution node and batch norm node. if self.is_qat: diff --git a/backends/nxp/quantizer/utils.py b/backends/nxp/quantizer/utils.py index cd403868a96..df9b71a76f7 100644 --- a/backends/nxp/quantizer/utils.py +++ b/backends/nxp/quantizer/utils.py @@ -219,6 +219,9 @@ def calibrate_and_quantize( m = convert_pt2e(m) - m = QuantizeFusedConvBnBiasAtenPass(default_zero_bias=True)(m).graph_module + if is_qat: + m = QuantizeFusedConvBnBiasAtenPass( + default_zero_bias=False, symmetric_quant=True + )(m).graph_module return m diff --git a/backends/nxp/tests/generic_tests/test_integration.py b/backends/nxp/tests/generic_tests/test_integration.py index e8d2a6faf26..fe157b44c48 100644 --- a/backends/nxp/tests/generic_tests/test_integration.py +++ b/backends/nxp/tests/generic_tests/test_integration.py @@ -29,7 +29,7 @@ def test_conv_fc_softmax__to_executorch_program(use_qat): delegation_info = get_delegation_info(program.graph_module) assert delegation_info.num_delegated_subgraphs == 1 assert delegation_info.num_non_delegated_nodes == 11 - assert delegation_info.num_delegated_nodes == 14 + assert delegation_info.num_delegated_nodes == 13 for node in program.graph.nodes: # Make sure Convolution and AddMM are delegated diff --git a/backends/nxp/tests/generic_tests/test_qdq_clustering_conv.py b/backends/nxp/tests/generic_tests/test_qdq_clustering_conv.py index 21e330a48ae..6db55347452 100644 --- a/backends/nxp/tests/generic_tests/test_qdq_clustering_conv.py +++ b/backends/nxp/tests/generic_tests/test_qdq_clustering_conv.py @@ -16,18 +16,16 @@ def test_conv2d_partitioner(): lowered_module = edge_program.exported_program().graph_module.lowered_module_0 nodes = list(lowered_module.original_module.graph.nodes) - assert len(nodes) == 13 + assert len(nodes) == 9 - q_x_node = nodes[6] - dq_w_node = nodes[7] - dq_x_node = nodes[8] - dq_bias_node = nodes[9] - conv_node = nodes[10] - q_y_node = nodes[11] + q_x_node = nodes[3] + dq_x_node = nodes[4] + dq_w_node = nodes[5] + conv_node = nodes[6] + q_y_node = nodes[7] assert "cluster" not in q_x_node.meta assert dq_w_node.meta["cluster"] == "aten_convolution_default_cluster" assert dq_x_node.meta["cluster"] == "aten_convolution_default_cluster" - assert dq_bias_node.meta["cluster"] == "aten_convolution_default_cluster" assert conv_node.meta["cluster"] == "aten_convolution_default_cluster" assert q_y_node.meta["cluster"] == "aten_convolution_default_cluster" diff --git a/backends/transforms/quantize_fused_convbn_bias_pass.py b/backends/transforms/quantize_fused_convbn_bias_pass.py index f1c599e05ba..f20e666fb01 100644 --- a/backends/transforms/quantize_fused_convbn_bias_pass.py +++ b/backends/transforms/quantize_fused_convbn_bias_pass.py @@ -171,6 +171,7 @@ def _quantize_fused_conv_bias( set_param, get_weight_scale_tensor, default_zero_bias=False, + use_symmetric_quantization=False, ): """Core logic for quantizing biases introduced by BatchNorm fusion/QAT. @@ -188,6 +189,7 @@ def _quantize_fused_conv_bias( set_param: Callable(node_or_name, tensor, insert_before=None) -> Node. get_weight_scale_tensor: Callable(node) -> Tensor. default_zero_bias: If True, create zero bias for conv nodes without bias. + use_symmetric_quantization: If True, uses symmetric quantization range. Returns: True if any modifications were made. @@ -236,6 +238,7 @@ def _quantize_fused_conv_bias( else torch.empty(bias.shape, dtype=torch.float32) ) + quant_min = -(2**31) + 1 if use_symmetric_quantization else -(2**31) if isinstance(weight_dequant.args[1], torch.fx.node.Node): weight_scale = get_weight_scale_tensor(weight_dequant.args[1]) bias_scale = input_dequant.args[1] * weight_scale @@ -246,7 +249,7 @@ def _quantize_fused_conv_bias( bias_scale, bias_zp, 0, - -(2**31), + quant_min, 2**31 - 1, torch.int32, ) @@ -267,7 +270,7 @@ def _quantize_fused_conv_bias( scale_node, zp_node, 0, - -(2**31), + quant_min, 2**31 - 1, torch.int32, ), @@ -279,14 +282,14 @@ def _quantize_fused_conv_bias( bias_scale = input_dequant.args[1] * weight_scale qbias = torch.ops.quantized_decomposed.quantize_per_tensor.default( - bias, bias_scale, 0, -(2**31), 2**31 - 1, torch.int32 + bias, bias_scale, 0, quant_min, 2**31 - 1, torch.int32 ) set_param(bias_node, qbias) with graph_module.graph.inserting_before(node): bias_dequant = graph_module.graph.call_function( dq_per_tensor, - (bias_node, bias_scale, 0, -(2**31), 2**31 - 1, torch.int32), + (bias_node, bias_scale, 0, quant_min, 2**31 - 1, torch.int32), ) bias_dequant.meta["val"] = dequant_val node.replace_input_with(bias_node, bias_dequant) @@ -306,9 +309,12 @@ class QuantizeFusedConvBnBiasAtenPass(PassBase): exported_program can be omitted. """ - def __init__(self, exported_program=None, default_zero_bias=False) -> None: + def __init__( + self, exported_program=None, default_zero_bias=False, symmetric_quant=False + ) -> None: self.exported_program = exported_program self.default_zero_bias = default_zero_bias + self.symmetric_quantization = symmetric_quant def call(self, graph_module: fx.GraphModule) -> PassResult: ep = self.exported_program @@ -351,5 +357,6 @@ def get_scale(node): set_param=set_param, get_weight_scale_tensor=get_scale, default_zero_bias=self.default_zero_bias, + use_symmetric_quantization=self.symmetric_quantization, ) return PassResult(graph_module, modified)