From a6044a2f8dfac6167f80688056e53ab26eb7a269 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 1 Dec 2025 09:25:16 +0100 Subject: [PATCH 1/7] Adds a helper to convert an onnx model into dot --- _doc/api/helpers/dot_helper.rst | 7 + _doc/api/helpers/index.rst | 1 + _doc/recipes/plot_dynamic_shapes_json.py | 2 +- _unittests/ut_helpers/test_dot_helper.py | 79 ++++++++++ onnx_diagnostic/helpers/cache_helper.py | 2 +- onnx_diagnostic/helpers/dot_helper.py | 187 +++++++++++++++++++++++ pyproject.toml | 1 + 7 files changed, 277 insertions(+), 2 deletions(-) create mode 100644 _doc/api/helpers/dot_helper.rst create mode 100644 _unittests/ut_helpers/test_dot_helper.py create mode 100644 onnx_diagnostic/helpers/dot_helper.py diff --git a/_doc/api/helpers/dot_helper.rst b/_doc/api/helpers/dot_helper.rst new file mode 100644 index 00000000..b5bedc38 --- /dev/null +++ b/_doc/api/helpers/dot_helper.rst @@ -0,0 +1,7 @@ + +onnx_diagnostic.helpers.dot_helper +================================== + +.. automodule:: onnx_diagnostic.helpers.dot_helper + :members: + :no-undoc-members: diff --git a/_doc/api/helpers/index.rst b/_doc/api/helpers/index.rst index 7d7fab1e..e42e553c 100644 --- a/_doc/api/helpers/index.rst +++ b/_doc/api/helpers/index.rst @@ -11,6 +11,7 @@ onnx_diagnostic.helpers cache_helper config_helper doc_helper + dot_helper fake_tensor_helper graph_helper helper diff --git a/_doc/recipes/plot_dynamic_shapes_json.py b/_doc/recipes/plot_dynamic_shapes_json.py index 0c8a9871..717c276f 100644 --- a/_doc/recipes/plot_dynamic_shapes_json.py +++ b/_doc/recipes/plot_dynamic_shapes_json.py @@ -74,7 +74,7 @@ def flatten_unflatten_like_dynamic_shapes(obj): start = 0 end = 0 subtrees = [] - for subspec in spec.children_specs: + for subspec in spec.children(): end += subspec.num_leaves value = subspec.unflatten(flat[start:end]) value = flatten_unflatten_like_dynamic_shapes(value) diff --git a/_unittests/ut_helpers/test_dot_helper.py b/_unittests/ut_helpers/test_dot_helper.py new file mode 100644 index 00000000..3726e085 --- /dev/null +++ b/_unittests/ut_helpers/test_dot_helper.py @@ -0,0 +1,79 @@ +import textwrap +import unittest +import onnx +import onnx.helper as oh +from onnx_diagnostic.ext_test_case import ExtTestCase +from onnx_diagnostic.helpers.dot_helper import to_dot +from onnx_diagnostic.export.api import to_onnx +from onnx_diagnostic.torch_export_patches import torch_export_patches +from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs + + +class TestDotHelper(ExtTestCase): + def test_custom_doc_kernels_layer_normalization(self): + TFLOAT16 = onnx.TensorProto.FLOAT16 + model = oh.make_model( + oh.make_graph( + [ + oh.make_node( + "LayerNormalization", + ["X", "W", "B"], + ["ln"], + axis=-1, + epsilon=9.999999974752427e-7, + ), + oh.make_node( + "Add", ["ln", "W"], ["Z"], axis=-1, epsilon=9.999999974752427e-7 + ), + ], + "dummy", + [ + oh.make_tensor_value_info("X", TFLOAT16, ["b", "c", "d"]), + oh.make_tensor_value_info("W", TFLOAT16, ["d"]), + oh.make_tensor_value_info("B", TFLOAT16, ["d"]), + ], + [oh.make_tensor_value_info("Z", TFLOAT16, ["b", "c", "d"])], + ), + ir_version=9, + opset_imports=[oh.make_opsetid("", 18)], + ) + dot = to_dot(model) + expected = textwrap.dedent( + """ + digraph { + graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, ranksep=0.2, fontsize=8]; + node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box]; + edge [arrowhead=vee, fontsize=6]; + I_0 [label="X", fillcolor="#aaeeaa"]; + I_1 [label="W", fillcolor="#aaeeaa"]; + I_2 [label="B", fillcolor="#aaeeaa"]; + LayerNormalization_3 [label="LayerNormalization(., ., ., axis=-1)", fillcolor="#cccccc"]; + Add_4 [label="Add(., ., axis=-1)", fillcolor="#cccccc"]; + I_0 -> LayerNormalization_3; + I_1 -> LayerNormalization_3; + I_2 -> LayerNormalization_3; + LayerNormalization_3 -> Add_4 [label="FLOAT16(b,c,d)"]; + I_1 -> Add_4; + O_5 [label="Z", fillcolor="#aaaaee"]; + Add_4 -> O_5; + } + """ + ) + self.maxDiff = None + self.assertEqual(expected.strip("\n "), dot.strip("\n ")) + + def test_dot_plot_tiny(self): + data = get_untrained_model_with_inputs("arnir0/Tiny-LLM") + model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] + with torch_export_patches(patch_transformers=True): + em = to_onnx(model, inputs, dynamic_shapes=ds, exporter="custom") + dot = to_dot(em.model_proto) + name = self.get_dump_file("test_dot_plot_tiny.dot") + with open(name, "w") as f: + f.write(dot) + # dot -Tpng dump_test/test_dot_plot_tiny.dot -o dump_test/test_dot_plot_tiny.png + self.assertIn("-> Add", dot) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index d00617be..582d71a3 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -80,7 +80,7 @@ def flatten_unflatten_for_dynamic_shapes( start = 0 end = 0 subtrees = [] - for subspec in spec.children_specs: + for subspec in spec.children(): end += subspec.num_leaves value = subspec.unflatten(flat[start:end]) value = flatten_unflatten_for_dynamic_shapes( diff --git a/onnx_diagnostic/helpers/dot_helper.py b/onnx_diagnostic/helpers/dot_helper.py new file mode 100644 index 00000000..9d3cc7ac --- /dev/null +++ b/onnx_diagnostic/helpers/dot_helper.py @@ -0,0 +1,187 @@ +from typing import Set +import onnx +from .onnx_helper import onnx_dtype_name, pretty_onnx + + +def _get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]: + hidden = set() + memo = ( + {i.name for i in graph.initializer} + | {i.values.name for i in graph.sparse_initializer} + | {i.name for i in graph.input} + ) + for node in graph.node: + for i in node.input: + if i not in memo: + hidden.add(i) + for att in node.attribute: + if att.type == onnx.AttributeProto.GRAPH and att.g: + hid = _get_hidden_inputs(att.g) + less = set(h for h in hid if h not in memo) + hidden |= less + memo |= set(node.output) + return hidden + + +def _make_node_label(node: onnx.NodeProto) -> str: + els = [f"{node.domain}.\\n{node.op_type}" if node.domain else node.op_type, "("] + ee = ["." if i else "" for i in node.input] + for att in node.attribute: + if att.name == "to": + ee.append(f"{att.name}={onnx_dtype_name(att.i)}") + elif att.name in {"to", "axis", "value_int", "stash_type"}: + ee.append(f"{att.name}={att.i}") + elif att.name in {"value_float"}: + ee.append(f"{att.name}={att.f}") + elif att.name in {"value_floats"}: + ee.append(f"{att.name}={att.floats}") + elif att.name in {"value_ints", "perm"}: + ee.append(f"{att.name}={att.ints}") + els.append(", ".join(ee)) + els.append(")") + if node.op_type == "Constant": + els.extend([" -> ", node.output[0]]) + return "".join(els) + + +def to_dot(model: onnx.ModelProto) -> str: + """ + Converts a model into a dot graph. + Here is an example: + + .. gdot:: + :script: DOT-SECTION + :process: + + from onnx_diagnostic.helpers.dot_helper import to_dot + from onnx_diagnostic.export.api import to_onnx + from onnx_diagnostic.torch_export_patches import torch_export_patches + from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs + + data = get_untrained_model_with_inputs("arnir0/Tiny-LLM") + model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] + with torch_export_patches(patch_transformers=True): + em = to_onnx(model, inputs, dynamic_shapes=ds, exporter="custom") + dot = to_dot(em.model_proto) + print("DOT-SECTION", dot) + + Or this one obtained with :func:`torch.onnx.export`. + + .. gdot:: + :script: DOT-SECTION + :process: + + from onnx_diagnostic.helpers.dot_helper import to_dot + from onnx_diagnostic.export.api import to_onnx + from onnx_diagnostic.torch_export_patches import torch_export_patches + from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs + + data = get_untrained_model_with_inputs("arnir0/Tiny-LLM") + model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] + with torch_export_patches(patch_transformers=True): + em = to_onnx(model, kwargs=inputs, dynamic_shapes=ds, exporter="onnx-dynamo") + dot = to_dot(em.model_proto) + print("DOT-SECTION", dot) + """ + _unique = {} + + def _mkn(obj: object) -> int: + id_obj = id(obj) + if id_obj in _unique: + return _unique[id_obj] + i = len(_unique) + _unique[id_obj] = i + return i + + model = onnx.shape_inference.infer_shapes(model) + + op_type_colors = { + "Shape": "#eeeeee", + "MatMul": "#ee9999", + "Transpose": "#ee99ee", + } + + edge_label = {} + for val in model.graph.value_info: + itype = val.type.tensor_type.elem_type + if itype == onnx.TensorProto.UNDEFINED: + continue + shape = tuple( + d.dim_param if d.dim_param else d.dim_value for d in val.type.tensor_type.shape.dim + ) + sshape = ",".join( + map( + str, + [("?" if isinstance(s, str) and s.startswith("unk") else s) for s in shape], + ) + ) + edge_label[val.name] = f"{onnx_dtype_name(itype)}({sshape})" + + rows = [ + "digraph {", + ( + " graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, " + "ranksep=0.2, fontsize=8];" + ), + ' node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box];', + " edge [arrowhead=vee, fontsize=6];", + ] + inputs = list(model.graph.input) + outputs = list(model.graph.output) + nodes = list(model.graph.node) + inits = list(model.graph.initializer) + name_to_ids = {} + for inp in inputs: + if not inp.name: + continue + rows.append(f' I_{_mkn(inp)} [label="{inp.name}", fillcolor="#aaeeaa"];') + name_to_ids[inp.name] = f"I_{_mkn(inp)}" + for init in inits: + rows.append(f' i_{_mkn(init)} [label="{init.name}", fillcolor="#cccc00"];') + name_to_ids[init.name] = f"i_{_mkn(init)}" + for node in nodes: + color = op_type_colors.get(node.op_type, "#cccccc") + label = _make_node_label(node) + rows.append(f' {node.op_type}_{_mkn(node)} [label="{label}", fillcolor="{color}"];') + name_to_ids.update({o: f"{node.op_type}_{_mkn(node)}" for o in node.output if o}) + + # nodes + done = set() + for node in nodes: + names = list(node.input) + for i in names: + if not i: + continue + if i not in name_to_ids: + raise ValueError(f"Unable to find {i!r}\n{pretty_onnx(model)}") + edge = name_to_ids[i], f"{node.op_type}_{_mkn(node)}" + if edge in done: + continue + done.add(edge) + lab = edge_label.get(i, "") + if lab: + ls = ",".join([f'label="{lab}"']) + lab = f" [{ls}]" + rows.append(f" {edge[0]} -> {edge[1]}{lab};") + if node.op_type in {"Scan", "Loop", "If"}: + unique = set() + for att in node.attribute: + if att.type == onnx.AttributeProto.GRAPH: + unique |= _get_hidden_inputs(att.g) + for i in unique: + edge = name_to_ids[i], _mkn(node) + if edge in done: + continue + done.add(edge) + rows.append(f" {edge[0]} -> {edge[1]} [style=dotted];") + + # outputs + for out in outputs: + if not out.name: + continue + rows.append(f' O_{_mkn(out)} [label="{out.name}", fillcolor="#aaaaee"];') + edge = name_to_ids[out.name], f"O_{_mkn(out)}" + rows.append(f" {edge[0]} -> {edge[1]};") + + rows.append("}") + return "\n".join(rows) diff --git a/pyproject.toml b/pyproject.toml index bf68f382..5e12ab22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -172,6 +172,7 @@ select = [ "_scripts/compare_model_execution.py" = ["E402", "F401"] "_doc/technical/plot_*.py" = ["E402", "B018", "PIE808", "RUF015", "SIM105", "SIM117"] "_unittests/*/test*.py" = ["B008", "B904", "PIE808", "SIM117", "SIM105", "UP008"] +"_unittests/ut_helpers/test_dot_helper.py" = ["E501"] "_unittests/ut_tasks/try_export.py" = ["B008", "B904", "E501", "PIE808", "SIM117", "SIM105", "UP008"] "onnx_diagnostic/export/__init__.py" = ["F401"] "onnx_diagnostic/helpers/__init__.py" = ["F401"] From 19a76a82525852a7a1899b2ddb976e7a35d0e47d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 1 Dec 2025 09:26:22 +0100 Subject: [PATCH 2/7] doc --- CHANGELOGS.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 814f5d08..0f7a49e2 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.8.3 +++++ +* :pr:`331`: adds a helper to convert an onnx model into dot * :pr:`330`: fixes access rope_parameters for ``transformers>=5`` * :pr:`329`: supports lists with OnnxruntimeEvaluator * :pr:`326`: use ConcatFromSequence in LoopMHA with the loop From 70526e6a2da01bd0bb1f58147d79ffe8ec997c4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 1 Dec 2025 09:27:52 +0100 Subject: [PATCH 3/7] mypy --- onnx_diagnostic/helpers/dot_helper.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnx_diagnostic/helpers/dot_helper.py b/onnx_diagnostic/helpers/dot_helper.py index 9d3cc7ac..a871d5a0 100644 --- a/onnx_diagnostic/helpers/dot_helper.py +++ b/onnx_diagnostic/helpers/dot_helper.py @@ -1,4 +1,4 @@ -from typing import Set +from typing import Dict, Set import onnx from .onnx_helper import onnx_dtype_name, pretty_onnx @@ -83,7 +83,7 @@ def to_dot(model: onnx.ModelProto) -> str: dot = to_dot(em.model_proto) print("DOT-SECTION", dot) """ - _unique = {} + _unique: Dict[int, int] = {} def _mkn(obj: object) -> int: id_obj = id(obj) @@ -169,7 +169,7 @@ def _mkn(obj: object) -> int: if att.type == onnx.AttributeProto.GRAPH: unique |= _get_hidden_inputs(att.g) for i in unique: - edge = name_to_ids[i], _mkn(node) + edge = name_to_ids[i], _mkn(node) # type: ignore[assignment] if edge in done: continue done.add(edge) From de538cc90d8d3467f0177a2217a41b7ce8bd3d38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 1 Dec 2025 10:02:31 +0100 Subject: [PATCH 4/7] improve dot --- _unittests/ut_helpers/test_dot_helper.py | 18 ++++---- onnx_diagnostic/helpers/dot_helper.py | 54 ++++++++++++++++-------- 2 files changed, 46 insertions(+), 26 deletions(-) diff --git a/_unittests/ut_helpers/test_dot_helper.py b/_unittests/ut_helpers/test_dot_helper.py index 3726e085..dff60ef3 100644 --- a/_unittests/ut_helpers/test_dot_helper.py +++ b/_unittests/ut_helpers/test_dot_helper.py @@ -43,18 +43,18 @@ def test_custom_doc_kernels_layer_normalization(self): digraph { graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, ranksep=0.2, fontsize=8]; node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box]; - edge [arrowhead=vee, fontsize=6]; - I_0 [label="X", fillcolor="#aaeeaa"]; - I_1 [label="W", fillcolor="#aaeeaa"]; - I_2 [label="B", fillcolor="#aaeeaa"]; + edge [arrowhead=vee, fontsize=7, labeldistance=-5, labelangle=0]; + I_0 [label="X\\nFLOAT16(b,c,d)", fillcolor="#aaeeaa"]; + I_1 [label="W\\nFLOAT16(d)", fillcolor="#aaeeaa"]; + I_2 [label="B\\nFLOAT16(d)", fillcolor="#aaeeaa"]; LayerNormalization_3 [label="LayerNormalization(., ., ., axis=-1)", fillcolor="#cccccc"]; Add_4 [label="Add(., ., axis=-1)", fillcolor="#cccccc"]; - I_0 -> LayerNormalization_3; - I_1 -> LayerNormalization_3; - I_2 -> LayerNormalization_3; + I_0 -> LayerNormalization_3 [label="FLOAT16(b,c,d)"]; + I_1 -> LayerNormalization_3 [label="FLOAT16(d)"]; + I_2 -> LayerNormalization_3 [label="FLOAT16(d)"]; LayerNormalization_3 -> Add_4 [label="FLOAT16(b,c,d)"]; - I_1 -> Add_4; - O_5 [label="Z", fillcolor="#aaaaee"]; + I_1 -> Add_4 [label="FLOAT16(d)"]; + O_5 [label="Z\\nFLOAT16(d)", fillcolor="#aaaaee"]; Add_4 -> O_5; } """ diff --git a/onnx_diagnostic/helpers/dot_helper.py b/onnx_diagnostic/helpers/dot_helper.py index a871d5a0..66fed738 100644 --- a/onnx_diagnostic/helpers/dot_helper.py +++ b/onnx_diagnostic/helpers/dot_helper.py @@ -1,5 +1,6 @@ from typing import Dict, Set import onnx +import onnx.numpy_helper as onh from .onnx_helper import onnx_dtype_name, pretty_onnx @@ -44,6 +45,24 @@ def _make_node_label(node: onnx.NodeProto) -> str: return "".join(els) +def _make_edge_label(value_info: onnx.ValueInfoProto, multi_line: bool = False) -> str: + itype = value_info.type.tensor_type.elem_type + if itype == onnx.TensorProto.UNDEFINED: + return "" + shape = tuple( + d.dim_param if d.dim_param else d.dim_value + for d in value_info.type.tensor_type.shape.dim + ) + res = [ + str(a) + for a in [("?" if isinstance(s, str) and s.startswith("unk") else s) for s in shape] + ] + sshape = ",".join(res) + if multi_line and len(sshape) > 30: + sshape = ",\\n".join(res) + return f"{onnx_dtype_name(itype)}({sshape})" + + def to_dot(model: onnx.ModelProto) -> str: """ Converts a model into a dot graph. @@ -103,19 +122,7 @@ def _mkn(obj: object) -> int: edge_label = {} for val in model.graph.value_info: - itype = val.type.tensor_type.elem_type - if itype == onnx.TensorProto.UNDEFINED: - continue - shape = tuple( - d.dim_param if d.dim_param else d.dim_value for d in val.type.tensor_type.shape.dim - ) - sshape = ",".join( - map( - str, - [("?" if isinstance(s, str) and s.startswith("unk") else s) for s in shape], - ) - ) - edge_label[val.name] = f"{onnx_dtype_name(itype)}({sshape})" + edge_label[val.name] = _make_edge_label(val, multi_line=True) rows = [ "digraph {", @@ -124,7 +131,7 @@ def _mkn(obj: object) -> int: "ranksep=0.2, fontsize=8];" ), ' node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box];', - " edge [arrowhead=vee, fontsize=6];", + " edge [arrowhead=vee, fontsize=7, labeldistance=-5, labelangle=0];", ] inputs = list(model.graph.input) outputs = list(model.graph.output) @@ -134,11 +141,23 @@ def _mkn(obj: object) -> int: for inp in inputs: if not inp.name: continue - rows.append(f' I_{_mkn(inp)} [label="{inp.name}", fillcolor="#aaeeaa"];') + lab = _make_edge_label(inp) + rows.append(f' I_{_mkn(inp)} [label="{inp.name}\\n{lab}", fillcolor="#aaeeaa"];') name_to_ids[inp.name] = f"I_{_mkn(inp)}" + edge_label[inp.name] = _make_edge_label(inp, multi_line=True) for init in inits: - rows.append(f' i_{_mkn(init)} [label="{init.name}", fillcolor="#cccc00"];') + shape = tuple(init.dims) + if len(shape) == 0 or (len(shape) == 1 and shape[0] < 10): + a = onh.to_array(init) + vals = f" = {a}" if len(shape) == 0 else f"\\n=[{', '.join([str(i) for i in a])}]" + else: + vals = "" + ls = f"{onnx_dtype_name(init.data_type)}({', '.join(map(str,shape))})" + rows.append( + f' i_{_mkn(init)} [label="{init.name}\\n{ls}{vals}", fillcolor="#cccc00"];' + ) name_to_ids[init.name] = f"i_{_mkn(init)}" + edge_label[init.name] = ls for node in nodes: color = op_type_colors.get(node.op_type, "#cccccc") label = _make_node_label(node) @@ -179,7 +198,8 @@ def _mkn(obj: object) -> int: for out in outputs: if not out.name: continue - rows.append(f' O_{_mkn(out)} [label="{out.name}", fillcolor="#aaaaee"];') + lab = _make_edge_label(inp) + rows.append(f' O_{_mkn(out)} [label="{out.name}\\n{lab}", fillcolor="#aaaaee"];') edge = name_to_ids[out.name], f"O_{_mkn(out)}" rows.append(f" {edge[0]} -> {edge[1]};") From f4fea2cb030109dc9309ffcdef87b19f8666ab8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 1 Dec 2025 10:34:56 +0100 Subject: [PATCH 5/7] improve rendering --- _unittests/ut_helpers/test_dot_helper.py | 3 ++- onnx_diagnostic/helpers/dot_helper.py | 32 ++++++++++++++---------- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/_unittests/ut_helpers/test_dot_helper.py b/_unittests/ut_helpers/test_dot_helper.py index dff60ef3..2e591852 100644 --- a/_unittests/ut_helpers/test_dot_helper.py +++ b/_unittests/ut_helpers/test_dot_helper.py @@ -2,7 +2,7 @@ import unittest import onnx import onnx.helper as oh -from onnx_diagnostic.ext_test_case import ExtTestCase +from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers from onnx_diagnostic.helpers.dot_helper import to_dot from onnx_diagnostic.export.api import to_onnx from onnx_diagnostic.torch_export_patches import torch_export_patches @@ -62,6 +62,7 @@ def test_custom_doc_kernels_layer_normalization(self): self.maxDiff = None self.assertEqual(expected.strip("\n "), dot.strip("\n ")) + @requires_transformers("4.57") def test_dot_plot_tiny(self): data = get_untrained_model_with_inputs("arnir0/Tiny-LLM") model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] diff --git a/onnx_diagnostic/helpers/dot_helper.py b/onnx_diagnostic/helpers/dot_helper.py index 66fed738..cd87141d 100644 --- a/onnx_diagnostic/helpers/dot_helper.py +++ b/onnx_diagnostic/helpers/dot_helper.py @@ -24,13 +24,13 @@ def _get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]: return hidden -def _make_node_label(node: onnx.NodeProto) -> str: +def _make_node_label(node: onnx.NodeProto, tiny_inits: Dict[str, str]) -> str: els = [f"{node.domain}.\\n{node.op_type}" if node.domain else node.op_type, "("] - ee = ["." if i else "" for i in node.input] + ee = [tiny_inits.get(i, ".") if i else "" for i in node.input] for att in node.attribute: if att.name == "to": ee.append(f"{att.name}={onnx_dtype_name(att.i)}") - elif att.name in {"to", "axis", "value_int", "stash_type"}: + elif att.name in {"to", "axis", "value_int", "stash_type", "start", "end"}: ee.append(f"{att.name}={att.i}") elif att.name in {"value_float"}: ee.append(f"{att.name}={att.f}") @@ -115,9 +115,12 @@ def _mkn(obj: object) -> int: model = onnx.shape_inference.infer_shapes(model) op_type_colors = { - "Shape": "#eeeeee", + "Shape": "#d2a81f", "MatMul": "#ee9999", "Transpose": "#ee99ee", + "Reshape": "#eeeeee", + "Squeeze": "#eeeeee", + "Unsqueeze": "#eeeeee", } edge_label = {} @@ -137,6 +140,7 @@ def _mkn(obj: object) -> int: outputs = list(model.graph.output) nodes = list(model.graph.node) inits = list(model.graph.initializer) + tiny_inits = {} name_to_ids = {} for inp in inputs: if not inp.name: @@ -150,17 +154,19 @@ def _mkn(obj: object) -> int: if len(shape) == 0 or (len(shape) == 1 and shape[0] < 10): a = onh.to_array(init) vals = f" = {a}" if len(shape) == 0 else f"\\n=[{', '.join([str(i) for i in a])}]" + tiny_inits[init.name] = ( + str(a) if len(shape) == 0 else f"[{', '.join([str(i) for i in a])}]" + ) else: - vals = "" - ls = f"{onnx_dtype_name(init.data_type)}({', '.join(map(str,shape))})" - rows.append( - f' i_{_mkn(init)} [label="{init.name}\\n{ls}{vals}", fillcolor="#cccc00"];' - ) - name_to_ids[init.name] = f"i_{_mkn(init)}" - edge_label[init.name] = ls + ls = f"{onnx_dtype_name(init.data_type)}({', '.join(map(str,shape))})" + rows.append( + f' i_{_mkn(init)} [label="{init.name}\\n{ls}{vals}", fillcolor="#cccc00"];' + ) + name_to_ids[init.name] = f"i_{_mkn(init)}" + edge_label[init.name] = ls for node in nodes: color = op_type_colors.get(node.op_type, "#cccccc") - label = _make_node_label(node) + label = _make_node_label(node, tiny_inits) rows.append(f' {node.op_type}_{_mkn(node)} [label="{label}", fillcolor="{color}"];') name_to_ids.update({o: f"{node.op_type}_{_mkn(node)}" for o in node.output if o}) @@ -169,7 +175,7 @@ def _mkn(obj: object) -> int: for node in nodes: names = list(node.input) for i in names: - if not i: + if not i or i in tiny_inits: continue if i not in name_to_ids: raise ValueError(f"Unable to find {i!r}\n{pretty_onnx(model)}") From b80e7b6534e69cd05a70118250a3e1aee28f3dd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 1 Dec 2025 12:00:38 +0100 Subject: [PATCH 6/7] fix --- _unittests/ut_xrun_doc/test_command_lines.py | 9 +++ .../ut_xrun_doc/test_command_lines_exe.py | 39 ++++++++++ onnx_diagnostic/_command_lines_parser.py | 75 +++++++++++++++++++ onnx_diagnostic/helpers/dot_helper.py | 5 +- 4 files changed, 124 insertions(+), 4 deletions(-) diff --git a/_unittests/ut_xrun_doc/test_command_lines.py b/_unittests/ut_xrun_doc/test_command_lines.py index 5317190f..449d5fac 100644 --- a/_unittests/ut_xrun_doc/test_command_lines.py +++ b/_unittests/ut_xrun_doc/test_command_lines.py @@ -6,6 +6,7 @@ get_main_parser, get_parser_agg, get_parser_config, + get_parser_dot, get_parser_find, get_parser_lighten, get_parser_print, @@ -23,6 +24,7 @@ def test_main_parser(self): get_main_parser().print_help() text = st.getvalue() self.assertIn("lighten", text) + self.assertIn("dot", text) def test_parser_lighten(self): st = StringIO() @@ -87,6 +89,13 @@ def test_parser_sbs(self): text = st.getvalue() self.assertIn("--onnx", text) + def test_parser_dot(self): + st = StringIO() + with redirect_stdout(st): + get_parser_dot().print_help() + text = st.getvalue() + self.assertIn("--run", text) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_xrun_doc/test_command_lines_exe.py b/_unittests/ut_xrun_doc/test_command_lines_exe.py index f9ca7c5c..6ea08def 100644 --- a/_unittests/ut_xrun_doc/test_command_lines_exe.py +++ b/_unittests/ut_xrun_doc/test_command_lines_exe.py @@ -162,6 +162,45 @@ def forward(self, x): sdf = df[(df.ep_target == "placeholder") & (df.onnx_op_type == "initializer")] self.assertEqual(sdf.shape[0], 4) + @ignore_warnings(UserWarning) + @requires_transformers("4.53") + def test_i_parser_dot(self): + import torch + + class Model(torch.nn.Module): + def __init__(self): + super(Model, self).__init__() + self.fc1 = torch.nn.Linear(10, 32) # input size 10 → hidden size 32 + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(32, 1) # hidden → output + + def forward(self, x): + x = self.relu(self.fc1(x)) + x = self.fc2(x) + return x + + inputs = dict(x=torch.randn((5, 10))) + ds = dict(x={0: "batch"}) + onnx_file = self.get_dump_file("test_i_parser_dot.model.onnx") + to_onnx( + Model(), + kwargs=inputs, + dynamic_shapes=ds, + exporter="custom", + filename=onnx_file, + ) + + output = self.get_dump_file("test_i_parser_dot.dot") + args = ["dot", onnx_file, "-v", "1", "-o", output] + if not self.unit_test_going(): + args.extend(["--run", "svg"]) + + st = StringIO() + with redirect_stdout(st): + main(args) + text = st.getvalue() + print(text) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index 6a34d763..519e5dc2 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -11,6 +11,77 @@ from argparse import ArgumentParser, RawTextHelpFormatter, BooleanOptionalAction +def get_parser_dot() -> ArgumentParser: + parser = ArgumentParser( + prog="dot", + description=textwrap.dedent( + """ + Converts a model into a dot file dot can draw into a graph. + """ + ), + ) + parser.add_argument("input", type=str, help="onnx model to lighten") + parser.add_argument( + "-o", + "--output", + default="", + type=str, + required=False, + help="dot model to output or empty to print out the result", + ) + parser.add_argument( + "-v", + "--verbose", + type=int, + default=0, + required=False, + help="verbosity", + ) + parser.add_argument( + "-r", + "--run", + default="", + required=False, + help="run dot, in that case, format must be given (svg, png)", + ) + return parser + + +def _cmd_dot(argv: List[Any]): + import subprocess + from .helpers.dot_helper import to_dot + + parser = get_parser_dot() + args = parser.parse_args(argv[1:]) + if args.verbose: + print(f"-- loads {args.input!r}") + onx = onnx.load(args.input, load_external_data=False) + if args.verbose: + print("-- converts into dot") + dot = to_dot(onx) + if args.output: + if args.verbose: + print(f"-- saves into {args.output}") + with open(args.output, "w") as f: + f.write(dot) + else: + print(dot) + if args.run: + assert args.output, "Cannot run dot without an output file." + cmds = ["dot", f"-T{args.run}", args.output, "-o", f"{args.output}.{args.run}"] + if args.verbose: + print(f"-- run {' '.join(cmds)}") + p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + res = p.communicate() + out, err = res + if out: + print("--") + print(out) + if err: + print("--") + print(err) + + def get_parser_lighten() -> ArgumentParser: parser = ArgumentParser( prog="lighten", @@ -1412,6 +1483,7 @@ def get_main_parser() -> ArgumentParser: agg - aggregates statistics from multiple files config - prints a configuration for a model id + dot - converts an onnx model into dot format exportsample - produces a code to export a model find - find node consuming or producing a result lighten - makes an onnx model lighter by removing the weights, @@ -1428,6 +1500,7 @@ def get_main_parser() -> ArgumentParser: choices=[ "agg", "config", + "dot", "exportsample", "find", "lighten", @@ -1446,6 +1519,7 @@ def main(argv: Optional[List[Any]] = None): fcts = dict( agg=_cmd_agg, config=_cmd_config, + dot=_cmd_dot, exportsample=_cmd_export_sample, find=_cmd_find, lighten=_cmd_lighten, @@ -1470,6 +1544,7 @@ def main(argv: Optional[List[Any]] = None): parsers = dict( agg=get_parser_agg, config=get_parser_config, + dot=get_parser_dot, exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator] find=get_parser_find, lighten=get_parser_lighten, diff --git a/onnx_diagnostic/helpers/dot_helper.py b/onnx_diagnostic/helpers/dot_helper.py index cd87141d..328446cd 100644 --- a/onnx_diagnostic/helpers/dot_helper.py +++ b/onnx_diagnostic/helpers/dot_helper.py @@ -153,15 +153,12 @@ def _mkn(obj: object) -> int: shape = tuple(init.dims) if len(shape) == 0 or (len(shape) == 1 and shape[0] < 10): a = onh.to_array(init) - vals = f" = {a}" if len(shape) == 0 else f"\\n=[{', '.join([str(i) for i in a])}]" tiny_inits[init.name] = ( str(a) if len(shape) == 0 else f"[{', '.join([str(i) for i in a])}]" ) else: ls = f"{onnx_dtype_name(init.data_type)}({', '.join(map(str,shape))})" - rows.append( - f' i_{_mkn(init)} [label="{init.name}\\n{ls}{vals}", fillcolor="#cccc00"];' - ) + rows.append(f' i_{_mkn(init)} [label="{init.name}\\n{ls}", fillcolor="#cccc00"];') name_to_ids[init.name] = f"i_{_mkn(init)}" edge_label[init.name] = ls for node in nodes: From aa9685d215843f64d6fc6fae48b96fea675eee40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 1 Dec 2025 12:30:10 +0100 Subject: [PATCH 7/7] json --- _doc/recipes/plot_dynamic_shapes_json.py | 2 +- onnx_diagnostic/helpers/cache_helper.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/_doc/recipes/plot_dynamic_shapes_json.py b/_doc/recipes/plot_dynamic_shapes_json.py index 717c276f..19569fa9 100644 --- a/_doc/recipes/plot_dynamic_shapes_json.py +++ b/_doc/recipes/plot_dynamic_shapes_json.py @@ -74,7 +74,7 @@ def flatten_unflatten_like_dynamic_shapes(obj): start = 0 end = 0 subtrees = [] - for subspec in spec.children(): + for subspec in (spec.children() if hasattr(spec, "children") else spec.children_specs): end += subspec.num_leaves value = subspec.unflatten(flat[start:end]) value = flatten_unflatten_like_dynamic_shapes(value) diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index 582d71a3..3ff36d9b 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -80,7 +80,7 @@ def flatten_unflatten_for_dynamic_shapes( start = 0 end = 0 subtrees = [] - for subspec in spec.children(): + for subspec in (spec.children() if hasattr(spec, "children") else spec.children_specs): end += subspec.num_leaves value = subspec.unflatten(flat[start:end]) value = flatten_unflatten_for_dynamic_shapes(