Skip to content

Commit cf10c69

Browse files
committed
Added code changes to have a pathway for when serialize succedes and when serialize fails
Signed-off-by: Hrishith Thadicherla <hthadicherla@nvidia.com>
1 parent b2c31e4 commit cf10c69

2 files changed

Lines changed: 28 additions & 16 deletions

File tree

modelopt/onnx/trt_utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,22 @@ def load_onnx_model(
322322

323323
# Load the model and weights
324324
onnx_model = onnx.load(onnx_path, load_external_data=True)
325-
size_threshold = 2 * (1024**3) # 2GB
326-
use_external_data_format = onnx_model.ByteSize() > size_threshold or use_external_data_format
325+
if not use_external_data_format:
326+
try:
327+
model_size = onnx_model.ByteSize()
328+
except Exception as e:
329+
logger.warning(
330+
"Failed to compute model size with ByteSize (%s). Saving tensors as external data.",
331+
e,
332+
)
333+
use_external_data_format = True
334+
else:
335+
if model_size <= 0 or model_size >= onnx.checker.MAXIMUM_PROTOBUF:
336+
use_external_data_format = True
337+
logger.debug(
338+
"Model is too large to save as a single file but 'use_external_data_format'"
339+
" is False. Saving tensors as external data, regardless."
340+
)
327341

328342
# If inputs are dynamic and override shapes are given, set them as static
329343
dynamic_inputs = get_dynamic_graph_inputs(onnx_model)

modelopt/onnx/utils.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from collections import defaultdict
2424
from typing import Any
2525

26-
import google.protobuf.message
2726
import numpy as np
2827
import onnx
2928
import onnx_graphsurgeon as gs
@@ -643,19 +642,18 @@ def get_variable_inputs(node: Node) -> list[Variable]:
643642
def save_onnx(model: onnx.ModelProto, onnx_path: str, save_as_external_data: bool = False):
644643
"""Save an ONNX model to given path. If a model is larger than 2GB, will save with external data."""
645644
size_threshold = 2 * (1024**3) # 2GB
646-
try:
647-
model_proto = model.SerializeToString()
648-
model_size = len(model_proto)
649-
save_as_external_data = save_as_external_data or model_size > size_threshold
650-
logger.debug(
651-
f"Model size: {model_size} bytes, using external data: {save_as_external_data}"
652-
)
653-
654-
except (ValueError, google.protobuf.message.EncodeError) as e:
655-
logger.warning(
656-
"Model exceeds 2GB limit, switching to external data storage. Error message: [%s]", e
657-
)
658-
save_as_external_data = True
645+
if not save_as_external_data:
646+
try:
647+
model_proto = model.SerializeToString()
648+
except Exception as e:
649+
logger.warning("Failed to serialize model. Saving tensors as external data. (%s)", e)
650+
save_as_external_data = True
651+
else:
652+
model_size = len(model_proto)
653+
save_as_external_data = model_size > size_threshold
654+
logger.debug(
655+
f"Model size: {model_size} bytes, using external data: {save_as_external_data}"
656+
)
659657

660658
# Set ir_version to 10, remove it once ORT supports ir_version 11
661659
model.ir_version = 10

0 commit comments

Comments
 (0)