@@ -121,54 +121,34 @@ def _load_inputs(self, inputs):
121121 return data_loader
122122
123123 def _get_ort_runner (self , model ):
124- import onnxruntime as ort
125124 from polygraphy .backend .onnx import BytesFromOnnx
126125 from polygraphy .backend .onnxrt import OnnxrtRunner , SessionFromOnnx
127126
128127 # Check if model has external data by checking:
129128 # 1. If any initializer has data_location set to EXTERNAL (even if data is loaded)
130129 # 2. If model size would exceed 2GB (indicating need for external data)
131- has_external_data = onnx_utils .check_model_uses_external_data (self .model )
132-
133- # Also check if model would be too large (>2GB) for SerializeToString
134- # This handles cases where model was loaded with external data already loaded
135- if not has_external_data :
136- try :
137- # Try to estimate size by serializing the model
138- # If it fails or exceeds 2GB, we need file-based approach
139- model_size = len (self .model .SerializeToString ())
140- if model_size > 2 * (1024 ** 3 ): # 2GB threshold
141- has_external_data = True
142- logger .debug (
143- f"Model size ({ model_size / (1024 ** 3 ):.2f} GB) exceeds 2GB, using file-based approach"
144- )
145- except (ValueError , AttributeError ) as e :
146- # SerializeToString failed (likely >2GB limit), use file-based approach
147- if "exceeds maximum protobuf size" in str (e ) or "2GB" in str (e ):
148- has_external_data = True
149- logger .debug ("Model exceeds protobuf 2GB limit, using file-based approach" )
150-
151- if has_external_data :
130+ needs_external_data = onnx_utils .check_model_uses_external_data (
131+ self .model
132+ ) or self .model .ByteSize () > 2 * (1024 ** 3 )
133+ if needs_external_data :
152134 logger .debug ("Model has external data, using file-based approach" )
153135 # Get the actual ONNX ModelProto from ModifyOutputs wrapper
154136 modified_model = model ()
155137
156- # Use a persistent temp file to handle external data files properly
138+ # Use a persistent temp file, because we need the file to be present in an broader context
157139 tmp_file = tempfile .NamedTemporaryFile (suffix = ".onnx" , delete = False )
158140 tmp_file .close ()
159141 tmp_file_path = tmp_file .name
160142 onnx_utils .save_onnx (modified_model , tmp_file_path , save_as_external_data = True )
161143 logger .debug (f"Model with all outputs saved to { tmp_file_path } " )
162- session = ort .InferenceSession (tmp_file_path , providers = self .providers )
163- runners = [OnnxrtRunner (lambda : session )]
144+ build_onnxrt_session = SessionFromOnnx (tmp_file_path , providers = self .providers )
164145
165146 else :
166147 # For models without external data, use the original BytesFromOnnx approach (no tmp files)
167148 logger .debug ("Model has no external data, using BytesFromOnnx approach" )
168149 serialize_onnx = BytesFromOnnx (model )
169150 build_onnxrt_session = SessionFromOnnx (serialize_onnx , providers = self .providers )
170- runners = [OnnxrtRunner (build_onnxrt_session )]
171-
151+ runners = [OnnxrtRunner (build_onnxrt_session )]
172152 return runners
173153
174154 def run (self , inputs = None ):
0 commit comments