Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions _unittests/ut_helpers/test_dot_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,63 @@ def test_custom_doc_kernels_layer_normalization(self):
self.maxDiff = None
self.assertEqual(expected.strip("\n "), dot.strip("\n "))

def test_custom_doc_kernels_layer_normalization_constant(self):
TFLOAT16 = onnx.TensorProto.FLOAT16
model = oh.make_model(
oh.make_graph(
[
oh.make_node(
"LayerNormalization",
["X", "W", "B"],
["ln"],
axis=-1,
epsilon=9.999999974752427e-7,
),
oh.make_node("Constant", [], ["cst"], value_float=[1]),
oh.make_node("Cast", ["cst"], ["cst16"], to=onnx.TensorProto.FLOAT16),
oh.make_node("Add", ["ln", "cst16"], ["Z"], axis=-1),
],
"dummy",
[
oh.make_tensor_value_info("X", TFLOAT16, ["b", "c", "d"]),
oh.make_tensor_value_info("W", TFLOAT16, ["d"]),
oh.make_tensor_value_info("B", TFLOAT16, ["d"]),
],
[oh.make_tensor_value_info("Z", TFLOAT16, ["b", "c", "d"])],
),
ir_version=9,
opset_imports=[oh.make_opsetid("", 18)],
)
dot = to_dot(model)
expected = (
textwrap.dedent(
"""
digraph {
graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, ranksep=0.2, fontsize=8];
node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box];
edge [arrowhead=vee, fontsize=7, labeldistance=-5, labelangle=0];
I_0 [label="X\\nFLOAT16(b,c,d)", fillcolor="#aaeeaa"];
I_1 [label="W\\nFLOAT16(d)", fillcolor="#aaeeaa"];
I_2 [label="B\\nFLOAT16(d)", fillcolor="#aaeeaa"];
LayerNormalization_3 [label="LayerNormalization(., ., ., axis=-1)", fillcolor="#cccccc"];
Cast_4 [label="Cast([1.0], to=FLOAT16)", fillcolor="#cccccc"];
Add_5[label="Add(.,.,axis=-1)",fillcolor="#cccccc"];
I_0 -> LayerNormalization_3 [label="FLOAT16(b,c,d)"];
I_1 -> LayerNormalization_3 [label="FLOAT16(d)"];
I_2 -> LayerNormalization_3 [label="FLOAT16(d)"];
LayerNormalization_3 -> Add_5 [label="FLOAT16(b,c,d)"];
Cast_4->Add_5[label="FLOAT16()"];
O_6 [label="Z\\nFLOAT16(b,c,d)", fillcolor="#aaaaee"];
Add_5 -> O_6;
}
"""
)
.strip("\n")
.replace(" ", "")
)
self.maxDiff = None
self.assertEqual(expected, dot.strip("\n").replace(" ", ""))

@requires_transformers("4.57")
def test_dot_plot_tiny(self):
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
Expand Down
35 changes: 33 additions & 2 deletions onnx_diagnostic/helpers/dot_helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Dict, Set
import numpy as np
import onnx
import onnx.numpy_helper as onh
from ..reference import ExtendedReferenceEvaluator as Inference
from .onnx_helper import onnx_dtype_name, pretty_onnx


Expand All @@ -25,7 +27,7 @@ def _get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]:


def _make_node_label(node: onnx.NodeProto, tiny_inits: Dict[str, str]) -> str:
els = [f"{node.domain}.\\n{node.op_type}" if node.domain else node.op_type, "("]
els = [f"{node.domain}.\\n{node.op_type}" if node.domain else node.op_type, "\\n("]
ee = [tiny_inits.get(i, ".") if i else "" for i in node.input]
for att in node.attribute:
if att.name == "to":
Expand All @@ -42,7 +44,10 @@ def _make_node_label(node: onnx.NodeProto, tiny_inits: Dict[str, str]) -> str:
els.append(")")
if node.op_type == "Constant":
els.extend([" -> ", node.output[0]])
return "".join(els)
res = "".join(els)
if len(res) < 40:
return res.replace("\\n(", "(")
return res


def _make_edge_label(value_info: onnx.ValueInfoProto, multi_line: bool = False) -> str:
Expand Down Expand Up @@ -142,14 +147,37 @@ def _mkn(obj: object) -> int:
inits = list(model.graph.initializer)
tiny_inits = {}
name_to_ids = {}

for inp in inputs:
if not inp.name:
continue
lab = _make_edge_label(inp)
rows.append(f' I_{_mkn(inp)} [label="{inp.name}\\n{lab}", fillcolor="#aaeeaa"];')
name_to_ids[inp.name] = f"I_{_mkn(inp)}"
edge_label[inp.name] = _make_edge_label(inp, multi_line=True)

# Small constant --> initializer
for node in nodes:
if node.op_type != "Constant":
continue
skip = False
for att in node.attribute:
if att.name == "value" and (
len(att.t.dims) > 1 or np.prod(tuple(att.t.dims)) > 10
):
skip = True
break
if skip:
continue

sess = Inference(node)
value = sess.run(None, {})[0]
inits.append(onh.from_array(value, name=node.output[0]))

for init in inits:
if init.name in name_to_ids:
# hide optional inputs
continue
shape = tuple(init.dims)
if len(shape) == 0 or (len(shape) == 1 and shape[0] < 10):
a = onh.to_array(init)
Expand All @@ -161,7 +189,10 @@ def _mkn(obj: object) -> int:
rows.append(f' i_{_mkn(init)} [label="{init.name}\\n{ls}", fillcolor="#cccc00"];')
name_to_ids[init.name] = f"i_{_mkn(init)}"
edge_label[init.name] = ls

for node in nodes:
if node.op_type == "Constant" and node.output[0] in tiny_inits:
continue
color = op_type_colors.get(node.op_type, "#cccccc")
label = _make_node_label(node, tiny_inits)
rows.append(f' {node.op_type}_{_mkn(node)} [label="{label}", fillcolor="{color}"];')
Expand Down
Loading