diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 5601cf7dbf..7b69169cde 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -25,7 +25,8 @@ NVIDIA Model Optimizer Changelog - Improve ``auto_quantize`` checkpoint/resume: calibration state is now saved and restored across runs, avoiding redundant calibration when resuming a search. - Add support for Nemotron-3 (NemotronHForCausalLM) model quantization and support for NemotronH MoE expert support in ``auto_quantize`` grouping and scoring rules. - Add support for block-granular RHT for non-power-of-2 dimensions. - +- Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes. + **Misc** - Migrated project metadata from ``setup.py`` to a fully declarative ``pyproject.toml``. diff --git a/modelopt/onnx/autocast/graphsanitizer.py b/modelopt/onnx/autocast/graphsanitizer.py index 85f407a591..2154a42568 100644 --- a/modelopt/onnx/autocast/graphsanitizer.py +++ b/modelopt/onnx/autocast/graphsanitizer.py @@ -130,7 +130,7 @@ def remove_disconnected_outputs(self) -> None: """Remove disconnected outputs from the model.""" tensors_to_remove = [] for tensor in self.model.graph.output: - if not utils.get_producer_nodes(self.model, tensor.name): + if not onnx_utils.get_producer_nodes(self.model, tensor.name): tensors_to_remove.append(tensor) logger.debug(f"Found disconnected output: {tensor.name}") @@ -279,7 +279,7 @@ def _match_layernorm_pattern(self, mean_node: onnx.NodeProto) -> dict | None: # Find variance computation branch pow_nodes = [ n - for n in utils.get_consumer_nodes(self.model, sub_node.output[0]) + for n in onnx_utils.get_consumer_nodes(self.model, sub_node.output[0]) if n.op_type == "Pow" ] if len(pow_nodes) != 1: @@ -303,8 +303,8 @@ def _match_layernorm_pattern(self, mean_node: onnx.NodeProto) -> dict | None: # Find Div node # Find the Div node that consumes both sqrt and sub outputs - sqrt_consumers = utils.get_consumer_nodes(self.model, sqrt_node.output[0]) - sub_consumers = utils.get_consumer_nodes(self.model, sub_node.output[0]) + sqrt_consumers = onnx_utils.get_consumer_nodes(self.model, sqrt_node.output[0]) + sub_consumers = onnx_utils.get_consumer_nodes(self.model, sub_node.output[0]) div_nodes = [n for n in sqrt_consumers if n in sub_consumers and n.op_type == "Div"] if len(div_nodes) != 1: @@ -342,14 +342,14 @@ def _match_layernorm_pattern(self, mean_node: onnx.NodeProto) -> dict | None: div_node, ] - consumers = utils.get_consumer_nodes(self.model, div_node.output[0]) + consumers = onnx_utils.get_consumer_nodes(self.model, div_node.output[0]) if len(consumers) == 1 and consumers[0].op_type == "Mul": mul_node = consumers[0] scale = self._get_initializer_value(mul_node.input[1], return_array=True) final_node = mul_node nodes_to_remove.append(mul_node) - consumers = utils.get_consumer_nodes(self.model, mul_node.output[0]) + consumers = onnx_utils.get_consumer_nodes(self.model, mul_node.output[0]) if len(consumers) == 1 and consumers[0].op_type == "Add": add_node = consumers[0] bias = self._get_initializer_value(add_node.input[1], return_array=True) @@ -457,7 +457,7 @@ def _create_layernorm_node(self, pattern: dict) -> onnx.NodeProto: def _find_insertion_point(self, input_name: str) -> int: """Find the correct insertion point for the new LayerNorm node.""" - producer_nodes = utils.get_producer_nodes(self.model, input_name) + producer_nodes = onnx_utils.get_producer_nodes(self.model, input_name) if not producer_nodes: return 0 diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 278486c4b4..3d6cb2a849 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -914,51 +914,12 @@ def _get_name(node: onnx.NodeProto | InputIndexTracker) -> str: return self._convert_initializer_data(init, from_type, to_type) - def _replace_tensor_name( - self, consumers: list[onnx.NodeProto], original_tensor_name: str, new_tensor_name: str - ) -> None: - """Replace occurrences of a tensor name in the given consumers' inputs with a new tensor name.""" - for consumer in consumers: - for idx, inp in enumerate(consumer.input): - if inp == original_tensor_name: - consumer.input[idx] = new_tensor_name - - def _bypass_cast_node(self, node: onnx.NodeProto) -> None: - # handling only a single input and output, as we only remove cast nodes - assert len(node.input) == 1 - assert len(node.output) == 1 - - input_tensor = node.input[0] - output_tensor = node.output[0] - - # Check if the cast output is also a graph output - is_output_producer = any(output.name == output_tensor for output in self.model.graph.output) - - # If the removed cast node is producing a network output, update the producer of the cast input so - # the network output name is preserved. - if is_output_producer: - producers = utils.get_producer_nodes(self.model, input_tensor) - for producer in producers: - for i, prod_out in enumerate(producer.output): - if prod_out == input_tensor: - producer.output[i] = output_tensor - consumers = utils.get_consumer_nodes(self.model, prod_out) - if len(consumers) > 1: - self._replace_tensor_name(consumers, prod_out, output_tensor) - else: - # Reconnect consumers of the cast output to use the cast input instead - consumers = utils.get_consumer_nodes(self.model, output_tensor) - for consumer in consumers: - for i, input_name in enumerate(consumer.input): - if input_name == output_tensor: - consumer.input[i] = input_tensor - def _remove_preexisting_casts(self) -> None: nodes_to_remove = [] for node in self.model.graph.node: if node.op_type == "Cast": - cast_from_type = self._get_tensor_type(node.input[0]) - cast_to_type = utils.get_cast_to_type(node) + cast_from_type = onnx_utils._get_tensor_type_by_name(self.model, node.input[0]) + cast_to_type = onnx_utils.get_cast_to_type(node) is_fp_cast = cast_to_type in [ onnx.TensorProto.FLOAT16, onnx.TensorProto.FLOAT, @@ -978,7 +939,7 @@ def _remove_preexisting_casts(self) -> None: ): continue nodes_to_remove.append(node) - self._bypass_cast_node(node) + onnx_utils._bypass_cast_node(self.model, node) logger.debug(f"Removing {len(nodes_to_remove)} pre-existing casts") for node in nodes_to_remove: @@ -1044,7 +1005,7 @@ def _add_cast( ) if tensor_to_consumers is None: - consumer_nodes = utils.get_consumer_nodes(self.model, tensor_name) + consumer_nodes = onnx_utils.get_consumer_nodes(self.model, tensor_name) else: consumer_nodes = tensor_to_consumers.get(tensor_name, []) consumer_nodes = [n for n in consumer_nodes if n.name not in exclude_consumers] @@ -1067,7 +1028,7 @@ def _add_cast( # Find producer node to insert cast after it if tensor_to_producers is None: - producer_nodes = utils.get_producer_nodes(self.model, tensor_name) + producer_nodes = onnx_utils.get_producer_nodes(self.model, tensor_name) else: producer_nodes = tensor_to_producers.get(tensor_name, []) if producer_nodes: @@ -1106,7 +1067,7 @@ def _cleanup_no_consumer_nodes(self): node for node in self.model.graph.node if not any( - out in network_outputs or utils.get_consumer_nodes(self.model, out) + out in network_outputs or onnx_utils.get_consumer_nodes(self.model, out) for out in node.output ) ] @@ -1124,65 +1085,30 @@ def _cleanup_pre_output_same_type_cast(self): for output in self.model.graph.output: if "_cast_to_" in output.name: - out_producer_nodes = utils.get_producer_nodes(self.model, output.name) + out_producer_nodes = onnx_utils.get_producer_nodes(self.model, output.name) if len(out_producer_nodes) == 1 and out_producer_nodes[0].op_type == "Cast": second_cast_node = out_producer_nodes[0] - cast_producer_nodes = utils.get_producer_nodes( + cast_producer_nodes = onnx_utils.get_producer_nodes( self.model, second_cast_node.input[0] ) if len(cast_producer_nodes) == 1 and cast_producer_nodes[0].op_type == "Cast": first_cast_node = cast_producer_nodes[0] if ( - self._is_same_type_cast(first_cast_node) - and utils.get_cast_to_type(second_cast_node) + onnx_utils._is_same_type_cast(self.model, first_cast_node) + and onnx_utils.get_cast_to_type(second_cast_node) == self.high_precision_type.onnx_type ): logger.debug(f"Removing pre-output double cast: {first_cast_node.name}") - self._bypass_cast_node(first_cast_node) + onnx_utils._bypass_cast_node(self.model, first_cast_node) self.model.graph.node.remove(first_cast_node) - def _is_same_type_cast(self, node: onnx.NodeProto) -> bool: - assert node.op_type == "Cast" - input_types = [self._get_tensor_type(inp) for inp in node.input] - output_type = utils.get_cast_to_type(node) - return all(inp_type == output_type for inp_type in input_types) and input_types is not None - - def _is_sequential_cast(self, node: onnx.NodeProto) -> bool: - assert node.op_type == "Cast" - output_type = utils.get_cast_to_type(node) - - # Cast to high precision -> cast to low precision, first cast has no impact and can be safely removed - # Cast to low precision -> cast to high precision affects precision and should not be removed - precision_order = [ - TensorProto.DOUBLE, - TensorProto.FLOAT, - TensorProto.FLOAT16, - TensorProto.BFLOAT16, - ] - consumers = [ - n for n in utils.get_consumer_nodes(self.model, node.output[0]) if n.op_type == "Cast" - ] - - # If the first cast has additional consumers, we should not remove it - if len(consumers) != 1: - return False - - next_node = consumers[0] - first_cast_type = output_type - second_cast_type = utils.get_cast_to_type(next_node) - - return ( - first_cast_type in precision_order - and second_cast_type in precision_order - and precision_order.index(first_cast_type) <= precision_order.index(second_cast_type) - ) - def _remove_redundant_casts(self): """Removes both sequential casts and casts that don't change precision. This method optimizes the graph by removing unnecessary cast operations that either: 1. Don't actually change the data type 2. Could be replaced by a single cast operation + 3. Can be folded into a preceding Constant node """ if self.custom_ops: self.model = self._propagate_types_shapes_custom_ops(self.model) @@ -1198,42 +1124,14 @@ def _remove_redundant_casts(self): check_type=True, ) - nodes_to_remove = [] - for node in self.model.graph.node: - if node.op_type == "Cast": - # Find cast nodes that don't change precision - if self._is_same_type_cast(node): - nodes_to_remove.append(node) - self._bypass_cast_node(node) - logger.debug(f"Found redundant same-type cast: {node.name}") - continue - - # Find sequential casts that don't change precision - if self._is_sequential_cast(node): - nodes_to_remove.append(node) - self._bypass_cast_node(node) - logger.debug(f"Found removable double-cast: {node.name}") - - # Find foldable Constant -> Cast. Initializers are handled by _convert_initializers. - if self._is_foldable_constant_cast_pattern(node): - nodes_to_remove.append(node) - cast_producers = utils.get_producer_nodes(self.model, node.input[0]) - assert len(cast_producers) == 1 and cast_producers[0].op_type == "Constant" - constant_producer = cast_producers[0] - self._convert_constant_values(constant_producer, node) - self._bypass_cast_node(node) - logger.debug(f"Found foldable Constant->Cast pattern, removing {node.name}") - - logger.debug(f"Removing redundant casts: {[n.name for n in nodes_to_remove]}") - for node in nodes_to_remove: - self.model.graph.node.remove(node) + self.model = onnx_utils.remove_redundant_casts(self.model) def _fix_network_output_names(self): modified = False for output in self.model.graph.output: if "_cast_to_" in output.name: post_cast_name = output.name - producer_nodes = utils.get_producer_nodes(self.model, output.name) + producer_nodes = onnx_utils.get_producer_nodes(self.model, output.name) if ( len(producer_nodes) == 1 and producer_nodes[0].op_type == "Cast" @@ -1245,7 +1143,7 @@ def _fix_network_output_names(self): pre_cast_name = original_name + "_pre_cast" output.name = original_name # Update all consumers of the original (pre-cast) output to use the pre-cast name - for node in utils.get_consumer_nodes(self.model, original_name): + for node in onnx_utils.get_consumer_nodes(self.model, original_name): if node == cast_node: continue for i, input_name in enumerate(node.input): @@ -1253,13 +1151,15 @@ def _fix_network_output_names(self): node.input[i] = pre_cast_name # do not break, can use the same tensor for multiple node inputs # Update all consumers of the post-cast output to use the original name - for node in utils.get_consumer_nodes(self.model, post_cast_name): + for node in onnx_utils.get_consumer_nodes(self.model, post_cast_name): for i, input_name in enumerate(node.input): if input_name == post_cast_name: node.input[i] = original_name # do not break, can use the same tensor for multiple node inputs # Update all producers of the original output to use the original name - cast_producer_nodes = utils.get_producer_nodes(self.model, cast_node.input[0]) + cast_producer_nodes = onnx_utils.get_producer_nodes( + self.model, cast_node.input[0] + ) for node in cast_producer_nodes: for i, node_output in enumerate(node.output): if node_output == original_name: @@ -1305,7 +1205,7 @@ def _sanity_check(self): # Verify that the output tensors are not disconnected for output in network_outputs: - producer_nodes = utils.get_producer_nodes(self.model, output.name) + producer_nodes = onnx_utils.get_producer_nodes(self.model, output.name) if len(producer_nodes) == 0: logger.warning( f"Output tensor {output.name} is disconnected. This may be benign if it's part of a cast operation " @@ -1353,87 +1253,6 @@ def _sanity_check(self): if not sanity_ok: raise Exception("Sanity Check Failed") - def _get_tensor_type(self, tensor_name): - if tensor_name in self.value_info_map: - return self.value_info_map[tensor_name].type.tensor_type.elem_type - if tensor_name in self.initializer_map: - return self.initializer_map[tensor_name].data_type - raise Exception(f"did not find tensor {tensor_name}") - - def _convert_constant_values(self, const_node, cast_node: onnx.NodeProto) -> None: - original_tensor = const_node.attribute[0].t - if original_tensor.data_type == onnx.TensorProto.BFLOAT16: - original_data = onnx_utils.read_f16_tensor_as_fp32(original_tensor) - else: - original_data = onnx.numpy_helper.to_array(original_tensor) - - # Precompute casted value - cast_to_type = utils.get_cast_to_type(cast_node) - cast_dtype = onnx.helper.tensor_dtype_to_np_dtype(cast_to_type) - - # Handle bfloat16 conversion manually since numpy doesn't support it natively - if cast_to_type == onnx.TensorProto.BFLOAT16: - casted_data = original_data.astype(ml_dtypes.bfloat16) - else: - casted_data = original_data.astype(cast_dtype) - - # Create a new constant node with casted data - if cast_to_type == onnx.TensorProto.BFLOAT16: - # Create TensorProto manually for bfloat16 - tensor_proto = onnx.TensorProto() - tensor_proto.name = const_node.output[0] - tensor_proto.data_type = onnx.TensorProto.BFLOAT16 - tensor_proto.dims.extend(casted_data.shape) - # Convert bfloat16 to raw bytes - bf16_bytes = casted_data.astype(ml_dtypes.bfloat16).view(np.uint16) - tensor_proto.raw_data = bf16_bytes.tobytes() - else: - # Create tensor manually to ensure proper handling - tensor_proto = onnx.numpy_helper.from_array(casted_data) - tensor_proto.name = const_node.output[0] - - new_const_node = onnx.helper.make_node( - "Constant", - inputs=[], - outputs=const_node.output, - value=tensor_proto, - name=const_node.name, - ) - - # Replace the original constant node with the new constant node - # The scope of this function is to convert the constant node data. Removing the cast is done later. - for node in utils.get_consumer_nodes(self.model, const_node.name): - for i, input_name in enumerate(node.input): - if input_name == const_node.name: - node.input[i] = new_const_node.output[0] - break - - const_idx = -1 - for i, node in enumerate(self.model.graph.node): - if node == const_node: - const_idx = i - break - - self.model.graph.node.remove(const_node) - self.model.graph.node.insert(const_idx, new_const_node) - # The Cast node is the sole consumer of the Constant node, guaranteed by _is_foldable_constant_cast_pattern - cast_node.input[0] = new_const_node.output[0] - - def _is_foldable_constant_cast_pattern(self, node: onnx.NodeProto) -> bool: - """Constant -> Cast and Cast is the only consumer of the Constant node.""" - assert node.op_type == "Cast" - - producer = utils.get_producer_nodes(self.model, node.input[0]) - - const_producer = ( - producer[0] if len(producer) == 1 and producer[0].op_type == "Constant" else None - ) - - if const_producer: - get_consumer_nodes = utils.get_consumer_nodes(self.model, const_producer.output[0]) - return len(get_consumer_nodes) == 1 and get_consumer_nodes[0] == node - return False - def _sanitize_model(self): graph_sanitizer = GraphSanitizer( self.model, diff --git a/modelopt/onnx/autocast/utils.py b/modelopt/onnx/autocast/utils.py index 629fab0890..a9ad5484c0 100644 --- a/modelopt/onnx/autocast/utils.py +++ b/modelopt/onnx/autocast/utils.py @@ -27,6 +27,7 @@ import onnx +import modelopt.onnx.utils as onnx_utils from modelopt.onnx.utils import get_opset_version @@ -60,32 +61,6 @@ def setup_mappings(model: onnx.ModelProto) -> tuple[dict, dict, dict]: return value_info_map, initializer_map, node_to_init_map -def get_consumer_nodes(model: onnx.ModelProto, tensor_name: str) -> list[onnx.NodeProto]: - """Get all consumer nodes for a given tensor name. - - Args: - model: The ONNX model to search. - tensor_name: Name of the tensor to find consumers for. - - Returns: - list[onnx.NodeProto]: List of nodes that consume the tensor. - """ - return [n for n in model.graph.node if tensor_name in n.input] - - -def get_producer_nodes(model: onnx.ModelProto, tensor_name: str) -> list[onnx.NodeProto]: - """Get all producer nodes for a given tensor name. - - Args: - model: The ONNX model to search. - tensor_name: Name of the tensor to find producers for. - - Returns: - list[onnx.NodeProto]: List of nodes that produce the tensor. - """ - return [n for n in model.graph.node if tensor_name in n.output] - - def get_unique_consumer_node(model: onnx.ModelProto, tensor_name: str) -> onnx.NodeProto: """Get a single consumer node and raise exception if there are multiple consumers. @@ -99,30 +74,12 @@ def get_unique_consumer_node(model: onnx.ModelProto, tensor_name: str) -> onnx.N Raises: Exception: If there is not exactly one consumer node. """ - consumers = get_consumer_nodes(model, tensor_name) + consumers = onnx_utils.get_consumer_nodes(model, tensor_name) if len(consumers) != 1: raise Exception(f"Expected single consumer for {tensor_name}, found {len(consumers)}") return consumers[0] -def get_cast_to_type(cast_node: onnx.NodeProto) -> int: - """Get the target type from a Cast node. - - Args: - cast_node: The Cast node to extract type from. - - Returns: - int: The target type value from the Cast node's 'to' attribute. - - Raises: - ValueError: If the Cast node does not have a 'to' attribute. - """ - for attr in cast_node.attribute: - if attr.name == "to": - return attr.i - raise ValueError("Cast node does not have 'to' attribute") - - def walk_subgraphs_recursive( graph: onnx.GraphProto, callback: Callable, diff --git a/modelopt/onnx/export/fp8_exporter.py b/modelopt/onnx/export/fp8_exporter.py index 28e6b1da1e..ffcbd89423 100644 --- a/modelopt/onnx/export/fp8_exporter.py +++ b/modelopt/onnx/export/fp8_exporter.py @@ -22,6 +22,8 @@ import torch from onnx_graphsurgeon.ir.tensor import LazyValues +from modelopt.onnx.logging_config import logger + from .base_exporter import ONNXQuantExporter @@ -45,13 +47,13 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: Even though modelopt supports FP8 onnx export, the weights are represented in fp32 + QDQ. The storage is therefore very bad. In this function, Q nodes will get removed from the weights and have only DQ nodes with those converted FP8 - weights in the output model. + weights in the output model. TRT custom ops are converted to native ONNX DequantizeLinear. Parameters: - onnx_model: ONNX model with FP32/FP16 weights and QDQ nodes. + onnx_model: ONNX model with FP32/FP16 weights and TRT_FP8 QDQ nodes. Returns: - ONNX model with FP8 weights and only DQ nodes for weights (QDQ preserved for activations). + ONNX model with FP8 weights and native ONNX DQ nodes for weights (QDQ preserved for activations). """ start_time = time.time() print("Replacing all (fp32 weights + fp8 QDQ) with (fp8 weights + DQ)...") @@ -62,7 +64,7 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: for node in graph.nodes: if node.op == "TRT_FP8QuantizeLinear": - # Should not remove input QDQ + # Should not remove input QDQ (only process weight quantization) if not isinstance(node.inputs[0], gs.Constant): continue @@ -88,7 +90,7 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: onnx_weights_fp8 = gs.Constant(quantizer_name + "/fp8_weights", values) node.outputs.clear() - # DQ Op is separated out + # Convert TRT DQ to native ONNX DequantizeLinear with FP8 weights dq_op.inputs[0] = onnx_weights_fp8 dq_op.op = "DequantizeLinear" dq_op.outputs[0].dtype = dq_op.inputs[1].dtype @@ -101,5 +103,46 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: @staticmethod def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: - """Post-processes the ONNX model for FP8 quantization.""" - return onnx_model + """Post-processes the ONNX model for FP8 quantization. + + Converts TRT_FP8 QDQ ops to native ONNX QuantizeLinear/DequantizeLinear: + - TRT_FP8QuantizeLinear -> QuantizeLinear with FP8E4M3FN zero_point and saturate=1 + - TRT_FP8DequantizeLinear -> DequantizeLinear + + Args: + onnx_model: The ONNX model containing TRT_FP8 quantization nodes. + + Returns: + The post-processed ONNX model with native ONNX quantization ops. + """ + logger.info("Post-processing FP8 quantized model") + graph = gs.import_onnx(onnx_model) + + # Convert TRT_FP8QuantizeLinear to native QuantizeLinear + for node in graph.nodes: + if node.op == "TRT_FP8QuantizeLinear": + node.op = "QuantizeLinear" + # Add FP8 zero_point if not present + if len(node.inputs) == 2: + # Create FP8 zero point constant + zp_tensor = onnx.TensorProto() + zp_tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN + zp_tensor.dims.extend([1]) # 1-element tensor + zp_tensor.raw_data = b"\x00" # Zero in FP8 + zp_values = LazyValues(zp_tensor) + zero_point = gs.Constant(node.name + "_zero_point", zp_values) + node.inputs.append(zero_point) + # Add saturate attribute for FP8 + node.attrs["saturate"] = 1 + logger.debug(f"Converted {node.name} from TRT_FP8QuantizeLinear to QuantizeLinear") + + # Convert TRT_FP8DequantizeLinear to native DequantizeLinear + for node in graph.nodes: + if node.op == "TRT_FP8DequantizeLinear": + node.op = "DequantizeLinear" + logger.debug( + f"Converted {node.name} from TRT_FP8DequantizeLinear to DequantizeLinear" + ) + + graph.cleanup().toposort() + return gs.export_onnx(graph) diff --git a/modelopt/onnx/export/nvfp4_exporter.py b/modelopt/onnx/export/nvfp4_exporter.py index 416c2fdf80..a80a9845fb 100644 --- a/modelopt/onnx/export/nvfp4_exporter.py +++ b/modelopt/onnx/export/nvfp4_exporter.py @@ -215,7 +215,7 @@ def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto: logger.debug(f"Found {len(fp4_qdq_nodes)} FP4QDQ nodes to process") for node in fp4_qdq_nodes: - idx = initializer_indices.get(node.input[0], None) + idx = initializer_indices.get(node.input[0]) assert idx is not None, f"Initializer for weight '{node.input[0]}' not found." tensor = initializers[idx] @@ -259,7 +259,7 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: fp4_qdq_nodes = [node for node in graph.node if node.op_type == "TRT_FP4QDQ"] for node in fp4_qdq_nodes: - idx = initializer_indices.get(node.input[0], None) + idx = initializer_indices.get(node.input[0]) assert idx is not None, f"Initializer for weight '{node.input[0]}' not found." tensor = initializers[idx] @@ -365,7 +365,7 @@ def _cast_input_dtypes(node: onnx.NodeProto, precision_dtype: str): logger.debug(f"Found {len(fp4_qdq_nodes)} FP4QDQ nodes to convert") for node in fp4_qdq_nodes: - idx = initializer_indices.get(node.input[0], None) + idx = initializer_indices.get(node.input[0]) assert idx is not None, f"Initializer for weight '{node.input[0]}' not found." initializers_to_delete.append(graph.initializer[idx].name) diff --git a/modelopt/onnx/utils.py b/modelopt/onnx/utils.py index 4025ea065a..b578dc3806 100644 --- a/modelopt/onnx/utils.py +++ b/modelopt/onnx/utils.py @@ -433,8 +433,8 @@ def randomize_weights_onnx_bytes(onnx_bytes: bytes, seed: int = 0) -> bytes: if len(init.dims) > 1: dtype = onnx.helper.tensor_dtype_to_np_dtype(init.data_type) if dtype in ["float16", "float32", "float64"]: - avg = weight_metadata.get(init.name + "_avg", None) - var = weight_metadata.get(init.name + "_var", None) + avg = weight_metadata.get(init.name + "_avg") + var = weight_metadata.get(init.name + "_var") if avg and var: numpy_array = np.random.normal(float(avg), float(var), size=init.dims).astype( dtype @@ -1215,6 +1215,274 @@ def onnx_type_str_to_enum(dtype: str) -> int: return getattr(onnx.TensorProto, dtype) +def get_cast_to_type(cast_node: onnx.NodeProto) -> int: + """Get the target type from a Cast node. + + Args: + cast_node: The Cast node to extract type from. + + Returns: + int: The target type value from the Cast node's 'to' attribute. + + Raises: + ValueError: If the Cast node does not have a 'to' attribute. + """ + for attr in cast_node.attribute: + if attr.name == "to": + return attr.i + raise ValueError("Cast node does not have 'to' attribute") + + +def get_consumer_nodes(model: onnx.ModelProto, tensor_name: str) -> list[onnx.NodeProto]: + """Get all consumer nodes for a given tensor name. + + Args: + model: The ONNX model to search. + tensor_name: Name of the tensor to find consumers for. + + Returns: + list[onnx.NodeProto]: List of nodes that consume the tensor. + """ + return [n for n in model.graph.node if tensor_name in n.input] + + +def get_producer_nodes(model: onnx.ModelProto, tensor_name: str) -> list[onnx.NodeProto]: + """Get all producer nodes for a given tensor name. + + Args: + model: The ONNX model to search. + tensor_name: Name of the tensor to find producers for. + + Returns: + list[onnx.NodeProto]: List of nodes that produce the tensor. + """ + return [n for n in model.graph.node if tensor_name in n.output] + + +def _build_tensor_type_map(model: onnx.ModelProto) -> dict[str, int]: + """Build an O(1) name-to-element-type lookup from all graph tensors.""" + type_map: dict[str, int] = {} + for vi in model.graph.value_info: + type_map[vi.name] = vi.type.tensor_type.elem_type + for init in model.graph.initializer: + type_map[init.name] = init.data_type + for inp in model.graph.input: + type_map[inp.name] = inp.type.tensor_type.elem_type + for out in model.graph.output: + type_map[out.name] = out.type.tensor_type.elem_type + # Constant node outputs are often not in value_info — extract type from the attribute. + for node in model.graph.node: + if node.op_type == "Constant" and node.output[0] not in type_map: + for attr in node.attribute: + if attr.name == "value" and attr.type == onnx.AttributeProto.TENSOR: + type_map[node.output[0]] = attr.t.data_type + break + return type_map + + +def _get_tensor_type_by_name( + model: onnx.ModelProto, tensor_name: str, type_map: dict[str, int] | None = None +): + """Get the tensor element type. Searches value_info, initializers, inputs, and outputs. + + Args: + model: The ONNX model (used as fallback when type_map is not provided). + tensor_name: Name of the tensor to look up. + type_map: Pre-built lookup from _build_tensor_type_map for O(1) access. + When called in a loop, pass this to avoid repeated linear scans. + """ + if type_map is not None: + if tensor_name in type_map: + return type_map[tensor_name] + raise Exception(f"did not find tensor {tensor_name}") + for vi in model.graph.value_info: + if vi.name == tensor_name: + return vi.type.tensor_type.elem_type + for init in model.graph.initializer: + if init.name == tensor_name: + return init.data_type + for inp in model.graph.input: + if inp.name == tensor_name: + return inp.type.tensor_type.elem_type + for out in model.graph.output: + if out.name == tensor_name: + return out.type.tensor_type.elem_type + for node in model.graph.node: + if node.op_type == "Constant" and node.output[0] == tensor_name: + for attr in node.attribute: + if attr.name == "value" and attr.type == onnx.AttributeProto.TENSOR: + return attr.t.data_type + raise Exception(f"did not find tensor {tensor_name}") + + +def _replace_tensor_name( + consumers: list[onnx.NodeProto], original_tensor_name: str, new_tensor_name: str +) -> None: + """Replace occurrences of a tensor name in the given consumers' inputs with a new tensor name.""" + for consumer in consumers: + for idx, inp in enumerate(consumer.input): + if inp == original_tensor_name: + consumer.input[idx] = new_tensor_name + + +def _is_same_type_cast( + model: onnx.ModelProto, node: onnx.NodeProto, type_map: dict[str, int] | None = None +) -> bool: + assert node.op_type == "Cast" + input_types = [_get_tensor_type_by_name(model, inp, type_map) for inp in node.input] + output_type = get_cast_to_type(node) + return bool(input_types) and all(inp_type == output_type for inp_type in input_types) + + +def _is_sequential_cast(model: onnx.ModelProto, node: onnx.NodeProto) -> bool: + assert node.op_type == "Cast" + output_type = get_cast_to_type(node) + + # Cast to high precision -> cast to low precision, first cast has no impact and can be safely removed + # Cast to low precision -> cast to high precision affects precision and should not be removed + precision_order = [ + onnx.TensorProto.DOUBLE, + onnx.TensorProto.FLOAT, + onnx.TensorProto.FLOAT16, + onnx.TensorProto.BFLOAT16, + ] + consumers = [n for n in get_consumer_nodes(model, node.output[0]) if n.op_type == "Cast"] + + # If the first cast has additional consumers, we should not remove it + if len(consumers) != 1: + return False + + next_node = consumers[0] + first_cast_type = output_type + second_cast_type = get_cast_to_type(next_node) + + return ( + first_cast_type in precision_order + and second_cast_type in precision_order + and precision_order.index(first_cast_type) <= precision_order.index(second_cast_type) + ) + + +def _bypass_cast_node(model: onnx.ModelProto, node: onnx.NodeProto) -> None: + # handling only a single input and output, as we only remove cast nodes + assert len(node.input) == 1 + assert len(node.output) == 1 + + input_tensor = node.input[0] + output_tensor = node.output[0] + + # Check if the cast output is also a graph output + is_output_producer = any(output.name == output_tensor for output in model.graph.output) + + # If the removed cast node is producing a network output, update the producer of the cast input so + # the network output name is preserved. + if is_output_producer: + producers = get_producer_nodes(model, input_tensor) + for producer in producers: + for i, prod_out in enumerate(producer.output): + if prod_out == input_tensor: + producer.output[i] = output_tensor + consumers = get_consumer_nodes(model, prod_out) + if len(consumers) > 1: + _replace_tensor_name(consumers, prod_out, output_tensor) + else: + # Reconnect consumers of the cast output to use the cast input instead + consumers = get_consumer_nodes(model, output_tensor) + for consumer in consumers: + for i, input_name in enumerate(consumer.input): + if input_name == output_tensor: + consumer.input[i] = input_tensor + + +def _is_foldable_constant_cast_pattern(model: onnx.ModelProto, node: onnx.NodeProto) -> bool: + """Check if a Constant -> Cast pattern can be folded.""" + assert node.op_type == "Cast" + cast_producers = get_producer_nodes(model, node.input[0]) + if len(cast_producers) == 1 and cast_producers[0].op_type == "Constant": + consumers = get_consumer_nodes(model, cast_producers[0].output[0]) + return len(consumers) == 1 + return False + + +def _convert_constant_values(constant_node: onnx.NodeProto, cast_node: onnx.NodeProto) -> None: + """Convert the Constant node's values to the Cast node's target type.""" + cast_to_type = get_cast_to_type(cast_node) + for attr in constant_node.attribute: + if attr.name == "value" and attr.type == onnx.AttributeProto.TENSOR: + # Read input tensor — bfloat16 tensors use raw_data and need special handling + if attr.t.data_type == onnx.TensorProto.BFLOAT16: + np_array = read_f16_tensor_as_fp32(attr.t) + else: + np_array = onnx.numpy_helper.to_array(attr.t) + + # Write output tensor — bfloat16 cannot use numpy_helper.from_array + if cast_to_type == onnx.TensorProto.BFLOAT16: + import ml_dtypes + + new_tensor = onnx.TensorProto() + new_tensor.dims.extend(np_array.shape) + new_tensor.name = attr.t.name + new_tensor.data_type = onnx.TensorProto.BFLOAT16 + bf16_bytes = np_array.astype(np.float32).astype(ml_dtypes.bfloat16) + new_tensor.raw_data = bf16_bytes.view(np.uint16).tobytes() + else: + target_np_type = onnx.helper.tensor_dtype_to_np_dtype(cast_to_type) + new_array = np_array.astype(target_np_type) + new_tensor = onnx.numpy_helper.from_array(new_array, attr.t.name) + + attr.t.CopyFrom(new_tensor) + break + + +def remove_redundant_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Removes both sequential casts and casts that don't change precision. + + This method optimizes the graph by removing unnecessary cast operations that either: + 1. Don't actually change the data type + 2. Could be replaced by a single cast operation + 3. Can be folded into a preceding Constant node + + Args: + onnx_model: The ONNX model to optimize. + + Returns: + onnx.ModelProto: Model with redundant casts removed. + """ + type_map = _build_tensor_type_map(onnx_model) + nodes_to_remove = [] + for node in onnx_model.graph.node: + if node.op_type == "Cast": + # Find cast nodes that don't change precision + if _is_same_type_cast(onnx_model, node, type_map): + nodes_to_remove.append(node) + _bypass_cast_node(onnx_model, node) + logger.debug(f"Found redundant same-type cast: {node.name}") + continue + + # Find sequential casts that don't change precision + if _is_sequential_cast(onnx_model, node): + nodes_to_remove.append(node) + _bypass_cast_node(onnx_model, node) + logger.debug(f"Found removable double-cast: {node.name}") + continue + + # Find foldable Constant -> Cast. Initializers are handled by _convert_initializers. + if _is_foldable_constant_cast_pattern(onnx_model, node): + nodes_to_remove.append(node) + cast_producers = get_producer_nodes(onnx_model, node.input[0]) + assert len(cast_producers) == 1 and cast_producers[0].op_type == "Constant" + constant_producer = cast_producers[0] + _convert_constant_values(constant_producer, node) + _bypass_cast_node(onnx_model, node) + logger.debug(f"Found foldable Constant->Cast pattern, removing {node.name}") + + logger.debug(f"Removing redundant casts: {[n.name for n in nodes_to_remove]}") + for node in nodes_to_remove: + onnx_model.graph.node.remove(node) + + return onnx_model + + def remove_node_training_mode(onnx_model: onnx.ModelProto, node_op_type: str) -> onnx.ModelProto: """Remove `training_mode` attribute and extra training outputs from nodes of a given op type. @@ -1263,3 +1531,49 @@ def remove_node_training_mode(onnx_model: onnx.ModelProto, node_op_type: str) -> onnx_model.graph.value_info.extend(keep) return onnx_model + + +def change_casts_to_fp16(model: onnx.ModelProto, target_op_types: list[str]) -> onnx.ModelProto: + """Change FP16-to-FP32 Cast nodes whose entire fanout feeds target ops to cast to FP16 instead. + + Args: + model: The ONNX model to modify. + target_op_types: List of op types to check for. Cast nodes feeding exclusively into + these will be changed from FP32 to FP16. + + Returns: + The modified ONNX model with Cast nodes updated. + """ + type_map = _build_tensor_type_map(model) + + # Build a map of tensor name -> consumer nodes + tensor_to_consumers: dict[str, list[onnx.NodeProto]] = {} + for node in model.graph.node: + for inp in node.input: + if inp: + tensor_to_consumers.setdefault(inp, []).append(node) + + # Find Cast nodes that feed into target ops and change FP16->FP32 to FP16->FP16 + for node in model.graph.node: + if node.op_type != "Cast": + continue + + # Only retarget FP16->FP32 casts; leave other casts (e.g. FP64->FP32) alone + cast_to = get_cast_to_type(node) + if cast_to != onnx.TensorProto.FLOAT: + continue + source_type = type_map.get(node.input[0]) + if source_type != onnx.TensorProto.FLOAT16: + continue + + # Only change when ALL consumers are target ops to avoid breaking non-target branches + consumers = tensor_to_consumers.get(node.output[0], []) + if not consumers or not all(c.op_type in target_op_types for c in consumers): + continue + + for attr in node.attribute: + if attr.name == "to": + attr.i = onnx.TensorProto.FLOAT16 + break + + return model diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index 304fb8ec7a..035d3b9d08 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -18,6 +18,7 @@ import base64 import inspect import json +import logging import os import shutil import tempfile @@ -25,6 +26,7 @@ from typing import Any import onnx +import onnxconverter_common.float16 as _f16_module import torch import torch.nn as nn from onnx import ModelProto @@ -42,6 +44,7 @@ ) from modelopt.onnx.quantization.qdq_utils import qdq_to_dq, replace_zero_scale_with_smallest_nonzero from modelopt.onnx.utils import ( + change_casts_to_fp16, check_model_uses_external_data, get_input_names, get_input_shapes, @@ -50,6 +53,7 @@ get_output_shapes, infer_shapes, remove_node_training_mode, + remove_redundant_casts, ) from modelopt.torch.quantization.export_onnx import configure_linear_module_onnx_quantizers from modelopt.torch.utils import flatten_tree, standardize_named_model_args @@ -57,6 +61,30 @@ from ..utils.onnx_optimizer import Optimizer +# Monkey-patch for onnxconverter_common bug in remove_unnecessary_cast_node(): +# cast_node_downstream_dict stores either a single node or a list of nodes, but the +# downstream-node handling at lines ~770/787 always does `downstream_node.input`, +# which raises AttributeError("'list' object has no attribute 'input'") when the +# value is a list (i.e. a Cast output feeds multiple consumers). +# TODO: Remove this patch once onnxconverter-common ships a fix. +# Upstream issue: https://github.com/microsoft/onnxconverter-common/issues/261 +_original_remove_unnecessary_cast_node = _f16_module.remove_unnecessary_cast_node + +_logger = logging.getLogger(__name__) + + +def _patched_remove_unnecessary_cast_node(graph): + try: + _original_remove_unnecessary_cast_node(graph) + except AttributeError as e: + if "'list' object has no attribute 'input'" in str(e): + _logger.debug("Skipping remove_unnecessary_cast_node due to known upstream bug: %s", e) + else: + raise + + +_f16_module.remove_unnecessary_cast_node = _patched_remove_unnecessary_cast_node + ModelMetadata = dict[str, Any] ModelType = Any ValueInfoType = Any @@ -414,6 +442,10 @@ def quantize_weights(model: nn.Module, onnx_model: onnx.ModelProto) -> onnx.Mode if is_int8_quantized(model): onnx_exporters.append(INT8QuantExporter) + if len(onnx_exporters) == 0: + print("No quantization exporters found for the model.") + return onnx_model + for onnx_exporter in onnx_exporters: onnx_model = onnx_exporter.process_model(onnx_model) @@ -560,38 +592,36 @@ def get_onnx_bytes_and_metadata( tree_spec_input, tree_spec_output, input_none_names, onnx_opt_graph, model ) - # TODO: Remove manual ir_version change once ORT supports ir_version 11 - onnx_opt_graph.ir_version = 10 - - # Convert dummy TRT_FP4QDQ nodes to 2DQ format if the model is quantized in FP4 mode - # Or convert weights to MXFP8 format if the model is quantized in MXFP8 mode - if is_int4_quantized(model) or is_fp4_quantized(model) or is_mxfp8_quantized(model): - onnx_opt_graph = quantize_weights(model, onnx_opt_graph) + onnx_opt_graph = quantize_weights(model, onnx_opt_graph) if dq_only: onnx_opt_graph = qdq_to_dq(onnx_opt_graph) - try: - # TODO: Single-precision torch model assumed - param_dtype = next(model.parameters()).dtype - except StopIteration: - param_dtype = torch.float32 - if weights_dtype in ["fp16", "bf16"] and param_dtype == torch.float32: - if is_int4_quantized(model) or is_mxfp8_quantized(model): + if weights_dtype in ["fp16", "bf16"]: + if is_int4_quantized(model) or is_mxfp8_quantized(model) or is_fp8_quantized(model): assert weights_dtype == "fp16", "BF16 + MXFP8/INT4 mixed precision is not supported yet" onnx_opt_graph = convert_float_to_float16( onnx_opt_graph, keep_io_types=False, disable_shape_infer=True, check_fp16_ready=False, + op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"], ) + # Change FP32 cast nodes feeding into Concat/Add to FP16 + onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add"]) else: onnx_opt_graph = convert_to_f16( onnx_opt_graph, low_precision_type=weights_dtype, keep_io_types=False ) - # TensorRT expects all scales to be postive - onnx_opt_graph = replace_zero_scale_with_smallest_nonzero(onnx_opt_graph) + onnx_opt_graph = remove_redundant_casts(onnx_opt_graph) + + # TensorRT expects all scales to be postive + onnx_opt_graph = replace_zero_scale_with_smallest_nonzero(onnx_opt_graph) + + # TODO: Remove manual ir_version change once ORT supports ir_version 11 + # Must be set after all gs.export_onnx() calls as graphsurgeon resets ir_version + onnx_opt_graph.ir_version = 10 # If the onnx model contains external data store the external tensors in one file and save the onnx model if has_external_data(onnx_save_path): diff --git a/tests/unit/onnx/autocast/test_precisionconverter.py b/tests/unit/onnx/autocast/test_precisionconverter.py index a14991319a..c3e1a51db5 100644 --- a/tests/unit/onnx/autocast/test_precisionconverter.py +++ b/tests/unit/onnx/autocast/test_precisionconverter.py @@ -1117,10 +1117,10 @@ def test_constant_cast_folding( assert const_array.attribute[0].t.data_type == low_precision_onnx_type(low_precision_type) # Check that the constant nodes are consumed directly by the Add nodes - assert len(utils.get_consumer_nodes(converted_model, "const_scalar")) == 1 - assert utils.get_consumer_nodes(converted_model, "const_scalar")[0].op_type == "Add" - assert len(utils.get_consumer_nodes(converted_model, "const_array")) == 1 - assert utils.get_consumer_nodes(converted_model, "const_array")[0].op_type == "Add" + assert len(onnx_utils.get_consumer_nodes(converted_model, "const_scalar")) == 1 + assert onnx_utils.get_consumer_nodes(converted_model, "const_scalar")[0].op_type == "Add" + assert len(onnx_utils.get_consumer_nodes(converted_model, "const_array")) == 1 + assert onnx_utils.get_consumer_nodes(converted_model, "const_array")[0].op_type == "Add" @pytest.fixture