Skip to content

Commit e4df91b

Browse files
authored
OMNIML-2663] Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes (#852)
## What does this PR do? **Type of change:** New feature **Overview:** - Updated FP8 quant exporter to replace modelopt custom QDQ nodes with native ONNX QDQ nodes - Updated get_onnx_bytes_and_metadata to make convert_float_to_float16() default instead of autocast - Created util functions to fix graph structure after conversion ## Testing ``` python torch_quant_to_onnx.py --quantize_mode=fp8 \ --onnx_save_path=<model_path> \ --calibration_data_size 64 \ --batch_size 128 python evaluate.py --onnx_path=<model_path> \ --model_name=vit_base_patch16_224 \ --results_path=./results.txt \ --batch_size 128 ``` Results: Before replacement: ``` The top1 accuracy of the model is 85.06% The top5 accuracy of the model is 97.558% Inference latency of the model is 5.27963 ms ``` After replacement: ``` The top1 accuracy of the model is 85.054% The top5 accuracy of the model is 97.542% Inference latency of the model is 5.74771 ms ``` ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: No - Replaced modelopt QDQ nodes with native ONNX qdq nodes - **Did you write any new necessary tests?**: No - **Did you add or update any necessary documentation?**: No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: No <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * ONNX utilities to remove redundant Casts, fold Constant→Cast patterns, and convert targeted Casts to FP16. * **Improvements** * FP8 QDQ nodes now converted to native ONNX QDQ/Dequantize nodes for improved compatibility. * Export pipeline streamlined: consistent FP16 handling, unified weight quantization, cast cleanup ordering, and added logging for better traceability. * **Tests** * Unit tests updated to use the new ONNX utilities. * **Changelog** * Entry added noting FP8 QDQ → native ONNX QDQ conversion. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent beac6e9 commit e4df91b

File tree

9 files changed

+450
-286
lines changed

9 files changed

+450
-286
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ NVIDIA Model Optimizer Changelog
3232
- Improve ``auto_quantize`` checkpoint/resume: calibration state is now saved and restored across runs, avoiding redundant calibration when resuming a search.
3333
- Add support for Nemotron-3 (NemotronHForCausalLM) model quantization and support for NemotronH MoE expert support in ``auto_quantize`` grouping and scoring rules.
3434
- Add support for block-granular RHT for non-power-of-2 dimensions.
35-
35+
- Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes.
36+
3637
**Misc**
3738

3839
- Migrated project metadata from ``setup.py`` to a fully declarative ``pyproject.toml``.

modelopt/onnx/autocast/graphsanitizer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def remove_disconnected_outputs(self) -> None:
130130
"""Remove disconnected outputs from the model."""
131131
tensors_to_remove = []
132132
for tensor in self.model.graph.output:
133-
if not utils.get_producer_nodes(self.model, tensor.name):
133+
if not onnx_utils.get_producer_nodes(self.model, tensor.name):
134134
tensors_to_remove.append(tensor)
135135
logger.debug(f"Found disconnected output: {tensor.name}")
136136

@@ -279,7 +279,7 @@ def _match_layernorm_pattern(self, mean_node: onnx.NodeProto) -> dict | None:
279279
# Find variance computation branch
280280
pow_nodes = [
281281
n
282-
for n in utils.get_consumer_nodes(self.model, sub_node.output[0])
282+
for n in onnx_utils.get_consumer_nodes(self.model, sub_node.output[0])
283283
if n.op_type == "Pow"
284284
]
285285
if len(pow_nodes) != 1:
@@ -303,8 +303,8 @@ def _match_layernorm_pattern(self, mean_node: onnx.NodeProto) -> dict | None:
303303

304304
# Find Div node
305305
# Find the Div node that consumes both sqrt and sub outputs
306-
sqrt_consumers = utils.get_consumer_nodes(self.model, sqrt_node.output[0])
307-
sub_consumers = utils.get_consumer_nodes(self.model, sub_node.output[0])
306+
sqrt_consumers = onnx_utils.get_consumer_nodes(self.model, sqrt_node.output[0])
307+
sub_consumers = onnx_utils.get_consumer_nodes(self.model, sub_node.output[0])
308308

309309
div_nodes = [n for n in sqrt_consumers if n in sub_consumers and n.op_type == "Div"]
310310
if len(div_nodes) != 1:
@@ -342,14 +342,14 @@ def _match_layernorm_pattern(self, mean_node: onnx.NodeProto) -> dict | None:
342342
div_node,
343343
]
344344

345-
consumers = utils.get_consumer_nodes(self.model, div_node.output[0])
345+
consumers = onnx_utils.get_consumer_nodes(self.model, div_node.output[0])
346346
if len(consumers) == 1 and consumers[0].op_type == "Mul":
347347
mul_node = consumers[0]
348348
scale = self._get_initializer_value(mul_node.input[1], return_array=True)
349349
final_node = mul_node
350350
nodes_to_remove.append(mul_node)
351351

