Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 75 additions & 48 deletions _scripts/export_qwen25_vl_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def main(
second_input: bool = True,
make_zip: bool = False,
output_folder: str = "dump_models",
existing_onnx: str | None = None,
):
prefix = simplify_model_id_for_a_filename(model_id)
if "QWEN25ATTENTION" in os.environ:
Expand All @@ -115,7 +116,7 @@ def main(
print("------------------------------------------------------------------")
print(f"-- export in {filename!r}")

if os.path.exists(stat_file):
if os.path.exists(stat_file) and not existing_onnx:
print(f"-- skipping because {stat_file!r} already exists")
return

Expand Down Expand Up @@ -278,55 +279,73 @@ def compute_expected():
compute_expected() if not os.environ.get("STOPAT", "") else (None, None)
)

print("-- ######")
print("-- EXPORT")
print("-- ######")
if existing_onnx and os.path.exists(existing_onnx):
exporter = existing_onnx
filename = existing_onnx
export_duration = None
target_opset = None
else:
print("-- ######")
print("-- EXPORT")
print("-- ######")

dynamic_shapes = dict(
hidden_states={0: "hidden_width", 1: "hidden_height"},
grid_thw={}, # {0: "n_images"}, # TODO: fix
)
dynamic_shapes = dict(
hidden_states={0: "hidden_width", 1: "hidden_height"},
grid_thw={}, # {0: "n_images"}, # TODO: fix
)

begin = time.perf_counter()

target_opset = 22
if exporter == "onnx-dynamo" and device == "cuda" and "QWEN25ATTENTION" not in os.environ:
os.environ["QWEN25ATTENTION"] = "PACKED"
elif "QWEN25ATTENTION" in os.environ and os.environ["QWEN25ATTENTION"] == "LOOPA23":
target_opset = 23

with torch_export_patches(
patch_torch=False,
patch_sympy=False,
patch_transformers=True,
verbose=1,
stop_if_static=2,
):
if export_expected is None:
export_expected, other_expected, durations = compute_expected()
to_onnx(
model_to_export,
kwargs=export_inputs,
dynamic_shapes=dynamic_shapes,
filename=filename,
exporter=exporter,
begin = time.perf_counter()

target_opset = 22
if (
exporter == "onnx-dynamo"
and device == "cuda"
and "QWEN25ATTENTION" not in os.environ
):
os.environ["QWEN25ATTENTION"] = "PACKED"
elif "QWEN25ATTENTION" in os.environ and os.environ["QWEN25ATTENTION"] == "LOOPA23":
target_opset = 23

with torch_export_patches(
patch_torch=False,
patch_sympy=False,
patch_transformers=True,
verbose=1,
save_ep=None,
target_opset=target_opset,
optimize=True,
onnx_plugs=PLUGS,
)
export_duration = time.perf_counter() - begin
stop_if_static=2,
):
if export_expected is None:
export_expected, other_expected, durations = compute_expected()
to_onnx(
model_to_export,
kwargs=export_inputs,
dynamic_shapes=dynamic_shapes,
filename=filename,
exporter=exporter,
verbose=1,
save_ep=None,
target_opset=target_opset,
optimize=True,
onnx_plugs=PLUGS,
)
export_duration = time.perf_counter() - begin

if exporter == "onnx-dynamo":
# onnx-dynamo fails at producing function body with sequences as input / output.
# They are replaced by tensor type one step in the model.
print("-- remove_body_last_input_output_for_loop")
remove_inplace_body_last_input_output_type_for_loop(filename)
print("-- done.")
if exporter == "onnx-dynamo":
# onnx-dynamo fails at producing function body with sequences as input / output.
# They are replaced by tensor type one step in the model.
print("-- remove_body_last_input_output_for_loop")
remove_inplace_body_last_input_output_type_for_loop(filename)
print("-- done.")

with open(stat_file, "w") as f:

def _rename(k):
if rename_inputs:
if k == "hidden_states":
return "pixel_values"
if k == "grid_thw":
return "image_grid_thw"
return k

def fprint(s):
print(s)
f.write(f"{s}\n")
Expand All @@ -337,21 +356,22 @@ def fprint(s):
providers = providers[1:]
fprint(f"-- checking discrepancies with providers={providers!r}")
sess = onnxruntime.InferenceSession(filename, providers=providers)
rename_inputs = sess.get_inputs()[0].name != "hidden_states"

fprint(
f"-- export_inputs {string_type(export_inputs, with_shape=True, with_device=True)}"
)
fprint(
f"-- export_expected {string_type(export_expected, with_shape=True, with_device=True)}"
)
feeds = {k: v.detach().cpu().numpy() for k, v in export_inputs.items()}
feeds = {_rename(k): v.detach().cpu().numpy() for k, v in export_inputs.items()}
small = sess.run(None, feeds)
diff = max_diff(export_expected, small[0], hist=[0.1, 0.01])
fprint(f"-- discrepancies={diff}")

if second_input:
feeds = [
{k: v.detach().cpu().numpy() for k, v in inputs.items()}
{_rename(k): v.detach().cpu().numpy() for k, v in inputs.items()}
for inputs in other_inputs
]
fprint("")
Expand Down Expand Up @@ -440,7 +460,7 @@ def fprint(s):
]
stat = (
df[[*index, *values]]
.groupby(index)
.groupby(index, dropna=False)
.agg(
{
**{c: "max" for c in values if c != "speedup"},
Expand All @@ -450,8 +470,8 @@ def fprint(s):
)
stat.to_excel(statistics + ".agg.xlsx")
stat = (
df[df.exporter == "onnx-dynamo"][[*index, *values]]
.groupby(index)
df[df.exporter != "custom"][[*index, *values]]
.groupby(index, dropna=False)
.agg(
{
**{c: "max" for c in values if c != "speedup"},
Expand Down Expand Up @@ -517,6 +537,12 @@ def get_parser() -> ArgumentParser:
help="Folders where to put the results.",
action=BooleanOptionalAction,
)
parser.add_argument(
"-x",
"--existing-onnx",
default="",
help="If an onnx file exists, only measures the disrepancies.",
)
return parser


Expand All @@ -532,4 +558,5 @@ def get_parser() -> ArgumentParser:
second_input=args.second_input,
make_zip=args.zip,
output_folder=args.output_folder,
existing_onnx=args.existing_onnx,
)
Loading