Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions backends/nxp/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,10 +543,17 @@ def get_anchors(

# If the following node is a fusable activation, quantize together with activation
output = [(conv_node,)]
if len(
conv_node.users
) == 1 and self.neutron_target_info.is_supported_fused_activation__aten(
activation := next(iter(conv_node.users))
if len(conv_node.users) == 1 and (
self.neutron_target_info.is_supported_fused_activation__aten(
activation := next(iter(conv_node.users))
)
or (
self.is_qat
and _is_batch_norm(activation)
and self.neutron_target_info.is_supported_fused_activation__aten(
activation := next(iter(activation.users))
)
)
Comment thread
StrycekSimon marked this conversation as resolved.
):
activation_quantizer = self.neutron_quantizer.op_to_quantizer[
activation.target
Expand All @@ -555,6 +562,14 @@ def get_anchors(
output = []
activation.meta["quantization_annotation"].input_qspec_map = {}

if isinstance(bn := next(iter(conv_node.users)), Node) and _is_batch_norm(
bn
):
bn_quantizer = self.neutron_quantizer.op_to_quantizer[bn.target]
bn_quantizer.annotate(gm)
bn.meta["quantization_annotation"].input_qspec_map = {}
bn.meta["quantization_annotation"].output_qspec = None
Comment thread
StrycekSimon marked this conversation as resolved.

# In order for QAT to be numerically correct, there should be no quantization between
# convolution node and batch norm node.
if self.is_qat:
Expand Down
5 changes: 4 additions & 1 deletion backends/nxp/quantizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ def calibrate_and_quantize(

m = convert_pt2e(m)

m = QuantizeFusedConvBnBiasAtenPass(default_zero_bias=True)(m).graph_module
if is_qat:
m = QuantizeFusedConvBnBiasAtenPass(
default_zero_bias=False, symmetric_quant=True
)(m).graph_module
Comment thread
StrycekSimon marked this conversation as resolved.

return m
2 changes: 1 addition & 1 deletion backends/nxp/tests/generic_tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_conv_fc_softmax__to_executorch_program(use_qat):
delegation_info = get_delegation_info(program.graph_module)
assert delegation_info.num_delegated_subgraphs == 1
assert delegation_info.num_non_delegated_nodes == 11
assert delegation_info.num_delegated_nodes == 14
assert delegation_info.num_delegated_nodes == 13

for node in program.graph.nodes:
# Make sure Convolution and AddMM are delegated
Expand Down
14 changes: 6 additions & 8 deletions backends/nxp/tests/generic_tests/test_qdq_clustering_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,16 @@ def test_conv2d_partitioner():
lowered_module = edge_program.exported_program().graph_module.lowered_module_0
nodes = list(lowered_module.original_module.graph.nodes)

assert len(nodes) == 13
assert len(nodes) == 9

q_x_node = nodes[6]
dq_w_node = nodes[7]
dq_x_node = nodes[8]
dq_bias_node = nodes[9]
conv_node = nodes[10]
q_y_node = nodes[11]
q_x_node = nodes[3]
dq_x_node = nodes[4]
dq_w_node = nodes[5]
conv_node = nodes[6]
q_y_node = nodes[7]

assert "cluster" not in q_x_node.meta
assert dq_w_node.meta["cluster"] == "aten_convolution_default_cluster"
assert dq_x_node.meta["cluster"] == "aten_convolution_default_cluster"
assert dq_bias_node.meta["cluster"] == "aten_convolution_default_cluster"
assert conv_node.meta["cluster"] == "aten_convolution_default_cluster"
assert q_y_node.meta["cluster"] == "aten_convolution_default_cluster"
17 changes: 12 additions & 5 deletions backends/transforms/quantize_fused_convbn_bias_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def _quantize_fused_conv_bias(
set_param,
get_weight_scale_tensor,
default_zero_bias=False,
use_symmetric_quantization=False,
):
"""Core logic for quantizing biases introduced by BatchNorm fusion/QAT.

Expand All @@ -188,6 +189,7 @@ def _quantize_fused_conv_bias(
set_param: Callable(node_or_name, tensor, insert_before=None) -> Node.
get_weight_scale_tensor: Callable(node) -> Tensor.
default_zero_bias: If True, create zero bias for conv nodes without bias.
use_symmetric_quantization: If True, uses symmetric quantization range.

Returns:
True if any modifications were made.
Expand Down Expand Up @@ -236,6 +238,7 @@ def _quantize_fused_conv_bias(
else torch.empty(bias.shape, dtype=torch.float32)
)

quant_min = -(2**31) + 1 if use_symmetric_quantization else -(2**31)
if isinstance(weight_dequant.args[1], torch.fx.node.Node):
weight_scale = get_weight_scale_tensor(weight_dequant.args[1])
bias_scale = input_dequant.args[1] * weight_scale
Expand All @@ -246,7 +249,7 @@ def _quantize_fused_conv_bias(
bias_scale,
bias_zp,
0,
-(2**31),
quant_min,
2**31 - 1,
torch.int32,
)
Expand All @@ -267,7 +270,7 @@ def _quantize_fused_conv_bias(
scale_node,
zp_node,
0,
-(2**31),
quant_min,
2**31 - 1,
torch.int32,
),
Expand All @@ -279,14 +282,14 @@ def _quantize_fused_conv_bias(
bias_scale = input_dequant.args[1] * weight_scale

qbias = torch.ops.quantized_decomposed.quantize_per_tensor.default(
bias, bias_scale, 0, -(2**31), 2**31 - 1, torch.int32
bias, bias_scale, 0, quant_min, 2**31 - 1, torch.int32
)
set_param(bias_node, qbias)

with graph_module.graph.inserting_before(node):
bias_dequant = graph_module.graph.call_function(
dq_per_tensor,
(bias_node, bias_scale, 0, -(2**31), 2**31 - 1, torch.int32),
(bias_node, bias_scale, 0, quant_min, 2**31 - 1, torch.int32),
)
bias_dequant.meta["val"] = dequant_val
node.replace_input_with(bias_node, bias_dequant)
Expand All @@ -306,9 +309,12 @@ class QuantizeFusedConvBnBiasAtenPass(PassBase):
exported_program can be omitted.
"""

def __init__(self, exported_program=None, default_zero_bias=False) -> None:
def __init__(
self, exported_program=None, default_zero_bias=False, symmetric_quant=False
) -> None:
self.exported_program = exported_program
self.default_zero_bias = default_zero_bias
self.symmetric_quantization = symmetric_quant

def call(self, graph_module: fx.GraphModule) -> PassResult:
ep = self.exported_program
Expand Down Expand Up @@ -351,5 +357,6 @@ def get_scale(node):
set_param=set_param,
get_weight_scale_tensor=get_scale,
default_zero_bias=self.default_zero_bias,
use_symmetric_quantization=self.symmetric_quantization,
Comment thread
StrycekSimon marked this conversation as resolved.
)
return PassResult(graph_module, modified)
Loading