Skip to content

Commit d79520a

Browse files
authored
add more print format to print command line (#348)
* add more print format to print command line * style * dic
1 parent 35960aa commit d79520a

3 files changed

Lines changed: 20 additions & 5 deletions

File tree

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.8.5
55
+++++
66

7+
* :pr:`348`: add format dot, shape to command line print
78
* :pr:`346`: fix patch for sdpa_mask_recent_torch even if it was removed in transformers>=5.0
89

910
0.8.4

_unittests/ut_xrun_doc/test_command_lines_exe.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def dummy_path(self):
1818
)
1919

2020
def test_a_parser_print(self):
21-
for fmt in ["raw", "text", "pretty", "printer"]:
21+
for fmt in ["raw", "text", "pretty", "printer", "shape", "dot"]:
2222
with self.subTest(format=fmt):
2323
st = StringIO()
2424
with redirect_stdout(st):
@@ -199,7 +199,9 @@ def forward(self, x):
199199
with redirect_stdout(st):
200200
main(args)
201201
text = st.getvalue()
202-
print(text)
202+
if text:
203+
# text is empty is dot is not installed
204+
self.assertIn("converts into dot", text)
203205

204206

205207
if __name__ == "__main__":

onnx_diagnostic/_command_lines_parser.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,15 +198,19 @@ def get_parser_print() -> ArgumentParser:
198198
)
199199
parser.add_argument(
200200
"fmt",
201-
choices=["pretty", "raw", "text", "printer"],
201+
choices=["dot", "pretty", "printer", "raw", "shape", "text"],
202202
default="pretty",
203203
help=textwrap.dedent(
204204
"""
205205
Prints out a model on the standard output.
206-
raw - just prints the model with print(...)
207-
printer - onnx.printer.to_text(...)
206+
207+
dot - converts the graph into dot
208208
pretty - an improved rendering
209+
printer - onnx.printer.to_text(...)
210+
raw - just prints the model with print(...)
211+
shape - prints every node node with input and output shapes
209212
text - uses GraphRendering
213+
210214
""".strip(
211215
"\n"
212216
)
@@ -232,6 +236,14 @@ def _cmd_print(argv: List[Any]):
232236
from .helpers.graph_helper import GraphRendering
233237

234238
print(GraphRendering(onx).text_rendering())
239+
elif args.fmt == "shape":
240+
from experimental_experiment.xbuilder import GraphBuilder
241+
242+
print(GraphBuilder(onx).pretty_text())
243+
elif args.fmt == "dot":
244+
from .helpers.dot_helper import to_dot
245+
246+
print(to_dot(onx))
235247
else:
236248
raise ValueError(f"Unexpected value fmt={args.fmt!r}")
237249

0 commit comments

Comments
 (0)