Skip to content

Commit 21f2bc5

Browse files
committed
output statistics
1 parent d92e136 commit 21f2bc5

1 file changed

Lines changed: 39 additions & 3 deletions

File tree

onnx_diagnostic/export/api.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
23
import torch
34
from .onnx_plug import EagerDirectReplacementWithOnnx
@@ -138,7 +139,7 @@ def to_onnx(
138139
else None
139140
)
140141

141-
return _to_onnx(
142+
proto, opt_stats = _to_onnx(
142143
mod,
143144
args=args,
144145
kwargs=kwargs,
@@ -155,11 +156,47 @@ def to_onnx(
155156
inline=inline,
156157
dispatcher=main_dispatcher,
157158
optimize=optimize,
159+
return_optimize_report=True,
158160
**(exporter_kwargs or {}),
159161
)
162+
if opt_stats and os.path.exists(filename):
163+
import pandas
164+
165+
stat_filename = f"{os.path.splitext(filename)[0]}.opt.xlsx"
166+
pattern_stats = []
167+
for k, v in opt_stats.items():
168+
if "time" in k:
169+
pattern_stats.append(dict(level="main", pattern=k, time_in=v))
170+
pattern_stats.extend(
171+
[{**obs, "level": "detailed"} for obs in opt_stats["optimization"]]
172+
)
173+
df = pandas.DataFrame(pattern_stats)
174+
df.to_excel(stat_filename, index=False)
175+
cols = [
176+
c
177+
for c in [
178+
"level",
179+
"pattern",
180+
"time_in",
181+
"iteration",
182+
"inlined",
183+
"removed",
184+
"added",
185+
"instances",
186+
"changed",
187+
"scale",
188+
]
189+
if c in df.columns
190+
]
191+
agg = {k: "sum" for k in cols if k not in ("level", "pattern")}
192+
agg.update(dict(iteration="max", instances="mean"))
193+
agg = {k: v for k, v in agg.items() if k in df.columns}
194+
stat_filename = f"{os.path.splitext(filename)[0]}.opt.agg.xlsx"
195+
df[cols].groupby(["level", "pattern"]).agg(agg).to_excel(stat_filename)
196+
197+
return proto
160198

161199
if exporter in ("dynamo", "onnx-dynamo"):
162-
import os
163200
from ..helpers import flatten_object
164201
import onnxscript.rewriter.ort_fusions as ort_fusions
165202

@@ -226,7 +263,6 @@ def to_onnx(
226263
return epo
227264

228265
if exporter == "modelbuilder":
229-
import os
230266
from ..helpers import flatten_object, string_type
231267
from ..helpers.model_builder_helper import create_model_builder, save_model_builder
232268

0 commit comments

Comments
 (0)