File tree Expand file tree Collapse file tree 4 files changed +17
-16
lines changed
Expand file tree Collapse file tree 4 files changed +17
-16
lines changed Original file line number Diff line number Diff line change @@ -128,10 +128,7 @@ def _get_ort_runner(self, model):
128128 # Check if model has external data by checking:
129129 # 1. If any initializer has data_location set to EXTERNAL (even if data is loaded)
130130 # 2. If model size would exceed 2GB (indicating need for external data)
131- has_external_data = any (
132- init .HasField ("data_location" ) and init .data_location == onnx .TensorProto .EXTERNAL
133- for init in self .model .graph .initializer
134- )
131+ has_external_data = onnx_utils .check_model_uses_external_data (self .model )
135132
136133 # Also check if model would be too large (>2GB) for SerializeToString
137134 # This handles cases where model was loaded with external data already loaded
Original file line number Diff line number Diff line change @@ -696,6 +696,21 @@ def get_opset_version(model: onnx.ModelProto) -> int:
696696 return ai_onnx_domain [0 ].version
697697
698698
699+ def check_model_uses_external_data (model : onnx .ModelProto ) -> bool :
700+ """Checks if the model uses external data.
701+
702+ Args:
703+ model: Loaded in-memory onnx ModelProto.
704+
705+ Returns:
706+ True if any initializer tensor has data_location set to EXTERNAL.
707+ """
708+ return any (
709+ init .HasField ("data_location" ) and init .data_location == onnx .TensorProto .EXTERNAL
710+ for init in model .graph .initializer
711+ )
712+
713+
699714def bfloat16_to_float32 (bf16_array ):
700715 """Converts a bfloat16 array (as raw data) to a float32 array."""
701716 uint32_array = bf16_array .astype (np .uint32 ) << 16
Original file line number Diff line number Diff line change @@ -45,14 +45,3 @@ def _get_onnx_external_data_tensors(model: onnx.ModelProto) -> list[str]:
4545 if tensor .HasField ("data_location" ) and tensor .data_location == onnx .TensorProto .EXTERNAL
4646 ]
4747 return model_tensors_ext
48-
49-
50- def check_model_uses_external_data (model : onnx .ModelProto ) -> bool :
51- """
52- Checks if the model uses external data.
53- """
54- model_tensors = _get_initializer_tensors (model )
55- return any (
56- tensor .HasField ("data_location" ) and tensor .data_location == onnx .TensorProto .EXTERNAL
57- for tensor in model_tensors
58- )
Original file line number Diff line number Diff line change 4242)
4343from modelopt .onnx .quantization .qdq_utils import qdq_to_dq , replace_zero_scale_with_smallest_nonzero
4444from modelopt .onnx .utils import (
45+ check_model_uses_external_data ,
4546 get_input_names ,
4647 get_input_shapes ,
4748 get_node_names ,
5556from modelopt .torch .utils ._pytree import TreeSpec
5657
5758from ..utils .onnx_optimizer import Optimizer
58- from .onnx_utils import check_model_uses_external_data
5959
6060ModelMetadata = dict [str , Any ]
6161ModelType = Any
You can’t perform that action at this time.
0 commit comments