From 2eaacb93181375f6b3509c4cd6b9ac299acdbb9c Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Mon, 26 Jan 2026 17:40:55 +0000 Subject: [PATCH 01/11] [OMNIML-2663] Use standard QDQ nodes for ONNX export Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- modelopt/torch/_deploy/utils/torch_onnx.py | 37 +++++++--------------- 1 file changed, 12 insertions(+), 25 deletions(-) diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index 304fb8ec7a..98c1964e88 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -31,7 +31,6 @@ from onnxconverter_common import convert_float_to_float16 from torch.nn.parallel import DataParallel, DistributedDataParallel -from modelopt.onnx.autocast.convert import convert_to_f16 from modelopt.onnx.export import ( FP8QuantExporter, INT4QuantExporter, @@ -563,35 +562,23 @@ def get_onnx_bytes_and_metadata( # TODO: Remove manual ir_version change once ORT supports ir_version 11 onnx_opt_graph.ir_version = 10 - # Convert dummy TRT_FP4QDQ nodes to 2DQ format if the model is quantized in FP4 mode - # Or convert weights to MXFP8 format if the model is quantized in MXFP8 mode - if is_int4_quantized(model) or is_fp4_quantized(model) or is_mxfp8_quantized(model): - onnx_opt_graph = quantize_weights(model, onnx_opt_graph) + onnx_opt_graph = quantize_weights(model, onnx_opt_graph) if dq_only: onnx_opt_graph = qdq_to_dq(onnx_opt_graph) - try: - # TODO: Single-precision torch model assumed - param_dtype = next(model.parameters()).dtype - except StopIteration: - param_dtype = torch.float32 - if weights_dtype in ["fp16", "bf16"] and param_dtype == torch.float32: - if is_int4_quantized(model) or is_mxfp8_quantized(model): - assert weights_dtype == "fp16", "BF16 + MXFP8/INT4 mixed precision is not supported yet" - onnx_opt_graph = convert_float_to_float16( - onnx_opt_graph, - keep_io_types=False, - disable_shape_infer=True, - check_fp16_ready=False, - ) - else: - onnx_opt_graph = convert_to_f16( - onnx_opt_graph, low_precision_type=weights_dtype, keep_io_types=False - ) + assert weights_dtype == "fp16", ( + "Only FP16 weights are supported for torch quantization -> onnx export" + ) + onnx_opt_graph = convert_float_to_float16( + onnx_opt_graph, + keep_io_types=False, + disable_shape_infer=True, + check_fp16_ready=False, + ) - # TensorRT expects all scales to be postive - onnx_opt_graph = replace_zero_scale_with_smallest_nonzero(onnx_opt_graph) + # TensorRT expects all scales to be postive + onnx_opt_graph = replace_zero_scale_with_smallest_nonzero(onnx_opt_graph) # If the onnx model contains external data store the external tensors in one file and save the onnx model if has_external_data(onnx_save_path): From 81a19ef4b1b1cc3aacc321de4f208464eb6ebdcd Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Mon, 2 Feb 2026 18:25:02 +0000 Subject: [PATCH 02/11] Update the fp8 exporter Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- modelopt/onnx/export/fp8_exporter.py | 49 +++++++++++++++++++--- modelopt/torch/_deploy/utils/torch_onnx.py | 4 +- 2 files changed, 45 insertions(+), 8 deletions(-) diff --git a/modelopt/onnx/export/fp8_exporter.py b/modelopt/onnx/export/fp8_exporter.py index 28e6b1da1e..2d431dcbc3 100644 --- a/modelopt/onnx/export/fp8_exporter.py +++ b/modelopt/onnx/export/fp8_exporter.py @@ -22,6 +22,8 @@ import torch from onnx_graphsurgeon.ir.tensor import LazyValues +from modelopt.onnx.logging_config import logger + from .base_exporter import ONNXQuantExporter @@ -45,13 +47,13 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: Even though modelopt supports FP8 onnx export, the weights are represented in fp32 + QDQ. The storage is therefore very bad. In this function, Q nodes will get removed from the weights and have only DQ nodes with those converted FP8 - weights in the output model. + weights in the output model. TRT custom ops are converted to native ONNX DequantizeLinear. Parameters: - onnx_model: ONNX model with FP32/FP16 weights and QDQ nodes. + onnx_model: ONNX model with FP32/FP16 weights and TRT_FP8 QDQ nodes. Returns: - ONNX model with FP8 weights and only DQ nodes for weights (QDQ preserved for activations). + ONNX model with FP8 weights and native ONNX DQ nodes for weights (QDQ preserved for activations). """ start_time = time.time() print("Replacing all (fp32 weights + fp8 QDQ) with (fp8 weights + DQ)...") @@ -62,7 +64,7 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: for node in graph.nodes: if node.op == "TRT_FP8QuantizeLinear": - # Should not remove input QDQ + # Should not remove input QDQ (only process weight quantization) if not isinstance(node.inputs[0], gs.Constant): continue @@ -88,7 +90,7 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: onnx_weights_fp8 = gs.Constant(quantizer_name + "/fp8_weights", values) node.outputs.clear() - # DQ Op is separated out + # Convert TRT DQ to native ONNX DequantizeLinear with FP8 weights dq_op.inputs[0] = onnx_weights_fp8 dq_op.op = "DequantizeLinear" dq_op.outputs[0].dtype = dq_op.inputs[1].dtype @@ -101,5 +103,40 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: @staticmethod def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: - """Post-processes the ONNX model for FP8 quantization.""" + """Post-processes the ONNX model for FP8 quantization. + + Converts remaining TRT_FP8 QDQ ops (activations) to native ONNX QuantizeLinear/DequantizeLinear, + updates GELU nodes to use tanh approximation, and inserts Cast nodes after Sqrt. + """ + logger.info("Post-processing FP8 quantized model") + graph = gs.import_onnx(onnx_model) + + # Convert TRT_FP8QuantizeLinear to native QuantizeLinear + for node in graph.nodes: + if node.op == "TRT_FP8QuantizeLinear": + node.op = "QuantizeLinear" + # Add FP8 zero_point if not present + if len(node.inputs) == 2: + # Create FP8 zero point constant + zp_tensor = onnx.TensorProto() + zp_tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN + zp_tensor.raw_data = b"\x00" # Zero in FP8 + zp_values = LazyValues(zp_tensor) + zero_point = gs.Constant(node.name + "_zero_point", zp_values) + node.inputs.append(zero_point) + # Add saturate attribute for FP8 + node.attrs["saturate"] = 1 + logger.debug(f"Converted {node.name} from TRT_FP8QuantizeLinear to QuantizeLinear") + + # Convert TRT_FP8DequantizeLinear to native DequantizeLinear + for node in graph.nodes: + if node.op == "TRT_FP8DequantizeLinear": + node.op = "DequantizeLinear" + logger.debug( + f"Converted {node.name} from TRT_FP8DequantizeLinear to DequantizeLinear" + ) + + graph.cleanup().toposort() + onnx_model = gs.export_onnx(graph) + return onnx_model diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index 98c1964e88..156dec4de2 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -567,8 +567,8 @@ def get_onnx_bytes_and_metadata( if dq_only: onnx_opt_graph = qdq_to_dq(onnx_opt_graph) - assert weights_dtype == "fp16", ( - "Only FP16 weights are supported for torch quantization -> onnx export" + assert weights_dtype in ["fp16", "fp32"], ( + "Only FP16 and FP32 weights are supported for torch quantization -> onnx export" ) onnx_opt_graph = convert_float_to_float16( onnx_opt_graph, From 03b7232a1bbf2efb6da0b099f8e0df3d24182511 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Wed, 4 Feb 2026 01:00:20 +0000 Subject: [PATCH 03/11] Update fp8 quantization flow Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- modelopt/onnx/export/fp8_exporter.py | 4 +- modelopt/onnx/utils.py | 90 +++++++++++++++++++++- modelopt/torch/_deploy/utils/torch_onnx.py | 23 +++--- 3 files changed, 103 insertions(+), 14 deletions(-) diff --git a/modelopt/onnx/export/fp8_exporter.py b/modelopt/onnx/export/fp8_exporter.py index 2d431dcbc3..479048112e 100644 --- a/modelopt/onnx/export/fp8_exporter.py +++ b/modelopt/onnx/export/fp8_exporter.py @@ -137,6 +137,4 @@ def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: ) graph.cleanup().toposort() - onnx_model = gs.export_onnx(graph) - - return onnx_model + return gs.export_onnx(graph) diff --git a/modelopt/onnx/utils.py b/modelopt/onnx/utils.py index 4025ea065a..30ce8139f7 100644 --- a/modelopt/onnx/utils.py +++ b/modelopt/onnx/utils.py @@ -433,8 +433,8 @@ def randomize_weights_onnx_bytes(onnx_bytes: bytes, seed: int = 0) -> bytes: if len(init.dims) > 1: dtype = onnx.helper.tensor_dtype_to_np_dtype(init.data_type) if dtype in ["float16", "float32", "float64"]: - avg = weight_metadata.get(init.name + "_avg", None) - var = weight_metadata.get(init.name + "_var", None) + avg = weight_metadata.get(init.name + "_avg") + var = weight_metadata.get(init.name + "_var") if avg and var: numpy_array = np.random.normal(float(avg), float(var), size=init.dims).astype( dtype @@ -1215,6 +1215,52 @@ def onnx_type_str_to_enum(dtype: str) -> int: return getattr(onnx.TensorProto, dtype) +def remove_duplicate_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Removes consecutive Cast nodes that cast to the same type. + + Example: Cast(to=FP16) -> Cast(to=FP16) becomes just Cast(to=FP16) + """ + graph = gs.import_onnx(onnx_model) + removed_count = 0 + + for node in list(graph.nodes): + if node.op != "Cast": + continue + + # Check if output goes to exactly one Cast node + if len(node.outputs) != 1 or len(node.outputs[0].outputs) != 1: + continue + + next_node = node.outputs[0].outputs[0] + if next_node.op != "Cast": + continue + + first_to = node.attrs.get("to") + second_to = next_node.attrs.get("to") + + # Only handle same-type casts + if first_to != second_to: + continue + + # Bypass the second cast - keep first, remove second + input_tensor = node.outputs[0] + output_tensor = next_node.outputs[0] + + for consumer in list(output_tensor.outputs): + for i, inp in enumerate(consumer.inputs): + if inp == output_tensor: + consumer.inputs[i] = input_tensor + next_node.outputs.clear() + removed_count += 1 + logger.debug(f"Removed duplicate cast: {next_node.name} (same type as {node.name})") + + if removed_count > 0: + graph.cleanup().toposort() + logger.info(f"Removed {removed_count} duplicate Cast nodes") + + return gs.export_onnx(graph) + + def remove_node_training_mode(onnx_model: onnx.ModelProto, node_op_type: str) -> onnx.ModelProto: """Remove `training_mode` attribute and extra training outputs from nodes of a given op type. @@ -1263,3 +1309,43 @@ def remove_node_training_mode(onnx_model: onnx.ModelProto, node_op_type: str) -> onnx_model.graph.value_info.extend(keep) return onnx_model + + +def change_casts_to_fp16(model: onnx.ModelProto, target_op_types: list[str]) -> onnx.ModelProto: + """Change Cast nodes that cast to FP32 and feed into specified nodes to cast to FP16 instead. + + Args: + model: The ONNX model to modify. + target_op_types: List of op types to check for. Cast nodes feeding into these will be + changed from FP32 to FP16. + + Returns: + The modified ONNX model with Cast nodes updated. + """ + # Build a map of tensor name -> consumer nodes + tensor_to_consumers: dict[str, list[onnx.NodeProto]] = {} + for node in model.graph.node: + for inp in node.input: + if inp: + tensor_to_consumers.setdefault(inp, []).append(node) + + # Find Cast nodes that feed into target ops and change FP32 -> FP16 + for node in model.graph.node: + if node.op_type != "Cast": + continue + + # Check if this Cast outputs to a target op type + cast_output = node.output[0] + consumers = tensor_to_consumers.get(cast_output, []) + feeds_target = any(c.op_type in target_op_types for c in consumers) + + if not feeds_target: + continue + + # Check if Cast is to FP32, and change to FP16 + for attr in node.attribute: + if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT: + attr.i = onnx.TensorProto.FLOAT16 + break + + return model diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index 156dec4de2..8c1a83cceb 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -41,6 +41,7 @@ ) from modelopt.onnx.quantization.qdq_utils import qdq_to_dq, replace_zero_scale_with_smallest_nonzero from modelopt.onnx.utils import ( + change_casts_to_fp16, check_model_uses_external_data, get_input_names, get_input_shapes, @@ -48,6 +49,7 @@ get_output_names, get_output_shapes, infer_shapes, + remove_duplicate_casts, remove_node_training_mode, ) from modelopt.torch.quantization.export_onnx import configure_linear_module_onnx_quantizers @@ -567,15 +569,18 @@ def get_onnx_bytes_and_metadata( if dq_only: onnx_opt_graph = qdq_to_dq(onnx_opt_graph) - assert weights_dtype in ["fp16", "fp32"], ( - "Only FP16 and FP32 weights are supported for torch quantization -> onnx export" - ) - onnx_opt_graph = convert_float_to_float16( - onnx_opt_graph, - keep_io_types=False, - disable_shape_infer=True, - check_fp16_ready=False, - ) + if weights_dtype == "fp16": + onnx_opt_graph = convert_float_to_float16( + onnx_opt_graph, + keep_io_types=False, + disable_shape_infer=True, + check_fp16_ready=False, + op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"], + ) + # Change FP32 cast nodes feeding into Concat/Add to FP16 + onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add"]) + + onnx_opt_graph = remove_duplicate_casts(onnx_opt_graph) # TensorRT expects all scales to be postive onnx_opt_graph = replace_zero_scale_with_smallest_nonzero(onnx_opt_graph) From 2b7e4d40ac22e5723cc660e36b82726497123644 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Wed, 4 Feb 2026 23:26:30 +0000 Subject: [PATCH 04/11] Fix test failures Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- modelopt/onnx/export/fp8_exporter.py | 11 +++++++++-- modelopt/onnx/export/nvfp4_exporter.py | 6 +++--- modelopt/torch/_deploy/utils/torch_onnx.py | 7 ++++--- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/modelopt/onnx/export/fp8_exporter.py b/modelopt/onnx/export/fp8_exporter.py index 479048112e..427851d9c7 100644 --- a/modelopt/onnx/export/fp8_exporter.py +++ b/modelopt/onnx/export/fp8_exporter.py @@ -105,8 +105,15 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: """Post-processes the ONNX model for FP8 quantization. - Converts remaining TRT_FP8 QDQ ops (activations) to native ONNX QuantizeLinear/DequantizeLinear, - updates GELU nodes to use tanh approximation, and inserts Cast nodes after Sqrt. + Converts TRT_FP8 QDQ ops to native ONNX QuantizeLinear/DequantizeLinear: + - TRT_FP8QuantizeLinear -> QuantizeLinear with FP8E4M3FN zero_point and saturate=1 + - TRT_FP8DequantizeLinear -> DequantizeLinear + + Args: + onnx_model: The ONNX model containing TRT_FP8 quantization nodes. + + Returns: + The post-processed ONNX model with native ONNX quantization ops. """ logger.info("Post-processing FP8 quantized model") graph = gs.import_onnx(onnx_model) diff --git a/modelopt/onnx/export/nvfp4_exporter.py b/modelopt/onnx/export/nvfp4_exporter.py index 416c2fdf80..a80a9845fb 100644 --- a/modelopt/onnx/export/nvfp4_exporter.py +++ b/modelopt/onnx/export/nvfp4_exporter.py @@ -215,7 +215,7 @@ def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto: logger.debug(f"Found {len(fp4_qdq_nodes)} FP4QDQ nodes to process") for node in fp4_qdq_nodes: - idx = initializer_indices.get(node.input[0], None) + idx = initializer_indices.get(node.input[0]) assert idx is not None, f"Initializer for weight '{node.input[0]}' not found." tensor = initializers[idx] @@ -259,7 +259,7 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: fp4_qdq_nodes = [node for node in graph.node if node.op_type == "TRT_FP4QDQ"] for node in fp4_qdq_nodes: - idx = initializer_indices.get(node.input[0], None) + idx = initializer_indices.get(node.input[0]) assert idx is not None, f"Initializer for weight '{node.input[0]}' not found." tensor = initializers[idx] @@ -365,7 +365,7 @@ def _cast_input_dtypes(node: onnx.NodeProto, precision_dtype: str): logger.debug(f"Found {len(fp4_qdq_nodes)} FP4QDQ nodes to convert") for node in fp4_qdq_nodes: - idx = initializer_indices.get(node.input[0], None) + idx = initializer_indices.get(node.input[0]) assert idx is not None, f"Initializer for weight '{node.input[0]}' not found." initializers_to_delete.append(graph.initializer[idx].name) diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index 8c1a83cceb..fadb44c937 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -561,9 +561,6 @@ def get_onnx_bytes_and_metadata( tree_spec_input, tree_spec_output, input_none_names, onnx_opt_graph, model ) - # TODO: Remove manual ir_version change once ORT supports ir_version 11 - onnx_opt_graph.ir_version = 10 - onnx_opt_graph = quantize_weights(model, onnx_opt_graph) if dq_only: @@ -585,6 +582,10 @@ def get_onnx_bytes_and_metadata( # TensorRT expects all scales to be postive onnx_opt_graph = replace_zero_scale_with_smallest_nonzero(onnx_opt_graph) + # TODO: Remove manual ir_version change once ORT supports ir_version 11 + # Must be set after all gs.export_onnx() calls as graphsurgeon resets ir_version + onnx_opt_graph.ir_version = 10 + # If the onnx model contains external data store the external tensors in one file and save the onnx model if has_external_data(onnx_save_path): tensor_paths = get_external_tensor_paths(onnx_path) From 97828b5ed03a84dc2ce3be1fd1122910230af20d Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Fri, 6 Feb 2026 07:01:29 +0000 Subject: [PATCH 05/11] Fix test failures Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- modelopt/torch/_deploy/utils/torch_onnx.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index fadb44c937..2d0ff5649f 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -21,10 +21,11 @@ import os import shutil import tempfile -from contextlib import nullcontext +from contextlib import nullcontext, suppress from typing import Any import onnx +import onnxconverter_common.float16 as _f16_module import torch import torch.nn as nn from onnx import ModelProto @@ -58,6 +59,17 @@ from ..utils.onnx_optimizer import Optimizer +# Monkey-patch to fix onnxconverter_common bug where downstream_node is a list +_original_remove_unnecessary_cast_node = _f16_module.remove_unnecessary_cast_node + + +def _patched_remove_unnecessary_cast_node(graph): + with suppress(AttributeError): + _original_remove_unnecessary_cast_node(graph) + + +_f16_module.remove_unnecessary_cast_node = _patched_remove_unnecessary_cast_node + ModelMetadata = dict[str, Any] ModelType = Any ValueInfoType = Any From 1c6a3baa4886b410a2a72764286c7422b15bac0b Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Fri, 13 Feb 2026 14:09:35 +0000 Subject: [PATCH 06/11] update function to replace duplicate casts Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- CHANGELOG.rst | 3 +- modelopt/onnx/autocast/precisionconverter.py | 135 +------------------ modelopt/onnx/export/fp8_exporter.py | 1 + modelopt/onnx/utils.py | 108 +++++++++++---- modelopt/torch/_deploy/utils/torch_onnx.py | 4 +- 5 files changed, 90 insertions(+), 161 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 5601cf7dbf..7b69169cde 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -25,7 +25,8 @@ NVIDIA Model Optimizer Changelog - Improve ``auto_quantize`` checkpoint/resume: calibration state is now saved and restored across runs, avoiding redundant calibration when resuming a search. - Add support for Nemotron-3 (NemotronHForCausalLM) model quantization and support for NemotronH MoE expert support in ``auto_quantize`` grouping and scoring rules. - Add support for block-granular RHT for non-power-of-2 dimensions. - +- Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes. + **Misc** - Migrated project metadata from ``setup.py`` to a fully declarative ``pyproject.toml``. diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 278486c4b4..d7987e6499 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -1147,42 +1147,13 @@ def _is_same_type_cast(self, node: onnx.NodeProto) -> bool: output_type = utils.get_cast_to_type(node) return all(inp_type == output_type for inp_type in input_types) and input_types is not None - def _is_sequential_cast(self, node: onnx.NodeProto) -> bool: - assert node.op_type == "Cast" - output_type = utils.get_cast_to_type(node) - - # Cast to high precision -> cast to low precision, first cast has no impact and can be safely removed - # Cast to low precision -> cast to high precision affects precision and should not be removed - precision_order = [ - TensorProto.DOUBLE, - TensorProto.FLOAT, - TensorProto.FLOAT16, - TensorProto.BFLOAT16, - ] - consumers = [ - n for n in utils.get_consumer_nodes(self.model, node.output[0]) if n.op_type == "Cast" - ] - - # If the first cast has additional consumers, we should not remove it - if len(consumers) != 1: - return False - - next_node = consumers[0] - first_cast_type = output_type - second_cast_type = utils.get_cast_to_type(next_node) - - return ( - first_cast_type in precision_order - and second_cast_type in precision_order - and precision_order.index(first_cast_type) <= precision_order.index(second_cast_type) - ) - def _remove_redundant_casts(self): """Removes both sequential casts and casts that don't change precision. This method optimizes the graph by removing unnecessary cast operations that either: 1. Don't actually change the data type 2. Could be replaced by a single cast operation + 3. Can be folded into a preceding Constant node """ if self.custom_ops: self.model = self._propagate_types_shapes_custom_ops(self.model) @@ -1198,35 +1169,7 @@ def _remove_redundant_casts(self): check_type=True, ) - nodes_to_remove = [] - for node in self.model.graph.node: - if node.op_type == "Cast": - # Find cast nodes that don't change precision - if self._is_same_type_cast(node): - nodes_to_remove.append(node) - self._bypass_cast_node(node) - logger.debug(f"Found redundant same-type cast: {node.name}") - continue - - # Find sequential casts that don't change precision - if self._is_sequential_cast(node): - nodes_to_remove.append(node) - self._bypass_cast_node(node) - logger.debug(f"Found removable double-cast: {node.name}") - - # Find foldable Constant -> Cast. Initializers are handled by _convert_initializers. - if self._is_foldable_constant_cast_pattern(node): - nodes_to_remove.append(node) - cast_producers = utils.get_producer_nodes(self.model, node.input[0]) - assert len(cast_producers) == 1 and cast_producers[0].op_type == "Constant" - constant_producer = cast_producers[0] - self._convert_constant_values(constant_producer, node) - self._bypass_cast_node(node) - logger.debug(f"Found foldable Constant->Cast pattern, removing {node.name}") - - logger.debug(f"Removing redundant casts: {[n.name for n in nodes_to_remove]}") - for node in nodes_to_remove: - self.model.graph.node.remove(node) + self.model = onnx_utils.remove_redundant_casts(self.model) def _fix_network_output_names(self): modified = False @@ -1360,80 +1303,6 @@ def _get_tensor_type(self, tensor_name): return self.initializer_map[tensor_name].data_type raise Exception(f"did not find tensor {tensor_name}") - def _convert_constant_values(self, const_node, cast_node: onnx.NodeProto) -> None: - original_tensor = const_node.attribute[0].t - if original_tensor.data_type == onnx.TensorProto.BFLOAT16: - original_data = onnx_utils.read_f16_tensor_as_fp32(original_tensor) - else: - original_data = onnx.numpy_helper.to_array(original_tensor) - - # Precompute casted value - cast_to_type = utils.get_cast_to_type(cast_node) - cast_dtype = onnx.helper.tensor_dtype_to_np_dtype(cast_to_type) - - # Handle bfloat16 conversion manually since numpy doesn't support it natively - if cast_to_type == onnx.TensorProto.BFLOAT16: - casted_data = original_data.astype(ml_dtypes.bfloat16) - else: - casted_data = original_data.astype(cast_dtype) - - # Create a new constant node with casted data - if cast_to_type == onnx.TensorProto.BFLOAT16: - # Create TensorProto manually for bfloat16 - tensor_proto = onnx.TensorProto() - tensor_proto.name = const_node.output[0] - tensor_proto.data_type = onnx.TensorProto.BFLOAT16 - tensor_proto.dims.extend(casted_data.shape) - # Convert bfloat16 to raw bytes - bf16_bytes = casted_data.astype(ml_dtypes.bfloat16).view(np.uint16) - tensor_proto.raw_data = bf16_bytes.tobytes() - else: - # Create tensor manually to ensure proper handling - tensor_proto = onnx.numpy_helper.from_array(casted_data) - tensor_proto.name = const_node.output[0] - - new_const_node = onnx.helper.make_node( - "Constant", - inputs=[], - outputs=const_node.output, - value=tensor_proto, - name=const_node.name, - ) - - # Replace the original constant node with the new constant node - # The scope of this function is to convert the constant node data. Removing the cast is done later. - for node in utils.get_consumer_nodes(self.model, const_node.name): - for i, input_name in enumerate(node.input): - if input_name == const_node.name: - node.input[i] = new_const_node.output[0] - break - - const_idx = -1 - for i, node in enumerate(self.model.graph.node): - if node == const_node: - const_idx = i - break - - self.model.graph.node.remove(const_node) - self.model.graph.node.insert(const_idx, new_const_node) - # The Cast node is the sole consumer of the Constant node, guaranteed by _is_foldable_constant_cast_pattern - cast_node.input[0] = new_const_node.output[0] - - def _is_foldable_constant_cast_pattern(self, node: onnx.NodeProto) -> bool: - """Constant -> Cast and Cast is the only consumer of the Constant node.""" - assert node.op_type == "Cast" - - producer = utils.get_producer_nodes(self.model, node.input[0]) - - const_producer = ( - producer[0] if len(producer) == 1 and producer[0].op_type == "Constant" else None - ) - - if const_producer: - get_consumer_nodes = utils.get_consumer_nodes(self.model, const_producer.output[0]) - return len(get_consumer_nodes) == 1 and get_consumer_nodes[0] == node - return False - def _sanitize_model(self): graph_sanitizer = GraphSanitizer( self.model, diff --git a/modelopt/onnx/export/fp8_exporter.py b/modelopt/onnx/export/fp8_exporter.py index 427851d9c7..ffcbd89423 100644 --- a/modelopt/onnx/export/fp8_exporter.py +++ b/modelopt/onnx/export/fp8_exporter.py @@ -127,6 +127,7 @@ def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: # Create FP8 zero point constant zp_tensor = onnx.TensorProto() zp_tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN + zp_tensor.dims.extend([1]) # 1-element tensor zp_tensor.raw_data = b"\x00" # Zero in FP8 zp_values = LazyValues(zp_tensor) zero_point = gs.Constant(node.name + "_zero_point", zp_values) diff --git a/modelopt/onnx/utils.py b/modelopt/onnx/utils.py index 30ce8139f7..eb2ae27fff 100644 --- a/modelopt/onnx/utils.py +++ b/modelopt/onnx/utils.py @@ -1215,48 +1215,106 @@ def onnx_type_str_to_enum(dtype: str) -> int: return getattr(onnx.TensorProto, dtype) -def remove_duplicate_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto: - """Removes consecutive Cast nodes that cast to the same type. +def remove_redundant_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Removes redundant Cast nodes from an ONNX model. - Example: Cast(to=FP16) -> Cast(to=FP16) becomes just Cast(to=FP16) + Handles three patterns: + 1. Same-type casts: Cast where input type == output type (no-op) + 2. Sequential casts: Cast(to=high_prec) -> Cast(to=low_prec), first cast removed + 3. Constant->Cast folding: Fold cast into preceding Constant node's data + + Args: + onnx_model: The ONNX model to optimize. + + Returns: + onnx.ModelProto: Model with redundant casts removed. """ + import ml_dtypes + graph = gs.import_onnx(onnx_model) removed_count = 0 + # Precision ordering: lower index = higher precision + precision_order = { + onnx.TensorProto.DOUBLE: 0, + onnx.TensorProto.FLOAT: 1, + onnx.TensorProto.FLOAT16: 2, + onnx.TensorProto.BFLOAT16: 3, + } + + def _get_onnx_type(tensor): + """Get ONNX type enum from a GS tensor's dtype.""" + if tensor.dtype is None: + return None + try: + return onnx.helper.np_dtype_to_tensor_dtype(tensor.dtype) + except Exception: + return None + + def _bypass_cast(node): + """Reconnect consumers of cast output to use cast input, removing the cast.""" + inp = node.inputs[0] + out = node.outputs[0] + for consumer in list(out.outputs): + for i, consumer_inp in enumerate(consumer.inputs): + if consumer_inp is out: + consumer.inputs[i] = inp + for i, graph_out in enumerate(graph.outputs): + if graph_out is out: + graph.outputs[i] = inp + node.outputs.clear() + for node in list(graph.nodes): if node.op != "Cast": continue - # Check if output goes to exactly one Cast node - if len(node.outputs) != 1 or len(node.outputs[0].outputs) != 1: + cast_to = node.attrs.get("to") + if cast_to is None: continue - next_node = node.outputs[0].outputs[0] - if next_node.op != "Cast": - continue + input_tensor = node.inputs[0] + output_tensor = node.outputs[0] - first_to = node.attrs.get("to") - second_to = next_node.attrs.get("to") - - # Only handle same-type casts - if first_to != second_to: + # Pattern 1: Same-type cast (no-op) + input_type = _get_onnx_type(input_tensor) + if input_type is not None and input_type == cast_to: + _bypass_cast(node) + removed_count += 1 + logger.debug(f"Removed same-type cast: {node.name}") continue - # Bypass the second cast - keep first, remove second - input_tensor = node.outputs[0] - output_tensor = next_node.outputs[0] - - for consumer in list(output_tensor.outputs): - for i, inp in enumerate(consumer.inputs): - if inp == output_tensor: - consumer.inputs[i] = input_tensor - next_node.outputs.clear() - removed_count += 1 - logger.debug(f"Removed duplicate cast: {next_node.name} (same type as {node.name})") + # Pattern 2: Sequential casts where first can be removed + # Cast(to=high) -> Cast(to=low): first cast has no effect + cast_consumers = output_tensor.outputs + if len(cast_consumers) == 1 and cast_consumers[0].op == "Cast": + next_cast_to = cast_consumers[0].attrs.get("to") + if ( + cast_to in precision_order + and next_cast_to in precision_order + and precision_order[cast_to] <= precision_order[next_cast_to] + ): + _bypass_cast(node) + removed_count += 1 + logger.debug(f"Removed sequential cast: {node.name}") + continue + + # Pattern 3: Constant -> Cast folding (only if constant has single consumer) + if isinstance(input_tensor, Constant) and len(input_tensor.outputs) == 1: + try: + if cast_to == onnx.TensorProto.BFLOAT16: + input_tensor.values = input_tensor.values.astype(ml_dtypes.bfloat16) + else: + cast_dtype = onnx.helper.tensor_dtype_to_np_dtype(cast_to) + input_tensor.values = input_tensor.values.astype(cast_dtype) + _bypass_cast(node) + removed_count += 1 + logger.debug(f"Folded Constant->Cast: {node.name}") + except Exception as e: + logger.debug(f"Failed to fold Constant->Cast {node.name}: {e}") if removed_count > 0: graph.cleanup().toposort() - logger.info(f"Removed {removed_count} duplicate Cast nodes") + logger.info(f"Removed {removed_count} redundant Cast nodes") return gs.export_onnx(graph) diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index 2d0ff5649f..b89a7a4e9e 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -50,8 +50,8 @@ get_output_names, get_output_shapes, infer_shapes, - remove_duplicate_casts, remove_node_training_mode, + remove_redundant_casts, ) from modelopt.torch.quantization.export_onnx import configure_linear_module_onnx_quantizers from modelopt.torch.utils import flatten_tree, standardize_named_model_args @@ -589,7 +589,7 @@ def get_onnx_bytes_and_metadata( # Change FP32 cast nodes feeding into Concat/Add to FP16 onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add"]) - onnx_opt_graph = remove_duplicate_casts(onnx_opt_graph) + onnx_opt_graph = remove_redundant_casts(onnx_opt_graph) # TensorRT expects all scales to be postive onnx_opt_graph = replace_zero_scale_with_smallest_nonzero(onnx_opt_graph) From 05c33b216afa832694c4d24b7afec25bd1d6b8da Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Thu, 12 Mar 2026 20:40:02 +0000 Subject: [PATCH 07/11] Restore logic for remove_redundant_casts Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- modelopt/onnx/autocast/graphsanitizer.py | 14 +- modelopt/onnx/autocast/precisionconverter.py | 86 ++---- modelopt/onnx/autocast/utils.py | 47 +-- modelopt/onnx/utils.py | 280 ++++++++++++------ .../onnx/autocast/test_precisionconverter.py | 8 +- 5 files changed, 223 insertions(+), 212 deletions(-) diff --git a/modelopt/onnx/autocast/graphsanitizer.py b/modelopt/onnx/autocast/graphsanitizer.py index 85f407a591..2154a42568 100644 --- a/modelopt/onnx/autocast/graphsanitizer.py +++ b/modelopt/onnx/autocast/graphsanitizer.py @@ -130,7 +130,7 @@ def remove_disconnected_outputs(self) -> None: """Remove disconnected outputs from the model.""" tensors_to_remove = [] for tensor in self.model.graph.output: - if not utils.get_producer_nodes(self.model, tensor.name): + if not onnx_utils.get_producer_nodes(self.model, tensor.name): tensors_to_remove.append(tensor) logger.debug(f"Found disconnected output: {tensor.name}") @@ -279,7 +279,7 @@ def _match_layernorm_pattern(self, mean_node: onnx.NodeProto) -> dict | None: # Find variance computation branch pow_nodes = [ n - for n in utils.get_consumer_nodes(self.model, sub_node.output[0]) + for n in onnx_utils.get_consumer_nodes(self.model, sub_node.output[0]) if n.op_type == "Pow" ] if len(pow_nodes) != 1: @@ -303,8 +303,8 @@ def _match_layernorm_pattern(self, mean_node: onnx.NodeProto) -> dict | None: # Find Div node # Find the Div node that consumes both sqrt and sub outputs - sqrt_consumers = utils.get_consumer_nodes(self.model, sqrt_node.output[0]) - sub_consumers = utils.get_consumer_nodes(self.model, sub_node.output[0]) + sqrt_consumers = onnx_utils.get_consumer_nodes(self.model, sqrt_node.output[0]) + sub_consumers = onnx_utils.get_consumer_nodes(self.model, sub_node.output[0]) div_nodes = [n for n in sqrt_consumers if n in sub_consumers and n.op_type == "Div"] if len(div_nodes) != 1: @@ -342,14 +342,14 @@ def _match_layernorm_pattern(self, mean_node: onnx.NodeProto) -> dict | None: div_node, ] - consumers = utils.get_consumer_nodes(self.model, div_node.output[0]) + consumers = onnx_utils.get_consumer_nodes(self.model, div_node.output[0]) if len(consumers) == 1 and consumers[0].op_type == "Mul": mul_node = consumers[0] scale = self._get_initializer_value(mul_node.input[1], return_array=True) final_node = mul_node nodes_to_remove.append(mul_node) - consumers = utils.get_consumer_nodes(self.model, mul_node.output[0]) + consumers = onnx_utils.get_consumer_nodes(self.model, mul_node.output[0]) if len(consumers) == 1 and consumers[0].op_type == "Add": add_node = consumers[0] bias = self._get_initializer_value(add_node.input[1], return_array=True) @@ -457,7 +457,7 @@ def _create_layernorm_node(self, pattern: dict) -> onnx.NodeProto: def _find_insertion_point(self, input_name: str) -> int: """Find the correct insertion point for the new LayerNorm node.""" - producer_nodes = utils.get_producer_nodes(self.model, input_name) + producer_nodes = onnx_utils.get_producer_nodes(self.model, input_name) if not producer_nodes: return 0 diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index d7987e6499..3d6cb2a849 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -914,51 +914,12 @@ def _get_name(node: onnx.NodeProto | InputIndexTracker) -> str: return self._convert_initializer_data(init, from_type, to_type) - def _replace_tensor_name( - self, consumers: list[onnx.NodeProto], original_tensor_name: str, new_tensor_name: str - ) -> None: - """Replace occurrences of a tensor name in the given consumers' inputs with a new tensor name.""" - for consumer in consumers: - for idx, inp in enumerate(consumer.input): - if inp == original_tensor_name: - consumer.input[idx] = new_tensor_name - - def _bypass_cast_node(self, node: onnx.NodeProto) -> None: - # handling only a single input and output, as we only remove cast nodes - assert len(node.input) == 1 - assert len(node.output) == 1 - - input_tensor = node.input[0] - output_tensor = node.output[0] - - # Check if the cast output is also a graph output - is_output_producer = any(output.name == output_tensor for output in self.model.graph.output) - - # If the removed cast node is producing a network output, update the producer of the cast input so - # the network output name is preserved. - if is_output_producer: - producers = utils.get_producer_nodes(self.model, input_tensor) - for producer in producers: - for i, prod_out in enumerate(producer.output): - if prod_out == input_tensor: - producer.output[i] = output_tensor - consumers = utils.get_consumer_nodes(self.model, prod_out) - if len(consumers) > 1: - self._replace_tensor_name(consumers, prod_out, output_tensor) - else: - # Reconnect consumers of the cast output to use the cast input instead - consumers = utils.get_consumer_nodes(self.model, output_tensor) - for consumer in consumers: - for i, input_name in enumerate(consumer.input): - if input_name == output_tensor: - consumer.input[i] = input_tensor - def _remove_preexisting_casts(self) -> None: nodes_to_remove = [] for node in self.model.graph.node: if node.op_type == "Cast": - cast_from_type = self._get_tensor_type(node.input[0]) - cast_to_type = utils.get_cast_to_type(node) + cast_from_type = onnx_utils._get_tensor_type_by_name(self.model, node.input[0]) + cast_to_type = onnx_utils.get_cast_to_type(node) is_fp_cast = cast_to_type in [ onnx.TensorProto.FLOAT16, onnx.TensorProto.FLOAT, @@ -978,7 +939,7 @@ def _remove_preexisting_casts(self) -> None: ): continue nodes_to_remove.append(node) - self._bypass_cast_node(node) + onnx_utils._bypass_cast_node(self.model, node) logger.debug(f"Removing {len(nodes_to_remove)} pre-existing casts") for node in nodes_to_remove: @@ -1044,7 +1005,7 @@ def _add_cast( ) if tensor_to_consumers is None: - consumer_nodes = utils.get_consumer_nodes(self.model, tensor_name) + consumer_nodes = onnx_utils.get_consumer_nodes(self.model, tensor_name) else: consumer_nodes = tensor_to_consumers.get(tensor_name, []) consumer_nodes = [n for n in consumer_nodes if n.name not in exclude_consumers] @@ -1067,7 +1028,7 @@ def _add_cast( # Find producer node to insert cast after it if tensor_to_producers is None: - producer_nodes = utils.get_producer_nodes(self.model, tensor_name) + producer_nodes = onnx_utils.get_producer_nodes(self.model, tensor_name) else: producer_nodes = tensor_to_producers.get(tensor_name, []) if producer_nodes: @@ -1106,7 +1067,7 @@ def _cleanup_no_consumer_nodes(self): node for node in self.model.graph.node if not any( - out in network_outputs or utils.get_consumer_nodes(self.model, out) + out in network_outputs or onnx_utils.get_consumer_nodes(self.model, out) for out in node.output ) ] @@ -1124,29 +1085,23 @@ def _cleanup_pre_output_same_type_cast(self): for output in self.model.graph.output: if "_cast_to_" in output.name: - out_producer_nodes = utils.get_producer_nodes(self.model, output.name) + out_producer_nodes = onnx_utils.get_producer_nodes(self.model, output.name) if len(out_producer_nodes) == 1 and out_producer_nodes[0].op_type == "Cast": second_cast_node = out_producer_nodes[0] - cast_producer_nodes = utils.get_producer_nodes( + cast_producer_nodes = onnx_utils.get_producer_nodes( self.model, second_cast_node.input[0] ) if len(cast_producer_nodes) == 1 and cast_producer_nodes[0].op_type == "Cast": first_cast_node = cast_producer_nodes[0] if ( - self._is_same_type_cast(first_cast_node) - and utils.get_cast_to_type(second_cast_node) + onnx_utils._is_same_type_cast(self.model, first_cast_node) + and onnx_utils.get_cast_to_type(second_cast_node) == self.high_precision_type.onnx_type ): logger.debug(f"Removing pre-output double cast: {first_cast_node.name}") - self._bypass_cast_node(first_cast_node) + onnx_utils._bypass_cast_node(self.model, first_cast_node) self.model.graph.node.remove(first_cast_node) - def _is_same_type_cast(self, node: onnx.NodeProto) -> bool: - assert node.op_type == "Cast" - input_types = [self._get_tensor_type(inp) for inp in node.input] - output_type = utils.get_cast_to_type(node) - return all(inp_type == output_type for inp_type in input_types) and input_types is not None - def _remove_redundant_casts(self): """Removes both sequential casts and casts that don't change precision. @@ -1176,7 +1131,7 @@ def _fix_network_output_names(self): for output in self.model.graph.output: if "_cast_to_" in output.name: post_cast_name = output.name - producer_nodes = utils.get_producer_nodes(self.model, output.name) + producer_nodes = onnx_utils.get_producer_nodes(self.model, output.name) if ( len(producer_nodes) == 1 and producer_nodes[0].op_type == "Cast" @@ -1188,7 +1143,7 @@ def _fix_network_output_names(self): pre_cast_name = original_name + "_pre_cast" output.name = original_name # Update all consumers of the original (pre-cast) output to use the pre-cast name - for node in utils.get_consumer_nodes(self.model, original_name): + for node in onnx_utils.get_consumer_nodes(self.model, original_name): if node == cast_node: continue for i, input_name in enumerate(node.input): @@ -1196,13 +1151,15 @@ def _fix_network_output_names(self): node.input[i] = pre_cast_name # do not break, can use the same tensor for multiple node inputs # Update all consumers of the post-cast output to use the original name - for node in utils.get_consumer_nodes(self.model, post_cast_name): + for node in onnx_utils.get_consumer_nodes(self.model, post_cast_name): for i, input_name in enumerate(node.input): if input_name == post_cast_name: node.input[i] = original_name # do not break, can use the same tensor for multiple node inputs # Update all producers of the original output to use the original name - cast_producer_nodes = utils.get_producer_nodes(self.model, cast_node.input[0]) + cast_producer_nodes = onnx_utils.get_producer_nodes( + self.model, cast_node.input[0] + ) for node in cast_producer_nodes: for i, node_output in enumerate(node.output): if node_output == original_name: @@ -1248,7 +1205,7 @@ def _sanity_check(self): # Verify that the output tensors are not disconnected for output in network_outputs: - producer_nodes = utils.get_producer_nodes(self.model, output.name) + producer_nodes = onnx_utils.get_producer_nodes(self.model, output.name) if len(producer_nodes) == 0: logger.warning( f"Output tensor {output.name} is disconnected. This may be benign if it's part of a cast operation " @@ -1296,13 +1253,6 @@ def _sanity_check(self): if not sanity_ok: raise Exception("Sanity Check Failed") - def _get_tensor_type(self, tensor_name): - if tensor_name in self.value_info_map: - return self.value_info_map[tensor_name].type.tensor_type.elem_type - if tensor_name in self.initializer_map: - return self.initializer_map[tensor_name].data_type - raise Exception(f"did not find tensor {tensor_name}") - def _sanitize_model(self): graph_sanitizer = GraphSanitizer( self.model, diff --git a/modelopt/onnx/autocast/utils.py b/modelopt/onnx/autocast/utils.py index 629fab0890..a9ad5484c0 100644 --- a/modelopt/onnx/autocast/utils.py +++ b/modelopt/onnx/autocast/utils.py @@ -27,6 +27,7 @@ import onnx +import modelopt.onnx.utils as onnx_utils from modelopt.onnx.utils import get_opset_version @@ -60,32 +61,6 @@ def setup_mappings(model: onnx.ModelProto) -> tuple[dict, dict, dict]: return value_info_map, initializer_map, node_to_init_map -def get_consumer_nodes(model: onnx.ModelProto, tensor_name: str) -> list[onnx.NodeProto]: - """Get all consumer nodes for a given tensor name. - - Args: - model: The ONNX model to search. - tensor_name: Name of the tensor to find consumers for. - - Returns: - list[onnx.NodeProto]: List of nodes that consume the tensor. - """ - return [n for n in model.graph.node if tensor_name in n.input] - - -def get_producer_nodes(model: onnx.ModelProto, tensor_name: str) -> list[onnx.NodeProto]: - """Get all producer nodes for a given tensor name. - - Args: - model: The ONNX model to search. - tensor_name: Name of the tensor to find producers for. - - Returns: - list[onnx.NodeProto]: List of nodes that produce the tensor. - """ - return [n for n in model.graph.node if tensor_name in n.output] - - def get_unique_consumer_node(model: onnx.ModelProto, tensor_name: str) -> onnx.NodeProto: """Get a single consumer node and raise exception if there are multiple consumers. @@ -99,30 +74,12 @@ def get_unique_consumer_node(model: onnx.ModelProto, tensor_name: str) -> onnx.N Raises: Exception: If there is not exactly one consumer node. """ - consumers = get_consumer_nodes(model, tensor_name) + consumers = onnx_utils.get_consumer_nodes(model, tensor_name) if len(consumers) != 1: raise Exception(f"Expected single consumer for {tensor_name}, found {len(consumers)}") return consumers[0] -def get_cast_to_type(cast_node: onnx.NodeProto) -> int: - """Get the target type from a Cast node. - - Args: - cast_node: The Cast node to extract type from. - - Returns: - int: The target type value from the Cast node's 'to' attribute. - - Raises: - ValueError: If the Cast node does not have a 'to' attribute. - """ - for attr in cast_node.attribute: - if attr.name == "to": - return attr.i - raise ValueError("Cast node does not have 'to' attribute") - - def walk_subgraphs_recursive( graph: onnx.GraphProto, callback: Callable, diff --git a/modelopt/onnx/utils.py b/modelopt/onnx/utils.py index eb2ae27fff..4e36686fb4 100644 --- a/modelopt/onnx/utils.py +++ b/modelopt/onnx/utils.py @@ -1215,108 +1215,212 @@ def onnx_type_str_to_enum(dtype: str) -> int: return getattr(onnx.TensorProto, dtype) -def remove_redundant_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto: - """Removes redundant Cast nodes from an ONNX model. +def get_cast_to_type(cast_node: onnx.NodeProto) -> int: + """Get the target type from a Cast node. + + Args: + cast_node: The Cast node to extract type from. + + Returns: + int: The target type value from the Cast node's 'to' attribute. + + Raises: + ValueError: If the Cast node does not have a 'to' attribute. + """ + for attr in cast_node.attribute: + if attr.name == "to": + return attr.i + raise ValueError("Cast node does not have 'to' attribute") + - Handles three patterns: - 1. Same-type casts: Cast where input type == output type (no-op) - 2. Sequential casts: Cast(to=high_prec) -> Cast(to=low_prec), first cast removed - 3. Constant->Cast folding: Fold cast into preceding Constant node's data +def get_consumer_nodes(model: onnx.ModelProto, tensor_name: str) -> list[onnx.NodeProto]: + """Get all consumer nodes for a given tensor name. Args: - onnx_model: The ONNX model to optimize. + model: The ONNX model to search. + tensor_name: Name of the tensor to find consumers for. Returns: - onnx.ModelProto: Model with redundant casts removed. + list[onnx.NodeProto]: List of nodes that consume the tensor. """ - import ml_dtypes + return [n for n in model.graph.node if tensor_name in n.input] - graph = gs.import_onnx(onnx_model) - removed_count = 0 - - # Precision ordering: lower index = higher precision - precision_order = { - onnx.TensorProto.DOUBLE: 0, - onnx.TensorProto.FLOAT: 1, - onnx.TensorProto.FLOAT16: 2, - onnx.TensorProto.BFLOAT16: 3, - } - - def _get_onnx_type(tensor): - """Get ONNX type enum from a GS tensor's dtype.""" - if tensor.dtype is None: - return None - try: - return onnx.helper.np_dtype_to_tensor_dtype(tensor.dtype) - except Exception: - return None - def _bypass_cast(node): - """Reconnect consumers of cast output to use cast input, removing the cast.""" - inp = node.inputs[0] - out = node.outputs[0] - for consumer in list(out.outputs): - for i, consumer_inp in enumerate(consumer.inputs): - if consumer_inp is out: - consumer.inputs[i] = inp - for i, graph_out in enumerate(graph.outputs): - if graph_out is out: - graph.outputs[i] = inp - node.outputs.clear() - - for node in list(graph.nodes): - if node.op != "Cast": - continue +def get_producer_nodes(model: onnx.ModelProto, tensor_name: str) -> list[onnx.NodeProto]: + """Get all producer nodes for a given tensor name. - cast_to = node.attrs.get("to") - if cast_to is None: - continue + Args: + model: The ONNX model to search. + tensor_name: Name of the tensor to find producers for. + + Returns: + list[onnx.NodeProto]: List of nodes that produce the tensor. + """ + return [n for n in model.graph.node if tensor_name in n.output] - input_tensor = node.inputs[0] - output_tensor = node.outputs[0] - # Pattern 1: Same-type cast (no-op) - input_type = _get_onnx_type(input_tensor) - if input_type is not None and input_type == cast_to: - _bypass_cast(node) - removed_count += 1 - logger.debug(f"Removed same-type cast: {node.name}") - continue +def _get_tensor_type_by_name(model: onnx.ModelProto, tensor_name: str): + """Get the tensor element type. Searches value_info, initializers, inputs, and outputs.""" + for vi in model.graph.value_info: + if vi.name == tensor_name: + return vi.type.tensor_type.elem_type + for init in model.graph.initializer: + if init.name == tensor_name: + return init.data_type + for inp in model.graph.input: + if inp.name == tensor_name: + return inp.type.tensor_type.elem_type + for out in model.graph.output: + if out.name == tensor_name: + return out.type.tensor_type.elem_type + raise Exception(f"did not find tensor {tensor_name}") + + +def _replace_tensor_name( + consumers: list[onnx.NodeProto], original_tensor_name: str, new_tensor_name: str +) -> None: + """Replace occurrences of a tensor name in the given consumers' inputs with a new tensor name.""" + for consumer in consumers: + for idx, inp in enumerate(consumer.input): + if inp == original_tensor_name: + consumer.input[idx] = new_tensor_name + + +def _is_same_type_cast(model: onnx.ModelProto, node: onnx.NodeProto) -> bool: + assert node.op_type == "Cast" + input_types = [_get_tensor_type_by_name(model, inp) for inp in node.input] + output_type = get_cast_to_type(node) + return all(inp_type == output_type for inp_type in input_types) and input_types is not None + + +def _is_sequential_cast(model: onnx.ModelProto, node: onnx.NodeProto) -> bool: + assert node.op_type == "Cast" + output_type = get_cast_to_type(node) + + # Cast to high precision -> cast to low precision, first cast has no impact and can be safely removed + # Cast to low precision -> cast to high precision affects precision and should not be removed + precision_order = [ + onnx.TensorProto.DOUBLE, + onnx.TensorProto.FLOAT, + onnx.TensorProto.FLOAT16, + onnx.TensorProto.BFLOAT16, + ] + consumers = [n for n in get_consumer_nodes(model, node.output[0]) if n.op_type == "Cast"] + + # If the first cast has additional consumers, we should not remove it + if len(consumers) != 1: + return False + + next_node = consumers[0] + first_cast_type = output_type + second_cast_type = get_cast_to_type(next_node) + + return ( + first_cast_type in precision_order + and second_cast_type in precision_order + and precision_order.index(first_cast_type) <= precision_order.index(second_cast_type) + ) + + +def _bypass_cast_node(model: onnx.ModelProto, node: onnx.NodeProto) -> None: + # handling only a single input and output, as we only remove cast nodes + assert len(node.input) == 1 + assert len(node.output) == 1 + + input_tensor = node.input[0] + output_tensor = node.output[0] + + # Check if the cast output is also a graph output + is_output_producer = any(output.name == output_tensor for output in model.graph.output) + + # If the removed cast node is producing a network output, update the producer of the cast input so + # the network output name is preserved. + if is_output_producer: + producers = get_producer_nodes(model, input_tensor) + for producer in producers: + for i, prod_out in enumerate(producer.output): + if prod_out == input_tensor: + producer.output[i] = output_tensor + consumers = get_consumer_nodes(model, prod_out) + if len(consumers) > 1: + _replace_tensor_name(consumers, prod_out, output_tensor) + else: + # Reconnect consumers of the cast output to use the cast input instead + consumers = get_consumer_nodes(model, output_tensor) + for consumer in consumers: + for i, input_name in enumerate(consumer.input): + if input_name == output_tensor: + consumer.input[i] = input_tensor + + +def _is_foldable_constant_cast_pattern(model: onnx.ModelProto, node: onnx.NodeProto) -> bool: + """Check if a Constant -> Cast pattern can be folded.""" + assert node.op_type == "Cast" + cast_producers = get_producer_nodes(model, node.input[0]) + if len(cast_producers) == 1 and cast_producers[0].op_type == "Constant": + consumers = get_consumer_nodes(model, cast_producers[0].output[0]) + return len(consumers) == 1 + return False + + +def _convert_constant_values(constant_node: onnx.NodeProto, cast_node: onnx.NodeProto) -> None: + """Convert the Constant node's values to the Cast node's target type.""" + cast_to_type = get_cast_to_type(cast_node) + for attr in constant_node.attribute: + if attr.name == "value" and attr.type == onnx.AttributeProto.TENSOR: + np_array = onnx.numpy_helper.to_array(attr.t) + target_np_type = onnx.helper.tensor_dtype_to_np_dtype(cast_to_type) + new_array = np_array.astype(target_np_type) + new_tensor = onnx.numpy_helper.from_array(new_array, attr.t.name) + attr.t.CopyFrom(new_tensor) + break + + +def remove_redundant_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Removes both sequential casts and casts that don't change precision. + + This method optimizes the graph by removing unnecessary cast operations that either: + 1. Don't actually change the data type + 2. Could be replaced by a single cast operation + 3. Can be folded into a preceding Constant node + + Args: + onnx_model: The ONNX model to optimize. - # Pattern 2: Sequential casts where first can be removed - # Cast(to=high) -> Cast(to=low): first cast has no effect - cast_consumers = output_tensor.outputs - if len(cast_consumers) == 1 and cast_consumers[0].op == "Cast": - next_cast_to = cast_consumers[0].attrs.get("to") - if ( - cast_to in precision_order - and next_cast_to in precision_order - and precision_order[cast_to] <= precision_order[next_cast_to] - ): - _bypass_cast(node) - removed_count += 1 - logger.debug(f"Removed sequential cast: {node.name}") + Returns: + onnx.ModelProto: Model with redundant casts removed. + """ + nodes_to_remove = [] + for node in onnx_model.graph.node: + if node.op_type == "Cast": + # Find cast nodes that don't change precision + if _is_same_type_cast(onnx_model, node): + nodes_to_remove.append(node) + _bypass_cast_node(onnx_model, node) + logger.debug(f"Found redundant same-type cast: {node.name}") continue - # Pattern 3: Constant -> Cast folding (only if constant has single consumer) - if isinstance(input_tensor, Constant) and len(input_tensor.outputs) == 1: - try: - if cast_to == onnx.TensorProto.BFLOAT16: - input_tensor.values = input_tensor.values.astype(ml_dtypes.bfloat16) - else: - cast_dtype = onnx.helper.tensor_dtype_to_np_dtype(cast_to) - input_tensor.values = input_tensor.values.astype(cast_dtype) - _bypass_cast(node) - removed_count += 1 - logger.debug(f"Folded Constant->Cast: {node.name}") - except Exception as e: - logger.debug(f"Failed to fold Constant->Cast {node.name}: {e}") - - if removed_count > 0: - graph.cleanup().toposort() - logger.info(f"Removed {removed_count} redundant Cast nodes") - - return gs.export_onnx(graph) + # Find sequential casts that don't change precision + if _is_sequential_cast(onnx_model, node): + nodes_to_remove.append(node) + _bypass_cast_node(onnx_model, node) + logger.debug(f"Found removable double-cast: {node.name}") + + # Find foldable Constant -> Cast. Initializers are handled by _convert_initializers. + if _is_foldable_constant_cast_pattern(onnx_model, node): + nodes_to_remove.append(node) + cast_producers = get_producer_nodes(onnx_model, node.input[0]) + assert len(cast_producers) == 1 and cast_producers[0].op_type == "Constant" + constant_producer = cast_producers[0] + _convert_constant_values(constant_producer, node) + _bypass_cast_node(onnx_model, node) + logger.debug(f"Found foldable Constant->Cast pattern, removing {node.name}") + + logger.debug(f"Removing redundant casts: {[n.name for n in nodes_to_remove]}") + for node in nodes_to_remove: + onnx_model.graph.node.remove(node) + + return onnx_model def remove_node_training_mode(onnx_model: onnx.ModelProto, node_op_type: str) -> onnx.ModelProto: diff --git a/tests/unit/onnx/autocast/test_precisionconverter.py b/tests/unit/onnx/autocast/test_precisionconverter.py index a14991319a..c3e1a51db5 100644 --- a/tests/unit/onnx/autocast/test_precisionconverter.py +++ b/tests/unit/onnx/autocast/test_precisionconverter.py @@ -1117,10 +1117,10 @@ def test_constant_cast_folding( assert const_array.attribute[0].t.data_type == low_precision_onnx_type(low_precision_type) # Check that the constant nodes are consumed directly by the Add nodes - assert len(utils.get_consumer_nodes(converted_model, "const_scalar")) == 1 - assert utils.get_consumer_nodes(converted_model, "const_scalar")[0].op_type == "Add" - assert len(utils.get_consumer_nodes(converted_model, "const_array")) == 1 - assert utils.get_consumer_nodes(converted_model, "const_array")[0].op_type == "Add" + assert len(onnx_utils.get_consumer_nodes(converted_model, "const_scalar")) == 1 + assert onnx_utils.get_consumer_nodes(converted_model, "const_scalar")[0].op_type == "Add" + assert len(onnx_utils.get_consumer_nodes(converted_model, "const_array")) == 1 + assert onnx_utils.get_consumer_nodes(converted_model, "const_array")[0].op_type == "Add" @pytest.fixture From f6ce7b37910a4a104457823e88c500d05efb3866 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Mon, 16 Mar 2026 17:49:35 +0000 Subject: [PATCH 08/11] Fix comments Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- modelopt/onnx/utils.py | 25 ++++++++++++++++---- modelopt/torch/_deploy/utils/torch_onnx.py | 27 ++++++++++++++-------- 2 files changed, 38 insertions(+), 14 deletions(-) diff --git a/modelopt/onnx/utils.py b/modelopt/onnx/utils.py index 4e36686fb4..22f5cbcb94 100644 --- a/modelopt/onnx/utils.py +++ b/modelopt/onnx/utils.py @@ -1368,10 +1368,27 @@ def _convert_constant_values(constant_node: onnx.NodeProto, cast_node: onnx.Node cast_to_type = get_cast_to_type(cast_node) for attr in constant_node.attribute: if attr.name == "value" and attr.type == onnx.AttributeProto.TENSOR: - np_array = onnx.numpy_helper.to_array(attr.t) - target_np_type = onnx.helper.tensor_dtype_to_np_dtype(cast_to_type) - new_array = np_array.astype(target_np_type) - new_tensor = onnx.numpy_helper.from_array(new_array, attr.t.name) + # Read input tensor — bfloat16 tensors use raw_data and need special handling + if attr.t.data_type == onnx.TensorProto.BFLOAT16: + np_array = read_f16_tensor_as_fp32(attr.t) + else: + np_array = onnx.numpy_helper.to_array(attr.t) + + # Write output tensor — bfloat16 cannot use numpy_helper.from_array + if cast_to_type == onnx.TensorProto.BFLOAT16: + import ml_dtypes + + new_tensor = onnx.TensorProto() + new_tensor.dims.extend(np_array.shape) + new_tensor.name = attr.t.name + new_tensor.data_type = onnx.TensorProto.BFLOAT16 + bf16_bytes = np_array.astype(np.float32).astype(ml_dtypes.bfloat16) + new_tensor.raw_data = bf16_bytes.view(np.uint16).tobytes() + else: + target_np_type = onnx.helper.tensor_dtype_to_np_dtype(cast_to_type) + new_array = np_array.astype(target_np_type) + new_tensor = onnx.numpy_helper.from_array(new_array, attr.t.name) + attr.t.CopyFrom(new_tensor) break diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index b89a7a4e9e..9b5cc6bdde 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -32,6 +32,7 @@ from onnxconverter_common import convert_float_to_float16 from torch.nn.parallel import DataParallel, DistributedDataParallel +from modelopt.onnx.autocast.convert import convert_to_f16 from modelopt.onnx.export import ( FP8QuantExporter, INT4QuantExporter, @@ -578,16 +579,22 @@ def get_onnx_bytes_and_metadata( if dq_only: onnx_opt_graph = qdq_to_dq(onnx_opt_graph) - if weights_dtype == "fp16": - onnx_opt_graph = convert_float_to_float16( - onnx_opt_graph, - keep_io_types=False, - disable_shape_infer=True, - check_fp16_ready=False, - op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"], - ) - # Change FP32 cast nodes feeding into Concat/Add to FP16 - onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add"]) + if weights_dtype in ["fp16", "bf16"]: + if is_int4_quantized(model) or is_mxfp8_quantized(model) or is_fp8_quantized(model): + assert weights_dtype == "fp16", "BF16 + MXFP8/INT4 mixed precision is not supported yet" + onnx_opt_graph = convert_float_to_float16( + onnx_opt_graph, + keep_io_types=False, + disable_shape_infer=True, + check_fp16_ready=False, + op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"], + ) + # Change FP32 cast nodes feeding into Concat/Add to FP16 + onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add"]) + else: + onnx_opt_graph = convert_to_f16( + onnx_opt_graph, low_precision_type=weights_dtype, keep_io_types=False + ) onnx_opt_graph = remove_redundant_casts(onnx_opt_graph) From e223a2e2e2578d66480cedf2e3be7bdb27d738e0 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Mon, 16 Mar 2026 18:04:11 +0000 Subject: [PATCH 09/11] Add details about monkeypatch Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- modelopt/torch/_deploy/utils/torch_onnx.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index 9b5cc6bdde..1218d7ebe4 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -18,10 +18,11 @@ import base64 import inspect import json +import logging import os import shutil import tempfile -from contextlib import nullcontext, suppress +from contextlib import nullcontext from typing import Any import onnx @@ -60,13 +61,26 @@ from ..utils.onnx_optimizer import Optimizer -# Monkey-patch to fix onnxconverter_common bug where downstream_node is a list +# Monkey-patch for onnxconverter_common bug in remove_unnecessary_cast_node(): +# cast_node_downstream_dict stores either a single node or a list of nodes, but the +# downstream-node handling at lines ~770/787 always does `downstream_node.input`, +# which raises AttributeError("'list' object has no attribute 'input'") when the +# value is a list (i.e. a Cast output feeds multiple consumers). +# TODO: Remove this patch once onnxconverter-common ships a fix. +# Upstream issue: https://github.com/microsoft/onnxconverter-common/issues/261 _original_remove_unnecessary_cast_node = _f16_module.remove_unnecessary_cast_node +_logger = logging.getLogger(__name__) + def _patched_remove_unnecessary_cast_node(graph): - with suppress(AttributeError): + try: _original_remove_unnecessary_cast_node(graph) + except AttributeError as e: + if "'list' object has no attribute 'input'" in str(e): + _logger.debug("Skipping remove_unnecessary_cast_node due to known upstream bug: %s", e) + else: + raise _f16_module.remove_unnecessary_cast_node = _patched_remove_unnecessary_cast_node From 9e3a35aac196837bfcf3823bea0f5824cb327a63 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Mon, 16 Mar 2026 18:16:10 +0000 Subject: [PATCH 10/11] Cache model information for faster lookup Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- modelopt/onnx/utils.py | 42 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/modelopt/onnx/utils.py b/modelopt/onnx/utils.py index 22f5cbcb94..a1ae961654 100644 --- a/modelopt/onnx/utils.py +++ b/modelopt/onnx/utils.py @@ -1259,8 +1259,35 @@ def get_producer_nodes(model: onnx.ModelProto, tensor_name: str) -> list[onnx.No return [n for n in model.graph.node if tensor_name in n.output] -def _get_tensor_type_by_name(model: onnx.ModelProto, tensor_name: str): - """Get the tensor element type. Searches value_info, initializers, inputs, and outputs.""" +def _build_tensor_type_map(model: onnx.ModelProto) -> dict[str, int]: + """Build an O(1) name-to-element-type lookup from all graph tensors.""" + type_map: dict[str, int] = {} + for vi in model.graph.value_info: + type_map[vi.name] = vi.type.tensor_type.elem_type + for init in model.graph.initializer: + type_map[init.name] = init.data_type + for inp in model.graph.input: + type_map[inp.name] = inp.type.tensor_type.elem_type + for out in model.graph.output: + type_map[out.name] = out.type.tensor_type.elem_type + return type_map + + +def _get_tensor_type_by_name( + model: onnx.ModelProto, tensor_name: str, type_map: dict[str, int] | None = None +): + """Get the tensor element type. Searches value_info, initializers, inputs, and outputs. + + Args: + model: The ONNX model (used as fallback when type_map is not provided). + tensor_name: Name of the tensor to look up. + type_map: Pre-built lookup from _build_tensor_type_map for O(1) access. + When called in a loop, pass this to avoid repeated linear scans. + """ + if type_map is not None: + if tensor_name in type_map: + return type_map[tensor_name] + raise Exception(f"did not find tensor {tensor_name}") for vi in model.graph.value_info: if vi.name == tensor_name: return vi.type.tensor_type.elem_type @@ -1286,11 +1313,13 @@ def _replace_tensor_name( consumer.input[idx] = new_tensor_name -def _is_same_type_cast(model: onnx.ModelProto, node: onnx.NodeProto) -> bool: +def _is_same_type_cast( + model: onnx.ModelProto, node: onnx.NodeProto, type_map: dict[str, int] | None = None +) -> bool: assert node.op_type == "Cast" - input_types = [_get_tensor_type_by_name(model, inp) for inp in node.input] + input_types = [_get_tensor_type_by_name(model, inp, type_map) for inp in node.input] output_type = get_cast_to_type(node) - return all(inp_type == output_type for inp_type in input_types) and input_types is not None + return bool(input_types) and all(inp_type == output_type for inp_type in input_types) def _is_sequential_cast(model: onnx.ModelProto, node: onnx.NodeProto) -> bool: @@ -1407,11 +1436,12 @@ def remove_redundant_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto: Returns: onnx.ModelProto: Model with redundant casts removed. """ + type_map = _build_tensor_type_map(onnx_model) nodes_to_remove = [] for node in onnx_model.graph.node: if node.op_type == "Cast": # Find cast nodes that don't change precision - if _is_same_type_cast(onnx_model, node): + if _is_same_type_cast(onnx_model, node, type_map): nodes_to_remove.append(node) _bypass_cast_node(onnx_model, node) logger.debug(f"Found redundant same-type cast: {node.name}") From de464f99267b598995f682881f773921f89c27f1 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Mon, 16 Mar 2026 19:31:14 +0000 Subject: [PATCH 11/11] Address comments Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- modelopt/onnx/utils.py | 41 ++++++++++++++++------ modelopt/torch/_deploy/utils/torch_onnx.py | 4 +++ 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/modelopt/onnx/utils.py b/modelopt/onnx/utils.py index a1ae961654..b578dc3806 100644 --- a/modelopt/onnx/utils.py +++ b/modelopt/onnx/utils.py @@ -1270,6 +1270,13 @@ def _build_tensor_type_map(model: onnx.ModelProto) -> dict[str, int]: type_map[inp.name] = inp.type.tensor_type.elem_type for out in model.graph.output: type_map[out.name] = out.type.tensor_type.elem_type + # Constant node outputs are often not in value_info — extract type from the attribute. + for node in model.graph.node: + if node.op_type == "Constant" and node.output[0] not in type_map: + for attr in node.attribute: + if attr.name == "value" and attr.type == onnx.AttributeProto.TENSOR: + type_map[node.output[0]] = attr.t.data_type + break return type_map @@ -1300,6 +1307,11 @@ def _get_tensor_type_by_name( for out in model.graph.output: if out.name == tensor_name: return out.type.tensor_type.elem_type + for node in model.graph.node: + if node.op_type == "Constant" and node.output[0] == tensor_name: + for attr in node.attribute: + if attr.name == "value" and attr.type == onnx.AttributeProto.TENSOR: + return attr.t.data_type raise Exception(f"did not find tensor {tensor_name}") @@ -1452,6 +1464,7 @@ def remove_redundant_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto: nodes_to_remove.append(node) _bypass_cast_node(onnx_model, node) logger.debug(f"Found removable double-cast: {node.name}") + continue # Find foldable Constant -> Cast. Initializers are handled by _convert_initializers. if _is_foldable_constant_cast_pattern(onnx_model, node): @@ -1521,16 +1534,18 @@ def remove_node_training_mode(onnx_model: onnx.ModelProto, node_op_type: str) -> def change_casts_to_fp16(model: onnx.ModelProto, target_op_types: list[str]) -> onnx.ModelProto: - """Change Cast nodes that cast to FP32 and feed into specified nodes to cast to FP16 instead. + """Change FP16-to-FP32 Cast nodes whose entire fanout feeds target ops to cast to FP16 instead. Args: model: The ONNX model to modify. - target_op_types: List of op types to check for. Cast nodes feeding into these will be - changed from FP32 to FP16. + target_op_types: List of op types to check for. Cast nodes feeding exclusively into + these will be changed from FP32 to FP16. Returns: The modified ONNX model with Cast nodes updated. """ + type_map = _build_tensor_type_map(model) + # Build a map of tensor name -> consumer nodes tensor_to_consumers: dict[str, list[onnx.NodeProto]] = {} for node in model.graph.node: @@ -1538,22 +1553,26 @@ def change_casts_to_fp16(model: onnx.ModelProto, target_op_types: list[str]) -> if inp: tensor_to_consumers.setdefault(inp, []).append(node) - # Find Cast nodes that feed into target ops and change FP32 -> FP16 + # Find Cast nodes that feed into target ops and change FP16->FP32 to FP16->FP16 for node in model.graph.node: if node.op_type != "Cast": continue - # Check if this Cast outputs to a target op type - cast_output = node.output[0] - consumers = tensor_to_consumers.get(cast_output, []) - feeds_target = any(c.op_type in target_op_types for c in consumers) + # Only retarget FP16->FP32 casts; leave other casts (e.g. FP64->FP32) alone + cast_to = get_cast_to_type(node) + if cast_to != onnx.TensorProto.FLOAT: + continue + source_type = type_map.get(node.input[0]) + if source_type != onnx.TensorProto.FLOAT16: + continue - if not feeds_target: + # Only change when ALL consumers are target ops to avoid breaking non-target branches + consumers = tensor_to_consumers.get(node.output[0], []) + if not consumers or not all(c.op_type in target_op_types for c in consumers): continue - # Check if Cast is to FP32, and change to FP16 for attr in node.attribute: - if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT: + if attr.name == "to": attr.i = onnx.TensorProto.FLOAT16 break diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index 1218d7ebe4..035d3b9d08 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -442,6 +442,10 @@ def quantize_weights(model: nn.Module, onnx_model: onnx.ModelProto) -> onnx.Mode if is_int8_quantized(model): onnx_exporters.append(INT8QuantExporter) + if len(onnx_exporters) == 0: + print("No quantization exporters found for the model.") + return onnx_model + for onnx_exporter in onnx_exporters: onnx_model = onnx_exporter.process_model(onnx_model)