From 121f6afb0d02019d81024c571f7b370c91e4df4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 9 Dec 2025 18:29:31 +0000 Subject: [PATCH] minor --- _scripts/export_qwen25_vl_visual.py | 176 +++++++++++++++------------- 1 file changed, 96 insertions(+), 80 deletions(-) diff --git a/_scripts/export_qwen25_vl_visual.py b/_scripts/export_qwen25_vl_visual.py index 4f68e19d..48887786 100644 --- a/_scripts/export_qwen25_vl_visual.py +++ b/_scripts/export_qwen25_vl_visual.py @@ -98,19 +98,20 @@ def main( make_zip: bool = False, output_folder: str = "dump_models", existing_onnx: str | None = None, + part: str = "visual", ): prefix = simplify_model_id_for_a_filename(model_id) if "QWEN25ATTENTION" in os.environ: prefix = f"{prefix}.{os.environ['QWEN25ATTENTION']}" basename = os.path.join( - output_folder, f"model.{prefix}.visual.{device}.{dtype}.{exporter}" + output_folder, f"model.{prefix}.{part}.{device}.{dtype}.{exporter}" ) filename = f"{basename}.onnx" stat_file = f"{basename}.stats" print("------------------------------------------------------------------") print( - f"-- {model_id} {device} {dtype} {exporter} {pretrained} " + f"-- {model_id} {part} {device} {dtype} {exporter} {pretrained} " f"{second_input} {make_zip} {output_folder} {prefix}" ) print("------------------------------------------------------------------") @@ -186,47 +187,75 @@ def _config_reduction(config, task): print(f"-- model.device={model.device}") processor = AutoProcessor.from_pretrained(model_id, use_fast=True) print(f"-- processor={type(processor)}") - model_to_export = model.visual if hasattr(model, "visual") else model.model.visual - print(f"-- model_to_export={type(model_to_export)}") - - print("-- ############") - print("-- INPUT/OUTPUT") - print("-- ############") - - input_filename = os.path.join(output_folder, f"inputs.{prefix}.visual.{device}.{dtype}.pt") - if os.path.exists(input_filename): - print(f"-- restore inputs from {input_filename!r}") - data = torch.load(input_filename) - export_inputs = data["export_inputs"] - other_inputs = data["other_inputs"] - else: - export_inputs = dict( - hidden_states=torch.randn((1292, 1176), dtype=torch_dtype).to(device), - grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device), + + if part == "visual": + + class VisualPart(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, pixel_values, image_grid_thw): + return model.get_image_features(pixel_values, image_grid_thw) + + assert hasattr( + model, "get_image_features" + ), f"get_image_features not found in class {type(model)}" + model_to_export = VisualPart(model) + + print(f"-- part={part!r}") + print(f"-- model_to_export={type(model_to_export)}") + + print("-- ############") + print("-- INPUT/OUTPUT") + print("-- ############") + + input_filename = os.path.join( + output_folder, f"inputs.{prefix}.{part}.{device}.{dtype}.pt" ) - other_inputs = [] - if second_input: - other_inputs = [ - dict( - hidden_states=torch.randn((1292, 1176), dtype=torch_dtype).to(device), - grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device), - ), - dict( - hidden_states=torch.rand((1292, 1176), dtype=torch_dtype).to(device), - grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device), - ), - dict( - hidden_states=torch.randn((14308, 1176), dtype=torch_dtype).to(device), - grid_thw=torch.tensor([[1, 98, 146]], dtype=torch.int64).to(device), - ), - dict( - hidden_states=torch.rand((14308, 1176), dtype=torch_dtype).to(device), - grid_thw=torch.tensor([[1, 98, 146]], dtype=torch.int64).to(device), - ), - ] - data = dict(export_inputs=export_inputs, other_inputs=other_inputs) - print(f"-- dump inputs into {input_filename!r}") - torch.save(data, input_filename) + if os.path.exists(input_filename): + print(f"-- restore inputs from {input_filename!r}") + data = torch.load(input_filename) + export_inputs = data["export_inputs"] + other_inputs = data["other_inputs"] + else: + export_inputs = dict( + pixel_values=torch.randn((1292, 1176), dtype=torch_dtype).to(device), + image_grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device), + ) + other_inputs = [] + if second_input: + other_inputs = [ + dict( + pixel_values=torch.randn((1292, 1176), dtype=torch_dtype).to(device), + image_grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to( + device + ), + ), + dict( + pixel_values=torch.rand((1292, 1176), dtype=torch_dtype).to(device), + image_grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to( + device + ), + ), + dict( + pixel_values=torch.randn((14308, 1176), dtype=torch_dtype).to(device), + image_grid_thw=torch.tensor([[1, 98, 146]], dtype=torch.int64).to( + device + ), + ), + dict( + pixel_values=torch.rand((14308, 1176), dtype=torch_dtype).to(device), + image_grid_thw=torch.tensor([[1, 98, 146]], dtype=torch.int64).to( + device + ), + ), + ] + data = dict(export_inputs=export_inputs, other_inputs=other_inputs) + print(f"-- dump inputs into {input_filename!r}") + torch.save(data, input_filename) + else: + raise NotImplementedError(f"part={part!r} not implemnted yet") print(f"-- export_inputs={string_type(export_inputs, with_shape=True, with_device=True)}") print(f"-- other_inputs={string_type(other_inputs, with_shape=True, with_device=True)}") @@ -290,8 +319,8 @@ def compute_expected(): print("-- ######") dynamic_shapes = dict( - hidden_states={0: "hidden_width", 1: "hidden_height"}, - grid_thw={}, # {0: "n_images"}, # TODO: fix + pixel_values={0: "hidden_width", 1: "hidden_height"}, + image_grid_thw={}, # {0: "n_images"}, # TODO: fix ) begin = time.perf_counter() @@ -336,15 +365,11 @@ def compute_expected(): remove_inplace_body_last_input_output_type_for_loop(filename) print("-- done.") - with open(stat_file, "w") as f: + ############### + # check for discrepancies + ############### - def _rename(k): - if rename_inputs: - if k == "hidden_states": - return "pixel_values" - if k == "grid_thw": - return "image_grid_thw" - return k + with open(stat_file, "w") as f: def fprint(s): print(s) @@ -355,8 +380,8 @@ def fprint(s): if device == "cpu": providers = providers[1:] fprint(f"-- checking discrepancies with providers={providers!r}") + fprint(f"-- filename={filename!r}") sess = onnxruntime.InferenceSession(filename, providers=providers) - rename_inputs = sess.get_inputs()[0].name != "hidden_states" fprint( f"-- export_inputs {string_type(export_inputs, with_shape=True, with_device=True)}" @@ -364,14 +389,14 @@ def fprint(s): fprint( f"-- export_expected {string_type(export_expected, with_shape=True, with_device=True)}" ) - feeds = {_rename(k): v.detach().cpu().numpy() for k, v in export_inputs.items()} + feeds = {k: v.detach().cpu().numpy() for k, v in export_inputs.items()} small = sess.run(None, feeds) diff = max_diff(export_expected, small[0], hist=[0.1, 0.01]) fprint(f"-- discrepancies={diff}") if second_input: feeds = [ - {_rename(k): v.detach().cpu().numpy() for k, v in inputs.items()} + {k: v.detach().cpu().numpy() for k, v in inputs.items()} for inputs in other_inputs ] fprint("") @@ -399,6 +424,7 @@ def fprint(s): info = { "model_id": model_id, + "part": part, "device": device, "dtype": dtype, "exporter": exporter, @@ -432,23 +458,16 @@ def fprint(s): "timestamp", "model_id", "pretrained", + "part", "device", "dtype", "attention", "opset", ] + index = [*first[1:], "exporter"] df = df[[*first, *[c for c in df.columns if c not in set(first)]]] df.to_excel(statistics + ".xlsx") - index = [ - "model_id", - "pretrained", - "device", - "dtype", - "attention", - "opset", - "exporter", - ] values = [ "abs", "%>0.1", @@ -458,26 +477,16 @@ def fprint(s): "latency_torch", "latency_ort_n", ] - stat = ( - df[[*index, *values]] - .groupby(index, dropna=False) - .agg( - { - **{c: "max" for c in values if c != "speedup"}, - "speedup": "min", - } - ) - ) + agg = { + **{c: "max" for c in values if c != "speedup"}, + "speedup": "min", + } + stat = df[[*index, *values]].groupby(index, dropna=False).agg(agg) stat.to_excel(statistics + ".agg.xlsx") stat = ( df[df.exporter != "custom"][[*index, *values]] .groupby(index, dropna=False) - .agg( - { - **{c: "max" for c in values if c != "speedup"}, - "speedup": "min", - } - ) + .agg(agg) ) stat.to_excel(statistics + ".agg.onnx-dynamo.xlsx") @@ -543,6 +552,12 @@ def get_parser() -> ArgumentParser: default="", help="If an onnx file exists, only measures the disrepancies.", ) + parser.add_argument( + "-p", + "--part", + default="visual", + help="part of the model to export", + ) return parser @@ -559,4 +574,5 @@ def get_parser() -> ArgumentParser: make_zip=args.zip, output_folder=args.output_folder, existing_onnx=args.existing_onnx, + part=args.part, )