-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathsupported_models.py
More file actions
61 lines (47 loc) · 1.94 KB
/
supported_models.py
File metadata and controls
61 lines (47 loc) · 1.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from typing import Dict, List
from collections import OrderedDict
from collators import COLLATORS
from loaders import LOADERS
MODULE_KEYWORDS: Dict[str, Dict[str, List]] = {
"qwen2-vl-7b": {
"vision_encoder": ["visual.patch_embed", "visual.rotary_pos_emb", "visual.blocks"],
"vision_projector": ["visual.merger"],
"llm": ["model"]
},
"qwen2-vl-2b": {
"vision_encoder": ["visual.patch_embed", "visual.rotary_pos_emb", "visual.blocks"],
"vision_projector": ["visual.merger"],
"llm": ["model"]
}
}
MODEL_HF_PATH = OrderedDict()
MODEL_FAMILIES = OrderedDict()
def register_model(model_id: str, model_family_id: str, model_hf_path: str) -> None:
if model_id in MODEL_HF_PATH or model_id in MODEL_FAMILIES:
raise ValueError(f"Duplicate model_id: {model_id}")
MODEL_HF_PATH[model_id] = model_hf_path
MODEL_FAMILIES[model_id] = model_family_id
#=============================================================
register_model(
model_id="qwen2-vl-7b",
model_family_id="qwen2-vl-7b",
model_hf_path="./checkpoints/hf_models/Qwen2-VL-7B-Instruct"
)
register_model(
model_id="qwen2-vl-2b",
model_family_id="qwen2-vl-2b",
model_hf_path="./checkpoints/hf_models/Qwen2-VL-2B-Instruct"
)
# sanity check
for model_family_id in MODEL_FAMILIES.values():
assert model_family_id in COLLATORS, f"Collator not found for model family: {model_family_id}"
assert model_family_id in LOADERS, f"Loader not found for model family: {model_family_id}"
assert model_family_id in MODULE_KEYWORDS, f"Module keywords not found for model family: {model_family_id}"
if __name__ == "__main__":
temp = "Model ID"
ljust = 30
print("Supported models:")
print(f" {temp.ljust(ljust)}: HuggingFace Path")
print(" ------------------------------------------------")
for model_id, model_hf_path in MODEL_HF_PATH.items():
print(f" {model_id.ljust(ljust)}: {model_hf_path}")