Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 198 additions & 0 deletions _scripts/qwen25_vl_visual.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import os
import sys
from argparse import ArgumentParser, BooleanOptionalAction


def main(
model_id: str = "Qwen/Qwen2.5-VL-7B-Instruct",
device: str = "cpu",
dtype: str = "float32",
exporter: str = "onnx-dynamo",
pretrained: bool = True,
second_input: bool = True,
):
print("-- import torch")
import torch

print("-- import onnxruntime")
import onnxruntime

print("-- import transformers")
from transformers import AutoModel, AutoProcessor

print("-- import onnx_diagnostic")
from onnx_diagnostic.helpers import string_type, max_diff
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
PLUGS,
)
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
from onnx_diagnostic.export.api import to_onnx

print(f"-- creating model {model_id!r}")
print(
f"-- device={device!r}, dtype={dtype!r}, exporter={exporter!r}, "
f"pretrained={pretrained!r}"
)
torch_dtype = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}[dtype]

if pretrained:
print("-- pretrained model")
model = AutoModel.from_pretrained(
model_id, device_map=device, dtype=torch_dtype, attn_implementation="sdpa"
).eval()
else:
print("-- random model")

def _config_reduction(config, task):
return {
# "num_hidden_layers": 2,
"text_config": {
"num_hidden_layers": 2,
"layer_types": ["full_attention", "full_attention"],
},
# "_attn_implementation": "flash_attention_2",
"_attn_implementation": "sdpa",
"dtype": "float16",
}

config_reduction = _config_reduction
data = get_untrained_model_with_inputs(
model_id, verbose=1, add_second_input=False, config_reduction=config_reduction
)
model = data["model"]

model = model.to(device).to(getattr(torch, dtype))

print(f"-- config._attn_implementation={model.config._attn_implementation}")
print(f"-- model.dtype={model.dtype}")
print(f"-- model.device={model.device}")
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
print(f"-- processor={type(processor)}")

inputs = dict(
hidden_states=torch.rand((1292, 1176), dtype=torch_dtype).to(device),
grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device),
)
big_inputs = (
dict(
hidden_states=torch.rand((14308, 1176), dtype=torch_dtype).to(device),
grid_thw=torch.tensor([[1, 98, 146]], dtype=torch.int64).to(device),
)
if second_input
else None
)

model_to_export = model.visual if hasattr(model, "visual") else model.model.visual
if not os.environ.get("STOPAT", ""):
print(f"-- compute with inputs: {string_type(inputs, with_shape=True)}")
expected = model_to_export(**inputs)
print(f"-- got: {string_type(expected, with_shape=True)}")
print(f"-- compute with inputs: {string_type(big_inputs, with_shape=True)}")
expected_big = None if big_inputs is None else model_to_export(**big_inputs)
print(f"-- got: {string_type(expected_big, with_shape=True)}")
else:
expected = None
expected_big = None
print(f"-- expected: {string_type(expected, with_shape=True)}")

dynamic_shapes = dict(
hidden_states={0: "hidden_width", 1: "hidden_height"},
grid_thw={}, # {0: "n_images"}, # TODO: fix
)

filename = f"qwen25_vli_visual.{device}.{dtype}.{exporter}.onnx"
print(f"-- export in {filename!r}")

export_inputs = inputs
with torch_export_patches(
patch_torch=False,
patch_sympy=False,
patch_transformers=True,
verbose=1,
stop_if_static=2,
):
if expected is None:
expected = model_to_export(**inputs)
expected_big = None if big_inputs is None else model_to_export(**big_inputs)
to_onnx(
model_to_export,
kwargs=export_inputs,
dynamic_shapes=dynamic_shapes,
filename=filename,
exporter=exporter,
verbose=1,
save_ep=None,
target_opset=22,
optimize=True,
onnx_plugs=PLUGS,
)

print("-- checking discrepancies")
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
if device == "cpu":
providers = providers[1:]
sess = onnxruntime.InferenceSession(filename, providers=providers)

print(f"-- inputs {string_type(inputs, with_shape=True)}")
feeds = {k: v.detach().cpu().numpy() for k, v in inputs.items()}
small = sess.run(None, feeds)
diff = max_diff(expected, small[0], hist=[0.1])
print(f"-- discrepancies={diff}")

if second_input:
print(f"-- inputs {string_type(big_inputs, with_shape=True)}")
feeds = {k: v.detach().cpu().numpy() for k, v in big_inputs.items()}
big = sess.run(None, feeds)
diff = max_diff(expected_big, big[0], hist=[0.1])
print(f"-- discrepancies={diff}")


def get_parser() -> ArgumentParser:
parser = ArgumentParser(
prog="qwen25", description="""Export visual part of model Qwen 2.5 VL."""
)
parser.add_argument(
"-m",
"--mid",
type=str,
default="Qwen/Qwen2.5-VL-7B-Instruct",
help="model id, default is Qwen/Qwen2.5-VL-7B-Instruct",
)
parser.add_argument("-d", "--device", default="cpu", help="Device, cpu (default) or cuda.")
parser.add_argument(
"-t", "--dtype", default="float32", help="dtype, float32 (default) or float16"
)
parser.add_argument(
"-e", "--exporter", default="onnx-dynamo", help="exporter, default is onnx-dynamo"
)
parser.add_argument(
"--pretrained",
default=True,
help="use pretrained model or a random model",
action=BooleanOptionalAction,
)
parser.add_argument(
"--second-input",
default=True,
help="check discrepancies with other inputs",
action=BooleanOptionalAction,
)
return parser


