Skip to content

Commit 2ce8f17

Browse files
galagamdanielkorzekwa
authored andcommitted
[5725362] AutoCast Fixes for models with external data (#731)
## What does this PR do? **Type of change:** Bug fix <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** Fix AutoCast ReferenceRunner to handle large models. Models above 2GB cannot be serialized to string, which is what polygraphy is doing under the hood. Use a temporary file instead to save the modified onnx with all tensors marked as outputs. ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes <!--- If No, explain why. --> - **Did you write any new necessary tests?**: No - **Did you add or update any necessary documentation?**: No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Improvements** * Enhanced model processing to better support large ONNX models during validation and runtime execution * Added diagnostic logging of model sizes at key processing stages for improved debugging and performance monitoring <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
1 parent 498968a commit 2ce8f17

File tree

5 files changed

+58
-23
lines changed

5 files changed

+58
-23
lines changed

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def convert(
175175
onnx.ModelProto: The converted mixed precision model.
176176
"""
177177
try:
178-
self.model = onnx_utils.check_model(self.model)
178+
onnx_utils.check_model(self.model)
179179
except onnx.checker.ValidationError as e:
180180
logger.error(f"Internal error: onnx.checker failed on input model {e}")
181181
raise Exception(

modelopt/onnx/autocast/referencerunner.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@
2424
import copy
2525
import io
2626
import sys
27+
import tempfile
2728
from collections import OrderedDict
2829

2930
import numpy as np
3031
import onnx
3132

33+
from modelopt.onnx import utils as onnx_utils
3234
from modelopt.onnx.autocast.logging_config import configure_logging, logger
3335
from modelopt.onnx.quantization.calib_utils import CalibrationDataProvider
3436
from modelopt.onnx.quantization.ort_utils import _prepare_ep_list
@@ -125,13 +127,42 @@ def _load_inputs(self, inputs):
125127

126128
return data_loader
127129

130+
def _get_ort_runner(self, model):
131+
from polygraphy.backend.onnx import BytesFromOnnx
132+
from polygraphy.backend.onnxrt import OnnxrtRunner, SessionFromOnnx
133+
134+
# Check if model has external data by checking:
135+
# 1. If any initializer has data_location set to EXTERNAL (even if data is loaded)
136+
# 2. If model size would exceed 2GB (indicating need for external data)
137+
needs_external_data = onnx_utils.check_model_uses_external_data(
138+
self.model
139+
) or self.model.ByteSize() > 2 * (1024**3)
140+
if needs_external_data:
141+
logger.debug("Model has external data, using file-based approach")
142+
# Get the actual ONNX ModelProto from ModifyOutputs wrapper
143+
modified_model = model()
144+
145+
# Use a persistent temp file, because we need the file to be present in an broader context
146+
tmp_file = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
147+
tmp_file.close()
148+
tmp_file_path = tmp_file.name
149+
onnx_utils.save_onnx(modified_model, tmp_file_path, save_as_external_data=True)
150+
logger.debug(f"Model with all outputs saved to {tmp_file_path}")
151+
build_onnxrt_session = SessionFromOnnx(tmp_file_path, providers=self.providers)
152+
153+
else:
154+
# For models without external data, use the original BytesFromOnnx approach (no tmp files)
155+
logger.debug("Model has no external data, using BytesFromOnnx approach")
156+
serialize_onnx = BytesFromOnnx(model)
157+
build_onnxrt_session = SessionFromOnnx(serialize_onnx, providers=self.providers)
158+
runners = [OnnxrtRunner(build_onnxrt_session)]
159+
return runners
160+
128161
def run(self, inputs=None):
129162
"""Run FP32 inference with provided or random inputs."""
130163
import onnxruntime as ort
131164
from polygraphy import constants
132-
from polygraphy.backend.onnx import BytesFromOnnx
133165
from polygraphy.backend.onnx import ModifyOutputs as ModifyOnnxOutputs
134-
from polygraphy.backend.onnxrt import OnnxrtRunner, SessionFromOnnx
135166
from polygraphy.comparator import Comparator
136167

137168
logger.info("Running ONNX Runtime to obtain reference outputs (this may take a while)...")
@@ -140,9 +171,9 @@ def run(self, inputs=None):
140171

141172
model_copy = copy.deepcopy(self.model)
142173
modify_outputs = ModifyOnnxOutputs(model_copy, outputs=constants.MARK_ALL)
143-
serialize_onnx = BytesFromOnnx(modify_outputs)
144-
build_onnxrt_session = SessionFromOnnx(serialize_onnx, providers=self.providers)
145-
runners = [OnnxrtRunner(build_onnxrt_session)]
174+
175+
# Load the modified model and create an inference session
176+
runners = self._get_ort_runner(modify_outputs)
146177

147178
# Comparator is used despite the fact that we are using ONNXRuntime
148179
# because it provides the ability to generate random inputs using DataLoader

modelopt/onnx/utils.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Utility functions related to onnx."""
1717

18+
import copy
1819
import io
1920
import os
2021
import tempfile
@@ -552,7 +553,7 @@ def _get_unique_name(old_name):
552553
return onnx_model, is_modified
553554

554555

555-
def check_model(model: onnx.ModelProto) -> onnx.ModelProto:
556+
def check_model(model: onnx.ModelProto) -> None:
556557
"""Checks if the given model is valid."""
557558
if model.ByteSize() > (2 * (1024**3)): # 2GB limit
558559
with tempfile.TemporaryDirectory() as temp_dir:
@@ -561,10 +562,8 @@ def check_model(model: onnx.ModelProto) -> onnx.ModelProto:
561562
onnx_tmp_path = os.path.join(temp_dir, f"model_{unique_id}.onnx")
562563
save_onnx(model, onnx_tmp_path, save_as_external_data=True)
563564
onnx.checker.check_model(onnx_tmp_path)
564-
return onnx.load(onnx_tmp_path)
565565
else:
566566
onnx.checker.check_model(model)
567-
return model
568567

569568

570569
def find_lowest_common_ancestor(node1: Node, node2: Node) -> tuple[str | None, int, int]:
@@ -658,15 +657,16 @@ def save_onnx(model: onnx.ModelProto, onnx_path: str, save_as_external_data: boo
658657

659658
# Set ir_version to 10, remove it once ORT supports ir_version 11
660659
model.ir_version = 10
661-
662660
if save_as_external_data:
663661
external_data_path = os.path.basename(onnx_path) + "_data"
664662
if os.path.exists(external_data_path):
665663
logger.warning(f"Removing existing external data file: {external_data_path}")
666664
os.remove(external_data_path)
667665

666+
# Copy so the onnx.ModelProto object will not be modified
667+
model_copy = copy.deepcopy(model)
668668
onnx.save_model(
669-
model,
669+
model_copy,
670670
onnx_path,
671671
save_as_external_data=True,
672672
all_tensors_to_one_file=True,
@@ -696,6 +696,21 @@ def get_opset_version(model: onnx.ModelProto) -> int:
696696
return ai_onnx_domain[0].version
697697

698698

699+
def check_model_uses_external_data(model: onnx.ModelProto) -> bool:
700+
"""Checks if the model uses external data.
701+
702+
Args:
703+
model: Loaded in-memory onnx ModelProto.
704+
705+
Returns:
706+
True if any initializer tensor has data_location set to EXTERNAL.
707+
"""
708+
return any(
709+
init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL
710+
for init in model.graph.initializer
711+
)
712+
713+
699714
def bfloat16_to_float32(bf16_array):
700715
"""Converts a bfloat16 array (as raw data) to a float32 array."""
701716
uint32_array = bf16_array.astype(np.uint32) << 16

modelopt/torch/_deploy/utils/onnx_utils.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,3 @@ def _get_onnx_external_data_tensors(model: onnx.ModelProto) -> list[str]:
4545
if tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL
4646
]
4747
return model_tensors_ext
48-
49-
50-
def check_model_uses_external_data(model: onnx.ModelProto) -> bool:
51-
"""
52-
Checks if the model uses external data.
53-
"""
54-
model_tensors = _get_initializer_tensors(model)
55-
return any(
56-
tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL
57-
for tensor in model_tensors
58-
)

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
)
4343
from modelopt.onnx.quantization.qdq_utils import qdq_to_dq, replace_zero_scale_with_smallest_nonzero
4444
from modelopt.onnx.utils import (
45+
check_model_uses_external_data,
4546
get_input_names,
4647
get_input_shapes,
4748
get_node_names,
@@ -55,7 +56,6 @@
5556
from modelopt.torch.utils._pytree import TreeSpec
5657

5758
from ..utils.onnx_optimizer import Optimizer
58-
from .onnx_utils import check_model_uses_external_data
5959

6060
ModelMetadata = dict[str, Any]
6161
ModelType = Any

0 commit comments

Comments
 (0)