@@ -1259,8 +1259,35 @@ def get_producer_nodes(model: onnx.ModelProto, tensor_name: str) -> list[onnx.No
12591259 return [n for n in model .graph .node if tensor_name in n .output ]
12601260
12611261
1262- def _get_tensor_type_by_name (model : onnx .ModelProto , tensor_name : str ):
1263- """Get the tensor element type. Searches value_info, initializers, inputs, and outputs."""
1262+ def _build_tensor_type_map (model : onnx .ModelProto ) -> dict [str , int ]:
1263+ """Build an O(1) name-to-element-type lookup from all graph tensors."""
1264+ type_map : dict [str , int ] = {}
1265+ for vi in model .graph .value_info :
1266+ type_map [vi .name ] = vi .type .tensor_type .elem_type
1267+ for init in model .graph .initializer :
1268+ type_map [init .name ] = init .data_type
1269+ for inp in model .graph .input :
1270+ type_map [inp .name ] = inp .type .tensor_type .elem_type
1271+ for out in model .graph .output :
1272+ type_map [out .name ] = out .type .tensor_type .elem_type
1273+ return type_map
1274+
1275+
1276+ def _get_tensor_type_by_name (
1277+ model : onnx .ModelProto , tensor_name : str , type_map : dict [str , int ] | None = None
1278+ ):
1279+ """Get the tensor element type. Searches value_info, initializers, inputs, and outputs.
1280+
1281+ Args:
1282+ model: The ONNX model (used as fallback when type_map is not provided).
1283+ tensor_name: Name of the tensor to look up.
1284+ type_map: Pre-built lookup from _build_tensor_type_map for O(1) access.
1285+ When called in a loop, pass this to avoid repeated linear scans.
1286+ """
1287+ if type_map is not None :
1288+ if tensor_name in type_map :
1289+ return type_map [tensor_name ]
1290+ raise Exception (f"did not find tensor { tensor_name } " )
12641291 for vi in model .graph .value_info :
12651292 if vi .name == tensor_name :
12661293 return vi .type .tensor_type .elem_type
@@ -1286,11 +1313,13 @@ def _replace_tensor_name(
12861313 consumer .input [idx ] = new_tensor_name
12871314
12881315
1289- def _is_same_type_cast (model : onnx .ModelProto , node : onnx .NodeProto ) -> bool :
1316+ def _is_same_type_cast (
1317+ model : onnx .ModelProto , node : onnx .NodeProto , type_map : dict [str , int ] | None = None
1318+ ) -> bool :
12901319 assert node .op_type == "Cast"
1291- input_types = [_get_tensor_type_by_name (model , inp ) for inp in node .input ]
1320+ input_types = [_get_tensor_type_by_name (model , inp , type_map ) for inp in node .input ]
12921321 output_type = get_cast_to_type (node )
1293- return all (inp_type == output_type for inp_type in input_types ) and input_types is not None
1322+ return bool ( input_types ) and all (inp_type == output_type for inp_type in input_types )
12941323
12951324
12961325def _is_sequential_cast (model : onnx .ModelProto , node : onnx .NodeProto ) -> bool :
@@ -1407,11 +1436,12 @@ def remove_redundant_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
14071436 Returns:
14081437 onnx.ModelProto: Model with redundant casts removed.
14091438 """
1439+ type_map = _build_tensor_type_map (onnx_model )
14101440 nodes_to_remove = []
14111441 for node in onnx_model .graph .node :
14121442 if node .op_type == "Cast" :
14131443 # Find cast nodes that don't change precision
1414- if _is_same_type_cast (onnx_model , node ):
1444+ if _is_same_type_cast (onnx_model , node , type_map ):
14151445 nodes_to_remove .append (node )
14161446 _bypass_cast_node (onnx_model , node )
14171447 logger .debug (f"Found redundant same-type cast: { node .name } " )
0 commit comments