352-
consumers = utils.get_consumer_nodes(self.model, mul_node.output[0])
352+
consumers = onnx_utils.get_consumer_nodes(self.model, mul_node.output[0])
353353
if len(consumers) == 1 and consumers[0].op_type == "Add":
354354
add_node = consumers[0]
355355
bias = self._get_initializer_value(add_node.input[1], return_array=True)
@@ -457,7 +457,7 @@ def _create_layernorm_node(self, pattern: dict) -> onnx.NodeProto:
457457

458458
def _find_insertion_point(self, input_name: str) -> int:
459459
"""Find the correct insertion point for the new LayerNorm node."""
460-
producer_nodes = utils.get_producer_nodes(self.model, input_name)
460+
producer_nodes = onnx_utils.get_producer_nodes(self.model, input_name)
461461
if not producer_nodes:
462462
return 0
463463

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 20 additions & 201 deletions
Large diffs are not rendered by default.

modelopt/onnx/autocast/utils.py

Lines changed: 2 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import onnx
2929

30+
import modelopt.onnx.utils as onnx_utils
3031
from modelopt.onnx.utils import get_opset_version
3132

3233

@@ -60,32 +61,6 @@ def setup_mappings(model: onnx.ModelProto) -> tuple[dict, dict, dict]:
6061
return value_info_map, initializer_map, node_to_init_map
6162

6263

63-
def get_consumer_nodes(model: onnx.ModelProto, tensor_name: str) -> list[onnx.NodeProto]:
64-
"""Get all consumer nodes for a given tensor name.
65-
66-
Args:
67-
model: The ONNX model to search.
68-
tensor_name: Name of the tensor to find consumers for.
69-
70-
Returns:
71-
list[onnx.NodeProto]: List of nodes that consume the tensor.
72-
"""
73-
return [n for n in model.graph.node if tensor_name in n.input]
74-
75-
76-
def get_producer_nodes(model: onnx.ModelProto, tensor_name: str) -> list[onnx.NodeProto]:
77-
"""Get all producer nodes for a given tensor name.
78-
79-
Args:
80-
model: The ONNX model to search.
81-
tensor_name: Name of the tensor to find producers for.
82-
83-
Returns:
84-
list[onnx.NodeProto]: List of nodes that produce the tensor.
85-
"""
86-
return [n for n in model.graph.node if tensor_name in n.output]
87-
88-
8964
def get_unique_consumer_node(model: onnx.ModelProto, tensor_name: str) -> onnx.NodeProto:
9065
"""Get a single consumer node and raise exception if there are multiple consumers.
9166
@@ -99,30 +74,12 @@ def get_unique_consumer_node(model: onnx.ModelProto, tensor_name: str) -> onnx.N
9974
Raises:
10075
Exception: If there is not exactly one consumer node.
10176
"""
102-
consumers = get_consumer_nodes(model, tensor_name)
77+
consumers = onnx_utils.get_consumer_nodes(model, tensor_name)
10378
if len(consumers) != 1:
10479
raise Exception(f"Expected single consumer for {tensor_name}, found {len(consumers)}")
10580
return consumers[0]
10681

10782

