99
1010 git+https://github.com/sdpython/experimental-experiment.git # optional
1111 huggingface_hub>=1.2.1
12- onnx-diagnostic>=0.8.4
12+ onnx-diagnostic>=0.8.5
1313 onnxruntime>=1.23
1414 torch>=2.9 # weekly is better
1515 tqdm
@@ -98,6 +98,27 @@ def main(
9898 make_zip : bool = False ,
9999 output_folder : str = "dump_models" ,
100100):
101+ prefix = simplify_model_id_for_a_filename (model_id )
102+ if "QWEN25ATTENTION" in os .environ :
103+ prefix = f"{ prefix } .{ os .environ ['QWEN25ATTENTION' ]} "
104+ basename = os .path .join (
105+ output_folder , f"model.{ prefix } .visual.{ device } .{ dtype } .{ exporter } "
106+ )
107+ filename = f"{ basename } .onnx"
108+ stat_file = f"{ basename } .stats"
109+
110+ print ("------------------------------------------------------------------" )
111+ print (
112+ f"-- { model_id } { device } { dtype } { exporter } { pretrained } "
113+ f"{ second_input } { make_zip } { output_folder } { prefix } "
114+ )
115+ print ("------------------------------------------------------------------" )
116+ print (f"-- export in { filename !r} " )
117+
118+ if os .path .exists (stat_file ):
119+ print (f"-- skipping because { stat_file !r} already exists" )
120+ return
121+
101122 print ("-- import torch" )
102123 import torch
103124
@@ -171,7 +192,6 @@ def _config_reduction(config, task):
171192 print ("-- INPUT/OUTPUT" )
172193 print ("-- ############" )
173194
174- prefix = simplify_model_id_for_a_filename (model_id )
175195 input_filename = os .path .join (output_folder , f"inputs.{ prefix } .visual.{ device } .{ dtype } .pt" )
176196 if os .path .exists (input_filename ):
177197 print (f"-- restore inputs from { input_filename !r} " )
@@ -219,7 +239,7 @@ def compute_expected():
219239 expected = torch .load (output_filename )
220240 export_expected = expected ["export_expected" ]
221241 other_expected = expected ["other_expected" ]
222- duration = expected ["duration " ]
242+ durations = expected ["durations " ]
223243 else :
224244 print (
225245 f"-- compute with inputs: { string_type (export_inputs , with_shape = True , with_device = True )} "
@@ -229,31 +249,32 @@ def compute_expected():
229249 print (
230250 f"-- compute with inputs: { string_type (other_inputs , with_shape = True , with_device = True )} "
231251 )
232- begin = time .perf_counter ()
233252 other_expected = []
253+ durations = []
234254 for other in tqdm .tqdm (other_inputs ):
255+ begin = time .perf_counter ()
235256 expected = model_to_export (** other )
236257 other_expected .append (expected )
237- duration = time .perf_counter () - begin
258+ durations . append ( time .perf_counter () - begin )
238259 print (f"-- got: { string_type (other_expected , with_shape = True , with_device = True )} " )
239260
240261 expected = dict (
241262 export_expected = export_expected ,
242263 other_expected = other_expected ,
243- duration = duration ,
264+ durations = durations ,
244265 )
245266 print (f"-- dump expected outputs into { output_filename !r} " )
246267 torch .save (expected , output_filename )
247- print (f"-- computation took { duration } " )
268+ print (f"-- computation took { sum ( durations ) } " )
248269 print (
249270 f"-- export_expected={ string_type (export_expected , with_shape = True , with_device = True )} "
250271 )
251272 print (
252273 f"-- other_expected={ string_type (other_expected , with_shape = True , with_device = True )} "
253274 )
254- return export_expected , other_expected , duration
275+ return export_expected , other_expected , durations
255276
256- export_expected , other_expected , duration = (
277+ export_expected , other_expected , durations = (
257278 compute_expected () if not os .environ .get ("STOPAT" , "" ) else (None , None )
258279 )
259280
@@ -266,14 +287,6 @@ def compute_expected():
266287 grid_thw = {}, # {0: "n_images"}, # TODO: fix
267288 )
268289
269- if "QWEN25ATTENTION" in os .environ :
270- prefix = f"{ prefix } .{ os .environ ['QWEN25ATTENTION' ]} "
271- basename = os .path .join (
272- output_folder , f"model.{ prefix } .visual.{ device } .{ dtype } .{ exporter } "
273- )
274- filename = f"{ basename } .onnx"
275- print (f"-- export in { filename !r} " )
276- stat_file = f"{ basename } .stats"
277290 begin = time .perf_counter ()
278291
279292 target_opset = 22
@@ -290,7 +303,7 @@ def compute_expected():
290303 stop_if_static = 2 ,
291304 ):
292305 if export_expected is None :
293- export_expected , other_expected , duration = compute_expected ()
306+ export_expected , other_expected , durations = compute_expected ()
294307 to_onnx (
295308 model_to_export ,
296309 kwargs = export_inputs ,
@@ -348,11 +361,19 @@ def fprint(s):
348361 )
349362 begin = time .perf_counter ()
350363 gots = []
351- for feed in tqdm .tqdm (feeds ):
364+ for i , feed in enumerate (tqdm .tqdm (feeds )):
365+ if (
366+ device == "cpu"
367+ and os .environ .get ("QWEN25ATTENTION" , "default" ) == "LOOPA23"
368+ and i >= 2
369+ ):
370+ # two slow
371+ break
352372 gots .append (sess .run (None , feed )[0 ])
353373 oduration = time .perf_counter () - begin
354374 fprint (
355- f"-- torch duration={ duration } , onnx duration={ oduration } , speedup={ duration / oduration } "
375+ f"-- torch duration={ sum (durations [:len (gots )])} , onnx duration={ oduration } , "
376+ f"speedup={ sum (durations [:len (gots )])/ oduration } n={ len (gots )} "
356377 )
357378
358379 info = {
@@ -364,9 +385,10 @@ def fprint(s):
364385 "attention" : os .environ .get ("QWEN25ATTENTION" , "default" ),
365386 "timestamp" : datetime .datetime .now ().isoformat (),
366387 "export_duration" : export_duration ,
367- "latency_torch" : duration ,
388+ "latency_torch" : sum ( durations [: len ( gots )]) ,
368389 "latency_ort" : oduration ,
369- "speedup" : duration / oduration ,
390+ "speedup" : sum (durations [: len (gots )]) / oduration ,
391+ "latency_ort_n" : len (gots ),
370392 "opset" : target_opset ,
371393 ** get_versions (),
372394 }
0 commit comments