Skip to content

Commit a4c62d8

Browse files
committed
fix max_diff function
1 parent d79520a commit a4c62d8

3 files changed

Lines changed: 305 additions & 72 deletions

File tree

_scripts/export_qwen25_vl_visual.py

Lines changed: 153 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Export visual embedding of Qwen/Qwen2.5-VL-7B-Instruct
33
======================================================
44
5-
requirements
5+
Requirements
66
++++++++++++
77
88
::
@@ -12,6 +12,7 @@
1212
onnx-diagnostic>=0.8.4
1313
onnxruntime>=1.23
1414
torch>=2.9 # weekly is better
15+
tqdm
1516
transformers>=4.57
1617
1718
Examples
@@ -26,12 +27,19 @@
2627
And to untar:
2728
``tar -xzvf model.tar.gz``.
2829
30+
Rewritings
31+
++++++++++
32+
33+
* `overview <https://sdpython.github.io/doc/onnx-diagnostic/dev/status/patches_diff.html#auto-patch-transformers-qwen2-5-vlforconditionalgeneration-prepare-inputs-for-generation-patched-qwen2-5-vlforconditionalgeneration-prepare-inputs-for-generation>`_
34+
* code: `_patch_transformers_qwen2_5.py <https://github.com/sdpython/onnx-diagnostic/blob/main/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py>`_
35+
2936
Attention
3037
+++++++++
3138
32-
The attention is either implemented with ``MultiHeadAttention`` in a loop, either with ``PackedMultiHeadAttention``.
33-
The choice is made based on the device. It is possible to overwrite this by by setting
34-
environment variable to ``QWEN25ATTENTION`` to:
39+
The attention is either implemented with ``MultiHeadAttention`` in a loop,
40+
either with ``PackedMultiHeadAttention``. The choice is made based on the device.
41+
It is possible to overwrite this by by setting environment variable
42+
``QWEN25ATTENTION`` to:
3543
3644
* ``PACKED``: PackedMultiHeadAttention
3745
* ``LOOPMHA``: Loop over MultiHeadAttention
@@ -68,7 +76,8 @@ def main(
6876
exporter: str = "onnx-dynamo",
6977
pretrained: bool = True,
7078
second_input: bool = True,
71-
zip: bool = False,
79+
make_zip: bool = False,
80+
output_folder: str = "dump_models",
7281
):
7382
print("-- import torch")
7483
import torch
@@ -80,6 +89,7 @@ def main(
8089
from transformers import AutoModel, AutoProcessor
8190

8291
print("-- import onnx_diagnostic")
92+
import tqdm
8393
from onnx_diagnostic.helpers import string_type, max_diff
8494
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
8595
PLUGS,
@@ -88,6 +98,9 @@ def main(
8898
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
8999
from onnx_diagnostic.export.api import to_onnx
90100

101+
if output_folder and output_folder != ".":
102+
os.makedirs(output_folder, exist_ok=True)
103+
91104
print(f"-- creating model {model_id!r}")
92105
print(
93106
f"-- device={device!r}, dtype={dtype!r}, exporter={exporter!r}, "
@@ -132,42 +145,113 @@ def _config_reduction(config, task):
132145
print(f"-- model.device={model.device}")
133146
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
134147
print(f"-- processor={type(processor)}")
148+
model_to_export = model.visual if hasattr(model, "visual") else model.model.visual
149+
print(f"-- model_to_export={type(model_to_export)}")
135150

136-
inputs = dict(
137-
hidden_states=torch.rand((1292, 1176), dtype=torch_dtype).to(device),
138-
grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device),
139-
)
140-
big_inputs = (
141-
dict(
142-
hidden_states=torch.rand((14308, 1176), dtype=torch_dtype).to(device),
143-
grid_thw=torch.tensor([[1, 98, 146]], dtype=torch.int64).to(device),
151+
print("-- ############")
152+
print("-- INPUT/OUTPUT")
153+
print("-- ############")
154+
155+
prefix = simplify_model_id_for_a_filename(model_id)
156+
input_filename = os.path.join(output_folder, f"inputs.{prefix}.visual.{device}.{dtype}.pt")
157+
if os.path.exists(input_filename):
158+
print(f"-- restore inputs from {input_filename!r}")
159+
data = torch.load(input_filename)
160+
export_inputs = data["export_inputs"]
161+
other_inputs = data["other_inputs"]
162+
else:
163+
export_inputs = dict(
164+
hidden_states=torch.randn((1292, 1176), dtype=torch_dtype).to(device),
165+
grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device),
166+
)
167+
other_inputs = []
168+
if second_input:
169+
other_inputs = [
170+
dict(
171+
hidden_states=torch.randn((1292, 1176), dtype=torch_dtype).to(device),
172+
grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device),
173+
),
174+
dict(
175+
hidden_states=torch.rand((1292, 1176), dtype=torch_dtype).to(device),
176+
grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device),
177+
),
178+
dict(
179+
hidden_states=torch.randn((14308, 1176), dtype=torch_dtype).to(device),
180+
grid_thw=torch.tensor([[1, 98, 146]], dtype=torch.int64).to(device),
181+
),
182+
dict(
183+
hidden_states=torch.rand((14308, 1176), dtype=torch_dtype).to(device),
184+
grid_thw=torch.tensor([[1, 98, 146]], dtype=torch.int64).to(device),
185+
),
186+
]
187+
data = dict(export_inputs=export_inputs, other_inputs=other_inputs)
188+
print(f"-- dump inputs into {input_filename!r}")
189+
torch.save(data, input_filename)
190+
191+
print(f"-- export_inputs={string_type(export_inputs, with_shape=True, with_device=True)}")
192+
print(f"-- other_inputs={string_type(other_inputs, with_shape=True, with_device=True)}")
193+
194+
def compute_expected():
195+
output_filename = os.path.join(
196+
output_folder, f"expected.{prefix}.visual.{device}.{dtype}.pt"
197+
)
198+
if os.path.exists(output_filename):
199+
print(f"-- restore expected outputs from {output_filename!r}")
200+
expected = torch.load(output_filename)
201+
export_expected = expected["export_expected"]
202+
other_expected = expected["other_expected"]
203+
duration = expected["duration"]
204+
else:
205+
print(
206+
f"-- compute with inputs: {string_type(export_inputs, with_shape=True, with_device=True)}"
207+
)
208+
export_expected = model_to_export(**export_inputs)
209+
print(f"-- got: {string_type(export_expected, with_shape=True)}")
210+
print(
211+
f"-- compute with inputs: {string_type(other_inputs, with_shape=True, with_device=True)}"
212+
)
213+
begin = time.perf_counter()
214+
other_expected = []
215+
for other in tqdm.tqdm(other_inputs):
216+
expected = model_to_export(**other)
217+
other_expected.append(expected)
218+
duration = time.perf_counter() - begin
219+
print(f"-- got: {string_type(other_expected, with_shape=True, with_device=True)}")
220+
221+
expected = dict(
222+
export_expected=export_expected,
223+
other_expected=other_expected,
224+
duration=duration,
225+
)
226+
print(f"-- dump expected outputs into {output_filename!r}")
227+
torch.save(expected, output_filename)
228+
print(f"-- computation took {duration}")
229+
print(
230+
f"-- export_expected={string_type(export_expected, with_shape=True, with_device=True)}"
144231
)
145-
if second_input
146-
else None
232+
print(
233+
f"-- other_expected={string_type(other_expected, with_shape=True, with_device=True)}"
234+
)
235+
return export_expected, other_expected, duration
236+
237+
export_expected, other_expected, duration = (
238+
compute_expected() if not os.environ.get("STOPAT", "") else (None, None)
147239
)
148240

149-
model_to_export = model.visual if hasattr(model, "visual") else model.model.visual
150-
if not os.environ.get("STOPAT", ""):
151-
print(f"-- compute with inputs: {string_type(inputs, with_shape=True)}")
152-
expected = model_to_export(**inputs)
153-
print(f"-- got: {string_type(expected, with_shape=True)}")
154-
print(f"-- compute with inputs: {string_type(big_inputs, with_shape=True)}")
155-
expected_big = None if big_inputs is None else model_to_export(**big_inputs)
156-
print(f"-- got: {string_type(expected_big, with_shape=True)}")
157-
else:
158-
expected = None
159-
expected_big = None
160-
print(f"-- expected: {string_type(expected, with_shape=True)}")
241+
print("-- ######")
242+
print("-- EXPORT")
243+
print("-- ######")
161244

162245
dynamic_shapes = dict(
163246
hidden_states={0: "hidden_width", 1: "hidden_height"},
164247
grid_thw={}, # {0: "n_images"}, # TODO: fix
165248
)
166249

167-
prefix = simplify_model_id_for_a_filename(model_id)
168250
if "QWEN25ATTENTION" in os.environ:
169251
prefix = f"{prefix}.{os.environ['QWEN25ATTENTION']}"
170-
basename = f"model.{prefix}.visual.{device}.{dtype}.{exporter}"
252+
basename = os.path.join(
253+
output_folder, f"model.{prefix}.visual.{device}.{dtype}.{exporter}"
254+
)
171255
filename = f"{basename}.onnx"
172256
print(f"-- export in {filename!r}")
173257
stat_file = f"{basename}.stats"
@@ -176,17 +260,15 @@ def _config_reduction(config, task):
176260
if exporter == "onnx-dynamo" and device == "cuda" and "QWEN25ATTENTION" not in os.environ:
177261
os.environ["QWEN25ATTENTION"] = "PACKED"
178262

179-
export_inputs = inputs
180263
with torch_export_patches(
181264
patch_torch=False,
182265
patch_sympy=False,
183266
patch_transformers=True,
184267
verbose=1,
185268
stop_if_static=2,
186269
):
187-
if expected is None:
188-
expected = model_to_export(**inputs)
189-
expected_big = None if big_inputs is None else model_to_export(**big_inputs)
270+
if export_expected is None:
271+
export_expected, other_expected, duration = compute_expected()
190272
to_onnx(
191273
model_to_export,
192274
kwargs=export_inputs,
@@ -199,7 +281,7 @@ def _config_reduction(config, task):
199281
optimize=True,
200282
onnx_plugs=PLUGS,
201283
)
202-
duration = time.perf_counter() - begin
284+
export_duration = time.perf_counter() - begin
203285

204286
if exporter == "onnx-dynamo":
205287
# onnx-dynamo fails at producing function body with sequences as input / output.
@@ -214,32 +296,48 @@ def fprint(s):
214296
print(s)
215297
f.write(f"{s}\n")
216298

217-
fprint(f"-- export duration: {duration}")
299+
fprint(f"-- export duration: {export_duration}")
218300
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
219301
if device == "cpu":
220302
providers = providers[1:]
221303
fprint(f"-- checking discrepancies with providers={providers!r}")
222304
sess = onnxruntime.InferenceSession(filename, providers=providers)
223305

224-
fprint(f"-- inputs {string_type(inputs, with_shape=True, with_device=True)}")
225-
fprint(f"-- expected {string_type(expected, with_shape=True, with_device=True)}")
226-
feeds = {k: v.detach().cpu().numpy() for k, v in inputs.items()}
306+
fprint(
307+
f"-- export_inputs {string_type(export_inputs, with_shape=True, with_device=True)}"
308+
)
309+
fprint(
310+
f"-- export_expected {string_type(export_expected, with_shape=True, with_device=True)}"
311+
)
312+
feeds = {k: v.detach().cpu().numpy() for k, v in export_inputs.items()}
227313
small = sess.run(None, feeds)
228-
diff = max_diff(expected, small[0], hist=[0.1])
314+
diff = max_diff(export_expected, small[0], hist=[0.1, 0.01])
229315
fprint(f"-- discrepancies={diff}")
230316

231317
if second_input:
318+
feeds = [
319+
{k: v.detach().cpu().numpy() for k, v in inputs.items()}
320+
for inputs in other_inputs
321+
]
232322
fprint("")
233-
fprint(f"-- inputs {string_type(big_inputs, with_shape=True, with_device=True)}")
323+
fprint(f"-- inputs {string_type(feeds, with_shape=True, with_device=True)}")
324+
fprint(
325+
f"-- expected {string_type(other_expected, with_shape=True, with_device=True)}"
326+
)
327+
begin = time.perf_counter()
328+
gots = []
329+
for feed in tqdm.tqdm(feeds):
330+
gots.append(sess.run(None, feed)[0])
331+
oduration = time.perf_counter() - begin
234332
fprint(
235-
f"-- expected {string_type(expected_big, with_shape=True, with_device=True)}"
333+
f"-- torch duration={duration}, onnx duration={oduration}, speedup={duration/oduration}"
236334
)
237-
feeds = {k: v.detach().cpu().numpy() for k, v in big_inputs.items()}
238-
big = sess.run(None, feeds)
239-
diff = max_diff(expected_big, big[0], hist=[0.1])
240-
fprint(f"-- discrepancies={diff}")
241335

242-
if zip:
336+
for fe, e, b in zip(feeds, other_expected, gots):
337+
diff = max_diff(e, b, hist=[0.1, 0.01])
338+
fprint(f"-- inputs={string_type(fe, with_shape=True)} -- {diff}")
339+
340+
if make_zip:
243341
tar_file_name = f"{basename}.zip"
244342
print()
245343
print(f"-- make file {tar_file_name!r}")
@@ -288,6 +386,13 @@ def get_parser() -> ArgumentParser:
288386
help="Creates a file .zip with onnx file and data file.",
289387
action=BooleanOptionalAction,
290388
)
389+
parser.add_argument(
390+
"-o",
391+
"--output-folder",
392+
default="dump_models",
393+
help="Folders where to put the results.",
394+
action=BooleanOptionalAction,
395+
)
291396
return parser
292397

293398

@@ -301,5 +406,6 @@ def get_parser() -> ArgumentParser:
301406
exporter=args.exporter,
302407
pretrained=args.pretrained,
303408
second_input=args.second_input,
304-
zip=args.zip,
409+
make_zip=args.zip,
410+
output_folder=args.output_folder,
305411
)

0 commit comments

Comments
 (0)