Skip to content
3 changes: 2 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ NVIDIA Model Optimizer Changelog
- Improve ``auto_quantize`` checkpoint/resume: calibration state is now saved and restored across runs, avoiding redundant calibration when resuming a search.
- Add support for Nemotron-3 (NemotronHForCausalLM) model quantization and support for NemotronH MoE expert support in ``auto_quantize`` grouping and scoring rules.
- Add support for block-granular RHT for non-power-of-2 dimensions.

- Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes.

**Misc**

- Migrated project metadata from ``setup.py`` to a fully declarative ``pyproject.toml``.
Expand Down
14 changes: 7 additions & 7 deletions modelopt/onnx/autocast/graphsanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def remove_disconnected_outputs(self) -> None:
"""Remove disconnected outputs from the model."""
tensors_to_remove = []
for tensor in self.model.graph.output:
if not utils.get_producer_nodes(self.model, tensor.name):
if not onnx_utils.get_producer_nodes(self.model, tensor.name):
tensors_to_remove.append(tensor)
logger.debug(f"Found disconnected output: {tensor.name}")

Expand Down Expand Up @@ -279,7 +279,7 @@ def _match_layernorm_pattern(self, mean_node: onnx.NodeProto) -> dict | None:
# Find variance computation branch
pow_nodes = [
n
for n in utils.get_consumer_nodes(self.model, sub_node.output[0])
for n in onnx_utils.get_consumer_nodes(self.model, sub_node.output[0])
if n.op_type == "Pow"
]
if len(pow_nodes) != 1:
Expand All @@ -303,8 +303,8 @@ def _match_layernorm_pattern(self, mean_node: onnx.NodeProto) -> dict | None:

# Find Div node
# Find the Div node that consumes both sqrt and sub outputs
sqrt_consumers = utils.get_consumer_nodes(self.model, sqrt_node.output[0])
sub_consumers = utils.get_consumer_nodes(self.model, sub_node.output[0])
sqrt_consumers = onnx_utils.get_consumer_nodes(self.model, sqrt_node.output[0])
sub_consumers = onnx_utils.get_consumer_nodes(self.model, sub_node.output[0])

div_nodes = [n for n in sqrt_consumers if n in sub_consumers and n.op_type == "Div"]
if len(div_nodes) != 1:
Expand Down Expand Up @@ -342,14 +342,14 @@ def _match_layernorm_pattern(self, mean_node: onnx.NodeProto) -> dict | None:
div_node,
]

consumers = utils.get_consumer_nodes(self.model, div_node.output[0])
consumers = onnx_utils.get_consumer_nodes(self.model, div_node.output[0])
if len(consumers) == 1 and consumers[0].op_type == "Mul":
mul_node = consumers[0]
scale = self._get_initializer_value(mul_node.input[1], return_array=True)
final_node = mul_node
nodes_to_remove.append(mul_node)

consumers = utils.get_consumer_nodes(self.model, mul_node.output[0])
consumers = onnx_utils.get_consumer_nodes(self.model, mul_node.output[0])
if len(consumers) == 1 and consumers[0].op_type == "Add":
add_node = consumers[0]
bias = self._get_initializer_value(add_node.input[1], return_array=True)
Expand Down Expand Up @@ -457,7 +457,7 @@ def _create_layernorm_node(self, pattern: dict) -> onnx.NodeProto:

def _find_insertion_point(self, input_name: str) -> int:
"""Find the correct insertion point for the new LayerNorm node."""
producer_nodes = utils.get_producer_nodes(self.model, input_name)
producer_nodes = onnx_utils.get_producer_nodes(self.model, input_name)
if not producer_nodes:
return 0

Expand Down
221 changes: 20 additions & 201 deletions modelopt/onnx/autocast/precisionconverter.py

Large diffs are not rendered by default.

47 changes: 2 additions & 45 deletions modelopt/onnx/autocast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import onnx

import modelopt.onnx.utils as onnx_utils
from modelopt.onnx.utils import get_opset_version


Expand Down Expand Up @@ -60,32 +61,6 @@ def setup_mappings(model: onnx.ModelProto) -> tuple[dict, dict, dict]:
return value_info_map, initializer_map, node_to_init_map


