Skip to content

Commit 00ea80c

Browse files
committed
code de-deuplication
Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
1 parent cd96fa5 commit 00ea80c

File tree

4 files changed

+17
-16
lines changed

4 files changed

+17
-16
lines changed

modelopt/onnx/autocast/referencerunner.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,7 @@ def _get_ort_runner(self, model):
128128
# Check if model has external data by checking:
129129
# 1. If any initializer has data_location set to EXTERNAL (even if data is loaded)
130130
# 2. If model size would exceed 2GB (indicating need for external data)
131-
has_external_data = any(
132-
init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL
133-
for init in self.model.graph.initializer
134-
)
131+
has_external_data = onnx_utils.check_model_uses_external_data(self.model)
135132

136133
# Also check if model would be too large (>2GB) for SerializeToString
137134
# This handles cases where model was loaded with external data already loaded

modelopt/onnx/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,21 @@ def get_opset_version(model: onnx.ModelProto) -> int:
696696
return ai_onnx_domain[0].version
697697

698698

699+
def check_model_uses_external_data(model: onnx.ModelProto) -> bool:
700+
"""Checks if the model uses external data.
701+
702+
Args:
703+
model: Loaded in-memory onnx ModelProto.
704+
705+
Returns:
706+
True if any initializer tensor has data_location set to EXTERNAL.
707+
"""
708+
return any(
709+
init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL
710+
for init in model.graph.initializer
711+
)
712+
713+
699714
def bfloat16_to_float32(bf16_array):
700715
"""Converts a bfloat16 array (as raw data) to a float32 array."""
701716
uint32_array = bf16_array.astype(np.uint32) << 16

modelopt/torch/_deploy/utils/onnx_utils.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,3 @@ def _get_onnx_external_data_tensors(model: onnx.ModelProto) -> list[str]:
4545
if tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL
4646
]
4747
return model_tensors_ext
48-
49-
50-
def check_model_uses_external_data(model: onnx.ModelProto) -> bool:
51-
"""
52-
Checks if the model uses external data.
53-
"""
54-
model_tensors = _get_initializer_tensors(model)
55-
return any(
56-
tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL
57-
for tensor in model_tensors
58-
)

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
)
4343
from modelopt.onnx.quantization.qdq_utils import qdq_to_dq, replace_zero_scale_with_smallest_nonzero
4444
from modelopt.onnx.utils import (
45+
check_model_uses_external_data,
4546
get_input_names,
4647
get_input_shapes,
4748
get_node_names,
@@ -55,7 +56,6 @@
5556
from modelopt.torch.utils._pytree import TreeSpec
5657

5758
from ..utils.onnx_optimizer import Optimizer
58-
from .onnx_utils import check_model_uses_external_data
5959

6060
ModelMetadata = dict[str, Any]
6161
ModelType = Any

0 commit comments

Comments
 (0)