Skip to content

Commit 0186223

Browse files
committed
update function to replace duplicate casts
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent e2abd9d commit 0186223

File tree

5 files changed

+89
-160
lines changed

5 files changed

+89
-160
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ NVIDIA Model Optimizer Changelog (Linux)
2222
- Add PTQ support for GLM-4.7, including loading MTP layer weights from a separate ``mtp.safetensors`` file and export as-is.
2323
- Add support for image-text data calibration in PTQ for Nemotron VL models.
2424
- Add PTQ support for Nemotron Parse.
25+
- Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes
2526

2627
0.41 (2026-01-19)
2728
^^^^^^^^^^^^^^^^^

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 2 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,42 +1147,13 @@ def _is_same_type_cast(self, node: onnx.NodeProto) -> bool:
11471147
output_type = utils.get_cast_to_type(node)
11481148
return all(inp_type == output_type for inp_type in input_types) and input_types is not None
11491149

1150-
def _is_sequential_cast(self, node: onnx.NodeProto) -> bool:
1151-
assert node.op_type == "Cast"
1152-
output_type = utils.get_cast_to_type(node)
1153-
1154-
# Cast to high precision -> cast to low precision, first cast has no impact and can be safely removed
1155-
# Cast to low precision -> cast to high precision affects precision and should not be removed
1156-
precision_order = [
1157-
TensorProto.DOUBLE,
1158-
TensorProto.FLOAT,
1159-
TensorProto.FLOAT16,
1160-
TensorProto.BFLOAT16,
1161-
]
1162-
consumers = [
1163-
n for n in utils.get_consumer_nodes(self.model, node.output[0]) if n.op_type == "Cast"
1164-
]
1165-
1166-
# If the first cast has additional consumers, we should not remove it
1167-
if len(consumers) != 1:
1168-
return False
1169-
1170-
next_node = consumers[0]
1171-
first_cast_type = output_type
1172-
second_cast_type = utils.get_cast_to_type(next_node)
1173-
1174-
return (
1175-
first_cast_type in precision_order
1176-
and second_cast_type in precision_order
1177-
and precision_order.index(first_cast_type) <= precision_order.index(second_cast_type)
1178-
)
1179-
11801150
def _remove_redundant_casts(self):
11811151
"""Removes both sequential casts and casts that don't change precision.
11821152
11831153
This method optimizes the graph by removing unnecessary cast operations that either:
11841154
1. Don't actually change the data type
11851155
2. Could be replaced by a single cast operation
1156+
3. Can be folded into a preceding Constant node
11861157
"""
11871158
if self.custom_ops:
11881159
self.model = self._propagate_types_shapes_custom_ops(self.model)
@@ -1198,35 +1169,7 @@ def _remove_redundant_casts(self):
11981169
check_type=True,
11991170
)
12001171

1201-
nodes_to_remove = []
1202-
for node in self.model.graph.node:
1203-
if node.op_type == "Cast":
1204-
# Find cast nodes that don't change precision
1205-
if self._is_same_type_cast(node):
1206-
nodes_to_remove.append(node)
1207-
self._bypass_cast_node(node)
1208-
logger.debug(f"Found redundant same-type cast: {node.name}")
1209-
continue
1210-
1211-
# Find sequential casts that don't change precision
1212-
if self._is_sequential_cast(node):
1213-
nodes_to_remove.append(node)
1214-
self._bypass_cast_node(node)
1215-
logger.debug(f"Found removable double-cast: {node.name}")
1216-
1217-
# Find foldable Constant -> Cast. Initializers are handled by _convert_initializers.
1218-
if self._is_foldable_constant_cast_pattern(node):
1219-
nodes_to_remove.append(node)
1220-
cast_producers = utils.get_producer_nodes(self.model, node.input[0])
1221-
assert len(cast_producers) == 1 and cast_producers[0].op_type == "Constant"
1222-
constant_producer = cast_producers[0]
1223-
self._convert_constant_values(constant_producer, node)
1224-
self._bypass_cast_node(node)
1225-
logger.debug(f"Found foldable Constant->Cast pattern, removing {node.name}")
1226-
1227-
logger.debug(f"Removing redundant casts: {[n.name for n in nodes_to_remove]}")
1228-
for node in nodes_to_remove:
1229-
self.model.graph.node.remove(node)
1172+
self.model = onnx_utils.remove_redundant_casts(self.model)
12301173

