Skip to content

Commit 0b7c916

Browse files
committed
DLA surgeries added as part of ModelOpt
Signed-off-by: mgohil-png <mgohil@nvidia.com>
1 parent e2d29c8 commit 0b7c916

44 files changed

Lines changed: 16546 additions & 183 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

modelopt/onnx/graph_surgery/__init__.py

Lines changed: 124 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,15 @@
2323
- Converting model precision (e.g., FP16 to BF16)
2424
- Transposing DequantizeLinear weights for column-major storage optimization
2525
- Graph cleanup and optimization
26+
- Making models compatible with the NVIDIA DLA accelerator
2627
2728
Example usage:
2829
>>> from modelopt.onnx.graph_surgery import (
2930
... replace_attention_with_gqa,
3031
... convert_fp16_to_bf16,
3132
... transpose_dequantize_linear_weights,
3233
... add_cross_kv_to_encoder,
34+
... make_dla_compatible,
3335
... )
3436
>>> # Replace attention with GQA for LLMs (FP16 model)
3537
>>> replace_attention_with_gqa(
@@ -62,18 +64,30 @@
6264
... model_path="model_quantized.onnx",
6365
... output_path="model_quantized_transposed.onnx",
6466
... )
67+
>>> # Apply the full DLA compatibility pipeline (16-step transform sequence)
68+
>>> make_dla_compatible(
69+
... model_path="model.onnx",
70+
... output_path="model_dla.onnx",
71+
... )
6572
"""
6673

67-
from .dq_transpose import transpose_dequantize_linear_weights
68-
from .encoder_cross_kv import add_cross_kv_to_encoder
69-
from .gqa_replacement import replace_attention_with_gqa
70-
from .utils.dtype_conversion import convert_fp16_to_bf16
74+
import os
75+
76+
import onnx
77+
78+
from .dq_transpose import _transform_dq_transpose, transpose_dequantize_linear_weights
79+
from .encoder_cross_kv import _transform_cross_kv, add_cross_kv_to_encoder
80+
from .gqa_replacement import _transform_gqa, replace_attention_with_gqa
81+
from .make_dla_compatible import _transform_make_dla_compatible
82+
from .make_dla_compatible import dla_make_dla_compatible as make_dla_compatible
83+
from .utils.dtype_conversion import _transform_fp16_to_bf16, convert_fp16_to_bf16
7184

7285
_SURGERY_REGISTRY = {
73-
"replace-gqa": replace_attention_with_gqa,
74-
"add-cross-kv": add_cross_kv_to_encoder,
75-
"convert-bf16": convert_fp16_to_bf16,
76-
"transpose-dq": transpose_dequantize_linear_weights,
86+
"replace-gqa": _transform_gqa,
87+
"add-cross-kv": _transform_cross_kv,
88+
"convert-bf16": _transform_fp16_to_bf16,
89+
"transpose-dq": _transform_dq_transpose,
90+
"make-dla-compatible": _transform_make_dla_compatible,
7791
}
7892

7993

@@ -82,16 +96,85 @@ def get_available_surgeries() -> list[str]:
8296
return list(_SURGERY_REGISTRY.keys())
8397

8498

99+
def _save_model(
100+
model: onnx.ModelProto,
101+
output_path: str,
102+
use_external_data: bool = True,
103+
external_data_name: str | None = None,
104+
size_threshold: int = 1024,
105+
verbose: bool = True,
106+
) -> None:
107+
"""Unified model saving logic for all graph surgeries.
108+
109+
Args:
110+
model: The ONNX model to save.
111+
output_path: Path to save the model.
112+
use_external_data: Whether to save weights as external data.
113+
external_data_name: Name for external data file.
114+
Defaults to ``<output_filename>_data``.
115+
size_threshold: Minimum tensor size (bytes) to externalize.
116+
verbose: Whether to print progress messages.
117+
"""
118+
from ..logging_config import logger
119+
120+
if verbose:
121+
logger.info(f"\nSaving modified model to: {output_path}")
122+
123+
output_dir = os.path.dirname(output_path)
124+
if output_dir:
125+
os.makedirs(output_dir, exist_ok=True)
126+
127+
if use_external_data:
128+
if external_data_name is None:
129+
external_data_name = os.path.basename(output_path) + "_data"
130+
131+
if verbose:
132+
logger.info(f" Saving weights to external file: {external_data_name}")
133+
134+
onnx.save(
135+
model,
136+
output_path,
137+
save_as_external_data=True,
138+
all_tensors_to_one_file=True,
139+
location=external_data_name,
140+
size_threshold=size_threshold,
141+
)
142+
else:
143+
onnx.save(model, output_path)
144+
model = onnx.load(output_path, load_external_data=True)
145+
146+
# Run shape inference (beneficial for all surgeries, no-op if nothing changed)
147+
if verbose:
148+
logger.info("\nRunning shape inference (file-to-file)...")
149+
try:
150+
onnx.shape_inference.infer_shapes_path(
151+
output_path, output_path, check_type=False, strict_mode=False, data_prop=False
152+
)
153+
if verbose:
154+
logger.info(" Shape inference completed")
155+
except Exception as e:
156+
if verbose:
157+
logger.info(f" Shape inference failed (non-fatal, model already saved): {e}")
158+
159+
if verbose:
160+
logger.info("Done!")
161+
162+
85163
def run_graph_surgery(
86164
surgery_name: str,
87165
model_path: str,
88166
output_path: str,
167+
use_external_data: bool = True,
168+
external_data_name: str | None = None,
169+
verbose: bool = True,
89170
**kwargs,
90-
):
91-
"""Run a graph surgery by name.
171+
) -> onnx.ModelProto:
172+
"""Run a graph surgery by name with centralized model loading and saving.
92173
93-
This is the unified entry point for all graph surgeries. It dispatches
94-
to the appropriate surgery function based on the surgery name.
174+
This is the unified entry point for all graph surgeries. It handles:
175+
1. Loading the input model from ``model_path``
176+
2. Dispatching to the appropriate transform function
177+
3. Saving the result to ``output_path`` with unified save logic
95178
96179
When new surgeries are added to the registry, they are automatically
97180
available through this function without any changes to calling code.
@@ -101,10 +184,14 @@ def run_graph_surgery(
101184
Use get_available_surgeries() to see all available options.
102185
model_path: Path to the input ONNX model.
103186
output_path: Path to save the output ONNX model.
104-
**kwargs: Surgery-specific parameters. Passed directly to the surgery function.
187+
use_external_data: Whether to save weights as external data file.
188+
external_data_name: Name for external data file.
189+
Defaults to ``<output_filename>_data``.
190+
verbose: Whether to print progress messages.
191+
**kwargs: Surgery-specific parameters passed directly to the transform function.
105192
106193
Returns:
107-
The return value of the surgery function (typically ModelProto or dict).
194+
The modified ONNX ModelProto.
108195
109196
Raises:
110197
ValueError: If surgery_name is not registered.
@@ -120,18 +207,39 @@ def run_graph_surgery(
120207
... hf_model_id="meta-llama/Llama-2-7b-hf",
121208
... )
122209
"""
210+
from ..logging_config import logger
211+
123212
if surgery_name not in _SURGERY_REGISTRY:
124213
available = ", ".join(f"'{s}'" for s in _SURGERY_REGISTRY)
125214
raise ValueError(f"Unknown surgery: '{surgery_name}'. Available surgeries: {available}")
126215

127-
func = _SURGERY_REGISTRY[surgery_name]
128-
return func(model_path=model_path, output_path=output_path, **kwargs)
216+
# Load
217+
if verbose:
218+
logger.info(f"Loading model from: {model_path}")
219+
model = onnx.load(model_path, load_external_data=True)
220+
221+
# Transform
222+
transform_fn = _SURGERY_REGISTRY[surgery_name]
223+
kwargs.setdefault("verbose", verbose)
224+
model = transform_fn(model=model, **kwargs)
225+
226+
# Save
227+
_save_model(
228+
model,
229+
output_path,
230+
use_external_data=use_external_data,
231+
external_data_name=external_data_name,
232+
verbose=verbose,
233+
)
234+
235+
return model
129236

130237

131238
__all__ = [
132239
"add_cross_kv_to_encoder",
133240
"convert_fp16_to_bf16",
134241
"get_available_surgeries",
242+
"make_dla_compatible",
135243
"replace_attention_with_gqa",
136244
"run_graph_surgery",
137245
"transpose_dequantize_linear_weights",

0 commit comments

Comments
 (0)