@@ -167,8 +167,8 @@ def _add_bias(
167167 weight_node : torch .fx .Node ,
168168 ) -> torch .fx .Node :
169169 output_channels = get_first_fake_tensor (node ).shape [1 ]
170- # add a node containging zeros if quantized, use int32, otherwise use float32
171- if "output_qparams" in node . meta and len (node . meta [ "output_qparams" ]) > 0 :
170+ # add a node containing zeros if quantized, use int32, otherwise use float32
171+ if self . _is_quantized_conv (node ) :
172172 bias_data = torch .zeros (size = (output_channels ,), dtype = torch .int32 )
173173 else :
174174 output_dtype = node .meta ["val" ].dtype
@@ -188,9 +188,40 @@ def _add_bias(
188188 node .update_arg (2 , bias_node )
189189 return bias_node
190190
191- def insert_output_rescale (self , graph_module , node ):
192- input_qparams = get_input_qparams (node )
193- output_qparams = get_output_qparams (node )[0 ]
191+ def _is_quantized_conv (self , node : torch .fx .Node ) -> bool :
192+ return bool (node .meta .get ("input_qparams" , {}))
193+
194+ def _get_effective_output_qparams (self , node : torch .fx .Node ):
195+ """Return the quantized output domain for a conv node.
196+
197+ Quantization annotation may place output qparams on a following
198+ activation instead of on the conv itself. If that activation is not
199+ fuseable, it survives as a quantized ``clamp`` and still owns the
200+ branch output qparams needed for the conv output rescale.
201+
202+ """
203+ output_qparams = node .meta .get ("output_qparams" , {})
204+ if output_qparams :
205+ return output_qparams
206+
207+ users = list (node .users )
208+ if len (users ) != 1 :
209+ raise ValueError (
210+ f"RewriteConvPass: No output quantization parameter found in node { node } \n "
211+ f"original_aten={ node .meta .get ('original_aten' , 'None' )} "
212+ )
213+
214+ activation = users [0 ]
215+ if activation .target == exir_ops .edge .aten .clamp .default :
216+ activation_output_qparams = activation .meta .get ("output_qparams" , {})
217+ if activation_output_qparams :
218+ return activation_output_qparams
219+
220+ return get_output_qparams (node )
221+
222+ def insert_output_rescale (self , graph_module , source_node , conv_node ):
223+ input_qparams = get_input_qparams (source_node )
224+ output_qparams = self ._get_effective_output_qparams (source_node )[0 ]
194225 weight_qparams = input_qparams [1 ]
195226 input_qparams = input_qparams [0 ]
196227 is_per_channel = weight_qparams .per_channel
@@ -207,18 +238,18 @@ def insert_output_rescale(self, graph_module, node):
207238 itertools .cycle ([output_qparams .get_scale_per_tensor ()]),
208239 )
209240 ]
210- with graph_module .graph .inserting_after (node ):
241+ with graph_module .graph .inserting_after (conv_node ):
211242 rescale_node = create_node (
212243 graph = graph_module .graph ,
213244 op_target = exir_ops .backend .tosa .RESCALE .default ,
214245 args = (
215- node ,
246+ conv_node ,
216247 output_qparams .dtype ,
217248 post_conv2d_scale ,
218249 0 ,
219250 output_qparams .get_zp_per_tensor (),
220251 ),
221- from_node = node ,
252+ from_node = source_node ,
222253 )
223254 return rescale_node
224255
@@ -347,15 +378,17 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
347378 tosa_node_fake_tensor .dtype == torch .int32
348379 and input_fake_tensor .dtype == torch .int8
349380 ):
350- output_rescale = self .insert_output_rescale (graph_module , tosa_op )
381+ output_rescale = self .insert_output_rescale (graph_module , node , tosa_op )
351382 node .replace_all_uses_with (output_rescale )
352383 elif (
353384 tosa_node_fake_tensor .dtype == torch .int32
354385 and input_fake_tensor .dtype == torch .int16
355386 ):
356387 has_bias = len (node .meta ["input_qparams" ]) > 2
357388 if not has_bias :
358- output_rescale = self .insert_output_rescale (graph_module , tosa_op )
389+ output_rescale = self .insert_output_rescale (
390+ graph_module , node , tosa_op
391+ )
359392 node .replace_all_uses_with (output_rescale )
360393 else :
361394 node .replace_all_uses_with (tosa_op )
0 commit comments