if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args(sys.argv[1:])
main(
model_id=args.mid,
device=args.device,
dtype=args.dtype,
exporter=args.exporter,
pretrained=args.pretrained,
second_input=args.second_input,
)
1 change: 1 addition & 0 deletions _unittests/ut_helpers/test_cache_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def test_unflatten_flatten_encoder_decoder_cache(self):
)
self.assertEqual(0, max_diff(c2, c2)["abs"])
self.assertIsInstance(c2, transformers.cache_utils.EncoderDecoderCache)
self.assertEqual(max_diff(c2, c2)["abs"], 0)
flat, _spec = torch.utils._pytree.tree_flatten(c2)
self.assertIsInstance(flat, list)
self.assertEqual(len(flat), 12)
Expand Down
7 changes: 7 additions & 0 deletions _unittests/ut_helpers/test_onnx_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
extract_subset_of_nodes,
make_submodel,
select_model_inputs_outputs,
_enumerate_model_node_outputs,
)


Expand Down Expand Up @@ -602,6 +603,12 @@ def _get_model_select(self):
)
return onnx_model

def test__enumerate_model_node_outputs(self):
model = self._get_model_select()
outputs1 = list(_enumerate_model_node_outputs(model, order=False))
outputs2 = list(_enumerate_model_node_outputs(model, order=True))
self.assertEqual(set(outputs1), set(outputs2))

def test_select_model_inputs_outputs(self):
def enumerate_model_tensors(model):
for tensor in _get_all_tensors(model):
Expand Down
9 changes: 1 addition & 8 deletions _unittests/ut_torch_models/test_validate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,11 @@
requires_experimental,
requires_transformers,
requires_cuda,
has_torch,
has_transformers,
)
from onnx_diagnostic.torch_models.validate import validate_model


torch29_and_tr_main = not has_torch("2.9.9") and has_transformers("4.99999")


class TestValidateModel(ExtTestCase):
@unittest.skipIf(torch29_and_tr_main, "combination not working")
@requires_transformers("4.53")
@requires_torch("2.7.99")
@requires_experimental()
Expand All @@ -40,12 +34,12 @@ def test_validate_tiny_llms_bfloat16(self):
dtype="bfloat16",
device="cuda",
runtime="orteval",
optimization="default+onnxruntime+os_ort",
)
self.assertLess(summary["disc_onnx_ort_run_abs"], 2e-2)
self.assertIn("onnx_filename", data)
self.clean_dump()

@unittest.skipIf(torch29_and_tr_main, "combination not working")
@requires_transformers("4.57") # 4.53 works for some jobs fails due to no space left
@requires_torch("2.9.99") # 2.9 works for some jobs fails due to no space left
@requires_experimental()
Expand All @@ -68,7 +62,6 @@ def test_validate_microsoft_phi4_reasoning(self):
self.assertIn("onnx_filename", data)
self.clean_dump()

@unittest.skipIf(torch29_and_tr_main, "combination not working")
@requires_transformers("4.53")
@requires_torch("2.8.99")
@requires_experimental()
Expand Down
6 changes: 0 additions & 6 deletions _unittests/ut_torch_models/test_validate_whole_models1.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
requires_experimental,
requires_onnxscript,
requires_transformers,
has_torch,
has_transformers,
)
from onnx_diagnostic.torch_models.validate import (
get_inputs_for_task,
Expand All @@ -24,9 +22,6 @@
from onnx_diagnostic.tasks import supported_tasks


torch29_and_tr_main = not has_torch("2.9.9") and has_transformers("4.99999")


class TestValidateWholeModels1(ExtTestCase):
def test_a_get_inputs_for_task(self):
fcts = supported_tasks()
Expand Down Expand Up @@ -205,7 +200,6 @@ def test_k_filter_inputs(self):
ni, nd = filter_inputs(inputs, dynamic_shapes=ds, drop_names=["a"], model=["a", "b"])
self.assertEqual((ni, nd), (((None,), {"b": 4}), {"b": 30}))

@unittest.skipIf(torch29_and_tr_main, "combination not working")
@requires_torch("2.9.99")
@hide_stdout()
@ignore_warnings(FutureWarning)
Expand Down
5 changes: 0 additions & 5 deletions _unittests/ut_torch_models/test_validate_whole_models2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,11 @@
ignore_warnings,
requires_torch,
requires_transformers,
has_torch,
has_transformers,
)
from onnx_diagnostic.torch_models.validate import validate_model

torch29_and_tr_main = not has_torch("2.9.9") and has_transformers("4.99999")


class TestValidateWholeModels2(ExtTestCase):
@unittest.skipIf(torch29_and_tr_main, "combination not working")
@requires_torch("2.9")
@hide_stdout()
@ignore_warnings(FutureWarning)
Expand Down
5 changes: 0 additions & 5 deletions _unittests/ut_torch_models/test_validate_whole_models3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,11 @@
ignore_warnings,
requires_torch,
requires_transformers,
has_torch,
has_transformers,
)
from onnx_diagnostic.torch_models.validate import validate_model

torch29_and_tr_main = not has_torch("2.9.9") and has_transformers("4.99999")


class TestValidateWholeModels3(ExtTestCase):
@unittest.skipIf(torch29_and_tr_main, "combination not working")
@requires_torch("2.7")
@hide_stdout()
@ignore_warnings(FutureWarning)
Expand Down
Loading
Loading