@@ -1147,42 +1147,13 @@ def _is_same_type_cast(self, node: onnx.NodeProto) -> bool:
11471147 output_type = utils .get_cast_to_type (node )
11481148 return all (inp_type == output_type for inp_type in input_types ) and input_types is not None
11491149
1150- def _is_sequential_cast (self , node : onnx .NodeProto ) -> bool :
1151- assert node .op_type == "Cast"
1152- output_type = utils .get_cast_to_type (node )
1153-
1154- # Cast to high precision -> cast to low precision, first cast has no impact and can be safely removed
1155- # Cast to low precision -> cast to high precision affects precision and should not be removed
1156- precision_order = [
1157- TensorProto .DOUBLE ,
1158- TensorProto .FLOAT ,
1159- TensorProto .FLOAT16 ,
1160- TensorProto .BFLOAT16 ,
1161- ]
1162- consumers = [
1163- n for n in utils .get_consumer_nodes (self .model , node .output [0 ]) if n .op_type == "Cast"
1164- ]
1165-
1166- # If the first cast has additional consumers, we should not remove it
1167- if len (consumers ) != 1 :
1168- return False
1169-
1170- next_node = consumers [0 ]
1171- first_cast_type = output_type
1172- second_cast_type = utils .get_cast_to_type (next_node )
1173-
1174- return (
1175- first_cast_type in precision_order
1176- and second_cast_type in precision_order
1177- and precision_order .index (first_cast_type ) <= precision_order .index (second_cast_type )
1178- )
1179-
11801150 def _remove_redundant_casts (self ):
11811151 """Removes both sequential casts and casts that don't change precision.
11821152
11831153 This method optimizes the graph by removing unnecessary cast operations that either:
11841154 1. Don't actually change the data type
11851155 2. Could be replaced by a single cast operation
1156+ 3. Can be folded into a preceding Constant node
11861157 """
11871158 if self .custom_ops :
11881159 self .model = self ._propagate_types_shapes_custom_ops (self .model )
@@ -1198,35 +1169,7 @@ def _remove_redundant_casts(self):
11981169 check_type = True ,
11991170 )
12001171
1201- nodes_to_remove = []
1202- for node in self .model .graph .node :
1203- if node .op_type == "Cast" :
1204- # Find cast nodes that don't change precision
1205- if self ._is_same_type_cast (node ):
1206- nodes_to_remove .append (node )
1207- self ._bypass_cast_node (node )
1208- logger .debug (f"Found redundant same-type cast: { node .name } " )
1209- continue
1210-
1211- # Find sequential casts that don't change precision
1212- if self ._is_sequential_cast (node ):
1213- nodes_to_remove .append (node )
1214- self ._bypass_cast_node (node )
1215- logger .debug (f"Found removable double-cast: { node .name } " )
1216-
1217- # Find foldable Constant -> Cast. Initializers are handled by _convert_initializers.
1218- if self ._is_foldable_constant_cast_pattern (node ):
1219- nodes_to_remove .append (node )
1220- cast_producers = utils .get_producer_nodes (self .model , node .input [0 ])
1221- assert len (cast_producers ) == 1 and cast_producers [0 ].op_type == "Constant"
1222- constant_producer = cast_producers [0 ]
1223- self ._convert_constant_values (constant_producer , node )
1224- self ._bypass_cast_node (node )
1225- logger .debug (f"Found foldable Constant->Cast pattern, removing { node .name } " )
1226-
1227- logger .debug (f"Removing redundant casts: { [n .name for n in nodes_to_remove ]} " )
1228- for node in nodes_to_remove :
1229- self .model .graph .node .remove (node )
1172+ self .model = onnx_utils .remove_redundant_casts (self .model )
12301173
12311174 def _fix_network_output_names (self ):
12321175 modified = False
@@ -1360,80 +1303,6 @@ def _get_tensor_type(self, tensor_name):
13601303 return self .initializer_map [tensor_name ].data_type
13611304 raise Exception (f"did not find tensor { tensor_name } " )
13621305
1363- def _convert_constant_values (self , const_node , cast_node : onnx .NodeProto ) -> None :
1364- original_tensor = const_node .attribute [0 ].t
1365- if original_tensor .data_type == onnx .TensorProto .BFLOAT16 :
1366- original_data = onnx_utils .read_f16_tensor_as_fp32 (original_tensor )
1367- else :
1368- original_data = onnx .numpy_helper .to_array (original_tensor )
1369-
1370- # Precompute casted value
1371- cast_to_type = utils .get_cast_to_type (cast_node )
1372- cast_dtype = onnx .helper .tensor_dtype_to_np_dtype (cast_to_type )
1373-
1374- # Handle bfloat16 conversion manually since numpy doesn't support it natively
1375- if cast_to_type == onnx .TensorProto .BFLOAT16 :
1376- casted_data = original_data .astype (ml_dtypes .bfloat16 )
1377- else :
1378- casted_data = original_data .astype (cast_dtype )
1379-
1380- # Create a new constant node with casted data
1381- if cast_to_type == onnx .TensorProto .BFLOAT16 :
1382- # Create TensorProto manually for bfloat16
1383- tensor_proto = onnx .TensorProto ()
1384- tensor_proto .name = const_node .output [0 ]
1385- tensor_proto .data_type = onnx .TensorProto .BFLOAT16
1386- tensor_proto .dims .extend (casted_data .shape )
1387- # Convert bfloat16 to raw bytes
1388- bf16_bytes = casted_data .astype (ml_dtypes .bfloat16 ).view (np .uint16 )
1389- tensor_proto .raw_data = bf16_bytes .tobytes ()
1390- else :
1391- # Create tensor manually to ensure proper handling
1392- tensor_proto = onnx .numpy_helper .from_array (casted_data )
1393- tensor_proto .name = const_node .output [0 ]
1394-
1395- new_const_node = onnx .helper .make_node (
1396- "Constant" ,
1397- inputs = [],
1398- outputs = const_node .output ,
1399- value = tensor_proto ,
1400- name = const_node .name ,
1401- )
1402-
1403- # Replace the original constant node with the new constant node
1404- # The scope of this function is to convert the constant node data. Removing the cast is done later.
1405- for node in utils .get_consumer_nodes (self .model , const_node .name ):
1406- for i , input_name in enumerate (node .input ):
1407- if input_name == const_node .name :
1408- node .input [i ] = new_const_node .output [0 ]
1409- break
1410-
1411- const_idx = - 1
1412- for i , node in enumerate (self .model .graph .node ):
1413- if node == const_node :
1414- const_idx = i
1415- break
1416-
1417- self .model .graph .node .remove (const_node )
1418- self .model .graph .node .insert (const_idx , new_const_node )
1419- # The Cast node is the sole consumer of the Constant node, guaranteed by _is_foldable_constant_cast_pattern
1420- cast_node .input [0 ] = new_const_node .output [0 ]
1421-
1422- def _is_foldable_constant_cast_pattern (self , node : onnx .NodeProto ) -> bool :
1423- """Constant -> Cast and Cast is the only consumer of the Constant node."""
1424- assert node .op_type == "Cast"
1425-
1426- producer = utils .get_producer_nodes (self .model , node .input [0 ])
1427-
1428- const_producer = (
1429- producer [0 ] if len (producer ) == 1 and producer [0 ].op_type == "Constant" else None
1430- )
1431-
1432- if const_producer :
1433- get_consumer_nodes = utils .get_consumer_nodes (self .model , const_producer .output [0 ])
1434- return len (get_consumer_nodes ) == 1 and get_consumer_nodes [0 ] == node
1435- return False
1436-
14371306 def _sanitize_model (self ):
14381307 graph_sanitizer = GraphSanitizer (
14391308 self .model ,
0 commit comments