Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.8.7
+++++

* :pr:`366`: add command line to optimize a model
* :pr:`363`: patch for DynamicDimConstraintPrinter
* :pr:`360`, :pr:`364`: preliminary work for phi4

Expand Down
1 change: 1 addition & 0 deletions _doc/api/helpers/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ onnx_diagnostic.helpers
mini_onnx_builder
model_builder_helper
onnx_helper
optim_helper
ort_session
rt_helper
torch_fx_graph_helper
Expand Down
7 changes: 7 additions & 0 deletions _doc/api/helpers/optim_helper.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

onnx_diagnostic.helpers.optim_helper
====================================

.. automodule:: onnx_diagnostic.helpers.optim_helper
:members:
:no-undoc-members:
1 change: 1 addition & 0 deletions _doc/cmds/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ Command Lines

compare
config
optimize
sbs
validate
13 changes: 13 additions & 0 deletions _doc/cmds/optimize.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
-m onnx_diagnostic optimize ... optimizes an onnx model
=======================================================

Description
+++++++++++

See :func:`onnx_diagnostic.helpers.optim_helper.optimize_model`.

.. runpython::

from onnx_diagnostic._command_lines_parser import get_parser_optimize

get_parser_optimize().print_help()
38 changes: 38 additions & 0 deletions _unittests/ut_helpers/test_optim_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import unittest
import numpy as np
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
from onnx_diagnostic.helpers.optim_helper import optimize_model

TFLOAT = onnx.TensorProto.FLOAT


class TestOptimHelper(ExtTestCase):
@hide_stdout()
def test_optimize_model(self):
model = oh.make_model(
oh.make_graph(
[
oh.make_node("Shape", ["X"], ["D2"], start=2, end=3),
oh.make_node("Concat", ["I1", "D2"], ["d"], axis=0),
oh.make_node("Reshape", ["X", "d"], ["Y"]),
],
"test",
[oh.make_tensor_value_info("X", TFLOAT, [2, 3, "d"])],
[oh.make_tensor_value_info("Y", TFLOAT, [6, "d"])],
[onh.from_array(np.array([-1], dtype=np.int64), name="I1")],
),
opset_imports=[oh.make_operatorsetid("", 18)],
ir_version=10,
)
filename = self.dump_onnx("test_optimize_model.onnx", model)
for algo in ["default", "default+onnxruntime", "ir", "os_ort", "slim"]:
output = self.get_dump_file(f"test_optimize_model.{algo}.onnx")
with self.subTest(algo=algo):
optimize_model(algo, filename, output=output, verbose=1)


if __name__ == "__main__":
unittest.main(verbosity=2)
8 changes: 8 additions & 0 deletions _unittests/ut_xrun_doc/test_command_lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
get_parser_dot,
get_parser_find,
get_parser_lighten,
get_parser_optimize,
get_parser_print,
get_parser_sbs,
get_parser_stats,
Expand Down Expand Up @@ -178,6 +179,13 @@ def test_parser_compare(self):
text = st.getvalue()
self.assertIn("compare", text)

def test_parser_optimize(self):
st = StringIO()
with redirect_stdout(st):
get_parser_optimize().print_help()
text = st.getvalue()
self.assertIn("default", text)


if __name__ == "__main__":
unittest.main(verbosity=2)
9 changes: 9 additions & 0 deletions _unittests/ut_xrun_doc/test_command_lines_exe.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,15 @@ def test_j_parser_compare(self):
text = st.getvalue()
self.assertIn("done with distance 0", text)

def test_l_parser_optimize(self):
output = self.get_dump_file("test_parser_optimize.onnx")
st = StringIO()
with redirect_stdout(st):
main(["optimize", "default", self.dummy_path, "-o", output, "-v", "1"])
text = st.getvalue()
self.assertIn("default", text)
self.assertExists(output)


if __name__ == "__main__":
unittest.main(verbosity=2)
111 changes: 108 additions & 3 deletions onnx_diagnostic/_command_lines_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1547,6 +1547,107 @@ def _cmd_compare(argv: List[Any]):
print(ObsComparePair.to_str(pair_cmp))


