This repository was archived by the owner on May 13, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconvert_onnx.py
More file actions
48 lines (38 loc) · 1.46 KB
/
convert_onnx.py
File metadata and controls
48 lines (38 loc) · 1.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
from AI.src.modeling.backbones import NET_DEFAULT_CONFIG
def main() -> None:
RETURN_NODES = {
"rgb_i3d": {"squeeze_1": "output"},
"s3d": {"mean": "output"},
"clip_vit/b16": {"vision_model": "output"},
}
TRACER_ARGS = {
"rgb_i3d": {},
"s3d": {},
"clip_vit/b16": {"concrete_args": {"return_loss": None, "return_dict": None}},
}
SAVE_PATH = {
"rgb_i3d": r"~\Downloads\rgb_i3d.pt",
"s3d": r"~\Downloads\s3d.pt",
"clip_vit/b16": r"~\Downloads\clip_vit_b16.pt",
}
for (model_name, trace_arg), return_node, save_path in zip(TRACER_ARGS.items(),
RETURN_NODES.values(),
SAVE_PATH.values()
):
model, weight, _, dummy_input = NET_DEFAULT_CONFIG[model_name].values()
model: torch.nn.Module = model(weight)
torch.onnx.export(
model,
torch.rand(dummy_input),
save_path,
opset_version=16,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
export_params=True,
do_constant_folding=True,
)
return None
if __name__ == '__main__':
main()