def get_consumer_nodes(model: onnx.ModelProto, tensor_name: str) -> list[onnx.NodeProto]:
"""Get all consumer nodes for a given tensor name.

Args:
model: The ONNX model to search.
tensor_name: Name of the tensor to find consumers for.

Returns:
list[onnx.NodeProto]: List of nodes that consume the tensor.
"""
return [n for n in model.graph.node if tensor_name in n.input]


def get_producer_nodes(model: onnx.ModelProto, tensor_name: str) -> list[onnx.NodeProto]:
"""Get all producer nodes for a given tensor name.

Args:
model: The ONNX model to search.
tensor_name: Name of the tensor to find producers for.

Returns:
list[onnx.NodeProto]: List of nodes that produce the tensor.
"""
return [n for n in model.graph.node if tensor_name in n.output]


def get_unique_consumer_node(model: onnx.ModelProto, tensor_name: str) -> onnx.NodeProto:
"""Get a single consumer node and raise exception if there are multiple consumers.

Expand All @@ -99,30 +74,12 @@ def get_unique_consumer_node(model: onnx.ModelProto, tensor_name: str) -> onnx.N
Raises:
Exception: If there is not exactly one consumer node.
"""
consumers = get_consumer_nodes(model, tensor_name)
consumers = onnx_utils.get_consumer_nodes(model, tensor_name)
if len(consumers) != 1:
raise Exception(f"Expected single consumer for {tensor_name}, found {len(consumers)}")
return consumers[0]


def get_cast_to_type(cast_node: onnx.NodeProto) -> int:
"""Get the target type from a Cast node.

Args:
cast_node: The Cast node to extract type from.

Returns:
int: The target type value from the Cast node's 'to' attribute.

Raises:
ValueError: If the Cast node does not have a 'to' attribute.
"""
for attr in cast_node.attribute:
if attr.name == "to":
return attr.i
raise ValueError("Cast node does not have 'to' attribute")