def get_parser_optimize() -> ArgumentParser:
parser = ArgumentParser(
prog="optimize",
formatter_class=RawTextHelpFormatter,
description=textwrap.dedent(
"""
Optimizes an onnx model by fusing nodes. It looks for patterns in the graphs
and replaces them by the corresponding nodes. It also does basic optimization
such as removing identity nodes or unused nodes.
"""
),
epilog=textwrap.dedent(
"""
The goal is to make the model faster.
Argument patterns defines the patterns to apply or the set of patterns.
It is possible to show statistics or to remove a particular pattern.
Here are some environment variables which can be used to trigger
these displays.

Available options algorithms, default and default+runtime:

- DROPPATTERN=<pattern1,patterns2,...>: do not apply
those patterns when optimizing a model
- DUMPPATTERNS=<folder>: dumps all matched and applied
nodes when a pattern is applied
- PATTERN=<pattern1,pattern2,...>: increase verbosity for specific
patterns to understand why one pattern was not applied,
this shows which line is rejecting a pattern if it seems one pattern was missed
"""
),
)
parser.add_argument(
"algorithm",
choices=["ir", "os_ort", "slim", "default", "default+onnxruntime"],
help="algorithm or patterns optimization to apply",
)
parser.add_argument("input", type=str, help="onnx model to optimize")
parser.add_argument(
"-o",
"--output",
type=str,
required=False,
help="onnx model to output, if empty, if adds .opt-{algorithm}.onnx to the name",
)
parser.add_argument(
"-v",
"--verbose",
default=0,
required=False,
type=int,
help="verbosity",
)
parser.add_argument(
"--infer-shapes",
default=True,
action=BooleanOptionalAction,
help="infer shapes before optimizing the model",
)
parser.add_argument(
"--processor",
default="",
help=textwrap.dedent(
"""
optimization for a specific processor, CPU, CUDA or both CPU,CUDA,
some operators are only available in one processor, it might be not used
with all
"""
).strip("\n"),
)
parser.add_argument(
"--remove-shape-info",
default=True,
action=BooleanOptionalAction,
help="remove shape information before outputting the model",
)
return parser


def _cmd_optimize(argv: List[Any]):
parser = get_parser_optimize()
args = parser.parse_args(argv[1:])

from .helpers.optim_helper import optimize_model

output = (
args.output
if args.output
else f"{os.path.splitext(args.input)[0]}.o-{args.algorithm}.onnx"
)

optimize_model(
args.algorithm,
args.input,
output=output,
verbose=args.verbose,
processor=args.processor,
infer_shapes=args.infer_shapes,
remove_shape_info=args.remove_shape_info,
)


#############
# main parser
#############
Expand All @@ -1563,16 +1664,17 @@ def get_main_parser() -> ArgumentParser:
to get help for a specific command.

agg - aggregates statistics from multiple files
config - prints a configuration for a model id
config - prints a configuration for a model id (on HuggingFace Hub)
dot - converts an onnx model into dot format
exportsample - produces a code to export a model
find - find node consuming or producing a result
lighten - makes an onnx model lighter by removing the weights,
lighten - makes an onnx model lighter by removing the weights
optimize - optimizes an onnx model
print - prints the model on standard output
sbs - compares an exported program and a onnx model
stats - produces statistics on a model
unlighten - restores an onnx model produces by the previous experiment
validate - validate a model
validate - validate a model (knowing its model id on HuggginFace Hub)
"""
),
)
Expand All @@ -1585,6 +1687,7 @@ def get_main_parser() -> ArgumentParser:
"exportsample",
"find",
"lighten",
"optimize",
"print",
"sbs",
"stats",
Expand All @@ -1605,6 +1708,7 @@ def main(argv: Optional[List[Any]] = None):
exportsample=_cmd_export_sample,
find=_cmd_find,
lighten=_cmd_lighten,
optimize=_cmd_optimize,
print=_cmd_print,
sbs=_cmd_sbs,
stats=_cmd_stats,
Expand All @@ -1631,6 +1735,7 @@ def main(argv: Optional[List[Any]] = None):
exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator]
find=get_parser_find,
lighten=get_parser_lighten,
optimize=get_parser_optimize,
print=get_parser_print,
sbs=get_parser_sbs,
stats=get_parser_stats,
Expand Down
116 changes: 116 additions & 0 deletions onnx_diagnostic/helpers/optim_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from typing import Optional, Union
import pprint
import onnx


def optimize_model(
algorithm: str,
model: Union[onnx.ModelProto, str],
output: Optional[str] = None,
processor: Optional[str] = None,
infer_shapes: bool = True,
remove_shape_info: bool = False,
verbose: int = 1,
):
"""
Optimizes an onnx model by fusing nodes. It looks for patterns in the graphs
and replaces them by the corresponding nodes. It also does basic optimization
such as removing identity nodes or unused nodes.

