1+ import os
12from typing import Any , Dict , List , Optional , Sequence , Tuple , Union
23import torch
34from .onnx_plug import EagerDirectReplacementWithOnnx
@@ -138,7 +139,7 @@ def to_onnx(
138139 else None
139140 )
140141
141- return _to_onnx (
142+ proto , opt_stats = _to_onnx (
142143 mod ,
143144 args = args ,
144145 kwargs = kwargs ,
@@ -155,11 +156,47 @@ def to_onnx(
155156 inline = inline ,
156157 dispatcher = main_dispatcher ,
157158 optimize = optimize ,
159+ return_optimize_report = True ,
158160 ** (exporter_kwargs or {}),
159161 )
162+ if opt_stats and filename and os .path .exists (filename ):
163+ import pandas
164+
165+ stat_filename = f"{ os .path .splitext (filename )[0 ]} .opt.xlsx"
166+ pattern_stats = []
167+ for k , v in opt_stats .items ():
168+ if "time" in k :
169+ pattern_stats .append (dict (level = "main" , pattern = k , time_in = v ))
170+ pattern_stats .extend (
171+ [{** obs , "level" : "detailed" } for obs in opt_stats ["optimization" ]]
172+ )
173+ df = pandas .DataFrame (pattern_stats )
174+ df .to_excel (stat_filename , index = False )
175+ cols = [
176+ c
177+ for c in [
178+ "level" ,
179+ "pattern" ,
180+ "time_in" ,
181+ "iteration" ,
182+ "inlined" ,
183+ "removed" ,
184+ "added" ,
185+ "instances" ,
186+ "changed" ,
187+ "scale" ,
188+ ]
189+ if c in df .columns
190+ ]
191+ agg = {k : "sum" for k in cols if k not in ("level" , "pattern" )}
192+ agg .update (dict (iteration = "max" , instances = "mean" ))
193+ agg = {k : v for k , v in agg .items () if k in df .columns }
194+ stat_filename = f"{ os .path .splitext (filename )[0 ]} .opt.agg.xlsx"
195+ df [cols ].groupby (["level" , "pattern" ]).agg (agg ).to_excel (stat_filename )
196+
197+ return proto
160198
161199 if exporter in ("dynamo" , "onnx-dynamo" ):
162- import os
163200 from ..helpers import flatten_object
164201 import onnxscript .rewriter .ort_fusions as ort_fusions
165202
@@ -226,7 +263,6 @@ def to_onnx(
226263 return epo
227264
228265 if exporter == "modelbuilder" :
229- import os
230266 from ..helpers import flatten_object , string_type
231267 from ..helpers .model_builder_helper import create_model_builder , save_model_builder
232268
0 commit comments