def walk_subgraphs_recursive(
graph: onnx.GraphProto,
callback: Callable,
Expand Down
57 changes: 50 additions & 7 deletions modelopt/onnx/export/fp8_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import torch
from onnx_graphsurgeon.ir.tensor import LazyValues

from modelopt.onnx.logging_config import logger

from .base_exporter import ONNXQuantExporter


Expand All @@ -45,13 +47,13 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
Even though modelopt supports FP8 onnx export, the weights are represented in fp32 + QDQ.
The storage is therefore very bad. In this function,
Q nodes will get removed from the weights and have only DQ nodes with those converted FP8
weights in the output model.
weights in the output model. TRT custom ops are converted to native ONNX DequantizeLinear.

Parameters:
onnx_model: ONNX model with FP32/FP16 weights and QDQ nodes.
onnx_model: ONNX model with FP32/FP16 weights and TRT_FP8 QDQ nodes.

Returns:
ONNX model with FP8 weights and only DQ nodes for weights (QDQ preserved for activations).
ONNX model with FP8 weights and native ONNX DQ nodes for weights (QDQ preserved for activations).
"""
start_time = time.time()
print("Replacing all (fp32 weights + fp8 QDQ) with (fp8 weights + DQ)...")
Expand All @@ -62,7 +64,7 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:

for node in graph.nodes:
if node.op == "TRT_FP8QuantizeLinear":
# Should not remove input QDQ
# Should not remove input QDQ (only process weight quantization)
if not isinstance(node.inputs[0], gs.Constant):
continue

Expand All @@ -88,7 +90,7 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
onnx_weights_fp8 = gs.Constant(quantizer_name + "/fp8_weights", values)

node.outputs.clear()
# DQ Op is separated out
# Convert TRT DQ to native ONNX DequantizeLinear with FP8 weights
dq_op.inputs[0] = onnx_weights_fp8
dq_op.op = "DequantizeLinear"
dq_op.outputs[0].dtype = dq_op.inputs[1].dtype
Expand All @@ -101,5 +103,46 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:

@staticmethod
def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Post-processes the ONNX model for FP8 quantization."""
return onnx_model
"""Post-processes the ONNX model for FP8 quantization.

Converts TRT_FP8 QDQ ops to native ONNX QuantizeLinear/DequantizeLinear:
Comment thread
gcunhase marked this conversation as resolved.
- TRT_FP8QuantizeLinear -> QuantizeLinear with FP8E4M3FN zero_point and saturate=1
- TRT_FP8DequantizeLinear -> DequantizeLinear

Args:
onnx_model: The ONNX model containing TRT_FP8 quantization nodes.

Returns:
The post-processed ONNX model with native ONNX quantization ops.
"""
logger.info("Post-processing FP8 quantized model")
graph = gs.import_onnx(onnx_model)

# Convert TRT_FP8QuantizeLinear to native QuantizeLinear
for node in graph.nodes:
if node.op == "TRT_FP8QuantizeLinear":
node.op = "QuantizeLinear"
# Add FP8 zero_point if not present
if len(node.inputs) == 2:
# Create FP8 zero point constant
zp_tensor = onnx.TensorProto()
zp_tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN
zp_tensor.dims.extend([1]) # 1-element tensor
zp_tensor.raw_data = b"\x00" # Zero in FP8
zp_values = LazyValues(zp_tensor)
zero_point = gs.Constant(node.name + "_zero_point", zp_values)
node.inputs.append(zero_point)
Comment thread
gcunhase marked this conversation as resolved.
Comment on lines +127 to +134
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use a guaranteed-unique tensor name for the injected zero point.

node.name is optional in ONNX, so node.name + "_zero_point" can collapse to the same tensor name for multiple unnamed TRT FP8 Q nodes. That can make the exported graph invalid due to duplicate tensor names.

🛠️ Safer naming
-                    zero_point = gs.Constant(node.name + "_zero_point", zp_values)
+                    zero_point = gs.Constant(f"{node.outputs[0].name}_zero_point", zp_values)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/export/fp8_exporter.py` around lines 127 - 134, The injected
zero-point Constant currently uses node.name which may be empty and cause
duplicate tensor names; change the naming for the Constant created from
zp_tensor/zp_values/zero_point to a guaranteed-unique string (e.g., combine
node.name when present with a unique suffix such as a uuid4 or the node's memory
id or an incrementing counter, or use an ONNX/graph helper that returns a unique
name) so each FP8 zero-point Constant has a distinct tensor name even for
unnamed TRT FP8 Q nodes.

# Add saturate attribute for FP8
node.attrs["saturate"] = 1
logger.debug(f"Converted {node.name} from TRT_FP8QuantizeLinear to QuantizeLinear")

# Convert TRT_FP8DequantizeLinear to native DequantizeLinear
for node in graph.nodes:
if node.op == "TRT_FP8DequantizeLinear":
node.op = "DequantizeLinear"
logger.debug(
f"Converted {node.name} from TRT_FP8DequantizeLinear to DequantizeLinear"
)

graph.cleanup().toposort()
Comment on lines +121 to +147
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

Using the official ONNX operator and type documentation, what is the minimum ai.onnx opset that supports FLOAT8E4M3FN with QuantizeLinear and DequantizeLinear, including the saturate attribute on QuantizeLinear?

💡 Result:

Minimum ai.onnx opset: 19.

  • QuantizeLinear first supports tensor(float8e4m3fn) (via its T2 type constraint) in since_version = 19, and this is also the version that defines the saturate attribute (applying to float8 quantization). [1]
  • DequantizeLinear first supports tensor(float8e4m3fn) (via its T1 type constraint) in since_version = 19. [2]

Sources:
[1] QuantizeLinear (ONNX operator doc, version 19) (onnx.ai)
[2] DequantizeLinear (ONNX operator doc, version 19) (onnx.ai)

Citations:


🏁 Script executed:

# Check the fp8_exporter.py file to understand the full context
# Look for the function containing this code and any opset checks
fd "fp8_exporter.py" -x wc -l {} \;

Repository: NVIDIA/Model-Optimizer

Length of output: 153


🏁 Script executed:

# Read the file to understand the function context and any opset validation
cat -n modelopt/onnx/export/fp8_exporter.py | head -200

Repository: NVIDIA/Model-Optimizer

Length of output: 7470


🏁 Script executed:

# Search for opset-related code and checks in the file
rg -n "opset|onnx_opset" modelopt/onnx/export/fp8_exporter.py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Find the base exporter to understand the parent class
fd "base_exporter.py" modelopt/onnx/export

Repository: NVIDIA/Model-Optimizer

Length of output: 103


🏁 Script executed:

# Search for calls to post_process and get_onnx_bytes_and_metadata
rg -n "post_process|get_onnx_bytes_and_metadata" modelopt/onnx/export --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 942


🏁 Script executed:

# Check if there's opset validation in the base exporter
cat -n modelopt/onnx/export/base_exporter.py | head -300

Repository: NVIDIA/Model-Optimizer

Length of output: 2365


🏁 Script executed:

# Search for get_onnx_bytes_and_metadata function
rg -n "get_onnx_bytes_and_metadata|onnx_opset" modelopt/onnx --type py -B 2 -A 5

Repository: NVIDIA/Model-Optimizer

Length of output: 823


🏁 Script executed:

# Look for where FP8QuantExporter is instantiated/used
rg -n "FP8QuantExporter" modelopt/onnx --type py -B 2 -A 3

Repository: NVIDIA/Model-Optimizer

Length of output: 1792


🏁 Script executed:

# Check the quantizer file that likely calls the exporter
fd "quantizer.py" modelopt/onnx -x wc -l {} \;

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Search more broadly for onnx_opset and opset parameter usage
rg -n "onnx_opset|opset_version" modelopt/onnx --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 3244


🏁 Script executed:

# Look for quantizer or export related files that handle opset
fd "quantizer|export" modelopt/onnx --type py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 235


🏁 Script executed:

# Check the quantizer base class
rg -n "class.*Quantizer" modelopt/onnx --type py -A 10 | head -50

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look at the ort_patching.py file around the FP8 opset check
cat -n modelopt/onnx/quantization/ort_patching.py | sed -n '220,250p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1445


🏁 Script executed:

# Check where _check_opset_version is called and in what context
rg -n "_check_opset_version|check_opset_version" modelopt/onnx/quantization --type py -B 3 -A 3

Repository: NVIDIA/Model-Optimizer

Length of output: 1253


🏁 Script executed:

# Look at the quantize.py file to understand the full flow
cat -n modelopt/onnx/quantization/quantize.py | sed -n '140,185p'

Repository: NVIDIA/Model-Optimizer

Length of output: 2716


🏁 Script executed:

# Search for get_onnx_bytes_and_metadata function
rg -n "get_onnx_bytes_and_metadata" modelopt/onnx

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look at the quantize_mode constants to understand FP8 mode handling
rg -n "QDQ_PRECISION_MIN_OPSET|fp8" modelopt/onnx/quantization/quantize.py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 3436


🏁 Script executed:

# Check if there's a separate export entry point that may bypass opset checks
rg -n "def.*export|get_onnx_bytes" modelopt/onnx --type py | head -30

Repository: NVIDIA/Model-Optimizer

Length of output: 473


🏁 Script executed:

# Search for quantize_fp8 function
rg -n "def quantize_fp8|def quantize" modelopt/onnx/quantization/fp8.py -A 10 | head -50

Repository: NVIDIA/Model-Optimizer

Length of output: 604


🏁 Script executed:

# Look for where process_model is called
rg -n "process_model" modelopt/onnx --type py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 453


🏁 Script executed:

# Check the fp8.py file to see if it calls FP8QuantExporter.post_process
cat -n modelopt/onnx/quantization/fp8.py | head -150

Repository: NVIDIA/Model-Optimizer

Length of output: 7314


🏁 Script executed:

# Continue reading fp8.py to see where FP8QuantExporter is used
cat -n modelopt/onnx/quantization/fp8.py | sed -n '150,250p'

Repository: NVIDIA/Model-Optimizer

Length of output: 5087


🏁 Script executed:

# Search for where FP8QuantExporter.process_model is called
rg -n "\.process_model|FP8QuantExporter" modelopt/onnx --type py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 1639


🏁 Script executed:

# Check for any test files or examples that use FP8 export
fd "test.*fp8|fp8.*test" modelopt --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 235


🏁 Script executed:

# Search for where process_model is called with FP8QuantExporter
rg -n "process_model\|FP8QuantExporter" modelopt/onnx

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look in llm_export_utils since it showed up in opset search
cat -n modelopt/onnx/llm_export_utils/export_utils.py | sed -n '155,170p'

Repository: NVIDIA/Model-Optimizer

Length of output: 379


🏁 Script executed:

# Check the entire quantize function end to see if it calls process_model or post_process
cat -n modelopt/onnx/quantization/fp8.py | tail -100

Repository: NVIDIA/Model-Optimizer

Length of output: 5585


🏁 Script executed:

# Search for get_onnx_bytes_and_metadata more carefully
rg -n "get_onnx_bytes_and_metadata\|get_onnx_bytes"

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look for uses of FP8QuantExporter more comprehensively
rg -rn "FP8QuantExporter" modelopt/

Repository: NVIDIA/Model-Optimizer

Length of output: 667


🏁 Script executed:

# Check if there's an export function that uses the exporter classes
rg -n "def export\|def get_onnx" modelopt/onnx -B 2 -A 5 | head -80

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look at the torch_onnx.py to see how exporters are used
cat -n modelopt/torch/_deploy/utils/torch_onnx.py | grep -A 10 -B 10 "FP8QuantExporter\|n\|onnx_exporters"

Repository: NVIDIA/Model-Optimizer

Length of output: 31963


🏁 Script executed:

# Check the actual file content more carefully
grep -n "class.*Exporter" modelopt/onnx/export/*.py

Repository: NVIDIA/Model-Optimizer

Length of output: 564


🏁 Script executed:

# Look at what might be calling process_model on FP8QuantExporter
rg -n "process_model\|post_process" modelopt/torch

Repository: NVIDIA/Model-Optimizer

Length of output: 48


Add opset >= 19 validation before FP8 Q/DQ conversion in FP8QuantExporter.post_process().

The code converts TRT custom ops to native QuantizeLinear/DequantizeLinear with FLOAT8E4M3FN and the saturate attribute, but does not verify that the model's opset is >= 19 (the minimum required for these operators). When callers invoke get_onnx_bytes_and_metadata() with onnx_opset < 19 on a FP8-quantized model, the post-processor will silently generate an invalid ONNX model instead of upgrading the opset or raising an error.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/export/fp8_exporter.py` around lines 121 - 147, In
FP8QuantExporter.post_process(), before converting
TRT_FP8QuantizeLinear/TRT_FP8DequantizeLinear to native
QuantizeLinear/DequantizeLinear using FLOAT8E4M3FN and the saturate attribute,
validate the model opset version is >= 19; locate the method
FP8QuantExporter.post_process and check the graph/model opset (opset_import or
graph.model.opset_import) and if opset < 19 either raise a clear exception
(e.g., ValueError) telling callers to use onnx_opset >= 19 or programmatically
upgrade the model opset to 19 before performing the conversions (and then
proceed with the existing replacement logic for TRT_FP8QuantizeLinear and
TRT_FP8DequantizeLinear).

return gs.export_onnx(graph)
Comment thread
gcunhase marked this conversation as resolved.
6 changes: 3 additions & 3 deletions modelopt/onnx/export/nvfp4_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
logger.debug(f"Found {len(fp4_qdq_nodes)} FP4QDQ nodes to process")

for node in fp4_qdq_nodes:
idx = initializer_indices.get(node.input[0], None)
idx = initializer_indices.get(node.input[0])
assert idx is not None, f"Initializer for weight '{node.input[0]}' not found."

tensor = initializers[idx]
Expand Down Expand Up @@ -259,7 +259,7 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
fp4_qdq_nodes = [node for node in graph.node if node.op_type == "TRT_FP4QDQ"]

for node in fp4_qdq_nodes:
idx = initializer_indices.get(node.input[0], None)
idx = initializer_indices.get(node.input[0])
assert idx is not None, f"Initializer for weight '{node.input[0]}' not found."

tensor = initializers[idx]
Expand Down Expand Up @@ -365,7 +365,7 @@ def _cast_input_dtypes(node: onnx.NodeProto, precision_dtype: str):
logger.debug(f"Found {len(fp4_qdq_nodes)} FP4QDQ nodes to convert")

for node in fp4_qdq_nodes:
idx = initializer_indices.get(node.input[0], None)
idx = initializer_indices.get(node.input[0])
assert idx is not None, f"Initializer for weight '{node.input[0]}' not found."
initializers_to_delete.append(graph.initializer[idx].name)

Expand Down
Loading
Loading