2424import copy
2525import io
2626import sys
27+ import tempfile
2728from collections import OrderedDict
2829
2930import numpy as np
3031import onnx
3132
33+ from modelopt .onnx import utils as onnx_utils
3234from modelopt .onnx .autocast .logging_config import configure_logging , logger
3335from modelopt .onnx .quantization .ort_utils import _prepare_ep_list
3436
@@ -118,13 +120,65 @@ def _load_inputs(self, inputs):
118120
119121 return data_loader
120122
123+ def _get_ort_runner (self , model ):
124+ import onnxruntime as ort
125+ from polygraphy .backend .onnx import BytesFromOnnx
126+ from polygraphy .backend .onnxrt import OnnxrtRunner , SessionFromOnnx
127+
128+ # Check if model has external data by checking:
129+ # 1. If any initializer has data_location set to EXTERNAL (even if data is loaded)
130+ # 2. If model size would exceed 2GB (indicating need for external data)
131+ has_external_data = any (
132+ init .HasField ("data_location" ) and init .data_location == onnx .TensorProto .EXTERNAL
133+ for init in self .model .graph .initializer
134+ )
135+
136+ # Also check if model would be too large (>2GB) for SerializeToString
137+ # This handles cases where model was loaded with external data already loaded
138+ if not has_external_data :
139+ try :
140+ # Try to estimate size by serializing the model
141+ # If it fails or exceeds 2GB, we need file-based approach
142+ model_size = len (self .model .SerializeToString ())
143+ if model_size > 2 * (1024 ** 3 ): # 2GB threshold
144+ has_external_data = True
145+ logger .debug (
146+ f"Model size ({ model_size / (1024 ** 3 ):.2f} GB) exceeds 2GB, using file-based approach"
147+ )
148+ except (ValueError , AttributeError ) as e :
149+ # SerializeToString failed (likely >2GB limit), use file-based approach
150+ if "exceeds maximum protobuf size" in str (e ) or "2GB" in str (e ):
151+ has_external_data = True
152+ logger .debug ("Model exceeds protobuf 2GB limit, using file-based approach" )
153+
154+ if has_external_data :
155+ logger .debug ("Model has external data, using file-based approach" )
156+ # Get the actual ONNX ModelProto from ModifyOutputs wrapper
157+ modified_model = model ()
158+
159+ # Use a persistent temp file to handle external data files properly
160+ tmp_file = tempfile .NamedTemporaryFile (suffix = ".onnx" , delete = False )
161+ tmp_file .close ()
162+ tmp_file_path = tmp_file .name
163+ onnx_utils .save_onnx (modified_model , tmp_file_path , save_as_external_data = True )
164+ logger .debug (f"Model with all outputs saved to { tmp_file_path } " )
165+ session = ort .InferenceSession (tmp_file_path , providers = self .providers )
166+ runners = [OnnxrtRunner (lambda : session )]
167+
168+ else :
169+ # For models without external data, use the original BytesFromOnnx approach (no tmp files)
170+ logger .debug ("Model has no external data, using BytesFromOnnx approach" )
171+ serialize_onnx = BytesFromOnnx (model )
172+ build_onnxrt_session = SessionFromOnnx (serialize_onnx , providers = self .providers )
173+ runners = [OnnxrtRunner (build_onnxrt_session )]
174+
175+ return runners
176+
121177 def run (self , inputs = None ):
122178 """Run FP32 inference with provided or random inputs."""
123179 import onnxruntime as ort
124180 from polygraphy import constants
125- from polygraphy .backend .onnx import BytesFromOnnx
126181 from polygraphy .backend .onnx import ModifyOutputs as ModifyOnnxOutputs
127- from polygraphy .backend .onnxrt import OnnxrtRunner , SessionFromOnnx
128182 from polygraphy .comparator import Comparator
129183
130184 logger .info ("Running ONNX Runtime to obtain reference outputs (this may take a while)..." )
@@ -133,9 +187,9 @@ def run(self, inputs=None):
133187
134188 model_copy = copy .deepcopy (self .model )
135189 modify_outputs = ModifyOnnxOutputs (model_copy , outputs = constants .MARK_ALL )
136- serialize_onnx = BytesFromOnnx ( modify_outputs )
137- build_onnxrt_session = SessionFromOnnx ( serialize_onnx , providers = self . providers )
138- runners = [ OnnxrtRunner ( build_onnxrt_session )]
190+
191+ # Load the modified model and create an inference session
192+ runners = self . _get_ort_runner ( modify_outputs )
139193
140194 # Comparator is used despite the fact that we are using ONNXRuntime
141195 # because it provides the ability to generate random inputs using DataLoader
0 commit comments