Skip to content

Commit c1c9361

Browse files
committed
qwan
1 parent c1cc15d commit c1c9361

1 file changed

Lines changed: 198 additions & 0 deletions

File tree

_scripts/qwen25.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
import os
2+
import sys
3+
from argparse import ArgumentParser, BooleanOptionalAction
4+
5+
6+
def main(
7+
model_id: str = "Qwen/Qwen2.5-VL-7B-Instruct",
8+
device: str = "cpu",
9+
dtype: str = "float32",
10+
exporter: str = "onnx-dynamo",
11+
pretrained: bool = True,
12+
second_input: bool = True,
13+
):
14+
print("-- import torch")
15+
import torch
16+
17+
print("-- import onnxruntime")
18+
import onnxruntime
19+
20+
print("-- import transformers")
21+
from transformers import AutoModel, AutoProcessor
22+
23+
print("-- import onnx_diagnostic")
24+
from onnx_diagnostic.helpers import string_type, max_diff
25+
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
26+
PLUGS,
27+
)
28+
from onnx_diagnostic.torch_export_patches import torch_export_patches
29+
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
30+
from onnx_diagnostic.export.api import to_onnx
31+
32+
print(f"-- creating model {model_id!r}")
33+
print(
34+
f"-- device={device!r}, dtype={dtype!r}, exporter={exporter!r}, "
35+
f"pretrained={pretrained!r}"
36+
)
37+
torch_dtype = {
38+
"float16": torch.float16,
39+
"bfloat16": torch.bfloat16,
40+
"float32": torch.float32,
41+
}[dtype]
42+
43+
if pretrained:
44+
print("-- pretrained model")
45+
model = AutoModel.from_pretrained(
46+
model_id, device_map=device, dtype=torch_dtype, attn_implementation="sdpa"
47+
).eval()
48+
else:
49+
print("-- random model")
50+
51+
def _config_reduction(config, task):
52+
return {
53+
# "num_hidden_layers": 2,
54+
"text_config": {
55+
"num_hidden_layers": 2,
56+
"layer_types": ["full_attention", "full_attention"],
57+
},
58+
# "_attn_implementation": "flash_attention_2",
59+
"_attn_implementation": "sdpa",
60+
"dtype": "float16",
61+
}
62+
63+
config_reduction = _config_reduction
64+
data = get_untrained_model_with_inputs(
65+
model_id, verbose=1, add_second_input=False, config_reduction=config_reduction
66+
)
67+
model = data["model"]
68+
69+
model = model.to(device).to(getattr(torch, dtype))
70+
71+
print(f"-- config._attn_implementation={model.config._attn_implementation}")
72+
print(f"-- model.dtype={model.dtype}")
73+
print(f"-- model.device={model.device}")
74+
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
75+
print(f"-- processor={type(processor)}")
76+
77+
inputs = dict(
78+
hidden_states=torch.rand((1292, 1176), dtype=torch_dtype).to(device),
79+
grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device),
80+
)
81+
big_inputs = (
82+
dict(
83+
hidden_states=torch.rand((14308, 1176), dtype=torch_dtype).to(device),
84+
grid_thw=torch.tensor([[1, 98, 146]], dtype=torch.int64).to(device),
85+
)
86+
if second_input
87+
else None
88+
)
89+
90+
model_to_export = model.visual if hasattr(model, "visual") else model.model.visual
91+
if not os.environ.get("STOPAT", ""):
92+
print(f"-- compute with inputs: {string_type(inputs, with_shape=True)}")
93+
expected = model_to_export(**inputs)
94+
print(f"-- got: {string_type(expected, with_shape=True)}")
95+
print(f"-- compute with inputs: {string_type(big_inputs, with_shape=True)}")
96+
expected_big = None if big_inputs is None else model_to_export(**big_inputs)
97+
print(f"-- got: {string_type(expected_big, with_shape=True)}")
98+
else:
99+
expected = None
100+
expected_big = None
101+
print(f"-- expected: {string_type(expected, with_shape=True)}")
102+
103+
dynamic_shapes = dict(
104+
hidden_states={0: "hidden_width", 1: "hidden_height"},
105+
grid_thw={}, # {0: "n_images"}, # TODO: fix
106+
)
107+
108+
filename = f"qwen25_vli_visual.{device}.{dtype}.{exporter}.onnx"
109+
print(f"-- export in {filename!r}")
110+
111+
export_inputs = inputs
112+
with torch_export_patches(
113+
patch_torch=False,
114+
patch_sympy=False,
115+
patch_transformers=True,
116+
verbose=1,
117+
stop_if_static=2,
118+
):
119+
if expected is None:
120+
expected = model_to_export(**inputs)
121+
expected_big = None if big_inputs is None else model_to_export(**big_inputs)
122+
to_onnx(
123+
model_to_export,
124+
kwargs=export_inputs,
125+
dynamic_shapes=dynamic_shapes,
126+
filename=filename,
127+
exporter=exporter,
128+
verbose=1,
129+
save_ep=None,
130+
target_opset=22,
131+
optimize=True,
132+
onnx_plugs=PLUGS,
133+
)
134+
135+
print("-- checking discrepancies")
136+
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
137+
if device == "cpu":
138+
providers = providers[1:]
139+
sess = onnxruntime.InferenceSession(filename, providers=providers)
140+
141+
print(f"-- inputs {string_type(inputs, with_shape=True)}")
142+
feeds = {k: v.detach().cpu().numpy() for k, v in inputs.items()}
143+
small = sess.run(None, feeds)
144+
diff = max_diff(expected, small[0], hist=[0.1])
145+
print(f"-- discrepancies={diff}")
146+
147+
if second_input:
148+
print(f"-- inputs {string_type(big_inputs, with_shape=True)}")
149+
feeds = {k: v.detach().cpu().numpy() for k, v in big_inputs.items()}
150+
big = sess.run(None, feeds)
151+
diff = max_diff(expected_big, big[0], hist=[0.1])
152+
print(f"-- discrepancies={diff}")
153+
154+
155+
def get_parser() -> ArgumentParser:
156+
parser = ArgumentParser(
157+
prog="qwen25", description="""Export visual part of model Qwen 2.5 VL."""
158+
)
159+
parser.add_argument(
160+
"-m",
161+
"--mid",
162+
type=str,
163+
default="Qwen/Qwen2.5-VL-7B-Instruct",
164+
help="model id, default is Qwen/Qwen2.5-VL-7B-Instruct",
165+
)
166+
parser.add_argument("-d", "--device", default="cpu", help="Device, cpu (default) or cuda.")
167+
parser.add_argument(
168+
"-t", "--dtype", default="float32", help="dtype, float32 (default) or float16"
169+
)
170+
parser.add_argument(
171+
"-e", "--exporter", default="onnx-dynamo", help="exporter, default is onnx-dynamo"
172+
)
173+
parser.add_argument(
174+
"--pretrained",
175+
default=True,
176+
help="use pretrained model or a random model",
177+
action=BooleanOptionalAction,
178+
)
179+
parser.add_argument(
180+
"--second-input",
181+
default=True,
182+
help="check discrepancies with other inputs",
183+
action=BooleanOptionalAction,
184+
)
185+
return parser
186+
187+
188+
if __name__ == "__main__":
189+
parser = get_parser()
190+
args = parser.parse_args(sys.argv[1:])
191+
main(
192+
model_id=args.mid,
193+
device=args.device,
194+
dtype=args.dtype,
195+
exporter=args.exporter,
196+
pretrained=args.pretrained,
197+
second_input=args.second_input,
198+
)

0 commit comments

Comments
 (0)