diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 814f5d08..0f7a49e2 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.8.3 +++++ +* :pr:`331`: adds a helper to convert an onnx model into dot * :pr:`330`: fixes access rope_parameters for ``transformers>=5`` * :pr:`329`: supports lists with OnnxruntimeEvaluator * :pr:`326`: use ConcatFromSequence in LoopMHA with the loop diff --git a/_doc/api/helpers/dot_helper.rst b/_doc/api/helpers/dot_helper.rst new file mode 100644 index 00000000..b5bedc38 --- /dev/null +++ b/_doc/api/helpers/dot_helper.rst @@ -0,0 +1,7 @@ + +onnx_diagnostic.helpers.dot_helper +================================== + +.. automodule:: onnx_diagnostic.helpers.dot_helper + :members: + :no-undoc-members: diff --git a/_doc/api/helpers/index.rst b/_doc/api/helpers/index.rst index 7d7fab1e..e42e553c 100644 --- a/_doc/api/helpers/index.rst +++ b/_doc/api/helpers/index.rst @@ -11,6 +11,7 @@ onnx_diagnostic.helpers cache_helper config_helper doc_helper + dot_helper fake_tensor_helper graph_helper helper diff --git a/_doc/recipes/plot_dynamic_shapes_json.py b/_doc/recipes/plot_dynamic_shapes_json.py index 0c8a9871..19569fa9 100644 --- a/_doc/recipes/plot_dynamic_shapes_json.py +++ b/_doc/recipes/plot_dynamic_shapes_json.py @@ -74,7 +74,7 @@ def flatten_unflatten_like_dynamic_shapes(obj): start = 0 end = 0 subtrees = [] - for subspec in spec.children_specs: + for subspec in (spec.children() if hasattr(spec, "children") else spec.children_specs): end += subspec.num_leaves value = subspec.unflatten(flat[start:end]) value = flatten_unflatten_like_dynamic_shapes(value) diff --git a/_unittests/ut_helpers/test_dot_helper.py b/_unittests/ut_helpers/test_dot_helper.py new file mode 100644 index 00000000..2e591852 --- /dev/null +++ b/_unittests/ut_helpers/test_dot_helper.py @@ -0,0 +1,80 @@ +import textwrap +import unittest +import onnx +import onnx.helper as oh +from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers +from onnx_diagnostic.helpers.dot_helper import to_dot +from onnx_diagnostic.export.api import to_onnx +from onnx_diagnostic.torch_export_patches import torch_export_patches +from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs + + +class TestDotHelper(ExtTestCase): + def test_custom_doc_kernels_layer_normalization(self): + TFLOAT16 = onnx.TensorProto.FLOAT16 + model = oh.make_model( + oh.make_graph( + [ + oh.make_node( + "LayerNormalization", + ["X", "W", "B"], + ["ln"], + axis=-1, + epsilon=9.999999974752427e-7, + ), + oh.make_node( + "Add", ["ln", "W"], ["Z"], axis=-1, epsilon=9.999999974752427e-7 + ), + ], + "dummy", + [ + oh.make_tensor_value_info("X", TFLOAT16, ["b", "c", "d"]), + oh.make_tensor_value_info("W", TFLOAT16, ["d"]), + oh.make_tensor_value_info("B", TFLOAT16, ["d"]), + ], + [oh.make_tensor_value_info("Z", TFLOAT16, ["b", "c", "d"])], + ), + ir_version=9, + opset_imports=[oh.make_opsetid("", 18)], + ) + dot = to_dot(model) + expected = textwrap.dedent( + """ + digraph { + graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, ranksep=0.2, fontsize=8]; + node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box]; + edge [arrowhead=vee, fontsize=7, labeldistance=-5, labelangle=0]; + I_0 [label="X\\nFLOAT16(b,c,d)", fillcolor="#aaeeaa"]; + I_1 [label="W\\nFLOAT16(d)", fillcolor="#aaeeaa"]; + I_2 [label="B\\nFLOAT16(d)", fillcolor="#aaeeaa"]; + LayerNormalization_3 [label="LayerNormalization(., ., ., axis=-1)", fillcolor="#cccccc"]; + Add_4 [label="Add(., ., axis=-1)", fillcolor="#cccccc"]; + I_0 -> LayerNormalization_3 [label="FLOAT16(b,c,d)"]; + I_1 -> LayerNormalization_3 [label="FLOAT16(d)"]; + I_2 -> LayerNormalization_3 [label="FLOAT16(d)"]; + LayerNormalization_3 -> Add_4 [label="FLOAT16(b,c,d)"]; + I_1 -> Add_4 [label="FLOAT16(d)"]; + O_5 [label="Z\\nFLOAT16(d)", fillcolor="#aaaaee"]; + Add_4 -> O_5; + } + """ + ) + self.maxDiff = None + self.assertEqual(expected.strip("\n "), dot.strip("\n ")) + + @requires_transformers("4.57") + def test_dot_plot_tiny(self): + data = get_untrained_model_with_inputs("arnir0/Tiny-LLM") + model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] + with torch_export_patches(patch_transformers=True): + em = to_onnx(model, inputs, dynamic_shapes=ds, exporter="custom") + dot = to_dot(em.model_proto) + name = self.get_dump_file("test_dot_plot_tiny.dot") + with open(name, "w") as f: + f.write(dot) + # dot -Tpng dump_test/test_dot_plot_tiny.dot -o dump_test/test_dot_plot_tiny.png + self.assertIn("-> Add", dot) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_xrun_doc/test_command_lines.py b/_unittests/ut_xrun_doc/test_command_lines.py index 5317190f..449d5fac 100644 --- a/_unittests/ut_xrun_doc/test_command_lines.py +++ b/_unittests/ut_xrun_doc/test_command_lines.py @@ -6,6 +6,7 @@ get_main_parser, get_parser_agg, get_parser_config, + get_parser_dot, get_parser_find, get_parser_lighten, get_parser_print, @@ -23,6 +24,7 @@ def test_main_parser(self): get_main_parser().print_help() text = st.getvalue() self.assertIn("lighten", text) + self.assertIn("dot", text) def test_parser_lighten(self): st = StringIO() @@ -87,6 +89,13 @@ def test_parser_sbs(self): text = st.getvalue() self.assertIn("--onnx", text) + def test_parser_dot(self): + st = StringIO() + with redirect_stdout(st): + get_parser_dot().print_help() + text = st.getvalue() + self.assertIn("--run", text) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_xrun_doc/test_command_lines_exe.py b/_unittests/ut_xrun_doc/test_command_lines_exe.py index f9ca7c5c..6ea08def 100644 --- a/_unittests/ut_xrun_doc/test_command_lines_exe.py +++ b/_unittests/ut_xrun_doc/test_command_lines_exe.py @@ -162,6 +162,45 @@ def forward(self, x): sdf = df[(df.ep_target == "placeholder") & (df.onnx_op_type == "initializer")] self.assertEqual(sdf.shape[0], 4) + @ignore_warnings(UserWarning) + @requires_transformers("4.53") + def test_i_parser_dot(self): + import torch + + class Model(torch.nn.Module): + def __init__(self): + super(Model, self).__init__() + self.fc1 = torch.nn.Linear(10, 32) # input size 10 → hidden size 32 + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(32, 1) # hidden → output + + def forward(self, x): + x = self.relu(self.fc1(x)) + x = self.fc2(x) + return x + + inputs = dict(x=torch.randn((5, 10))) + ds = dict(x={0: "batch"}) + onnx_file = self.get_dump_file("test_i_parser_dot.model.onnx") + to_onnx( + Model(), + kwargs=inputs, + dynamic_shapes=ds, + exporter="custom", + filename=onnx_file, + ) + + output = self.get_dump_file("test_i_parser_dot.dot") + args = ["dot", onnx_file, "-v", "1", "-o", output] + if not self.unit_test_going(): + args.extend(["--run", "svg"]) + + st = StringIO() + with redirect_stdout(st): + main(args) + text = st.getvalue() + print(text) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index 6a34d763..519e5dc2 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -11,6 +11,77 @@ from argparse import ArgumentParser, RawTextHelpFormatter, BooleanOptionalAction +def get_parser_dot() -> ArgumentParser: + parser = ArgumentParser( + prog="dot", + description=textwrap.dedent( + """ + Converts a model into a dot file dot can draw into a graph. + """ + ), + ) + parser.add_argument("input", type=str, help="onnx model to lighten") + parser.add_argument( + "-o", + "--output", + default="", + type=str, + required=False, + help="dot model to output or empty to print out the result", + ) + parser.add_argument( + "-v", + "--verbose", + type=int, + default=0, + required=False, + help="verbosity", + ) + parser.add_argument( + "-r", + "--run", + default="", + required=False, + help="run dot, in that case, format must be given (svg, png)", + ) + return parser + + +def _cmd_dot(argv: List[Any]): + import subprocess + from .helpers.dot_helper import to_dot + + parser = get_parser_dot() + args = parser.parse_args(argv[1:]) + if args.verbose: + print(f"-- loads {args.input!r}") + onx = onnx.load(args.input, load_external_data=False) + if args.verbose: + print("-- converts into dot") + dot = to_dot(onx) + if args.output: + if args.verbose: + print(f"-- saves into {args.output}") + with open(args.output, "w") as f: + f.write(dot) + else: + print(dot) + if args.run: + assert args.output, "Cannot run dot without an output file." + cmds = ["dot", f"-T{args.run}", args.output, "-o", f"{args.output}.{args.run}"] + if args.verbose: + print(f"-- run {' '.join(cmds)}") + p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + res = p.communicate() + out, err = res + if out: + print("--") + print(out) + if err: + print("--") + print(err) + + def get_parser_lighten() -> ArgumentParser: parser = ArgumentParser( prog="lighten", @@ -1412,6 +1483,7 @@ def get_main_parser() -> ArgumentParser: agg - aggregates statistics from multiple files config - prints a configuration for a model id + dot - converts an onnx model into dot format exportsample - produces a code to export a model find - find node consuming or producing a result lighten - makes an onnx model lighter by removing the weights, @@ -1428,6 +1500,7 @@ def get_main_parser() -> ArgumentParser: choices=[ "agg", "config", + "dot", "exportsample", "find", "lighten", @@ -1446,6 +1519,7 @@ def main(argv: Optional[List[Any]] = None): fcts = dict( agg=_cmd_agg, config=_cmd_config, + dot=_cmd_dot, exportsample=_cmd_export_sample, find=_cmd_find, lighten=_cmd_lighten, @@ -1470,6 +1544,7 @@ def main(argv: Optional[List[Any]] = None): parsers = dict( agg=get_parser_agg, config=get_parser_config, + dot=get_parser_dot, exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator] find=get_parser_find, lighten=get_parser_lighten, diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index d00617be..3ff36d9b 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -80,7 +80,7 @@ def flatten_unflatten_for_dynamic_shapes( start = 0 end = 0 subtrees = [] - for subspec in spec.children_specs: + for subspec in (spec.children() if hasattr(spec, "children") else spec.children_specs): end += subspec.num_leaves value = subspec.unflatten(flat[start:end]) value = flatten_unflatten_for_dynamic_shapes( diff --git a/onnx_diagnostic/helpers/dot_helper.py b/onnx_diagnostic/helpers/dot_helper.py new file mode 100644 index 00000000..328446cd --- /dev/null +++ b/onnx_diagnostic/helpers/dot_helper.py @@ -0,0 +1,210 @@ +from typing import Dict, Set +import onnx +import onnx.numpy_helper as onh +from .onnx_helper import onnx_dtype_name, pretty_onnx + + +def _get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]: + hidden = set() + memo = ( + {i.name for i in graph.initializer} + | {i.values.name for i in graph.sparse_initializer} + | {i.name for i in graph.input} + ) + for node in graph.node: + for i in node.input: + if i not in memo: + hidden.add(i) + for att in node.attribute: + if att.type == onnx.AttributeProto.GRAPH and att.g: + hid = _get_hidden_inputs(att.g) + less = set(h for h in hid if h not in memo) + hidden |= less + memo |= set(node.output) + return hidden + + +def _make_node_label(node: onnx.NodeProto, tiny_inits: Dict[str, str]) -> str: + els = [f"{node.domain}.\\n{node.op_type}" if node.domain else node.op_type, "("] + ee = [tiny_inits.get(i, ".") if i else "" for i in node.input] + for att in node.attribute: + if att.name == "to": + ee.append(f"{att.name}={onnx_dtype_name(att.i)}") + elif att.name in {"to", "axis", "value_int", "stash_type", "start", "end"}: + ee.append(f"{att.name}={att.i}") + elif att.name in {"value_float"}: + ee.append(f"{att.name}={att.f}") + elif att.name in {"value_floats"}: + ee.append(f"{att.name}={att.floats}") + elif att.name in {"value_ints", "perm"}: + ee.append(f"{att.name}={att.ints}") + els.append(", ".join(ee)) + els.append(")") + if node.op_type == "Constant": + els.extend([" -> ", node.output[0]]) + return "".join(els) + + +def _make_edge_label(value_info: onnx.ValueInfoProto, multi_line: bool = False) -> str: + itype = value_info.type.tensor_type.elem_type + if itype == onnx.TensorProto.UNDEFINED: + return "" + shape = tuple( + d.dim_param if d.dim_param else d.dim_value + for d in value_info.type.tensor_type.shape.dim + ) + res = [ + str(a) + for a in [("?" if isinstance(s, str) and s.startswith("unk") else s) for s in shape] + ] + sshape = ",".join(res) + if multi_line and len(sshape) > 30: + sshape = ",\\n".join(res) + return f"{onnx_dtype_name(itype)}({sshape})" + + +def to_dot(model: onnx.ModelProto) -> str: + """ + Converts a model into a dot graph. + Here is an example: + + .. gdot:: + :script: DOT-SECTION + :process: + + from onnx_diagnostic.helpers.dot_helper import to_dot + from onnx_diagnostic.export.api import to_onnx + from onnx_diagnostic.torch_export_patches import torch_export_patches + from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs + + data = get_untrained_model_with_inputs("arnir0/Tiny-LLM") + model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] + with torch_export_patches(patch_transformers=True): + em = to_onnx(model, inputs, dynamic_shapes=ds, exporter="custom") + dot = to_dot(em.model_proto) + print("DOT-SECTION", dot) + + Or this one obtained with :func:`torch.onnx.export`. + + .. gdot:: + :script: DOT-SECTION + :process: + + from onnx_diagnostic.helpers.dot_helper import to_dot + from onnx_diagnostic.export.api import to_onnx + from onnx_diagnostic.torch_export_patches import torch_export_patches + from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs + + data = get_untrained_model_with_inputs("arnir0/Tiny-LLM") + model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] + with torch_export_patches(patch_transformers=True): + em = to_onnx(model, kwargs=inputs, dynamic_shapes=ds, exporter="onnx-dynamo") + dot = to_dot(em.model_proto) + print("DOT-SECTION", dot) + """ + _unique: Dict[int, int] = {} + + def _mkn(obj: object) -> int: + id_obj = id(obj) + if id_obj in _unique: + return _unique[id_obj] + i = len(_unique) + _unique[id_obj] = i + return i + + model = onnx.shape_inference.infer_shapes(model) + + op_type_colors = { + "Shape": "#d2a81f", + "MatMul": "#ee9999", + "Transpose": "#ee99ee", + "Reshape": "#eeeeee", + "Squeeze": "#eeeeee", + "Unsqueeze": "#eeeeee", + } + + edge_label = {} + for val in model.graph.value_info: + edge_label[val.name] = _make_edge_label(val, multi_line=True) + + rows = [ + "digraph {", + ( + " graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, " + "ranksep=0.2, fontsize=8];" + ), + ' node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box];', + " edge [arrowhead=vee, fontsize=7, labeldistance=-5, labelangle=0];", + ] + inputs = list(model.graph.input) + outputs = list(model.graph.output) + nodes = list(model.graph.node) + inits = list(model.graph.initializer) + tiny_inits = {} + name_to_ids = {} + for inp in inputs: + if not inp.name: + continue + lab = _make_edge_label(inp) + rows.append(f' I_{_mkn(inp)} [label="{inp.name}\\n{lab}", fillcolor="#aaeeaa"];') + name_to_ids[inp.name] = f"I_{_mkn(inp)}" + edge_label[inp.name] = _make_edge_label(inp, multi_line=True) + for init in inits: + shape = tuple(init.dims) + if len(shape) == 0 or (len(shape) == 1 and shape[0] < 10): + a = onh.to_array(init) + tiny_inits[init.name] = ( + str(a) if len(shape) == 0 else f"[{', '.join([str(i) for i in a])}]" + ) + else: + ls = f"{onnx_dtype_name(init.data_type)}({', '.join(map(str,shape))})" + rows.append(f' i_{_mkn(init)} [label="{init.name}\\n{ls}", fillcolor="#cccc00"];') + name_to_ids[init.name] = f"i_{_mkn(init)}" + edge_label[init.name] = ls + for node in nodes: + color = op_type_colors.get(node.op_type, "#cccccc") + label = _make_node_label(node, tiny_inits) + rows.append(f' {node.op_type}_{_mkn(node)} [label="{label}", fillcolor="{color}"];') + name_to_ids.update({o: f"{node.op_type}_{_mkn(node)}" for o in node.output if o}) + + # nodes + done = set() + for node in nodes: + names = list(node.input) + for i in names: + if not i or i in tiny_inits: + continue + if i not in name_to_ids: + raise ValueError(f"Unable to find {i!r}\n{pretty_onnx(model)}") + edge = name_to_ids[i], f"{node.op_type}_{_mkn(node)}" + if edge in done: + continue + done.add(edge) + lab = edge_label.get(i, "") + if lab: + ls = ",".join([f'label="{lab}"']) + lab = f" [{ls}]" + rows.append(f" {edge[0]} -> {edge[1]}{lab};") + if node.op_type in {"Scan", "Loop", "If"}: + unique = set() + for att in node.attribute: + if att.type == onnx.AttributeProto.GRAPH: + unique |= _get_hidden_inputs(att.g) + for i in unique: + edge = name_to_ids[i], _mkn(node) # type: ignore[assignment] + if edge in done: + continue + done.add(edge) + rows.append(f" {edge[0]} -> {edge[1]} [style=dotted];") + + # outputs + for out in outputs: + if not out.name: + continue + lab = _make_edge_label(inp) + rows.append(f' O_{_mkn(out)} [label="{out.name}\\n{lab}", fillcolor="#aaaaee"];') + edge = name_to_ids[out.name], f"O_{_mkn(out)}" + rows.append(f" {edge[0]} -> {edge[1]};") + + rows.append("}") + return "\n".join(rows) diff --git a/pyproject.toml b/pyproject.toml index bf68f382..5e12ab22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -172,6 +172,7 @@ select = [ "_scripts/compare_model_execution.py" = ["E402", "F401"] "_doc/technical/plot_*.py" = ["E402", "B018", "PIE808", "RUF015", "SIM105", "SIM117"] "_unittests/*/test*.py" = ["B008", "B904", "PIE808", "SIM117", "SIM105", "UP008"] +"_unittests/ut_helpers/test_dot_helper.py" = ["E501"] "_unittests/ut_tasks/try_export.py" = ["B008", "B904", "E501", "PIE808", "SIM117", "SIM105", "UP008"] "onnx_diagnostic/export/__init__.py" = ["F401"] "onnx_diagnostic/helpers/__init__.py" = ["F401"]