Skip to content

Commit 0bf7674

Browse files
committed
fix
1 parent 60b82c9 commit 0bf7674

2 files changed

Lines changed: 327 additions & 23 deletions

File tree

_scripts/export_qwen25_vl_visual.py

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
1010
git+https://github.com/sdpython/experimental-experiment.git # optional
1111
huggingface_hub>=1.2.1
12-
onnx-diagnostic>=0.8.4
12+
onnx-diagnostic>=0.8.5
1313
onnxruntime>=1.23
1414
torch>=2.9 # weekly is better
1515
tqdm
@@ -98,6 +98,27 @@ def main(
9898
make_zip: bool = False,
9999
output_folder: str = "dump_models",
100100
):
101+
prefix = simplify_model_id_for_a_filename(model_id)
102+
if "QWEN25ATTENTION" in os.environ:
103+
prefix = f"{prefix}.{os.environ['QWEN25ATTENTION']}"
104+
basename = os.path.join(
105+
output_folder, f"model.{prefix}.visual.{device}.{dtype}.{exporter}"
106+
)
107+
filename = f"{basename}.onnx"
108+
stat_file = f"{basename}.stats"
109+
110+
print("------------------------------------------------------------------")
111+
print(
112+
f"-- {model_id} {device} {dtype} {exporter} {pretrained} "
113+
f"{second_input} {make_zip} {output_folder} {prefix}"
114+
)
115+
print("------------------------------------------------------------------")
116+
print(f"-- export in {filename!r}")
117+
118+
if os.path.exists(stat_file):
119+
print(f"-- skipping because {stat_file!r} already exists")
120+
return
121+
101122
print("-- import torch")
102123
import torch
103124

@@ -171,7 +192,6 @@ def _config_reduction(config, task):
171192
print("-- INPUT/OUTPUT")
172193
print("-- ############")
173194

174-
prefix = simplify_model_id_for_a_filename(model_id)
175195
input_filename = os.path.join(output_folder, f"inputs.{prefix}.visual.{device}.{dtype}.pt")
176196
if os.path.exists(input_filename):
177197
print(f"-- restore inputs from {input_filename!r}")
@@ -219,7 +239,7 @@ def compute_expected():
219239
expected = torch.load(output_filename)
220240
export_expected = expected["export_expected"]
221241
other_expected = expected["other_expected"]
222-
duration = expected["duration"]
242+
durations = expected["durations"]
223243
else:
224244
print(
225245
f"-- compute with inputs: {string_type(export_inputs, with_shape=True, with_device=True)}"
@@ -229,31 +249,32 @@ def compute_expected():
229249
print(
230250
f"-- compute with inputs: {string_type(other_inputs, with_shape=True, with_device=True)}"
231251
)
232-
begin = time.perf_counter()
233252
other_expected = []
253+
durations = []
234254
for other in tqdm.tqdm(other_inputs):
255+
begin = time.perf_counter()
235256
expected = model_to_export(**other)
236257
other_expected.append(expected)
237-
duration = time.perf_counter() - begin
258+
durations.append(time.perf_counter() - begin)
238259
print(f"-- got: {string_type(other_expected, with_shape=True, with_device=True)}")
239260

240261
expected = dict(
241262
export_expected=export_expected,
242263
other_expected=other_expected,
243-
duration=duration,
264+
durations=durations,
244265
)
245266
print(f"-- dump expected outputs into {output_filename!r}")
246267
torch.save(expected, output_filename)
247-
print(f"-- computation took {duration}")
268+
print(f"-- computation took {sum(durations)}")
248269
print(
249270
f"-- export_expected={string_type(export_expected, with_shape=True, with_device=True)}"
250271
)
251272
print(
252273
f"-- other_expected={string_type(other_expected, with_shape=True, with_device=True)}"
253274
)
254-
return export_expected, other_expected, duration
275+
return export_expected, other_expected, durations
255276

256-
export_expected, other_expected, duration = (
277+
export_expected, other_expected, durations = (
257278
compute_expected() if not os.environ.get("STOPAT", "") else (None, None)
258279
)
259280

@@ -266,14 +287,6 @@ def compute_expected():
266287
grid_thw={}, # {0: "n_images"}, # TODO: fix
267288
)
268289

269-
if "QWEN25ATTENTION" in os.environ:
270-
prefix = f"{prefix}.{os.environ['QWEN25ATTENTION']}"
271-
basename = os.path.join(
272-
output_folder, f"model.{prefix}.visual.{device}.{dtype}.{exporter}"
273-
)
274-
filename = f"{basename}.onnx"
275-
print(f"-- export in {filename!r}")
276-
stat_file = f"{basename}.stats"
277290
begin = time.perf_counter()
278291

279292
target_opset = 22
@@ -290,7 +303,7 @@ def compute_expected():
290303
stop_if_static=2,
291304
):
292305
if export_expected is None:
293-
export_expected, other_expected, duration = compute_expected()
306+
export_expected, other_expected, durations = compute_expected()
294307
to_onnx(
295308
model_to_export,
296309
kwargs=export_inputs,
@@ -348,11 +361,19 @@ def fprint(s):
348361
)
349362
begin = time.perf_counter()
350363
gots = []
351-
for feed in tqdm.tqdm(feeds):
364+
for i, feed in enumerate(tqdm.tqdm(feeds)):
365+
if (
366+
device == "cpu"
367+
and os.environ.get("QWEN25ATTENTION", "default") == "LOOPA23"
368+
and i >= 2
369+
):
370+
# two slow
371+
break
352372
gots.append(sess.run(None, feed)[0])
353373
oduration = time.perf_counter() - begin
354374
fprint(
355-
f"-- torch duration={duration}, onnx duration={oduration}, speedup={duration/oduration}"
375+
f"-- torch duration={sum(durations[:len(gots)])}, onnx duration={oduration}, "
376+
f"speedup={sum(durations[:len(gots)])/oduration} n={len(gots)}"
356377
)
357378

358379
info = {
@@ -364,9 +385,10 @@ def fprint(s):
364385
"attention": os.environ.get("QWEN25ATTENTION", "default"),
365386
"timestamp": datetime.datetime.now().isoformat(),
366387
"export_duration": export_duration,
367-
"latency_torch": duration,
388+
"latency_torch": sum(durations[: len(gots)]),
368389
"latency_ort": oduration,
369-
"speedup": duration / oduration,
390+
"speedup": sum(durations[: len(gots)]) / oduration,
391+
"latency_ort_n": len(gots),
370392
"opset": target_opset,
371393
**get_versions(),
372394
}

0 commit comments

Comments
 (0)