:param algorithm: algorithm to choose
:param model: model to optimize as a proto or a filename
:param output: if not empty, the optimized model is saved
:param processor: optimization are done for the processor
:param infer_shapes: infer shapes before optimizing, this might not be
available for all algorithm
:param remove_shape_info: remove shape information before saving the model
:param verbose: verbosity level
:return: optimized model

The goal is to make the model faster.
Argument patterns defines the patterns to apply or the set of patterns.
It is possible to show statistics or to remove a particular pattern.
Here are some environment variables which can be used to trigger
these displays.

Available options algorithms, default and default+runtime:

- ``DROPPATTERN=<pattern1,patterns2,...>``: do not apply
those patterns when optimizing a model
- ``DUMPPATTERNS=<folder>``: dumps all matched and applied nodes when a pattern is applied
- ``PATTERN=<pattern1,pattern2,...>``: increase verbosity
for specific patterns to understand why one pattern was not applied,
this shows which line is rejecting a pattern if it seems one pattern was missed
"""
if isinstance(model, str):
if verbose:
print(f"[optimize_model] load {model!r}")
proto = onnx.load(model)
if verbose:
print("[optimize_model] done loading.")
else:
proto = model

if verbose:
print(f"[optimize_model] optimize with {algorithm!r}")
if algorithm in {"default", "default+onnxruntime"}:
from experimental_experiment.xoptim import get_pattern_list
from experimental_experiment.xbuilder import GraphBuilder, OptimizationOptions

pats = get_pattern_list(algorithm)

gr = GraphBuilder(
proto,
infer_shapes_options=infer_shapes,
optimization_options=OptimizationOptions(
patterns=pats,
verbose=verbose,
remove_unused=True,
constant_folding=True,
remove_identity=True,
max_iter=max(100, len(proto.graph.node) // 2),
processor=processor or "CPU",
),
)
if verbose:
print(f"[optimize_model] starts optimizing with {len(pats)} patterns")
print(f"[optimize_model] model has {len(proto.graph.node)} nodes")
opt_onx, report = gr.to_onnx(optimize=True, return_optimize_report=True)
if verbose:
print("[optimize_model] optimization report")
pprint.pprint(report)
print("[optimize_model] done")

elif algorithm == "slim":
import onnxslim

opt_onx = onnxslim.slim(proto, no_shape_infer=not infer_shapes)
elif algorithm in {"ir", "os_ort"}:
import onnx_ir
import onnxscript.optimizer
from onnxscript.rewriter.ort_fusions import optimize_for_ort

model_ir = onnx_ir.from_proto(proto)
if algorithm == "ir":
onnxscript.optimizer.optimize(model_ir)
else:
optimize_for_ort(model_ir)
opt_onx = onnx_ir.serde.serialize_model(model_ir)

del proto
if verbose:
print(f"[optimize_model] done optimizing, model has {len(opt_onx.graph.node)} nodes")
if remove_shape_info:
if verbose:
print(f"[optimize_model] remove shape information {len(opt_onx.graph.value_info)}")
del opt_onx.graph.value_info[:]
if verbose:
print("[optimize_model] done removing shape info")

if output:
if verbose:
print(f"[optimize_model] save file into {output!r}")
onnx.save(opt_onx, output, save_as_external_data=True)
if verbose:
print("[optimize_model] done saving")
return opt_onx
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ onnx-array-api>=0.3.1
onnx
onnxruntime-genai
onnxscript
onnxslim
openpyxl
packaging
pandas
Expand Down
Loading