Skip to content
Merged

minor #351

Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 96 additions & 80 deletions _scripts/export_qwen25_vl_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,20 @@ def main(
make_zip: bool = False,
output_folder: str = "dump_models",
existing_onnx: str | None = None,
part: str = "visual",
):
prefix = simplify_model_id_for_a_filename(model_id)
if "QWEN25ATTENTION" in os.environ:
prefix = f"{prefix}.{os.environ['QWEN25ATTENTION']}"
basename = os.path.join(
output_folder, f"model.{prefix}.visual.{device}.{dtype}.{exporter}"
output_folder, f"model.{prefix}.{part}.{device}.{dtype}.{exporter}"
)
filename = f"{basename}.onnx"
stat_file = f"{basename}.stats"

print("------------------------------------------------------------------")
print(
f"-- {model_id} {device} {dtype} {exporter} {pretrained} "
f"-- {model_id} {part} {device} {dtype} {exporter} {pretrained} "
f"{second_input} {make_zip} {output_folder} {prefix}"
)
print("------------------------------------------------------------------")
Expand Down Expand Up @@ -186,47 +187,75 @@ def _config_reduction(config, task):
print(f"-- model.device={model.device}")
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
print(f"-- processor={type(processor)}")
model_to_export = model.visual if hasattr(model, "visual") else model.model.visual
print(f"-- model_to_export={type(model_to_export)}")

print("-- ############")
print("-- INPUT/OUTPUT")
print("-- ############")

input_filename = os.path.join(output_folder, f"inputs.{prefix}.visual.{device}.{dtype}.pt")
if os.path.exists(input_filename):
print(f"-- restore inputs from {input_filename!r}")
data = torch.load(input_filename)
export_inputs = data["export_inputs"]
other_inputs = data["other_inputs"]
else:
export_inputs = dict(
hidden_states=torch.randn((1292, 1176), dtype=torch_dtype).to(device),
grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device),

if part == "visual":

class VisualPart(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, pixel_values, image_grid_thw):
return model.get_image_features(pixel_values, image_grid_thw)

assert hasattr(
model, "get_image_features"
), f"get_image_features not found in class {type(model)}"
model_to_export = VisualPart(model)

print(f"-- part={part!r}")
print(f"-- model_to_export={type(model_to_export)}")

print("-- ############")
print("-- INPUT/OUTPUT")
print("-- ############")

input_filename = os.path.join(
output_folder, f"inputs.{prefix}.{part}.{device}.{dtype}.pt"
)
other_inputs = []
if second_input:
other_inputs = [
dict(
hidden_states=torch.randn((1292, 1176), dtype=torch_dtype).to(device),
grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device),
),
dict(
hidden_states=torch.rand((1292, 1176), dtype=torch_dtype).to(device),
grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device),
),
dict(
hidden_states=torch.randn((14308, 1176), dtype=torch_dtype).to(device),
grid_thw=torch.tensor([[1, 98, 146]], dtype=torch.int64).to(device),
),
dict(
hidden_states=torch.rand((14308, 1176), dtype=torch_dtype).to(device),
grid_thw=torch.tensor([[1, 98, 146]], dtype=torch.int64).to(device),
),
]
data = dict(export_inputs=export_inputs, other_inputs=other_inputs)
print(f"-- dump inputs into {input_filename!r}")
torch.save(data, input_filename)
if os.path.exists(input_filename):
print(f"-- restore inputs from {input_filename!r}")
data = torch.load(input_filename)
export_inputs = data["export_inputs"]
other_inputs = data["other_inputs"]
else:
export_inputs = dict(
pixel_values=torch.randn((1292, 1176), dtype=torch_dtype).to(device),
image_grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device),
)
other_inputs = []
if second_input:
other_inputs = [
dict(
pixel_values=torch.randn((1292, 1176), dtype=torch_dtype).to(device),
image_grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(
device
),
),
dict(
pixel_values=torch.rand((1292, 1176), dtype=torch_dtype).to(device),
image_grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(
device
),
),
dict(
pixel_values=torch.randn((14308, 1176), dtype=torch_dtype).to(device),
image_grid_thw=torch.tensor([[1, 98, 146]], dtype=torch.int64).to(
device
),
),
dict(
pixel_values=torch.rand((14308, 1176), dtype=torch_dtype).to(device),
image_grid_thw=torch.tensor([[1, 98, 146]], dtype=torch.int64).to(
device
),
),
]
data = dict(export_inputs=export_inputs, other_inputs=other_inputs)
print(f"-- dump inputs into {input_filename!r}")
torch.save(data, input_filename)
else:
raise NotImplementedError(f"part={part!r} not implemnted yet")

