Skip to content

Commit 05c33b2

Browse files
committed
Restore logic for remove_redundant_casts
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent 1c6a3ba commit 05c33b2

5 files changed

Lines changed: 223 additions & 212 deletions

File tree

modelopt/onnx/autocast/graphsanitizer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def remove_disconnected_outputs(self) -> None:
130130
"""Remove disconnected outputs from the model."""
131131
tensors_to_remove = []
132132
for tensor in self.model.graph.output:
133-
if not utils.get_producer_nodes(self.model, tensor.name):
133+
if not onnx_utils.get_producer_nodes(self.model, tensor.name):
134134
tensors_to_remove.append(tensor)
135135
logger.debug(f"Found disconnected output: {tensor.name}")
136136

@@ -279,7 +279,7 @@ def _match_layernorm_pattern(self, mean_node: onnx.NodeProto) -> dict | None:
279279
# Find variance computation branch
280280
pow_nodes = [
281281
n
282-
for n in utils.get_consumer_nodes(self.model, sub_node.output[0])
282+
for n in onnx_utils.get_consumer_nodes(self.model, sub_node.output[0])
283283
if n.op_type == "Pow"
284284
]
285285
if len(pow_nodes) != 1:
@@ -303,8 +303,8 @@ def _match_layernorm_pattern(self, mean_node: onnx.NodeProto) -> dict | None:
303303

304304
# Find Div node
305305
# Find the Div node that consumes both sqrt and sub outputs
306-
sqrt_consumers = utils.get_consumer_nodes(self.model, sqrt_node.output[0])
307-
sub_consumers = utils.get_consumer_nodes(self.model, sub_node.output[0])
306+
sqrt_consumers = onnx_utils.get_consumer_nodes(self.model, sqrt_node.output[0])
307+
sub_consumers = onnx_utils.get_consumer_nodes(self.model, sub_node.output[0])
308308

309309
div_nodes = [n for n in sqrt_consumers if n in sub_consumers and n.op_type == "Div"]
310310
if len(div_nodes) != 1:
@@ -342,14 +342,14 @@ def _match_layernorm_pattern(self, mean_node: onnx.NodeProto) -> dict | None:
342342
div_node,
343343
]
344344

345-
consumers = utils.get_consumer_nodes(self.model, div_node.output[0])
345+
consumers = onnx_utils.get_consumer_nodes(self.model, div_node.output[0])
346346
if len(consumers) == 1 and consumers[0].op_type == "Mul":
347347
mul_node = consumers[0]
348348
scale = self._get_initializer_value(mul_node.input[1], return_array=True)
349349
final_node = mul_node
350350
nodes_to_remove.append(mul_node)
351351

352-
consumers = utils.get_consumer_nodes(self.model, mul_node.output[0])
352+
consumers = onnx_utils.get_consumer_nodes(self.model, mul_node.output[0])
353353
if len(consumers) == 1 and consumers[0].op_type == "Add":
354354
add_node = consumers[0]
355355
bias = self._get_initializer_value(add_node.input[1], return_array=True)
@@ -457,7 +457,7 @@ def _create_layernorm_node(self, pattern: dict) -> onnx.NodeProto:
457457

458458
def _find_insertion_point(self, input_name: str) -> int:
459459
"""Find the correct insertion point for the new LayerNorm node."""
460-
producer_nodes = utils.get_producer_nodes(self.model, input_name)
460+
producer_nodes = onnx_utils.get_producer_nodes(self.model, input_name)
461461
if not producer_nodes:
462462
return 0
463463

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 18 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -914,51 +914,12 @@ def _get_name(node: onnx.NodeProto | InputIndexTracker) -> str:
914914

915915
return self._convert_initializer_data(init, from_type, to_type)
916916

