Skip to content

Commit d4e15ed

Browse files
committed
fix referencerunner for external data
Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
1 parent 391f6cb commit d4e15ed

1 file changed

Lines changed: 59 additions & 5 deletions

File tree

modelopt/onnx/autocast/referencerunner.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@
2424
import copy
2525
import io
2626
import sys
27+
import tempfile
2728
from collections import OrderedDict
2829

2930
import numpy as np
3031
import onnx
3132

33+
from modelopt.onnx import utils as onnx_utils
3234
from modelopt.onnx.autocast.logging_config import configure_logging, logger
3335
from modelopt.onnx.quantization.ort_utils import _prepare_ep_list
3436

@@ -118,13 +120,65 @@ def _load_inputs(self, inputs):
118120

119121
return data_loader
120122

123+
def _get_ort_runner(self, model):
124+
import onnxruntime as ort
125+
from polygraphy.backend.onnx import BytesFromOnnx
126+
from polygraphy.backend.onnxrt import OnnxrtRunner, SessionFromOnnx
127+
128+
# Check if model has external data by checking:
129+
# 1. If any initializer has data_location set to EXTERNAL (even if data is loaded)
130+
# 2. If model size would exceed 2GB (indicating need for external data)
131+
has_external_data = any(
132+
init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL
133+
for init in self.model.graph.initializer
134+
)
135+
136+
# Also check if model would be too large (>2GB) for SerializeToString
137+
# This handles cases where model was loaded with external data already loaded
138+
if not has_external_data:
139+
try:
140+
# Try to estimate size by serializing the model
141+
# If it fails or exceeds 2GB, we need file-based approach
142+
model_size = len(self.model.SerializeToString())
143+
if model_size > 2 * (1024**3): # 2GB threshold
144+
has_external_data = True
145+
logger.debug(
146+
f"Model size ({model_size / (1024**3):.2f} GB) exceeds 2GB, using file-based approach"
147+
)
148+
except (ValueError, AttributeError) as e:
149+
# SerializeToString failed (likely >2GB limit), use file-based approach
150+
if "exceeds maximum protobuf size" in str(e) or "2GB" in str(e):
151+
has_external_data = True
152+
logger.debug("Model exceeds protobuf 2GB limit, using file-based approach")
153+
154+
if has_external_data:
155+
logger.debug("Model has external data, using file-based approach")
156+
# Get the actual ONNX ModelProto from ModifyOutputs wrapper
157+
modified_model = model()
158+
159+
# Use a persistent temp file to handle external data files properly
160+
tmp_file = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
161+
tmp_file.close()
162+
tmp_file_path = tmp_file.name
163+
onnx_utils.save_onnx(modified_model, tmp_file_path, save_as_external_data=True)
164+
logger.debug(f"Model with all outputs saved to {tmp_file_path}")
165+
session = ort.InferenceSession(tmp_file_path, providers=self.providers)
166+
runners = [OnnxrtRunner(lambda: session)]
167+
168+
else:
169+
# For models without external data, use the original BytesFromOnnx approach (no tmp files)
170+
logger.debug("Model has no external data, using BytesFromOnnx approach")
171+
serialize_onnx = BytesFromOnnx(model)
172+
build_onnxrt_session = SessionFromOnnx(serialize_onnx, providers=self.providers)
173+
runners = [OnnxrtRunner(build_onnxrt_session)]
174+
175+
return runners
176+
121177
def run(self, inputs=None):
122178
"""Run FP32 inference with provided or random inputs."""
123179
import onnxruntime as ort
124180
from polygraphy import constants
125-
from polygraphy.backend.onnx import BytesFromOnnx
126181
from polygraphy.backend.onnx import ModifyOutputs as ModifyOnnxOutputs
127-
from polygraphy.backend.onnxrt import OnnxrtRunner, SessionFromOnnx
128182
from polygraphy.comparator import Comparator
129183

130184
logger.info("Running ONNX Runtime to obtain reference outputs (this may take a while)...")
@@ -133,9 +187,9 @@ def run(self, inputs=None):
133187

134188
model_copy = copy.deepcopy(self.model)
135189
modify_outputs = ModifyOnnxOutputs(model_copy, outputs=constants.MARK_ALL)
136-
serialize_onnx = BytesFromOnnx(modify_outputs)
137-
build_onnxrt_session = SessionFromOnnx(serialize_onnx, providers=self.providers)
138-
runners = [OnnxrtRunner(build_onnxrt_session)]
190+
191+
# Load the modified model and create an inference session
192+
runners = self._get_ort_runner(modify_outputs)
139193

140194
# Comparator is used despite the fact that we are using ONNXRuntime
141195
# because it provides the ability to generate random inputs using DataLoader

0 commit comments

Comments
 (0)