Skip to content

Commit 6afc650

Browse files
committed
documentation
1 parent a6532fd commit 6afc650

2 files changed

Lines changed: 73 additions & 0 deletions

File tree

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.9.2
55
+++++
66

7+
* :pr:`415`: improves function make_model_with_local_functions to support ill-defined partitions
78
* :pr:`413`: fix InputObserver in the generic case
89
* :pr:`412`: patches for ViTModel (through rewriting)
910

onnx_diagnostic/helpers/onnx_helper.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1887,6 +1887,78 @@ def make_model_with_local_functions(
18871887
a partition if there are not already inside another partition
18881888
:param verbose: verbosity
18891889
:return: model proto
1890+
1891+
Example:
1892+
1893+
.. runpython::
1894+
:showcode:
1895+
1896+
import numpy as np
1897+
import onnx
1898+
import onnx.helper as oh
1899+
import onnx.numpy_helper as onh
1900+
from onnx_diagnostic.helpers.onnx_helper import (
1901+
make_model_with_local_functions,
1902+
pretty_onnx,
1903+
)
1904+
1905+
model = oh.make_model(
1906+
oh.make_graph(
1907+
[
1908+
oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]),
1909+
oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]),
1910+
oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]),
1911+
oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]),
1912+
oh.make_node("Cast", ["xm2c"], ["xm2"], to=1),
1913+
oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]),
1914+
oh.make_node("Reshape", ["xm", "shape3"], ["Z"]),
1915+
],
1916+
"dummy",
1917+
[oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [320, 1280])],
1918+
[oh.make_tensor_value_info("Z", onnx.TensorProto.FLOAT, [3, 5, 320, 640])],
1919+
[
1920+
onh.from_array(
1921+
np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y"
1922+
),
1923+
onh.from_array(np.array([0], dtype=np.int64), name="zero"),
1924+
onh.from_array(np.array([1], dtype=np.int64), name="un"),
1925+
onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"),
1926+
onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"),
1927+
onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"),
1928+
],
1929+
),
1930+
opset_imports=[oh.make_opsetid("", 18)],
1931+
ir_version=9,
1932+
)
1933+
for i_node in [0, 1, 2, 3]:
1934+
node = model.graph.node[i_node]
1935+
meta = node.metadata_props.add()
1936+
meta.key = f"source[{i_node}]"
1937+
meta.value = f"LLL{i_node//3}"
1938+
1939+
print("-- model before --")
1940+
print(pretty_onnx(model))
1941+
print()
1942+
print("-- metadata --")
1943+
for node in model.graph.node:
1944+
text = (
1945+
f" -- [{node.metadata_props[0].key}: {node.metadata_props[0].value}]"
1946+
if node.metadata_props
1947+
else ""
1948+
)
1949+
print(
1950+
f"-- {node.op_type}({', '.join(node.input)}) -> "
1951+
f"{', '.join(node.output)}{text}"
1952+
)
1953+
print()
1954+
1955+
new_model = make_model_with_local_functions(
1956+
model, "^LLL[01]$", metadata_key_prefix="source[", verbose=1
1957+
)
1958+
1959+
print()
1960+
print("-- model after --")
1961+
print(pretty_onnx(new_model))
18901962
"""
18911963
prefix = (
18921964
metadata_key_prefix

0 commit comments

Comments
 (0)