12311174
def _fix_network_output_names(self):
12321175
modified = False
@@ -1360,80 +1303,6 @@ def _get_tensor_type(self, tensor_name):
13601303
return self.initializer_map[tensor_name].data_type
13611304
raise Exception(f"did not find tensor {tensor_name}")
13621305

1363-
def _convert_constant_values(self, const_node, cast_node: onnx.NodeProto) -> None:
1364-
original_tensor = const_node.attribute[0].t
1365-
if original_tensor.data_type == onnx.TensorProto.BFLOAT16:
1366-
original_data = onnx_utils.read_f16_tensor_as_fp32(original_tensor)
1367-
else:
1368-
original_data = onnx.numpy_helper.to_array(original_tensor)
1369-
1370-
# Precompute casted value
1371-
cast_to_type = utils.get_cast_to_type(cast_node)
1372-
cast_dtype = onnx.helper.tensor_dtype_to_np_dtype(cast_to_type)
1373-
1374-
# Handle bfloat16 conversion manually since numpy doesn't support it natively
1375-
if cast_to_type == onnx.TensorProto.BFLOAT16:
1376-
casted_data = original_data.astype(ml_dtypes.bfloat16)
1377-
else:
1378-
casted_data = original_data.astype(cast_dtype)
1379-
1380-
# Create a new constant node with casted data
1381-
if cast_to_type == onnx.TensorProto.BFLOAT16:
1382-
# Create TensorProto manually for bfloat16
1383-
tensor_proto = onnx.TensorProto()
1384-
tensor_proto.name = const_node.output[0]
1385-
tensor_proto.data_type = onnx.TensorProto.BFLOAT16
1386-
tensor_proto.dims.extend(casted_data.shape)
1387-
# Convert bfloat16 to raw bytes
1388-
bf16_bytes = casted_data.astype(ml_dtypes.bfloat16).view(np.uint16)
1389-
tensor_proto.raw_data = bf16_bytes.tobytes()
1390-
else:
1391-
# Create tensor manually to ensure proper handling
1392-
tensor_proto = onnx.numpy_helper.from_array(casted_data)
1393-
tensor_proto.name = const_node.output[0]
1394-
1395-
new_const_node = onnx.helper.make_node(
1396-
"Constant",
1397-
inputs=[],
1398-
outputs=const_node.output,
1399-
value=tensor_proto,
1400-
name=const_node.name,
1401-
)
1402-
1403-
# Replace the original constant node with the new constant node
1404-
# The scope of this function is to convert the constant node data. Removing the cast is done later.
1405-
for node in utils.get_consumer_nodes(self.model, const_node.name):
1406-
for i, input_name in enumerate(node.input):
1407-
if input_name == const_node.name:
1408-
node.input[i] = new_const_node.output[0]
1409-
break
1410-
1411-
const_idx = -1
1412-
for i, node in enumerate(self.model.graph.node):
1413-
if node == const_node:
1414-
const_idx = i
1415-
break
1416-
1417-
self.model.graph.node.remove(const_node)
1418-
self.model.graph.node.insert(const_idx, new_const_node)
1419-
# The Cast node is the sole consumer of the Constant node, guaranteed by _is_foldable_constant_cast_pattern
1420-
cast_node.input[0] = new_const_node.output[0]
1421-
1422-
def _is_foldable_constant_cast_pattern(self, node: onnx.NodeProto) -> bool:
1423-
"""Constant -> Cast and Cast is the only consumer of the Constant node."""
1424-
assert node.op_type == "Cast"
1425-
1426-
producer = utils.get_producer_nodes(self.model, node.input[0])
1427-
1428-
const_producer = (
1429-
producer[0] if len(producer) == 1 and producer[0].op_type == "Constant" else None
1430-
)
1431-
1432-
if const_producer:
1433-
get_consumer_nodes = utils.get_consumer_nodes(self.model, const_producer.output[0])
1434-
return len(get_consumer_nodes) == 1 and get_consumer_nodes[0] == node
1435-
return False
1436-
14371306
def _sanitize_model(self):
14381307
graph_sanitizer = GraphSanitizer(
14391308
self.model,

modelopt/onnx/export/fp8_exporter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
127127
# Create FP8 zero point constant
128128
zp_tensor = onnx.TensorProto()
129129
zp_tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN
130+
zp_tensor.dims.extend([1]) # 1-element tensor
130131
zp_tensor.raw_data = b"\x00" # Zero in FP8
131132
zp_values = LazyValues(zp_tensor)
132133
zero_point = gs.Constant(node.name + "_zero_point", zp_values)

modelopt/onnx/utils.py

Lines changed: 83 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,48 +1215,106 @@ def onnx_type_str_to_enum(dtype: str) -> int:
12151215
return getattr(onnx.TensorProto, dtype)
12161216

12171217

1218-
def remove_duplicate_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
1219-
"""Removes consecutive Cast nodes that cast to the same type.
1218+
def remove_redundant_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
1219+
"""Removes redundant Cast nodes from an ONNX model.
12201220
1221-
Example: Cast(to=FP16) -> Cast(to=FP16) becomes just Cast(to=FP16)
1221+
Handles three patterns:
1222+
1. Same-type casts: Cast where input type == output type (no-op)
1223+
2. Sequential casts: Cast(to=high_prec) -> Cast(to=low_prec), first cast removed
1224+
3. Constant->Cast folding: Fold cast into preceding Constant node's data
1225+
1226+
Args:
1227+
onnx_model: The ONNX model to optimize.
1228+
1229+
Returns:
1230+
onnx.ModelProto: Model with redundant casts removed.
12221231
"""
1232+
import ml_dtypes
1233+
12231234
graph = gs.import_onnx(onnx_model)
12241235
removed_count = 0
12251236

1237+
# Precision ordering: lower index = higher precision
1238+
precision_order = {
1239+
onnx.TensorProto.DOUBLE: 0,
1240+
onnx.TensorProto.FLOAT: 1,
1241+
onnx.TensorProto.FLOAT16: 2,
1242+
onnx.TensorProto.BFLOAT16: 3,
1243+
}
1244+
1245+
def _get_onnx_type(tensor):
1246+
"""Get ONNX type enum from a GS tensor's dtype."""
1247+
if tensor.dtype is None:
1248+
return None
1249+
try:
1250+
return onnx.helper.np_dtype_to_tensor_dtype(tensor.dtype)
1251+
except Exception:
1252+
return None
1253+
1254+
def _bypass_cast(node):
1255+
"""Reconnect consumers of cast output to use cast input, removing the cast."""
1256+
inp = node.inputs[0]
1257+
out = node.outputs[0]
1258+
for consumer in list(out.outputs):
1259+
for i, consumer_inp in enumerate(consumer.inputs):
1260+
if consumer_inp is out:
1261+
consumer.inputs[i] = inp
1262+
for i, graph_out in enumerate(graph.outputs):
1263+
if graph_out is out:
1264+
graph.outputs[i] = inp
1265+
node.outputs.clear()
1266+
12261267
for node in list(graph.nodes):
12271268
if node.op != "Cast":
12281269
continue
12291270

1230-
# Check if output goes to exactly one Cast node
1231-
if len(node.outputs) != 1 or len(node.outputs[0].outputs) != 1:
1271+
cast_to = node.attrs.get("to")
1272+
if cast_to is None:
12321273
continue
12331274

1234-
next_node = node.outputs[0].outputs[0]
1235-
if next_node.op != "Cast":
1236-
continue
1275+
input_tensor = node.inputs[0]
1276+
output_tensor = node.outputs[0]
12371277

1238-
first_to = node.attrs.get("to")
1239-
second_to = next_node.attrs.get("to")
1240-
1241-
# Only handle same-type casts
1242-
if first_to != second_to:
1278+
# Pattern 1: Same-type cast (no-op)
1279+
input_type = _get_onnx_type(input_tensor)
1280+
if input_type is not None and input_type == cast_to:
1281+
_bypass_cast(node)
1282+
removed_count += 1
1283+
logger.debug(f"Removed same-type cast: {node.name}")
12431284
continue
12441285

1245-
# Bypass the second cast - keep first, remove second
1246-
input_tensor = node.outputs[0]
1247-
output_tensor = next_node.outputs[0]
1248-
1249-
for consumer in list(output_tensor.outputs):
1250-
for i, inp in enumerate(consumer.inputs):
1251-
if inp == output_tensor:
1252-
consumer.inputs[i] = input_tensor
1253-
next_node.outputs.clear()
1254-
removed_count += 1
1255-
logger.debug(f"Removed duplicate cast: {next_node.name} (same type as {node.name})")
1286+
# Pattern 2: Sequential casts where first can be removed
1287+
# Cast(to=high) -> Cast(to=low): first cast has no effect
1288+
cast_consumers = output_tensor.outputs
1289+
if len(cast_consumers) == 1 and cast_consumers[0].op == "Cast":
1290+
next_cast_to = cast_consumers[0].attrs.get("to")
1291+
if (
1292+
cast_to in precision_order
1293+
and next_cast_to in precision_order
1294+
and precision_order[cast_to] <= precision_order[next_cast_to]
1295+
):
1296+
_bypass_cast(node)
1297+
removed_count += 1
1298+
logger.debug(f"Removed sequential cast: {node.name}")
1299+
continue
1300+
1301+
# Pattern 3: Constant -> Cast folding (only if constant has single consumer)
1302+
if isinstance(input_tensor, Constant) and len(input_tensor.outputs) == 1:
1303+
try:
1304+
if cast_to == onnx.TensorProto.BFLOAT16:
1305+
input_tensor.values = input_tensor.values.astype(ml_dtypes.bfloat16)
1306+
else:
1307+
cast_dtype = onnx.helper.tensor_dtype_to_np_dtype(cast_to)
1308+
input_tensor.values = input_tensor.values.astype(cast_dtype)
1309+
_bypass_cast(node)
1310+
removed_count += 1
1311+
logger.debug(f"Folded Constant->Cast: {node.name}")
1312+
except Exception as e:
1313+
logger.debug(f"Failed to fold Constant->Cast {node.name}: {e}")
12561314

12571315
if removed_count > 0:
12581316
graph.cleanup().toposort()
1259-
logger.info(f"Removed {removed_count} duplicate Cast nodes")
1317+
logger.info(f"Removed {removed_count} redundant Cast nodes")
12601318

12611319
return gs.export_onnx(graph)
12621320

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@
5050
get_output_names,
5151
get_output_shapes,
5252
infer_shapes,
53-
remove_duplicate_casts,
5453
remove_node_training_mode,
54+
remove_redundant_casts,
5555
)
5656
from modelopt.torch.quantization.export_onnx import configure_linear_module_onnx_quantizers
5757
from modelopt.torch.utils import flatten_tree, standardize_named_model_args
@@ -589,7 +589,7 @@ def get_onnx_bytes_and_metadata(
589589
# Change FP32 cast nodes feeding into Concat/Add to FP16
590590
onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add"])
591591

592-
onnx_opt_graph = remove_duplicate_casts(onnx_opt_graph)
592+
onnx_opt_graph = remove_redundant_casts(onnx_opt_graph)
593593

594594
# TensorRT expects all scales to be postive
595595
onnx_opt_graph = replace_zero_scale_with_smallest_nonzero(onnx_opt_graph)

0 commit comments

Comments
 (0)