Skip to content

Commit 9e3a35a

Browse files
committed
Cache model information for faster lookup
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent e223a2e commit 9e3a35a

1 file changed

Lines changed: 36 additions & 6 deletions

File tree

modelopt/onnx/utils.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,8 +1259,35 @@ def get_producer_nodes(model: onnx.ModelProto, tensor_name: str) -> list[onnx.No
12591259
return [n for n in model.graph.node if tensor_name in n.output]
12601260

12611261

1262-
def _get_tensor_type_by_name(model: onnx.ModelProto, tensor_name: str):
1263-
"""Get the tensor element type. Searches value_info, initializers, inputs, and outputs."""
1262+
def _build_tensor_type_map(model: onnx.ModelProto) -> dict[str, int]:
1263+
"""Build an O(1) name-to-element-type lookup from all graph tensors."""
1264+
type_map: dict[str, int] = {}
1265+
for vi in model.graph.value_info:
1266+
type_map[vi.name] = vi.type.tensor_type.elem_type
1267+
for init in model.graph.initializer:
1268+
type_map[init.name] = init.data_type
1269+
for inp in model.graph.input:
1270+
type_map[inp.name] = inp.type.tensor_type.elem_type
1271+
for out in model.graph.output:
1272+
type_map[out.name] = out.type.tensor_type.elem_type
1273+
return type_map
1274+
1275+
1276+
def _get_tensor_type_by_name(
1277+
model: onnx.ModelProto, tensor_name: str, type_map: dict[str, int] | None = None
1278+
):
1279+
"""Get the tensor element type. Searches value_info, initializers, inputs, and outputs.
1280+
1281+
Args:
1282+
model: The ONNX model (used as fallback when type_map is not provided).
1283+
tensor_name: Name of the tensor to look up.
1284+
type_map: Pre-built lookup from _build_tensor_type_map for O(1) access.
1285+
When called in a loop, pass this to avoid repeated linear scans.
1286+
"""
1287+
if type_map is not None:
1288+
if tensor_name in type_map:
1289+
return type_map[tensor_name]
1290+
raise Exception(f"did not find tensor {tensor_name}")
12641291
for vi in model.graph.value_info:
12651292
if vi.name == tensor_name:
12661293
return vi.type.tensor_type.elem_type
@@ -1286,11 +1313,13 @@ def _replace_tensor_name(
12861313
consumer.input[idx] = new_tensor_name
12871314

12881315

1289-
def _is_same_type_cast(model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
1316+
def _is_same_type_cast(
1317+
model: onnx.ModelProto, node: onnx.NodeProto, type_map: dict[str, int] | None = None
1318+
) -> bool:
12901319
assert node.op_type == "Cast"
1291-
input_types = [_get_tensor_type_by_name(model, inp) for inp in node.input]
1320+
input_types = [_get_tensor_type_by_name(model, inp, type_map) for inp in node.input]
12921321
output_type = get_cast_to_type(node)
1293-
return all(inp_type == output_type for inp_type in input_types) and input_types is not None
1322+
return bool(input_types) and all(inp_type == output_type for inp_type in input_types)
12941323

12951324

12961325
def _is_sequential_cast(model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
@@ -1407,11 +1436,12 @@ def remove_redundant_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
14071436
Returns:
14081437
onnx.ModelProto: Model with redundant casts removed.
14091438
"""
1439+
type_map = _build_tensor_type_map(onnx_model)
14101440
nodes_to_remove = []
14111441
for node in onnx_model.graph.node:
14121442
if node.op_type == "Cast":
14131443
# Find cast nodes that don't change precision
1414-
if _is_same_type_cast(onnx_model, node):
1444+
if _is_same_type_cast(onnx_model, node, type_map):
14151445
nodes_to_remove.append(node)
14161446
_bypass_cast_node(onnx_model, node)
14171447
logger.debug(f"Found redundant same-type cast: {node.name}")

0 commit comments

Comments
 (0)