108-
def get_cast_to_type(cast_node: onnx.NodeProto) -> int:
109-
"""Get the target type from a Cast node.
110-
111-
Args:
112-
cast_node: The Cast node to extract type from.
113-
114-
Returns:
115-
int: The target type value from the Cast node's 'to' attribute.
116-
117-
Raises:
118-
ValueError: If the Cast node does not have a 'to' attribute.
119-
"""
120-
for attr in cast_node.attribute:
121-
if attr.name == "to":
122-
return attr.i
123-
raise ValueError("Cast node does not have 'to' attribute")
124-
125-
12683
def walk_subgraphs_recursive(
12784
graph: onnx.GraphProto,
12885
callback: Callable,

modelopt/onnx/export/fp8_exporter.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import torch
2323
from onnx_graphsurgeon.ir.tensor import LazyValues
2424

25+
from modelopt.onnx.logging_config import logger
26+
2527
from .base_exporter import ONNXQuantExporter
2628

2729

@@ -45,13 +47,13 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
4547
Even though modelopt supports FP8 onnx export, the weights are represented in fp32 + QDQ.
4648
The storage is therefore very bad. In this function,
4749
Q nodes will get removed from the weights and have only DQ nodes with those converted FP8
48-
weights in the output model.
50+
weights in the output model. TRT custom ops are converted to native ONNX DequantizeLinear.
4951
5052
Parameters:
51-
onnx_model: ONNX model with FP32/FP16 weights and QDQ nodes.
53+
onnx_model: ONNX model with FP32/FP16 weights and TRT_FP8 QDQ nodes.
5254
5355
Returns:
54-
ONNX model with FP8 weights and only DQ nodes for weights (QDQ preserved for activations).
56+
ONNX model with FP8 weights and native ONNX DQ nodes for weights (QDQ preserved for activations).
5557
"""
5658
start_time = time.time()
5759
print("Replacing all (fp32 weights + fp8 QDQ) with (fp8 weights + DQ)...")
@@ -62,7 +64,7 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
6264

6365
for node in graph.nodes:
6466
if node.op == "TRT_FP8QuantizeLinear":
65-
# Should not remove input QDQ
67+
# Should not remove input QDQ (only process weight quantization)
6668
if not isinstance(node.inputs[0], gs.Constant):
6769
continue
6870

@@ -88,7 +90,7 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
8890
onnx_weights_fp8 = gs.Constant(quantizer_name + "/fp8_weights", values)
8991

9092
node.outputs.clear()
91-
# DQ Op is separated out
93+
# Convert TRT DQ to native ONNX DequantizeLinear with FP8 weights
9294
dq_op.inputs[0] = onnx_weights_fp8
9395
dq_op.op = "DequantizeLinear"
9496
dq_op.outputs[0].dtype = dq_op.inputs[1].dtype
@@ -101,5 +103,46 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
101103

102104
@staticmethod
103105
def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
104-
"""Post-processes the ONNX model for FP8 quantization."""
105-
return onnx_model
106+
"""Post-processes the ONNX model for FP8 quantization.
107+
108+
Converts TRT_FP8 QDQ ops to native ONNX QuantizeLinear/DequantizeLinear:
109+
- TRT_FP8QuantizeLinear -> QuantizeLinear with FP8E4M3FN zero_point and saturate=1
110+
- TRT_FP8DequantizeLinear -> DequantizeLinear
111+
112+
Args:
113+
onnx_model: The ONNX model containing TRT_FP8 quantization nodes.
114+
115+
Returns:
116+
The post-processed ONNX model with native ONNX quantization ops.
117+
"""
118+
logger.info("Post-processing FP8 quantized model")
119+
graph = gs.import_onnx(onnx_model)
120+
121+
# Convert TRT_FP8QuantizeLinear to native QuantizeLinear
122+
for node in graph.nodes:
123+
if node.op == "TRT_FP8QuantizeLinear":
124+
node.op = "QuantizeLinear"
125+
# Add FP8 zero_point if not present
126+
if len(node.inputs) == 2:
127+
# Create FP8 zero point constant
128+
zp_tensor = onnx.TensorProto()
129+
zp_tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN
130+
zp_tensor.dims.extend([1]) # 1-element tensor
131+
zp_tensor.raw_data = b"\x00" # Zero in FP8
132+
zp_values = LazyValues(zp_tensor)
133+
zero_point = gs.Constant(node.name + "_zero_point", zp_values)
134+
node.inputs.append(zero_point)
135+
# Add saturate attribute for FP8
136+
node.attrs["saturate"] = 1
137+
logger.debug(f"Converted {node.name} from TRT_FP8QuantizeLinear to QuantizeLinear")
138+
139+
# Convert TRT_FP8DequantizeLinear to native DequantizeLinear
140+
for node in graph.nodes:
141+
if node.op == "TRT_FP8DequantizeLinear":
142+
node.op = "DequantizeLinear"
143+
logger.debug(
144+
f"Converted {node.name} from TRT_FP8DequantizeLinear to DequantizeLinear"
145+
)
146+
147+
graph.cleanup().toposort()
148+
return gs.export_onnx(graph)

modelopt/onnx/export/nvfp4_exporter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
215215
logger.debug(f"Found {len(fp4_qdq_nodes)} FP4QDQ nodes to process")
216216

217217
for node in fp4_qdq_nodes:
218-
idx = initializer_indices.get(node.input[0], None)
218+
idx = initializer_indices.get(node.input[0])
219219
assert idx is not None, f"Initializer for weight '{node.input[0]}' not found."
220220

221221
tensor = initializers[idx]
@@ -259,7 +259,7 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
259259
fp4_qdq_nodes = [node for node in graph.node if node.op_type == "TRT_FP4QDQ"]
260260

261261
for node in fp4_qdq_nodes:
262-
idx = initializer_indices.get(node.input[0], None)
262+
idx = initializer_indices.get(node.input[0])
263263
assert idx is not None, f"Initializer for weight '{node.input[0]}' not found."
264264

265265
tensor = initializers[idx]
@@ -365,7 +365,7 @@ def _cast_input_dtypes(node: onnx.NodeProto, precision_dtype: str):
365365
logger.debug(f"Found {len(fp4_qdq_nodes)} FP4QDQ nodes to convert")
366366

367367
for node in fp4_qdq_nodes:
368-
idx = initializer_indices.get(node.input[0], None)
368+
idx = initializer_indices.get(node.input[0])
369369
assert idx is not None, f"Initializer for weight '{node.input[0]}' not found."
370370
initializers_to_delete.append(graph.initializer[idx].name)
371371

0 commit comments

Comments
 (0)