Skip to content

Commit caafcbd

Browse files
committed
code cleanup
Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
1 parent 00ea80c commit caafcbd

File tree

1 file changed

+7
-27
lines changed

1 file changed

+7
-27
lines changed

modelopt/onnx/autocast/referencerunner.py

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -121,54 +121,34 @@ def _load_inputs(self, inputs):
121121
return data_loader
122122

123123
def _get_ort_runner(self, model):
124-
import onnxruntime as ort
125124
from polygraphy.backend.onnx import BytesFromOnnx
126125
from polygraphy.backend.onnxrt import OnnxrtRunner, SessionFromOnnx
127126

128127
# Check if model has external data by checking:
129128
# 1. If any initializer has data_location set to EXTERNAL (even if data is loaded)
130129
# 2. If model size would exceed 2GB (indicating need for external data)
131-
has_external_data = onnx_utils.check_model_uses_external_data(self.model)
132-
133-
# Also check if model would be too large (>2GB) for SerializeToString
134-
# This handles cases where model was loaded with external data already loaded
135-
if not has_external_data:
136-
try:
137-
# Try to estimate size by serializing the model
138-
# If it fails or exceeds 2GB, we need file-based approach
139-
model_size = len(self.model.SerializeToString())
140-
if model_size > 2 * (1024**3): # 2GB threshold
141-
has_external_data = True
142-
logger.debug(
143-
f"Model size ({model_size / (1024**3):.2f} GB) exceeds 2GB, using file-based approach"
144-
)
145-
except (ValueError, AttributeError) as e:
146-
# SerializeToString failed (likely >2GB limit), use file-based approach
147-
if "exceeds maximum protobuf size" in str(e) or "2GB" in str(e):
148-
has_external_data = True
149-
logger.debug("Model exceeds protobuf 2GB limit, using file-based approach")
150-
151-
if has_external_data:
130+
needs_external_data = onnx_utils.check_model_uses_external_data(
131+
self.model
132+
) or self.model.ByteSize() > 2 * (1024**3)
133+
if needs_external_data:
152134
logger.debug("Model has external data, using file-based approach")
153135
# Get the actual ONNX ModelProto from ModifyOutputs wrapper
154136
modified_model = model()
155137

156-
# Use a persistent temp file to handle external data files properly
138+
# Use a persistent temp file, because we need the file to be present in an broader context
157139
tmp_file = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
158140
tmp_file.close()
159141
tmp_file_path = tmp_file.name
160142
onnx_utils.save_onnx(modified_model, tmp_file_path, save_as_external_data=True)
161143
logger.debug(f"Model with all outputs saved to {tmp_file_path}")
162-
session = ort.InferenceSession(tmp_file_path, providers=self.providers)
163-
runners = [OnnxrtRunner(lambda: session)]
144+
build_onnxrt_session = SessionFromOnnx(tmp_file_path, providers=self.providers)
164145

165146
else:
166147
# For models without external data, use the original BytesFromOnnx approach (no tmp files)
167148
logger.debug("Model has no external data, using BytesFromOnnx approach")
168149
serialize_onnx = BytesFromOnnx(model)
169150
build_onnxrt_session = SessionFromOnnx(serialize_onnx, providers=self.providers)
170-
runners = [OnnxrtRunner(build_onnxrt_session)]
171-
151+
runners = [OnnxrtRunner(build_onnxrt_session)]
172152
return runners
173153

174154
def run(self, inputs=None):

0 commit comments

Comments
 (0)