917-
def _replace_tensor_name(
918-
self, consumers: list[onnx.NodeProto], original_tensor_name: str, new_tensor_name: str
919-
) -> None:
920-
"""Replace occurrences of a tensor name in the given consumers' inputs with a new tensor name."""
921-
for consumer in consumers:
922-
for idx, inp in enumerate(consumer.input):
923-
if inp == original_tensor_name:
924-
consumer.input[idx] = new_tensor_name
925-
926-
def _bypass_cast_node(self, node: onnx.NodeProto) -> None:
927-
# handling only a single input and output, as we only remove cast nodes
928-
assert len(node.input) == 1
929-
assert len(node.output) == 1
930-
931-
input_tensor = node.input[0]
932-
output_tensor = node.output[0]
933-
934-
# Check if the cast output is also a graph output
935-
is_output_producer = any(output.name == output_tensor for output in self.model.graph.output)
936-
937-
# If the removed cast node is producing a network output, update the producer of the cast input so
938-
# the network output name is preserved.
939-
if is_output_producer:
940-
producers = utils.get_producer_nodes(self.model, input_tensor)
941-
for producer in producers:
942-
for i, prod_out in enumerate(producer.output):
943-
if prod_out == input_tensor:
944-
producer.output[i] = output_tensor
945-
consumers = utils.get_consumer_nodes(self.model, prod_out)
946-
if len(consumers) > 1:
947-
self._replace_tensor_name(consumers, prod_out, output_tensor)
948-
else:
949-
# Reconnect consumers of the cast output to use the cast input instead
950-
consumers = utils.get_consumer_nodes(self.model, output_tensor)
951-
for consumer in consumers:
952-
for i, input_name in enumerate(consumer.input):
953-
if input_name == output_tensor:
954-
consumer.input[i] = input_tensor
955-
956917
def _remove_preexisting_casts(self) -> None:
957918
nodes_to_remove = []
958919
for node in self.model.graph.node:
959920
if node.op_type == "Cast":
960-
cast_from_type = self._get_tensor_type(node.input[0])
961-
cast_to_type = utils.get_cast_to_type(node)
921+
cast_from_type = onnx_utils._get_tensor_type_by_name(self.model, node.input[0])
922+
cast_to_type = onnx_utils.get_cast_to_type(node)
962923
is_fp_cast = cast_to_type in [
963924
onnx.TensorProto.FLOAT16,
964925
onnx.TensorProto.FLOAT,
@@ -978,7 +939,7 @@ def _remove_preexisting_casts(self) -> None:
978939
):
979940
continue
980941
nodes_to_remove.append(node)
981-
self._bypass_cast_node(node)
942+
onnx_utils._bypass_cast_node(self.model, node)
982943
logger.debug(f"Removing {len(nodes_to_remove)} pre-existing casts")
983944

984945
for node in nodes_to_remove:
@@ -1044,7 +1005,7 @@ def _add_cast(
10441005
)
10451006

10461007
if tensor_to_consumers is None:
1047-
consumer_nodes = utils.get_consumer_nodes(self.model, tensor_name)
1008+
consumer_nodes = onnx_utils.get_consumer_nodes(self.model, tensor_name)
10481009
else:
10491010
consumer_nodes = tensor_to_consumers.get(tensor_name, [])
10501011
consumer_nodes = [n for n in consumer_nodes if n.name not in exclude_consumers]
@@ -1067,7 +1028,7 @@ def _add_cast(
10671028

10681029
# Find producer node to insert cast after it
10691030
if tensor_to_producers is None:
1070-
producer_nodes = utils.get_producer_nodes(self.model, tensor_name)
1031+
producer_nodes = onnx_utils.get_producer_nodes(self.model, tensor_name)
10711032
else:
10721033
producer_nodes = tensor_to_producers.get(tensor_name, [])
10731034
if producer_nodes:
@@ -1106,7 +1067,7 @@ def _cleanup_no_consumer_nodes(self):
11061067
node
11071068
for node in self.model.graph.node
11081069
if not any(
1109-
out in network_outputs or utils.get_consumer_nodes(self.model, out)
1070+
out in network_outputs or onnx_utils.get_consumer_nodes(self.model, out)
11101071
for out in node.output
11111072
)
11121073
]
@@ -1124,29 +1085,23 @@ def _cleanup_pre_output_same_type_cast(self):
11241085

