55
66
77import logging
8+ import operator
89from typing import List , Optional
910
1011import torch
@@ -250,23 +251,61 @@ class SharedQspecQuantizer(Quantizer, QuantizerReporterUser):
250251 torch .ops .aten .clone .default ,
251252 torch .ops .aten .lift_fresh_copy .default ,
252253 torch .ops .aten .detach_ .default ,
254+ torch .ops .aten .alias .default ,
255+ torch .ops .aten .alias_copy .default ,
256+ torch .ops .aten .copy_ .default ,
257+ torch .ops .aten .detach_copy .default ,
258+ torch .ops .aten .unfold_copy .default ,
259+ torch .ops .aten .unbind .int ,
253260 # Min/Max/Mean
254261 torch .ops .aten .minimum .default ,
255262 torch .ops .aten .maximum .default ,
263+ torch .ops .aten .min .dim ,
264+ torch .ops .aten .max .dim ,
265+ torch .ops .aten .amin .default ,
266+ torch .ops .aten .amax .default ,
256267 # Data shuffling
257268 torch .ops .aten .permute .default ,
258269 torch .ops .aten .permute_copy .default ,
259- torch .ops .aten .transpose .Dimname ,
260270 torch .ops .aten .transpose .int ,
261271 torch .ops .aten .transpose_copy .int ,
262272 torch .ops .aten .t_copy .default ,
263273 torch .ops .aten .t .default ,
274+ torch .ops .aten .repeat .default ,
275+ torch .ops .aten .repeat_interleave .self_int ,
276+ torch .ops .aten .expand_copy .default ,
277+ torch .ops .aten .expand .default ,
278+ torch .ops .aten .select .int ,
279+ torch .ops .aten .select_copy .int ,
280+ torch .ops .aten .slice .Tensor ,
281+ torch .ops .aten .slice_copy .Tensor ,
282+ torch .ops .aten .split .Tensor ,
283+ torch .ops .aten .split_with_sizes .default ,
284+ torch .ops .aten .split_copy .Tensor ,
285+ torch .ops .aten .tile .default ,
286+ torch .ops .aten .flip .default ,
287+ torch .ops .aten .index_select .default ,
288+ torch .ops .aten .index_put .default ,
289+ torch .ops .aten .contiguous .default ,
290+ torch .ops .aten .as_strided_copy .default ,
291+ torch .ops .aten .pixel_shuffle .default ,
292+ torch .ops .aten .pixel_unshuffle .default ,
293+ torch .ops .aten .cat .default ,
294+ torch .ops .aten .concatenate .default ,
295+ torch .ops .aten .stack .default ,
296+ torch .ops .aten .dropout .default ,
297+ torch .ops .aten .dropout_ .default ,
298+ torch .ops .aten .chunk .default ,
299+ torch .ops .aten .index .Tensor ,
300+ torch .ops .aten .gather .default ,
301+ operator .getitem ,
264302 # Change shape
265303 torch .ops .aten .squeeze .default ,
266304 torch .ops .aten .squeeze_copy .default ,
267305 torch .ops .aten .squeeze_copy .dim ,
268306 torch .ops .aten .squeeze .dim ,
269307 torch .ops .aten .squeeze .dims ,
308+ torch .ops .aten .squeeze_ .dim ,
270309 torch .ops .aten .unsqueeze .default ,
271310 torch .ops .aten .unsqueeze_copy .default ,
272311 torch .ops .aten .reshape .default ,
@@ -279,22 +318,50 @@ class SharedQspecQuantizer(Quantizer, QuantizerReporterUser):
279318 # Padding
280319 torch .ops .aten .pad .default ,
281320 torch .ops .aten .constant_pad_nd .default ,
321+ # Ativation functions
322+ torch .ops .aten .clamp .default ,
323+ torch .ops .aten .clamp .Tensor ,
324+ torch .ops .aten .hardtanh .default ,
325+ torch .ops .aten .hardtanh_ .default ,
326+ torch .ops .aten .relu .default ,
327+ torch .ops .aten .relu_ .default ,
328+ # Logic ops
329+ torch .ops .aten .eq .Tensor ,
330+ torch .ops .aten .eq .Scalar ,
331+ torch .ops .aten .ne .Tensor ,
332+ torch .ops .aten .ne .Scalar ,
333+ torch .ops .aten .ge .Tensor ,
334+ torch .ops .aten .ge .Scalar ,
335+ torch .ops .aten .gt .Tensor ,
336+ torch .ops .aten .gt .Scalar ,
337+ torch .ops .aten .le .Tensor ,
338+ torch .ops .aten .le .Scalar ,
339+ torch .ops .aten .lt .Tensor ,
340+ torch .ops .aten .lt .Scalar ,
341+ torch .ops .aten .where .self ,
342+ torch .ops .aten .where .default ,
343+ torch .ops .higher_order .while_loop ,
344+ torch .ops .higher_order .cond ,
282345 ]
283346
284347 def __init__ (self , targets : Optional [List [OpOverload ]] = None ) -> None :
285348 super ().__init__ ()
286349 if targets is None :
287350 self .targets = self .SHARED_QSPEC_OPS_DEFAULT
351+ self .support_config_path = (
352+ __name__ + f".{ self .__class__ .__name__ } .SHARED_QSPEC_OPS_DEFAULT"
353+ )
288354 else :
289355 self .targets = targets
356+ self .support_config_path = (
357+ f"CUSTOM TARGETS: { ', ' .join ([str (target ) for target in targets ])} "
358+ )
290359
291360 def get_quantizer_info (self ):
292361 name = self .__class__ .__name__
293362 targeted_nodes_description = ""
294363 quantization_config_path = "SHARED_QCONFIG"
295- support_config_path = (
296- __name__ + f".{ self .__class__ .__name__ } .SHARED_QSPEC_OPS_DEFAULT"
297- )
364+ support_config_path = self .support_config_path
298365 return QuantizerInfo (
299366 name ,
300367 targeted_nodes_description ,
@@ -319,35 +386,38 @@ def _get_shared_clique(self, root_node: Node) -> set[Node]:
319386 """
320387 shared_nodes = set ()
321388 bfs_queue = [root_node ]
322- adjacent_qspecs = set ()
389+ adjacent_qspecs = []
323390
324391 while bfs_queue :
325392 node = bfs_queue .pop (0 )
326393 shared_nodes .add (node )
327394
328395 # Neighbours may either be other shared nodes, annotated nodes, or non-annotated (float) nodes.
329- for input_node in self . _get_input_nodes_with_float_output ( node ) :
396+ for input_node in node . all_input_nodes :
330397 if input_node .target in self .targets and input_node not in shared_nodes :
331398 if not self ._is_annotated (input_node ):
332399 bfs_queue .append (input_node )
333400 if self ._is_annotated (input_node ):
334- output_qspec = input_node .meta .get (
335- Q_ANNOTATION_KEY , None
336- ).output_qspec
337- adjacent_qspecs .add (output_qspec )
401+ output_qspec = input_node .meta .get (Q_ANNOTATION_KEY ).output_qspec
402+ if output_qspec is not None :
403+ adjacent_qspecs .append (output_qspec )
338404
339- for output_node in self . _get_user_nodes_with_float_input ( node ):
405+ for output_node in node . users . keys ( ):
340406 if (
341407 output_node .target in self .targets
342408 and output_node not in shared_nodes
343409 ):
344410 if not self ._is_annotated (output_node ):
345411 bfs_queue .append (output_node )
346- if self ._is_annotated (output_node ):
412+ if (
413+ self ._is_annotated (output_node )
414+ and node in output_node .meta .get (Q_ANNOTATION_KEY ).input_qspec_map
415+ ):
347416 input_qspec = output_node .meta .get (
348- Q_ANNOTATION_KEY , None
417+ Q_ANNOTATION_KEY
349418 ).input_qspec_map [node ]
350- adjacent_qspecs .add (input_qspec )
419+ if input_qspec is not None :
420+ adjacent_qspecs .append (input_qspec )
351421
352422 return shared_nodes , adjacent_qspecs
353423
@@ -357,6 +427,21 @@ def _annotate_shared_cluster(self, root_node: Node) -> None:
357427 SharedQuantizationSpec.
358428 """
359429
430+ if (
431+ len (self ._get_input_nodes_with_float_output (root_node )) == 0
432+ and len (self ._get_user_nodes_with_float_input (root_node )) == 0
433+ ):
434+ self .report_reject (
435+ [root_node ],
436+ "No float inputs nor outputs to annotate" ,
437+ )
438+ mark_node_as_annotated (
439+ root_node ,
440+ {},
441+ None ,
442+ )
443+ return
444+
360445 shared_nodes , adjacent_qspecs = self ._get_shared_clique (root_node )
361446 node_order = {node : index for index , node in enumerate (root_node .graph .nodes )}
362447 ordered_nodes = sorted (shared_nodes , key = lambda node : node_order .get (node , 0 ))
@@ -369,10 +454,21 @@ def _annotate_shared_cluster(self, root_node: Node) -> None:
369454 # This means that we need to make sure that the root node of the shared_qspec
370455 # has an input node with a quantization spec, so that an observer is created.
371456
372- if len (adjacent_qspecs ) == 1 :
373- root_node_first_input = self ._get_input_nodes_with_float_output (root_node )[
374- 0
375- ]
457+ if len (adjacent_qspecs ) > 0 :
458+ # Warn if multiple different adjacent qspecs are found.
459+ if len (adjacent_qspecs ) > 1 :
460+ logger .warning (
461+ f"Multiple adjacent quantization specs found for { ', ' .join ([n .name for n in ordered_nodes ])} , all nodes will share the input quantization spec of { root_node .name } ."
462+ )
463+
464+ root_node_float_inputs = self ._get_input_nodes_with_float_output (root_node )
465+ if len (root_node_float_inputs ) == 0 :
466+ self .report_reject (
467+ ordered_nodes ,
468+ "Couldn't find any floating point input to base shared quantization spec on." ,
469+ )
470+ return
471+ root_node_first_input = root_node_float_inputs [0 ]
376472
377473 # Make all nodes share qspec with the root node's first input
378474 shared_qspec = SharedQuantizationSpec ((root_node_first_input , root_node ))
@@ -386,25 +482,21 @@ def _annotate_shared_cluster(self, root_node: Node) -> None:
386482 else :
387483 output_qspec = shared_qspec
388484 mark_node_as_annotated (
389- node , input_qspec_map , output_qspec , self .reporter , self
485+ node ,
486+ input_qspec_map ,
487+ output_qspec ,
390488 )
391489
392490 # Force the root qspec to be the adjacent spec
393- root_node .meta [Q_ANNOTATION_KEY ].input_qspec_map [
394- root_node_first_input
395- ] = adjacent_qspecs . pop ( )
491+ root_node .meta [Q_ANNOTATION_KEY ].input_qspec_map [root_node_first_input ] = (
492+ adjacent_qspecs [ 0 ]
493+ )
396494 self .report_accept (ordered_nodes )
397495
398- elif len (adjacent_qspecs ) == 0 :
399- self .report_reject (
400- ordered_nodes ,
401- "Couldn't find any adjacent quantization spec to base shared quantization spec on." ,
402- )
403- return
404496 else :
405497 self .report_reject (
406498 ordered_nodes ,
407- "Found multiple adjacent quantization specs to base shared quantization spec on." ,
499+ "Couldn't find any adjacent quantization spec to base shared quantization spec on. You may however quantize these nodes manually if required ." ,
408500 )
409501 return
410502
0 commit comments