diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 872672ea..772eaa2a 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.8.7 +++++ +* :pr:`366`: add command line to optimize a model * :pr:`363`: patch for DynamicDimConstraintPrinter * :pr:`360`, :pr:`364`: preliminary work for phi4 diff --git a/_doc/api/helpers/index.rst b/_doc/api/helpers/index.rst index e42e553c..0137f60a 100644 --- a/_doc/api/helpers/index.rst +++ b/_doc/api/helpers/index.rst @@ -21,6 +21,7 @@ onnx_diagnostic.helpers mini_onnx_builder model_builder_helper onnx_helper + optim_helper ort_session rt_helper torch_fx_graph_helper diff --git a/_doc/api/helpers/optim_helper.rst b/_doc/api/helpers/optim_helper.rst new file mode 100644 index 00000000..6cf28a78 --- /dev/null +++ b/_doc/api/helpers/optim_helper.rst @@ -0,0 +1,7 @@ + +onnx_diagnostic.helpers.optim_helper +==================================== + +.. automodule:: onnx_diagnostic.helpers.optim_helper + :members: + :no-undoc-members: diff --git a/_doc/cmds/index.rst b/_doc/cmds/index.rst index 4c81b46a..a357777a 100644 --- a/_doc/cmds/index.rst +++ b/_doc/cmds/index.rst @@ -10,5 +10,6 @@ Command Lines compare config + optimize sbs validate diff --git a/_doc/cmds/optimize.rst b/_doc/cmds/optimize.rst new file mode 100644 index 00000000..37450cf9 --- /dev/null +++ b/_doc/cmds/optimize.rst @@ -0,0 +1,13 @@ +-m onnx_diagnostic optimize ... optimizes an onnx model +======================================================= + +Description ++++++++++++ + +See :func:`onnx_diagnostic.helpers.optim_helper.optimize_model`. + +.. runpython:: + + from onnx_diagnostic._command_lines_parser import get_parser_optimize + + get_parser_optimize().print_help() diff --git a/_unittests/ut_helpers/test_optim_helper.py b/_unittests/ut_helpers/test_optim_helper.py new file mode 100644 index 00000000..201c80cb --- /dev/null +++ b/_unittests/ut_helpers/test_optim_helper.py @@ -0,0 +1,38 @@ +import unittest +import numpy as np +import onnx +import onnx.helper as oh +import onnx.numpy_helper as onh +from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout +from onnx_diagnostic.helpers.optim_helper import optimize_model + +TFLOAT = onnx.TensorProto.FLOAT + + +class TestOptimHelper(ExtTestCase): + @hide_stdout() + def test_optimize_model(self): + model = oh.make_model( + oh.make_graph( + [ + oh.make_node("Shape", ["X"], ["D2"], start=2, end=3), + oh.make_node("Concat", ["I1", "D2"], ["d"], axis=0), + oh.make_node("Reshape", ["X", "d"], ["Y"]), + ], + "test", + [oh.make_tensor_value_info("X", TFLOAT, [2, 3, "d"])], + [oh.make_tensor_value_info("Y", TFLOAT, [6, "d"])], + [onh.from_array(np.array([-1], dtype=np.int64), name="I1")], + ), + opset_imports=[oh.make_operatorsetid("", 18)], + ir_version=10, + ) + filename = self.dump_onnx("test_optimize_model.onnx", model) + for algo in ["default", "default+onnxruntime", "ir", "os_ort", "slim"]: + output = self.get_dump_file(f"test_optimize_model.{algo}.onnx") + with self.subTest(algo=algo): + optimize_model(algo, filename, output=output, verbose=1) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_xrun_doc/test_command_lines.py b/_unittests/ut_xrun_doc/test_command_lines.py index 1a193229..60f37180 100644 --- a/_unittests/ut_xrun_doc/test_command_lines.py +++ b/_unittests/ut_xrun_doc/test_command_lines.py @@ -10,6 +10,7 @@ get_parser_dot, get_parser_find, get_parser_lighten, + get_parser_optimize, get_parser_print, get_parser_sbs, get_parser_stats, @@ -178,6 +179,13 @@ def test_parser_compare(self): text = st.getvalue() self.assertIn("compare", text) + def test_parser_optimize(self): + st = StringIO() + with redirect_stdout(st): + get_parser_optimize().print_help() + text = st.getvalue() + self.assertIn("default", text) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_xrun_doc/test_command_lines_exe.py b/_unittests/ut_xrun_doc/test_command_lines_exe.py index 801d19c8..e88ecb87 100644 --- a/_unittests/ut_xrun_doc/test_command_lines_exe.py +++ b/_unittests/ut_xrun_doc/test_command_lines_exe.py @@ -210,6 +210,15 @@ def test_j_parser_compare(self): text = st.getvalue() self.assertIn("done with distance 0", text) + def test_l_parser_optimize(self): + output = self.get_dump_file("test_parser_optimize.onnx") + st = StringIO() + with redirect_stdout(st): + main(["optimize", "default", self.dummy_path, "-o", output, "-v", "1"]) + text = st.getvalue() + self.assertIn("default", text) + self.assertExists(output) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index bb080332..9aa7db7c 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -1547,6 +1547,107 @@ def _cmd_compare(argv: List[Any]): print(ObsComparePair.to_str(pair_cmp)) +def get_parser_optimize() -> ArgumentParser: + parser = ArgumentParser( + prog="optimize", + formatter_class=RawTextHelpFormatter, + description=textwrap.dedent( + """ + Optimizes an onnx model by fusing nodes. It looks for patterns in the graphs + and replaces them by the corresponding nodes. It also does basic optimization + such as removing identity nodes or unused nodes. + """ + ), + epilog=textwrap.dedent( + """ + The goal is to make the model faster. + Argument patterns defines the patterns to apply or the set of patterns. + It is possible to show statistics or to remove a particular pattern. + Here are some environment variables which can be used to trigger + these displays. + + Available options algorithms, default and default+runtime: + + - DROPPATTERN=: do not apply + those patterns when optimizing a model + - DUMPPATTERNS=: dumps all matched and applied + nodes when a pattern is applied + - PATTERN=: increase verbosity for specific + patterns to understand why one pattern was not applied, + this shows which line is rejecting a pattern if it seems one pattern was missed + """ + ), + ) + parser.add_argument( + "algorithm", + choices=["ir", "os_ort", "slim", "default", "default+onnxruntime"], + help="algorithm or patterns optimization to apply", + ) + parser.add_argument("input", type=str, help="onnx model to optimize") + parser.add_argument( + "-o", + "--output", + type=str, + required=False, + help="onnx model to output, if empty, if adds .opt-{algorithm}.onnx to the name", + ) + parser.add_argument( + "-v", + "--verbose", + default=0, + required=False, + type=int, + help="verbosity", + ) + parser.add_argument( + "--infer-shapes", + default=True, + action=BooleanOptionalAction, + help="infer shapes before optimizing the model", + ) + parser.add_argument( + "--processor", + default="", + help=textwrap.dedent( + """ + optimization for a specific processor, CPU, CUDA or both CPU,CUDA, + some operators are only available in one processor, it might be not used + with all + """ + ).strip("\n"), + ) + parser.add_argument( + "--remove-shape-info", + default=True, + action=BooleanOptionalAction, + help="remove shape information before outputting the model", + ) + return parser + + +def _cmd_optimize(argv: List[Any]): + parser = get_parser_optimize() + args = parser.parse_args(argv[1:]) + + from .helpers.optim_helper import optimize_model + + output = ( + args.output + if args.output + else f"{os.path.splitext(args.input)[0]}.o-{args.algorithm}.onnx" + ) + + optimize_model( + args.algorithm, + args.input, + output=output, + verbose=args.verbose, + processor=args.processor, + infer_shapes=args.infer_shapes, + remove_shape_info=args.remove_shape_info, + ) + + ############# # main parser ############# @@ -1563,16 +1664,17 @@ def get_main_parser() -> ArgumentParser: to get help for a specific command. agg - aggregates statistics from multiple files - config - prints a configuration for a model id + config - prints a configuration for a model id (on HuggingFace Hub) dot - converts an onnx model into dot format exportsample - produces a code to export a model find - find node consuming or producing a result - lighten - makes an onnx model lighter by removing the weights, + lighten - makes an onnx model lighter by removing the weights + optimize - optimizes an onnx model print - prints the model on standard output sbs - compares an exported program and a onnx model stats - produces statistics on a model unlighten - restores an onnx model produces by the previous experiment - validate - validate a model + validate - validate a model (knowing its model id on HuggginFace Hub) """ ), ) @@ -1585,6 +1687,7 @@ def get_main_parser() -> ArgumentParser: "exportsample", "find", "lighten", + "optimize", "print", "sbs", "stats", @@ -1605,6 +1708,7 @@ def main(argv: Optional[List[Any]] = None): exportsample=_cmd_export_sample, find=_cmd_find, lighten=_cmd_lighten, + optimize=_cmd_optimize, print=_cmd_print, sbs=_cmd_sbs, stats=_cmd_stats, @@ -1631,6 +1735,7 @@ def main(argv: Optional[List[Any]] = None): exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator] find=get_parser_find, lighten=get_parser_lighten, + optimize=get_parser_optimize, print=get_parser_print, sbs=get_parser_sbs, stats=get_parser_stats, diff --git a/onnx_diagnostic/helpers/optim_helper.py b/onnx_diagnostic/helpers/optim_helper.py new file mode 100644 index 00000000..64fe3de6 --- /dev/null +++ b/onnx_diagnostic/helpers/optim_helper.py @@ -0,0 +1,116 @@ +from typing import Optional, Union +import pprint +import onnx + + +def optimize_model( + algorithm: str, + model: Union[onnx.ModelProto, str], + output: Optional[str] = None, + processor: Optional[str] = None, + infer_shapes: bool = True, + remove_shape_info: bool = False, + verbose: int = 1, +): + """ + Optimizes an onnx model by fusing nodes. It looks for patterns in the graphs + and replaces them by the corresponding nodes. It also does basic optimization + such as removing identity nodes or unused nodes. + + :param algorithm: algorithm to choose + :param model: model to optimize as a proto or a filename + :param output: if not empty, the optimized model is saved + :param processor: optimization are done for the processor + :param infer_shapes: infer shapes before optimizing, this might not be + available for all algorithm + :param remove_shape_info: remove shape information before saving the model + :param verbose: verbosity level + :return: optimized model + + The goal is to make the model faster. + Argument patterns defines the patterns to apply or the set of patterns. + It is possible to show statistics or to remove a particular pattern. + Here are some environment variables which can be used to trigger + these displays. + + Available options algorithms, default and default+runtime: + + - ``DROPPATTERN=``: do not apply + those patterns when optimizing a model + - ``DUMPPATTERNS=``: dumps all matched and applied nodes when a pattern is applied + - ``PATTERN=``: increase verbosity + for specific patterns to understand why one pattern was not applied, + this shows which line is rejecting a pattern if it seems one pattern was missed + """ + if isinstance(model, str): + if verbose: + print(f"[optimize_model] load {model!r}") + proto = onnx.load(model) + if verbose: + print("[optimize_model] done loading.") + else: + proto = model + + if verbose: + print(f"[optimize_model] optimize with {algorithm!r}") + if algorithm in {"default", "default+onnxruntime"}: + from experimental_experiment.xoptim import get_pattern_list + from experimental_experiment.xbuilder import GraphBuilder, OptimizationOptions + + pats = get_pattern_list(algorithm) + + gr = GraphBuilder( + proto, + infer_shapes_options=infer_shapes, + optimization_options=OptimizationOptions( + patterns=pats, + verbose=verbose, + remove_unused=True, + constant_folding=True, + remove_identity=True, + max_iter=max(100, len(proto.graph.node) // 2), + processor=processor or "CPU", + ), + ) + if verbose: + print(f"[optimize_model] starts optimizing with {len(pats)} patterns") + print(f"[optimize_model] model has {len(proto.graph.node)} nodes") + opt_onx, report = gr.to_onnx(optimize=True, return_optimize_report=True) + if verbose: + print("[optimize_model] optimization report") + pprint.pprint(report) + print("[optimize_model] done") + + elif algorithm == "slim": + import onnxslim + + opt_onx = onnxslim.slim(proto, no_shape_infer=not infer_shapes) + elif algorithm in {"ir", "os_ort"}: + import onnx_ir + import onnxscript.optimizer + from onnxscript.rewriter.ort_fusions import optimize_for_ort + + model_ir = onnx_ir.from_proto(proto) + if algorithm == "ir": + onnxscript.optimizer.optimize(model_ir) + else: + optimize_for_ort(model_ir) + opt_onx = onnx_ir.serde.serialize_model(model_ir) + + del proto + if verbose: + print(f"[optimize_model] done optimizing, model has {len(opt_onx.graph.node)} nodes") + if remove_shape_info: + if verbose: + print(f"[optimize_model] remove shape information {len(opt_onx.graph.value_info)}") + del opt_onx.graph.value_info[:] + if verbose: + print("[optimize_model] done removing shape info") + + if output: + if verbose: + print(f"[optimize_model] save file into {output!r}") + onnx.save(opt_onx, output, save_as_external_data=True) + if verbose: + print("[optimize_model] done saving") + return opt_onx diff --git a/requirements-dev.txt b/requirements-dev.txt index 917587dd..211f3bd2 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -9,6 +9,7 @@ onnx-array-api>=0.3.1 onnx onnxruntime-genai onnxscript +onnxslim openpyxl packaging pandas