@@ -532,45 +532,61 @@ def _get_input_nodes_with_float_output(self, node: Node) -> list[Node]:
532532 def _get_user_nodes_with_float_input (self , node : Node ) -> list [Node ]:
533533 return [n for n in node .users .keys () if has_float_output (node )]
534534
535+ def _skip_shared_qspec_from_io (self , node : Node , qspec : QuantizationSpec ) -> bool :
536+ return node .op in ("placeholder" , "output" ) and qspec .dtype == torch .uint8
537+
538+ def _maybe_enqueue_shared_node (
539+ self , neighbor : Node , shared_nodes : set [Node ], bfs_queue : list [Node ]
540+ ) -> None :
541+ if neighbor .target in self .targets and neighbor not in shared_nodes :
542+ if not self ._is_annotated (neighbor ):
543+ bfs_queue .append (neighbor )
544+
545+ def _append_output_qspec (self , node : Node , adjacent_qspecs : list [Any ]) -> None :
546+ if not self ._is_annotated (node ):
547+ return
548+ output_qspec = node .meta .get ( # type: ignore[union-attr]
549+ Q_ANNOTATION_KEY
550+ ).output_qspec
551+ if output_qspec is None :
552+ return
553+ if self ._skip_shared_qspec_from_io (node , output_qspec ):
554+ return
555+ adjacent_qspecs .append (output_qspec )
556+
557+ def _append_input_qspec (
558+ self , user_node : Node , input_node : Node , adjacent_qspecs : list [Any ]
559+ ) -> None :
560+ if not self ._is_annotated (user_node ):
561+ return
562+ qspec_map = user_node .meta .get (Q_ANNOTATION_KEY )
563+ if qspec_map is None :
564+ return
565+ if input_node not in qspec_map .input_qspec_map :
566+ return
567+ input_qspec = qspec_map .input_qspec_map [input_node ]
568+ if input_qspec is None :
569+ return
570+ if self ._skip_shared_qspec_from_io (user_node , input_qspec ):
571+ return
572+ adjacent_qspecs .append (input_qspec )
573+
535574 def _get_shared_clique (self , root_node : Node ) -> tuple [set [Node ], list [Any ]]:
536575 shared_nodes = set ()
537576 bfs_queue = [root_node ]
538- adjacent_qspecs = []
577+ adjacent_qspecs : list [ Any ] = []
539578
540579 while bfs_queue :
541580 node = bfs_queue .pop (0 )
542581 shared_nodes .add (node )
543582
544583 for input_node in node .all_input_nodes :
545- if input_node .target in self .targets and input_node not in shared_nodes :
546- if not self ._is_annotated (input_node ):
547- bfs_queue .append (input_node )
548- if self ._is_annotated (input_node ):
549- output_qspec = input_node .meta .get ( # type: ignore[union-attr]
550- Q_ANNOTATION_KEY
551- ).output_qspec
552- if output_qspec is not None :
553- adjacent_qspecs .append (output_qspec )
584+ self ._maybe_enqueue_shared_node (input_node , shared_nodes , bfs_queue )
585+ self ._append_output_qspec (input_node , adjacent_qspecs )
554586
555587 for output_node in node .users .keys ():
556- if (
557- output_node .target in self .targets
558- and output_node not in shared_nodes
559- ):
560- if not self ._is_annotated (output_node ):
561- bfs_queue .append (output_node )
562- if (
563- self ._is_annotated (output_node )
564- and node
565- in output_node .meta .get ( # type: ignore[union-attr]
566- Q_ANNOTATION_KEY
567- ).input_qspec_map
568- ):
569- input_qspec = output_node .meta .get ( # type: ignore[union-attr]
570- Q_ANNOTATION_KEY
571- ).input_qspec_map [node ]
572- if input_qspec is not None :
573- adjacent_qspecs .append (input_qspec )
588+ self ._maybe_enqueue_shared_node (output_node , shared_nodes , bfs_queue )
589+ self ._append_input_qspec (output_node , node , adjacent_qspecs )
574590
575591 return shared_nodes , adjacent_qspecs
576592
0 commit comments