2323- Converting model precision (e.g., FP16 to BF16)
2424- Transposing DequantizeLinear weights for column-major storage optimization
2525- Graph cleanup and optimization
26+ - Making models compatible with the NVIDIA DLA accelerator
2627
2728Example usage:
2829 >>> from modelopt.onnx.graph_surgery import (
2930 ... replace_attention_with_gqa,
3031 ... convert_fp16_to_bf16,
3132 ... transpose_dequantize_linear_weights,
3233 ... add_cross_kv_to_encoder,
34+ ... make_dla_compatible,
3335 ... )
3436 >>> # Replace attention with GQA for LLMs (FP16 model)
3537 >>> replace_attention_with_gqa(
6264 ... model_path="model_quantized.onnx",
6365 ... output_path="model_quantized_transposed.onnx",
6466 ... )
67+ >>> # Apply the full DLA compatibility pipeline (16-step transform sequence)
68+ >>> make_dla_compatible(
69+ ... model_path="model.onnx",
70+ ... output_path="model_dla.onnx",
71+ ... )
6572"""
6673
67- from .dq_transpose import transpose_dequantize_linear_weights
68- from .encoder_cross_kv import add_cross_kv_to_encoder
69- from .gqa_replacement import replace_attention_with_gqa
70- from .utils .dtype_conversion import convert_fp16_to_bf16
74+ import os
75+
76+ import onnx
77+
78+ from .dq_transpose import _transform_dq_transpose , transpose_dequantize_linear_weights
79+ from .encoder_cross_kv import _transform_cross_kv , add_cross_kv_to_encoder
80+ from .gqa_replacement import _transform_gqa , replace_attention_with_gqa
81+ from .make_dla_compatible import _transform_make_dla_compatible
82+ from .make_dla_compatible import dla_make_dla_compatible as make_dla_compatible
83+ from .utils .dtype_conversion import _transform_fp16_to_bf16 , convert_fp16_to_bf16
7184
7285_SURGERY_REGISTRY = {
73- "replace-gqa" : replace_attention_with_gqa ,
74- "add-cross-kv" : add_cross_kv_to_encoder ,
75- "convert-bf16" : convert_fp16_to_bf16 ,
76- "transpose-dq" : transpose_dequantize_linear_weights ,
86+ "replace-gqa" : _transform_gqa ,
87+ "add-cross-kv" : _transform_cross_kv ,
88+ "convert-bf16" : _transform_fp16_to_bf16 ,
89+ "transpose-dq" : _transform_dq_transpose ,
90+ "make-dla-compatible" : _transform_make_dla_compatible ,
7791}
7892
7993
@@ -82,16 +96,85 @@ def get_available_surgeries() -> list[str]:
8296 return list (_SURGERY_REGISTRY .keys ())
8397
8498
99+ def _save_model (
100+ model : onnx .ModelProto ,
101+ output_path : str ,
102+ use_external_data : bool = True ,
103+ external_data_name : str | None = None ,
104+ size_threshold : int = 1024 ,
105+ verbose : bool = True ,
106+ ) -> None :
107+ """Unified model saving logic for all graph surgeries.
108+
109+ Args:
110+ model: The ONNX model to save.
111+ output_path: Path to save the model.
112+ use_external_data: Whether to save weights as external data.
113+ external_data_name: Name for external data file.
114+ Defaults to ``<output_filename>_data``.
115+ size_threshold: Minimum tensor size (bytes) to externalize.
116+ verbose: Whether to print progress messages.
117+ """
118+ from ..logging_config import logger
119+
120+ if verbose :
121+ logger .info (f"\n Saving modified model to: { output_path } " )
122+
123+ output_dir = os .path .dirname (output_path )
124+ if output_dir :
125+ os .makedirs (output_dir , exist_ok = True )
126+
127+ if use_external_data :
128+ if external_data_name is None :
129+ external_data_name = os .path .basename (output_path ) + "_data"
130+
131+ if verbose :
132+ logger .info (f" Saving weights to external file: { external_data_name } " )
133+
134+ onnx .save (
135+ model ,
136+ output_path ,
137+ save_as_external_data = True ,
138+ all_tensors_to_one_file = True ,
139+ location = external_data_name ,
140+ size_threshold = size_threshold ,
141+ )
142+ else :
143+ onnx .save (model , output_path )
144+ model = onnx .load (output_path , load_external_data = True )
145+
146+ # Run shape inference (beneficial for all surgeries, no-op if nothing changed)
147+ if verbose :
148+ logger .info ("\n Running shape inference (file-to-file)..." )
149+ try :
150+ onnx .shape_inference .infer_shapes_path (
151+ output_path , output_path , check_type = False , strict_mode = False , data_prop = False
152+ )
153+ if verbose :
154+ logger .info (" Shape inference completed" )
155+ except Exception as e :
156+ if verbose :
157+ logger .info (f" Shape inference failed (non-fatal, model already saved): { e } " )
158+
159+ if verbose :
160+ logger .info ("Done!" )
161+
162+
85163def run_graph_surgery (
86164 surgery_name : str ,
87165 model_path : str ,
88166 output_path : str ,
167+ use_external_data : bool = True ,
168+ external_data_name : str | None = None ,
169+ verbose : bool = True ,
89170 ** kwargs ,
90- ):
91- """Run a graph surgery by name.
171+ ) -> onnx . ModelProto :
172+ """Run a graph surgery by name with centralized model loading and saving .
92173
93- This is the unified entry point for all graph surgeries. It dispatches
94- to the appropriate surgery function based on the surgery name.
174+ This is the unified entry point for all graph surgeries. It handles:
175+ 1. Loading the input model from ``model_path``
176+ 2. Dispatching to the appropriate transform function
177+ 3. Saving the result to ``output_path`` with unified save logic
95178
96179 When new surgeries are added to the registry, they are automatically
97180 available through this function without any changes to calling code.
@@ -101,10 +184,14 @@ def run_graph_surgery(
101184 Use get_available_surgeries() to see all available options.
102185 model_path: Path to the input ONNX model.
103186 output_path: Path to save the output ONNX model.
104- **kwargs: Surgery-specific parameters. Passed directly to the surgery function.
187+ use_external_data: Whether to save weights as external data file.
188+ external_data_name: Name for external data file.
189+ Defaults to ``<output_filename>_data``.
190+ verbose: Whether to print progress messages.
191+ **kwargs: Surgery-specific parameters passed directly to the transform function.
105192
106193 Returns:
107- The return value of the surgery function (typically ModelProto or dict) .
194+ The modified ONNX ModelProto.
108195
109196 Raises:
110197 ValueError: If surgery_name is not registered.
@@ -120,18 +207,39 @@ def run_graph_surgery(
120207 ... hf_model_id="meta-llama/Llama-2-7b-hf",
121208 ... )
122209 """
210+ from ..logging_config import logger
211+
123212 if surgery_name not in _SURGERY_REGISTRY :
124213 available = ", " .join (f"'{ s } '" for s in _SURGERY_REGISTRY )
125214 raise ValueError (f"Unknown surgery: '{ surgery_name } '. Available surgeries: { available } " )
126215
127- func = _SURGERY_REGISTRY [surgery_name ]
128- return func (model_path = model_path , output_path = output_path , ** kwargs )
216+ # Load
217+ if verbose :
218+ logger .info (f"Loading model from: { model_path } " )
219+ model = onnx .load (model_path , load_external_data = True )
220+
221+ # Transform
222+ transform_fn = _SURGERY_REGISTRY [surgery_name ]
223+ kwargs .setdefault ("verbose" , verbose )
224+ model = transform_fn (model = model , ** kwargs )
225+
226+ # Save
227+ _save_model (
228+ model ,
229+ output_path ,
230+ use_external_data = use_external_data ,
231+ external_data_name = external_data_name ,
232+ verbose = verbose ,
233+ )
234+
235+ return model
129236
130237
131238__all__ = [
132239 "add_cross_kv_to_encoder" ,
133240 "convert_fp16_to_bf16" ,
134241 "get_available_surgeries" ,
242+ "make_dla_compatible" ,
135243 "replace_attention_with_gqa" ,
136244 "run_graph_surgery" ,
137245 "transpose_dequantize_linear_weights" ,
0 commit comments