Skip to content

Commit 6532522

Browse files
committed
draft: skip model checker for models with external data
Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
1 parent d4e15ed commit 6532522

2 files changed

Lines changed: 18 additions & 12 deletions

File tree

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ class PrecisionConverter:
8282
Public Methods:
8383
convert: Convert specified nodes to FP16/BF16 precision while keeping others in FP32.
8484
"""
85+
def print_byte_size(self, label: str):
86+
model_proto = self.model.SerializeToString()
87+
model_size = len(model_proto)
88+
print(f"GAGAM {label} ByteSize: {model_size}")
8589

8690
def __init__(
8791
self,
@@ -175,7 +179,7 @@ def convert(
175179
onnx.ModelProto: The converted mixed precision model.
176180
"""
177181
try:
178-
self.model = onnx_utils.check_model(self.model)
182+
onnx_utils.check_model(self.model)
179183
except onnx.checker.ValidationError as e:
180184
logger.error(f"Internal error: onnx.checker failed on input model {e}")
181185
raise Exception(
@@ -1294,7 +1298,9 @@ def _fix_network_output_names(self):
12941298
def _sanity_check(self):
12951299
sanity_ok = True
12961300
try:
1301+
self.print_byte_size("before check_model")
12971302
onnx_utils.check_model(self.model)
1303+
self.print_byte_size("after check_model")
12981304
except onnx.checker.ValidationError as e:
12991305
logger.error(f"Internal error: onnx.checker failed: {e}")
13001306
sanity_ok = False

modelopt/onnx/utils.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -552,19 +552,19 @@ def _get_unique_name(old_name):
552552
return onnx_model, is_modified
553553

554554

555-
def check_model(model: onnx.ModelProto) -> onnx.ModelProto:
555+
def check_model(model: onnx.ModelProto) -> None:
556556
"""Checks if the given model is valid."""
557557
if model.ByteSize() > (2 * (1024**3)): # 2GB limit
558-
with tempfile.TemporaryDirectory() as temp_dir:
559-
# ONNX also looks in CWD, so we need to use a unique id
560-
unique_id = str(uuid.uuid4())[:8]
561-
onnx_tmp_path = os.path.join(temp_dir, f"model_{unique_id}.onnx")
562-
save_onnx(model, onnx_tmp_path, save_as_external_data=True)
563-
onnx.checker.check_model(onnx_tmp_path)
564-
return onnx.load(onnx_tmp_path)
558+
logger.warning("Model exceeds 2GB limit, skipping check_model")
559+
# with tempfile.TemporaryDirectory() as temp_dir:
560+
# # ONNX also looks in CWD, so we need to use a unique id
561+
# unique_id = str(uuid.uuid4())[:8]
562+
# onnx_tmp_path = os.path.join(temp_dir, f"model_{unique_id}.onnx")
563+
# save_onnx(model, onnx_tmp_path, save_as_external_data=True)
564+
# onnx.checker.check_model(onnx_tmp_path)
565+
565566
else:
566567
onnx.checker.check_model(model)
567-
return model
568568

569569

570570
def find_lowest_common_ancestor(node1: Node, node2: Node) -> tuple[str | None, int, int]:
@@ -644,7 +644,7 @@ def save_onnx(model: onnx.ModelProto, onnx_path: str, save_as_external_data: boo
644644
model_proto = model.SerializeToString()
645645
model_size = len(model_proto)
646646
save_as_external_data = save_as_external_data or model_size > size_threshold
647-
logger.debug(
647+
logger.warning(
648648
f"Model size: {model_size} bytes, using external data: {save_as_external_data}"
649649
)
650650

@@ -658,7 +658,7 @@ def save_onnx(model: onnx.ModelProto, onnx_path: str, save_as_external_data: boo
658658

659659
# Set ir_version to 10, remove it once ORT supports ir_version 11
660660
model.ir_version = 10
661-
661+
save_as_external_data = True # GAGAM: for debug
662662
if save_as_external_data:
663663
external_data_path = os.path.basename(onnx_path) + "_data"
664664
if os.path.exists(external_data_path):

0 commit comments

Comments
 (0)