-
Notifications
You must be signed in to change notification settings - Fork 360
OMNIML-2663] Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes #852
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2eaacb9
81a19ef
03b7232
2b7e4d4
97828b5
1c6a3ba
05c33b2
f6ce7b3
e223a2e
9e3a35a
de464f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,8 @@ | |
| import torch | ||
| from onnx_graphsurgeon.ir.tensor import LazyValues | ||
|
|
||
| from modelopt.onnx.logging_config import logger | ||
|
|
||
| from .base_exporter import ONNXQuantExporter | ||
|
|
||
|
|
||
|
|
@@ -45,13 +47,13 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: | |
| Even though modelopt supports FP8 onnx export, the weights are represented in fp32 + QDQ. | ||
| The storage is therefore very bad. In this function, | ||
| Q nodes will get removed from the weights and have only DQ nodes with those converted FP8 | ||
| weights in the output model. | ||
| weights in the output model. TRT custom ops are converted to native ONNX DequantizeLinear. | ||
|
|
||
| Parameters: | ||
| onnx_model: ONNX model with FP32/FP16 weights and QDQ nodes. | ||
| onnx_model: ONNX model with FP32/FP16 weights and TRT_FP8 QDQ nodes. | ||
|
|
||
| Returns: | ||
| ONNX model with FP8 weights and only DQ nodes for weights (QDQ preserved for activations). | ||
| ONNX model with FP8 weights and native ONNX DQ nodes for weights (QDQ preserved for activations). | ||
| """ | ||
| start_time = time.time() | ||
| print("Replacing all (fp32 weights + fp8 QDQ) with (fp8 weights + DQ)...") | ||
|
|
@@ -62,7 +64,7 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: | |
|
|
||
| for node in graph.nodes: | ||
| if node.op == "TRT_FP8QuantizeLinear": | ||
| # Should not remove input QDQ | ||
| # Should not remove input QDQ (only process weight quantization) | ||
| if not isinstance(node.inputs[0], gs.Constant): | ||
| continue | ||
|
|
||
|
|
@@ -88,7 +90,7 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: | |
| onnx_weights_fp8 = gs.Constant(quantizer_name + "/fp8_weights", values) | ||
|
|
||
| node.outputs.clear() | ||
| # DQ Op is separated out | ||
| # Convert TRT DQ to native ONNX DequantizeLinear with FP8 weights | ||
| dq_op.inputs[0] = onnx_weights_fp8 | ||
| dq_op.op = "DequantizeLinear" | ||
| dq_op.outputs[0].dtype = dq_op.inputs[1].dtype | ||
|
|
@@ -101,5 +103,46 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: | |
|
|
||
| @staticmethod | ||
| def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: | ||
| """Post-processes the ONNX model for FP8 quantization.""" | ||
| return onnx_model | ||
| """Post-processes the ONNX model for FP8 quantization. | ||
|
|
||
| Converts TRT_FP8 QDQ ops to native ONNX QuantizeLinear/DequantizeLinear: | ||
| - TRT_FP8QuantizeLinear -> QuantizeLinear with FP8E4M3FN zero_point and saturate=1 | ||
| - TRT_FP8DequantizeLinear -> DequantizeLinear | ||
|
|
||
| Args: | ||
| onnx_model: The ONNX model containing TRT_FP8 quantization nodes. | ||
|
|
||
| Returns: | ||
| The post-processed ONNX model with native ONNX quantization ops. | ||
| """ | ||
| logger.info("Post-processing FP8 quantized model") | ||
| graph = gs.import_onnx(onnx_model) | ||
|
|
||
| # Convert TRT_FP8QuantizeLinear to native QuantizeLinear | ||
| for node in graph.nodes: | ||
| if node.op == "TRT_FP8QuantizeLinear": | ||
| node.op = "QuantizeLinear" | ||
| # Add FP8 zero_point if not present | ||
| if len(node.inputs) == 2: | ||
| # Create FP8 zero point constant | ||
| zp_tensor = onnx.TensorProto() | ||
| zp_tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN | ||
| zp_tensor.dims.extend([1]) # 1-element tensor | ||
| zp_tensor.raw_data = b"\x00" # Zero in FP8 | ||
| zp_values = LazyValues(zp_tensor) | ||
| zero_point = gs.Constant(node.name + "_zero_point", zp_values) | ||
| node.inputs.append(zero_point) | ||
|
gcunhase marked this conversation as resolved.
Comment on lines
+127
to
+134
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use a guaranteed-unique tensor name for the injected zero point.
🛠️ Safer naming- zero_point = gs.Constant(node.name + "_zero_point", zp_values)
+ zero_point = gs.Constant(f"{node.outputs[0].name}_zero_point", zp_values)🤖 Prompt for AI Agents |
||
| # Add saturate attribute for FP8 | ||
| node.attrs["saturate"] = 1 | ||
| logger.debug(f"Converted {node.name} from TRT_FP8QuantizeLinear to QuantizeLinear") | ||
|
|
||
| # Convert TRT_FP8DequantizeLinear to native DequantizeLinear | ||
| for node in graph.nodes: | ||
| if node.op == "TRT_FP8DequantizeLinear": | ||
| node.op = "DequantizeLinear" | ||
| logger.debug( | ||
| f"Converted {node.name} from TRT_FP8DequantizeLinear to DequantizeLinear" | ||
| ) | ||
|
|
||
| graph.cleanup().toposort() | ||
|
Comment on lines
+121
to
+147
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🌐 Web query:
💡 Result: Minimum ai.onnx opset: 19.
Sources: Citations:
🏁 Script executed: # Check the fp8_exporter.py file to understand the full context
# Look for the function containing this code and any opset checks
fd "fp8_exporter.py" -x wc -l {} \;Repository: NVIDIA/Model-Optimizer Length of output: 153 🏁 Script executed: # Read the file to understand the function context and any opset validation
cat -n modelopt/onnx/export/fp8_exporter.py | head -200Repository: NVIDIA/Model-Optimizer Length of output: 7470 🏁 Script executed: # Search for opset-related code and checks in the file
rg -n "opset|onnx_opset" modelopt/onnx/export/fp8_exporter.py -B 2 -A 2Repository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Find the base exporter to understand the parent class
fd "base_exporter.py" modelopt/onnx/exportRepository: NVIDIA/Model-Optimizer Length of output: 103 🏁 Script executed: # Search for calls to post_process and get_onnx_bytes_and_metadata
rg -n "post_process|get_onnx_bytes_and_metadata" modelopt/onnx/export --type pyRepository: NVIDIA/Model-Optimizer Length of output: 942 🏁 Script executed: # Check if there's opset validation in the base exporter
cat -n modelopt/onnx/export/base_exporter.py | head -300Repository: NVIDIA/Model-Optimizer Length of output: 2365 🏁 Script executed: # Search for get_onnx_bytes_and_metadata function
rg -n "get_onnx_bytes_and_metadata|onnx_opset" modelopt/onnx --type py -B 2 -A 5Repository: NVIDIA/Model-Optimizer Length of output: 823 🏁 Script executed: # Look for where FP8QuantExporter is instantiated/used
rg -n "FP8QuantExporter" modelopt/onnx --type py -B 2 -A 3Repository: NVIDIA/Model-Optimizer Length of output: 1792 🏁 Script executed: # Check the quantizer file that likely calls the exporter
fd "quantizer.py" modelopt/onnx -x wc -l {} \;Repository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Search more broadly for onnx_opset and opset parameter usage
rg -n "onnx_opset|opset_version" modelopt/onnx --type pyRepository: NVIDIA/Model-Optimizer Length of output: 3244 🏁 Script executed: # Look for quantizer or export related files that handle opset
fd "quantizer|export" modelopt/onnx --type py | head -20Repository: NVIDIA/Model-Optimizer Length of output: 235 🏁 Script executed: # Check the quantizer base class
rg -n "class.*Quantizer" modelopt/onnx --type py -A 10 | head -50Repository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Look at the ort_patching.py file around the FP8 opset check
cat -n modelopt/onnx/quantization/ort_patching.py | sed -n '220,250p'Repository: NVIDIA/Model-Optimizer Length of output: 1445 🏁 Script executed: # Check where _check_opset_version is called and in what context
rg -n "_check_opset_version|check_opset_version" modelopt/onnx/quantization --type py -B 3 -A 3Repository: NVIDIA/Model-Optimizer Length of output: 1253 🏁 Script executed: # Look at the quantize.py file to understand the full flow
cat -n modelopt/onnx/quantization/quantize.py | sed -n '140,185p'Repository: NVIDIA/Model-Optimizer Length of output: 2716 🏁 Script executed: # Search for get_onnx_bytes_and_metadata function
rg -n "get_onnx_bytes_and_metadata" modelopt/onnxRepository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Look at the quantize_mode constants to understand FP8 mode handling
rg -n "QDQ_PRECISION_MIN_OPSET|fp8" modelopt/onnx/quantization/quantize.py -B 2 -A 2Repository: NVIDIA/Model-Optimizer Length of output: 3436 🏁 Script executed: # Check if there's a separate export entry point that may bypass opset checks
rg -n "def.*export|get_onnx_bytes" modelopt/onnx --type py | head -30Repository: NVIDIA/Model-Optimizer Length of output: 473 🏁 Script executed: # Search for quantize_fp8 function
rg -n "def quantize_fp8|def quantize" modelopt/onnx/quantization/fp8.py -A 10 | head -50Repository: NVIDIA/Model-Optimizer Length of output: 604 🏁 Script executed: # Look for where process_model is called
rg -n "process_model" modelopt/onnx --type py -B 2 -A 2Repository: NVIDIA/Model-Optimizer Length of output: 453 🏁 Script executed: # Check the fp8.py file to see if it calls FP8QuantExporter.post_process
cat -n modelopt/onnx/quantization/fp8.py | head -150Repository: NVIDIA/Model-Optimizer Length of output: 7314 🏁 Script executed: # Continue reading fp8.py to see where FP8QuantExporter is used
cat -n modelopt/onnx/quantization/fp8.py | sed -n '150,250p'Repository: NVIDIA/Model-Optimizer Length of output: 5087 🏁 Script executed: # Search for where FP8QuantExporter.process_model is called
rg -n "\.process_model|FP8QuantExporter" modelopt/onnx --type py -B 2 -A 2Repository: NVIDIA/Model-Optimizer Length of output: 1639 🏁 Script executed: # Check for any test files or examples that use FP8 export
fd "test.*fp8|fp8.*test" modelopt --type pyRepository: NVIDIA/Model-Optimizer Length of output: 235 🏁 Script executed: # Search for where process_model is called with FP8QuantExporter
rg -n "process_model\|FP8QuantExporter" modelopt/onnxRepository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Look in llm_export_utils since it showed up in opset search
cat -n modelopt/onnx/llm_export_utils/export_utils.py | sed -n '155,170p'Repository: NVIDIA/Model-Optimizer Length of output: 379 🏁 Script executed: # Check the entire quantize function end to see if it calls process_model or post_process
cat -n modelopt/onnx/quantization/fp8.py | tail -100Repository: NVIDIA/Model-Optimizer Length of output: 5585 🏁 Script executed: # Search for get_onnx_bytes_and_metadata more carefully
rg -n "get_onnx_bytes_and_metadata\|get_onnx_bytes"Repository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Look for uses of FP8QuantExporter more comprehensively
rg -rn "FP8QuantExporter" modelopt/Repository: NVIDIA/Model-Optimizer Length of output: 667 🏁 Script executed: # Check if there's an export function that uses the exporter classes
rg -n "def export\|def get_onnx" modelopt/onnx -B 2 -A 5 | head -80Repository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Look at the torch_onnx.py to see how exporters are used
cat -n modelopt/torch/_deploy/utils/torch_onnx.py | grep -A 10 -B 10 "FP8QuantExporter\|n\|onnx_exporters"Repository: NVIDIA/Model-Optimizer Length of output: 31963 🏁 Script executed: # Check the actual file content more carefully
grep -n "class.*Exporter" modelopt/onnx/export/*.pyRepository: NVIDIA/Model-Optimizer Length of output: 564 🏁 Script executed: # Look at what might be calling process_model on FP8QuantExporter
rg -n "process_model\|post_process" modelopt/torchRepository: NVIDIA/Model-Optimizer Length of output: 48 Add opset >= 19 validation before FP8 Q/DQ conversion in The code converts TRT custom ops to native 🤖 Prompt for AI Agents |
||
| return gs.export_onnx(graph) | ||
|
gcunhase marked this conversation as resolved.
|
||
Uh oh!
There was an error while loading. Please reload this page.