Skip to content

Commit 00f61cc

Browse files
committed
Fix test failures
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent 2ebf0a2 commit 00f61cc

3 files changed

Lines changed: 16 additions & 8 deletions

File tree

modelopt/onnx/export/fp8_exporter.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,15 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
105105
def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
106106
"""Post-processes the ONNX model for FP8 quantization.
107107
108-
Converts remaining TRT_FP8 QDQ ops (activations) to native ONNX QuantizeLinear/DequantizeLinear,
109-
updates GELU nodes to use tanh approximation, and inserts Cast nodes after Sqrt.
108+
Converts TRT_FP8 QDQ ops to native ONNX QuantizeLinear/DequantizeLinear:
109+
- TRT_FP8QuantizeLinear -> QuantizeLinear with FP8E4M3FN zero_point and saturate=1
110+
- TRT_FP8DequantizeLinear -> DequantizeLinear
111+
112+
Args:
113+
onnx_model: The ONNX model containing TRT_FP8 quantization nodes.
114+
115+
Returns:
116+
The post-processed ONNX model with native ONNX quantization ops.
110117
"""
111118
logger.info("Post-processing FP8 quantized model")
112119
graph = gs.import_onnx(onnx_model)

modelopt/onnx/export/nvfp4_exporter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
215215
logger.debug(f"Found {len(fp4_qdq_nodes)} FP4QDQ nodes to process")
216216

217217
for node in fp4_qdq_nodes:
218-
idx = initializer_indices.get(node.input[0], None)
218+
idx = initializer_indices.get(node.input[0])
219219
assert idx is not None, f"Initializer for weight '{node.input[0]}' not found."
220220

221221
tensor = initializers[idx]
@@ -259,7 +259,7 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
259259
fp4_qdq_nodes = [node for node in graph.node if node.op_type == "TRT_FP4QDQ"]
260260

261261
for node in fp4_qdq_nodes:
262-
idx = initializer_indices.get(node.input[0], None)
262+
idx = initializer_indices.get(node.input[0])
263263
assert idx is not None, f"Initializer for weight '{node.input[0]}' not found."
264264

265265
tensor = initializers[idx]
@@ -365,7 +365,7 @@ def _cast_input_dtypes(node: onnx.NodeProto, precision_dtype: str):
365365
logger.debug(f"Found {len(fp4_qdq_nodes)} FP4QDQ nodes to convert")
366366

367367
for node in fp4_qdq_nodes:
368-
idx = initializer_indices.get(node.input[0], None)
368+
idx = initializer_indices.get(node.input[0])
369369
assert idx is not None, f"Initializer for weight '{node.input[0]}' not found."
370370
initializers_to_delete.append(graph.initializer[idx].name)
371371

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -561,9 +561,6 @@ def get_onnx_bytes_and_metadata(
561561
tree_spec_input, tree_spec_output, input_none_names, onnx_opt_graph, model
562562
)
563563

564-
# TODO: Remove manual ir_version change once ORT supports ir_version 11
565-
onnx_opt_graph.ir_version = 10
566-
567564
onnx_opt_graph = quantize_weights(model, onnx_opt_graph)
568565

569566
if dq_only:
@@ -585,6 +582,10 @@ def get_onnx_bytes_and_metadata(
585582
# TensorRT expects all scales to be postive
586583
onnx_opt_graph = replace_zero_scale_with_smallest_nonzero(onnx_opt_graph)
587584

585+
# TODO: Remove manual ir_version change once ORT supports ir_version 11
586+
# Must be set after all gs.export_onnx() calls as graphsurgeon resets ir_version
587+
onnx_opt_graph.ir_version = 10
588+
588589
# If the onnx model contains external data store the external tensors in one file and save the onnx model
589590
if has_external_data(onnx_save_path):
590591
tensor_paths = get_external_tensor_paths(onnx_path)

0 commit comments

Comments
 (0)