File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -543,10 +543,17 @@ def get_anchors(
543543
544544 # If the following node is a fusable activation, quantize together with activation
545545 output = [(conv_node ,)]
546- if len (
547- conv_node .users
548- ) == 1 and self .neutron_target_info .is_supported_fused_activation__aten (
549- activation := next (iter (conv_node .users ))
546+ if len (conv_node .users ) == 1 and (
547+ self .neutron_target_info .is_supported_fused_activation__aten (
548+ activation := next (iter (conv_node .users ))
549+ )
550+ or (
551+ self .is_qat
552+ and _is_batch_norm (activation )
553+ and self .neutron_target_info .is_supported_fused_activation__aten (
554+ activation := next (iter (activation .users ))
555+ )
556+ )
550557 ):
551558 activation_quantizer = self .neutron_quantizer .op_to_quantizer [
552559 activation .target
@@ -555,6 +562,14 @@ def get_anchors(
555562 output = []
556563 activation .meta ["quantization_annotation" ].input_qspec_map = {}
557564
565+ if isinstance (bn := next (iter (conv_node .users )), Node ) and _is_batch_norm (
566+ bn
567+ ):
568+ bn_quantizer = self .neutron_quantizer .op_to_quantizer [bn .target ]
569+ bn_quantizer .annotate (gm )
570+ bn .meta ["quantization_annotation" ].input_qspec_map = {}
571+ bn .meta ["quantization_annotation" ].output_qspec = None
572+
558573 # In order for QAT to be numerically correct, there should be no quantization between
559574 # convolution node and batch norm node.
560575 if self .is_qat :
You can’t perform that action at this time.
0 commit comments