@@ -914,51 +914,12 @@ def _get_name(node: onnx.NodeProto | InputIndexTracker) -> str:
914914
915915 return self ._convert_initializer_data (init , from_type , to_type )
916916
917- def _replace_tensor_name (
918- self , consumers : list [onnx .NodeProto ], original_tensor_name : str , new_tensor_name : str
919- ) -> None :
920- """Replace occurrences of a tensor name in the given consumers' inputs with a new tensor name."""
921- for consumer in consumers :
922- for idx , inp in enumerate (consumer .input ):
923- if inp == original_tensor_name :
924- consumer .input [idx ] = new_tensor_name
925-
926- def _bypass_cast_node (self , node : onnx .NodeProto ) -> None :
927- # handling only a single input and output, as we only remove cast nodes
928- assert len (node .input ) == 1
929- assert len (node .output ) == 1
930-
931- input_tensor = node .input [0 ]
932- output_tensor = node .output [0 ]
933-
934- # Check if the cast output is also a graph output
935- is_output_producer = any (output .name == output_tensor for output in self .model .graph .output )
936-
937- # If the removed cast node is producing a network output, update the producer of the cast input so
938- # the network output name is preserved.
939- if is_output_producer :
940- producers = utils .get_producer_nodes (self .model , input_tensor )
941- for producer in producers :
942- for i , prod_out in enumerate (producer .output ):
943- if prod_out == input_tensor :
944- producer .output [i ] = output_tensor
945- consumers = utils .get_consumer_nodes (self .model , prod_out )
946- if len (consumers ) > 1 :
947- self ._replace_tensor_name (consumers , prod_out , output_tensor )
948- else :
949- # Reconnect consumers of the cast output to use the cast input instead
950- consumers = utils .get_consumer_nodes (self .model , output_tensor )
951- for consumer in consumers :
952- for i , input_name in enumerate (consumer .input ):
953- if input_name == output_tensor :
954- consumer .input [i ] = input_tensor
955-
956917 def _remove_preexisting_casts (self ) -> None :
957918 nodes_to_remove = []
958919 for node in self .model .graph .node :
959920 if node .op_type == "Cast" :
960- cast_from_type = self . _get_tensor_type ( node .input [0 ])
961- cast_to_type = utils .get_cast_to_type (node )
921+ cast_from_type = onnx_utils . _get_tensor_type_by_name ( self . model , node .input [0 ])
922+ cast_to_type = onnx_utils .get_cast_to_type (node )
962923 is_fp_cast = cast_to_type in [
963924 onnx .TensorProto .FLOAT16 ,
964925 onnx .TensorProto .FLOAT ,
@@ -978,7 +939,7 @@ def _remove_preexisting_casts(self) -> None:
978939 ):
979940 continue
980941 nodes_to_remove .append (node )
981- self ._bypass_cast_node (node )
942+ onnx_utils ._bypass_cast_node (self . model , node )
982943 logger .debug (f"Removing { len (nodes_to_remove )} pre-existing casts" )
983944
984945 for node in nodes_to_remove :
@@ -1044,7 +1005,7 @@ def _add_cast(
10441005 )
10451006
10461007 if tensor_to_consumers is None :
1047- consumer_nodes = utils .get_consumer_nodes (self .model , tensor_name )
1008+ consumer_nodes = onnx_utils .get_consumer_nodes (self .model , tensor_name )
10481009 else :
10491010 consumer_nodes = tensor_to_consumers .get (tensor_name , [])
10501011 consumer_nodes = [n for n in consumer_nodes if n .name not in exclude_consumers ]
@@ -1067,7 +1028,7 @@ def _add_cast(
10671028
10681029 # Find producer node to insert cast after it
10691030 if tensor_to_producers is None :
1070- producer_nodes = utils .get_producer_nodes (self .model , tensor_name )
1031+ producer_nodes = onnx_utils .get_producer_nodes (self .model , tensor_name )
10711032 else :
10721033 producer_nodes = tensor_to_producers .get (tensor_name , [])
10731034 if producer_nodes :
@@ -1106,7 +1067,7 @@ def _cleanup_no_consumer_nodes(self):
11061067 node
11071068 for node in self .model .graph .node
11081069 if not any (
1109- out in network_outputs or utils .get_consumer_nodes (self .model , out )
1070+ out in network_outputs or onnx_utils .get_consumer_nodes (self .model , out )
11101071 for out in node .output
11111072 )
11121073 ]
@@ -1124,29 +1085,23 @@ def _cleanup_pre_output_same_type_cast(self):
11241085
11251086 for output in self .model .graph .output :
11261087 if "_cast_to_" in output .name :
1127- out_producer_nodes = utils .get_producer_nodes (self .model , output .name )
1088+ out_producer_nodes = onnx_utils .get_producer_nodes (self .model , output .name )
11281089 if len (out_producer_nodes ) == 1 and out_producer_nodes [0 ].op_type == "Cast" :
11291090 second_cast_node = out_producer_nodes [0 ]
1130- cast_producer_nodes = utils .get_producer_nodes (
1091+ cast_producer_nodes = onnx_utils .get_producer_nodes (
11311092 self .model , second_cast_node .input [0 ]
11321093 )
11331094 if len (cast_producer_nodes ) == 1 and cast_producer_nodes [0 ].op_type == "Cast" :
11341095 first_cast_node = cast_producer_nodes [0 ]
11351096 if (
1136- self ._is_same_type_cast (first_cast_node )
1137- and utils .get_cast_to_type (second_cast_node )
1097+ onnx_utils ._is_same_type_cast (self . model , first_cast_node )
1098+ and onnx_utils .get_cast_to_type (second_cast_node )
11381099 == self .high_precision_type .onnx_type
11391100 ):
11401101 logger .debug (f"Removing pre-output double cast: { first_cast_node .name } " )
1141- self ._bypass_cast_node (first_cast_node )
1102+ onnx_utils ._bypass_cast_node (self . model , first_cast_node )
11421103 self .model .graph .node .remove (first_cast_node )
11431104
1144- def _is_same_type_cast (self , node : onnx .NodeProto ) -> bool :
1145- assert node .op_type == "Cast"
1146- input_types = [self ._get_tensor_type (inp ) for inp in node .input ]
1147- output_type = utils .get_cast_to_type (node )
1148- return all (inp_type == output_type for inp_type in input_types ) and input_types is not None
1149-
11501105 def _remove_redundant_casts (self ):
11511106 """Removes both sequential casts and casts that don't change precision.
11521107
@@ -1176,7 +1131,7 @@ def _fix_network_output_names(self):
11761131 for output in self .model .graph .output :
11771132 if "_cast_to_" in output .name :
11781133 post_cast_name = output .name
1179- producer_nodes = utils .get_producer_nodes (self .model , output .name )
1134+ producer_nodes = onnx_utils .get_producer_nodes (self .model , output .name )
11801135 if (
11811136 len (producer_nodes ) == 1
11821137 and producer_nodes [0 ].op_type == "Cast"
@@ -1188,21 +1143,23 @@ def _fix_network_output_names(self):
11881143 pre_cast_name = original_name + "_pre_cast"
11891144 output .name = original_name
11901145 # Update all consumers of the original (pre-cast) output to use the pre-cast name
1191- for node in utils .get_consumer_nodes (self .model , original_name ):
1146+ for node in onnx_utils .get_consumer_nodes (self .model , original_name ):
11921147 if node == cast_node :
11931148 continue
11941149 for i , input_name in enumerate (node .input ):
11951150 if input_name == original_name :
11961151 node .input [i ] = pre_cast_name
11971152 # do not break, can use the same tensor for multiple node inputs
11981153 # Update all consumers of the post-cast output to use the original name
1199- for node in utils .get_consumer_nodes (self .model , post_cast_name ):
1154+ for node in onnx_utils .get_consumer_nodes (self .model , post_cast_name ):
12001155 for i , input_name in enumerate (node .input ):
12011156 if input_name == post_cast_name :
12021157 node .input [i ] = original_name
12031158 # do not break, can use the same tensor for multiple node inputs
12041159 # Update all producers of the original output to use the original name
1205- cast_producer_nodes = utils .get_producer_nodes (self .model , cast_node .input [0 ])
1160+ cast_producer_nodes = onnx_utils .get_producer_nodes (
1161+ self .model , cast_node .input [0 ]
1162+ )
12061163 for node in cast_producer_nodes :
12071164 for i , node_output in enumerate (node .output ):
12081165 if node_output == original_name :
@@ -1248,7 +1205,7 @@ def _sanity_check(self):
12481205
12491206 # Verify that the output tensors are not disconnected
12501207 for output in network_outputs :
1251- producer_nodes = utils .get_producer_nodes (self .model , output .name )
1208+ producer_nodes = onnx_utils .get_producer_nodes (self .model , output .name )
12521209 if len (producer_nodes ) == 0 :
12531210 logger .warning (
12541211 f"Output tensor { output .name } is disconnected. This may be benign if it's part of a cast operation "
@@ -1296,13 +1253,6 @@ def _sanity_check(self):
12961253 if not sanity_ok :
12971254 raise Exception ("Sanity Check Failed" )
12981255
1299- def _get_tensor_type (self , tensor_name ):
1300- if tensor_name in self .value_info_map :
1301- return self .value_info_map [tensor_name ].type .tensor_type .elem_type
1302- if tensor_name in self .initializer_map :
1303- return self .initializer_map [tensor_name ].data_type
1304- raise Exception (f"did not find tensor { tensor_name } " )
1305-
13061256 def _sanitize_model (self ):
13071257 graph_sanitizer = GraphSanitizer (
13081258 self .model ,
0 commit comments