print(f"-- export_inputs={string_type(export_inputs, with_shape=True, with_device=True)}")
print(f"-- other_inputs={string_type(other_inputs, with_shape=True, with_device=True)}")
Expand Down Expand Up @@ -290,8 +319,8 @@ def compute_expected():
print("-- ######")

dynamic_shapes = dict(
hidden_states={0: "hidden_width", 1: "hidden_height"},
grid_thw={}, # {0: "n_images"}, # TODO: fix
pixel_values={0: "hidden_width", 1: "hidden_height"},
image_grid_thw={}, # {0: "n_images"}, # TODO: fix
)

begin = time.perf_counter()
Expand Down Expand Up @@ -336,15 +365,11 @@ def compute_expected():
remove_inplace_body_last_input_output_type_for_loop(filename)
print("-- done.")

with open(stat_file, "w") as f:
###############
# check for discrepancies
###############

def _rename(k):
if rename_inputs:
if k == "hidden_states":
return "pixel_values"
if k == "grid_thw":
return "image_grid_thw"
return k
with open(stat_file, "w") as f:

def fprint(s):
print(s)
Expand All @@ -355,23 +380,23 @@ def fprint(s):
if device == "cpu":
providers = providers[1:]
fprint(f"-- checking discrepancies with providers={providers!r}")
fprint(f"-- filename={filename!r}")
sess = onnxruntime.InferenceSession(filename, providers=providers)
rename_inputs = sess.get_inputs()[0].name != "hidden_states"

fprint(
f"-- export_inputs {string_type(export_inputs, with_shape=True, with_device=True)}"
)
fprint(
f"-- export_expected {string_type(export_expected, with_shape=True, with_device=True)}"
)
feeds = {_rename(k): v.detach().cpu().numpy() for k, v in export_inputs.items()}
feeds = {k: v.detach().cpu().numpy() for k, v in export_inputs.items()}
small = sess.run(None, feeds)
diff = max_diff(export_expected, small[0], hist=[0.1, 0.01])
fprint(f"-- discrepancies={diff}")

if second_input:
feeds = [
{_rename(k): v.detach().cpu().numpy() for k, v in inputs.items()}
{k: v.detach().cpu().numpy() for k, v in inputs.items()}
for inputs in other_inputs
]
fprint("")
Expand Down Expand Up @@ -399,6 +424,7 @@ def fprint(s):

info = {
"model_id": model_id,
"part": part,
"device": device,
"dtype": dtype,
"exporter": exporter,
Expand Down Expand Up @@ -432,23 +458,16 @@ def fprint(s):
"timestamp",
"model_id",
"pretrained",
"part",
"device",
"dtype",
"attention",
"opset",
]
index = [*first[1:], "exporter"]
df = df[[*first, *[c for c in df.columns if c not in set(first)]]]
df.to_excel(statistics + ".xlsx")

index = [
"model_id",
"pretrained",
"device",
"dtype",
"attention",
"opset",
"exporter",
]
values = [
"abs",
"%>0.1",
Expand All @@ -458,26 +477,16 @@ def fprint(s):
"latency_torch",
"latency_ort_n",
]
stat = (
df[[*index, *values]]
.groupby(index, dropna=False)
.agg(
{
**{c: "max" for c in values if c != "speedup"},
"speedup": "min",
}
)
)
agg = {
**{c: "max" for c in values if c != "speedup"},
"speedup": "min",
}
stat = df[[*index, *values]].groupby(index, dropna=False).agg(agg)
stat.to_excel(statistics + ".agg.xlsx")
stat = (
df[df.exporter != "custom"][[*index, *values]]
.groupby(index, dropna=False)
.agg(
{
**{c: "max" for c in values if c != "speedup"},
"speedup": "min",
}
)
.agg(agg)
)
stat.to_excel(statistics + ".agg.onnx-dynamo.xlsx")

Expand Down Expand Up @@ -543,6 +552,12 @@ def get_parser() -> ArgumentParser:
default="",
help="If an onnx file exists, only measures the disrepancies.",
)
parser.add_argument(
"-p",
"--part",
default="visual",
help="part of the model to export",
)
return parser


Expand All @@ -559,4 +574,5 @@ def get_parser() -> ArgumentParser:
make_zip=args.zip,
output_folder=args.output_folder,
existing_onnx=args.existing_onnx,
part=args.part,
)
Loading