11251086
for output in self.model.graph.output:
11261087
if "_cast_to_" in output.name:
1127-
out_producer_nodes = utils.get_producer_nodes(self.model, output.name)
1088+
out_producer_nodes = onnx_utils.get_producer_nodes(self.model, output.name)
11281089
if len(out_producer_nodes) == 1 and out_producer_nodes[0].op_type == "Cast":
11291090
second_cast_node = out_producer_nodes[0]
1130-
cast_producer_nodes = utils.get_producer_nodes(
1091+
cast_producer_nodes = onnx_utils.get_producer_nodes(
11311092
self.model, second_cast_node.input[0]
11321093
)
11331094
if len(cast_producer_nodes) == 1 and cast_producer_nodes[0].op_type == "Cast":
11341095
first_cast_node = cast_producer_nodes[0]
11351096
if (
1136-
self._is_same_type_cast(first_cast_node)
1137-
and utils.get_cast_to_type(second_cast_node)
1097+
onnx_utils._is_same_type_cast(self.model, first_cast_node)
1098+
and onnx_utils.get_cast_to_type(second_cast_node)
11381099
== self.high_precision_type.onnx_type
11391100
):
11401101
logger.debug(f"Removing pre-output double cast: {first_cast_node.name}")
1141-
self._bypass_cast_node(first_cast_node)
1102+
onnx_utils._bypass_cast_node(self.model, first_cast_node)
11421103
self.model.graph.node.remove(first_cast_node)
11431104

1144-
def _is_same_type_cast(self, node: onnx.NodeProto) -> bool:
1145-
assert node.op_type == "Cast"
1146-
input_types = [self._get_tensor_type(inp) for inp in node.input]
1147-
output_type = utils.get_cast_to_type(node)
1148-
return all(inp_type == output_type for inp_type in input_types) and input_types is not None
1149-
11501105
def _remove_redundant_casts(self):
11511106
"""Removes both sequential casts and casts that don't change precision.
11521107
@@ -1176,7 +1131,7 @@ def _fix_network_output_names(self):
11761131
for output in self.model.graph.output:
11771132
if "_cast_to_" in output.name:
11781133
post_cast_name = output.name
1179-
producer_nodes = utils.get_producer_nodes(self.model, output.name)
1134+
producer_nodes = onnx_utils.get_producer_nodes(self.model, output.name)
11801135
if (
11811136
len(producer_nodes) == 1
11821137
and producer_nodes[0].op_type == "Cast"
@@ -1188,21 +1143,23 @@ def _fix_network_output_names(self):
11881143
pre_cast_name = original_name + "_pre_cast"
11891144
output.name = original_name
11901145
# Update all consumers of the original (pre-cast) output to use the pre-cast name
1191-
for node in utils.get_consumer_nodes(self.model, original_name):
1146+
for node in onnx_utils.get_consumer_nodes(self.model, original_name):
11921147
if node == cast_node:
11931148
continue
11941149
for i, input_name in enumerate(node.input):
11951150
if input_name == original_name:
11961151
node.input[i] = pre_cast_name
11971152
# do not break, can use the same tensor for multiple node inputs
11981153
# Update all consumers of the post-cast output to use the original name
1199-
for node in utils.get_consumer_nodes(self.model, post_cast_name):
1154+
for node in onnx_utils.get_consumer_nodes(self.model, post_cast_name):
12001155
for i, input_name in enumerate(node.input):
12011156
if input_name == post_cast_name:
12021157
node.input[i] = original_name
12031158
# do not break, can use the same tensor for multiple node inputs
12041159
# Update all producers of the original output to use the original name
1205-
cast_producer_nodes = utils.get_producer_nodes(self.model, cast_node.input[0])
1160+
cast_producer_nodes = onnx_utils.get_producer_nodes(
1161+
self.model, cast_node.input[0]
1162+
)
12061163
for node in cast_producer_nodes:
12071164
for i, node_output in enumerate(node.output):
12081165
if node_output == original_name:
@@ -1248,7 +1205,7 @@ def _sanity_check(self):
12481205

12491206
# Verify that the output tensors are not disconnected
12501207
for output in network_outputs:
1251-
producer_nodes = utils.get_producer_nodes(self.model, output.name)
1208+
producer_nodes = onnx_utils.get_producer_nodes(self.model, output.name)
12521209
if len(producer_nodes) == 0:
12531210
logger.warning(
12541211
f"Output tensor {output.name} is disconnected. This may be benign if it's part of a cast operation "
@@ -1296,13 +1253,6 @@ def _sanity_check(self):
12961253
if not sanity_ok:
12971254
raise Exception("Sanity Check Failed")
12981255

1299-
def _get_tensor_type(self, tensor_name):
1300-
if tensor_name in self.value_info_map:
1301-
return self.value_info_map[tensor_name].type.tensor_type.elem_type
1302-
if tensor_name in self.initializer_map:
1303-
return self.initializer_map[tensor_name].data_type
1304-
raise Exception(f"did not find tensor {tensor_name}")
1305-
13061256
def _sanitize_model(self):
13071257
graph_sanitizer = GraphSanitizer(
13081258
self.model,

modelopt/onnx/autocast/utils.py

Lines changed: 2 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import onnx
2929

30+
import modelopt.onnx.utils as onnx_utils
3031
from modelopt.onnx.utils import get_opset_version
3132

3233

@@ -60,32 +61,6 @@ def setup_mappings(model: onnx.ModelProto) -> tuple[dict, dict, dict]:
6061
return value_info_map, initializer_map, node_to_init_map
6162

6263

63-
def get_consumer_nodes(model: onnx.ModelProto, tensor_name: str) -> list[onnx.NodeProto]:
64-
"""Get all consumer nodes for a given tensor name.
65-
66-
Args:
67-
model: The ONNX model to search.
68-
tensor_name: Name of the tensor to find consumers for.
69-
70-
Returns:
71-
list[onnx.NodeProto]: List of nodes that consume the tensor.
72-
"""
73-
return [n for n in model.graph.node if tensor_name in n.input]
74-
75-
76-
def get_producer_nodes(model: onnx.ModelProto, tensor_name: str) -> list[onnx.NodeProto]:
77-
"""Get all producer nodes for a given tensor name.
78-
79-
Args:
80-
model: The ONNX model to search.
81-
tensor_name: Name of the tensor to find producers for.
82-
83-
Returns:
84-
list[onnx.NodeProto]: List of nodes that produce the tensor.
85-
"""
86-
return [n for n in model.graph.node if tensor_name in n.output]
87-
88-
8964
def get_unique_consumer_node(model: onnx.ModelProto, tensor_name: str) -> onnx.NodeProto:
9065
"""Get a single consumer node and raise exception if there are multiple consumers.
9166
@@ -99,30 +74,12 @@ def get_unique_consumer_node(model: onnx.ModelProto, tensor_name: str) -> onnx.N
9974
Raises:
10075
Exception: If there is not exactly one consumer node.
10176
"""
102-
consumers = get_consumer_nodes(model, tensor_name)
77+
consumers = onnx_utils.get_consumer_nodes(model, tensor_name)
10378
if len(consumers) != 1:
10479
raise Exception(f"Expected single consumer for {tensor_name}, found {len(consumers)}")
10580
return consumers[0]
10681

10782

108-
def get_cast_to_type(cast_node: onnx.NodeProto) -> int:
109-
"""Get the target type from a Cast node.
110-
111-
Args:
112-
cast_node: The Cast node to extract type from.
113-
114-
Returns:
115-
int: The target type value from the Cast node's 'to' attribute.
116-
117-
Raises:
118-
ValueError: If the Cast node does not have a 'to' attribute.
119-
"""
120-
for attr in cast_node.attribute:
121-
if attr.name == "to":
122-
return attr.i
123-
raise ValueError("Cast node does not have 'to' attribute")
124-
125-
12683
def walk_subgraphs_recursive(
12784
graph: onnx.GraphProto,
12885
callback: Callable,

0 commit comments

Comments
 (0)