Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.8.5
+++++

* :pr:`348`: add format dot, shape to command line print
* :pr:`346`: fix patch for sdpa_mask_recent_torch even if it was removed in transformers>=5.0

0.8.4
Expand Down
6 changes: 4 additions & 2 deletions _unittests/ut_xrun_doc/test_command_lines_exe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def dummy_path(self):
)

def test_a_parser_print(self):
for fmt in ["raw", "text", "pretty", "printer"]:
for fmt in ["raw", "text", "pretty", "printer", "shape", "dot"]:
with self.subTest(format=fmt):
st = StringIO()
with redirect_stdout(st):
Expand Down Expand Up @@ -199,7 +199,9 @@ def forward(self, x):
with redirect_stdout(st):
main(args)
text = st.getvalue()
print(text)
if text:
# text is empty is dot is not installed
self.assertIn("converts into dot", text)


if __name__ == "__main__":
Expand Down
18 changes: 15 additions & 3 deletions onnx_diagnostic/_command_lines_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,15 +198,19 @@ def get_parser_print() -> ArgumentParser:
)
parser.add_argument(
"fmt",
choices=["pretty", "raw", "text", "printer"],
choices=["dot", "pretty", "printer", "raw", "shape", "text"],
default="pretty",
help=textwrap.dedent(
"""
Prints out a model on the standard output.
raw - just prints the model with print(...)
printer - onnx.printer.to_text(...)

dot - converts the graph into dot
pretty - an improved rendering
printer - onnx.printer.to_text(...)
raw - just prints the model with print(...)
shape - prints every node node with input and output shapes
text - uses GraphRendering

""".strip(
"\n"
)
Expand All @@ -232,6 +236,14 @@ def _cmd_print(argv: List[Any]):
from .helpers.graph_helper import GraphRendering

print(GraphRendering(onx).text_rendering())
elif args.fmt == "shape":
from experimental_experiment.xbuilder import GraphBuilder

print(GraphBuilder(onx).pretty_text())
elif args.fmt == "dot":
from .helpers.dot_helper import to_dot

print(to_dot(onx))
else:
raise ValueError(f"Unexpected value fmt={args.fmt!r}")

Expand Down
Loading