Skip to content

Commit cc7200b

Browse files
pwilkinCISC
andauthored
Refactor: convert_hf_to_gguf.py (ggml-org#17114)
* move conversion code to a dedicated conversion directory and split the files akin to the src/models architecture --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
1 parent 769cc93 commit cc7200b

82 files changed

Lines changed: 15179 additions & 14023 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

conversion/__init__.py

Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
from __future__ import annotations
2+
3+
from .base import (
4+
ModelBase, TextModel, MmprojModel, ModelType, SentencePieceTokenTypes,
5+
logger, _mistral_common_installed, _mistral_import_error_msg,
6+
get_model_architecture, LazyTorchTensor,
7+
)
8+
from typing import Type
9+
10+
11+
__all__ = [
12+
"ModelBase", "TextModel", "MmprojModel", "ModelType", "SentencePieceTokenTypes",
13+
"get_model_architecture", "LazyTorchTensor", "logger",
14+
"_mistral_common_installed", "_mistral_import_error_msg",
15+
"get_model_class", "print_registered_models", "load_all_models",
16+
]
17+
18+
19+
TEXT_MODEL_MAP: dict[str, str] = {
20+
"AfmoeForCausalLM": "afmoe",
21+
"ApertusForCausalLM": "llama",
22+
"ArceeForCausalLM": "llama",
23+
"ArcticForCausalLM": "arctic",
24+
"AudioFlamingo3ForConditionalGeneration": "qwen",
25+
"BaiChuanForCausalLM": "baichuan",
26+
"BaichuanForCausalLM": "baichuan",
27+
"BailingMoeForCausalLM": "bailingmoe",
28+
"BailingMoeV2ForCausalLM": "bailingmoe",
29+
"BambaForCausalLM": "granite",
30+
"BertForMaskedLM": "bert",
31+
"BertForSequenceClassification": "bert",
32+
"BertModel": "bert",
33+
"BitnetForCausalLM": "bitnet",
34+
"BloomForCausalLM": "bloom",
35+
"BloomModel": "bloom",
36+
"CamembertModel": "bert",
37+
"ChameleonForCausalLM": "chameleon",
38+
"ChameleonForConditionalGeneration": "chameleon",
39+
"ChatGLMForConditionalGeneration": "chatglm",
40+
"ChatGLMModel": "chatglm",
41+
"CodeShellForCausalLM": "codeshell",
42+
"CogVLMForCausalLM": "cogvlm",
43+
"Cohere2ForCausalLM": "command_r",
44+
"CohereForCausalLM": "command_r",
45+
"DbrxForCausalLM": "dbrx",
46+
"DeciLMForCausalLM": "deci",
47+
"DeepseekForCausalLM": "deepseek",
48+
"DeepseekV2ForCausalLM": "deepseek",
49+
"DeepseekV3ForCausalLM": "deepseek",
50+
"DistilBertForMaskedLM": "bert",
51+
"DistilBertForSequenceClassification": "bert",
52+
"DistilBertModel": "bert",
53+
"Dots1ForCausalLM": "dots1",
54+
"DotsOCRForCausalLM": "qwen",
55+
"DreamModel": "dream",
56+
"Ernie4_5ForCausalLM": "ernie",
57+
"Ernie4_5_ForCausalLM": "ernie",
58+
"Ernie4_5_MoeForCausalLM": "ernie",
59+
"EuroBertModel": "bert",
60+
"Exaone4ForCausalLM": "exaone",
61+
"ExaoneForCausalLM": "exaone",
62+
"ExaoneMoEForCausalLM": "exaone",
63+
"FalconForCausalLM": "falcon",
64+
"FalconH1ForCausalLM": "falcon_h1",
65+
"FalconMambaForCausalLM": "mamba",
66+
"GPT2LMHeadModel": "gpt2",
67+
"GPTBigCodeForCausalLM": "starcoder",
68+
"GPTNeoXForCausalLM": "gptneox",
69+
"GPTRefactForCausalLM": "refact",
70+
"Gemma2ForCausalLM": "gemma",
71+
"Gemma3ForCausalLM": "gemma",
72+
"Gemma3ForConditionalGeneration": "gemma",
73+
"Gemma3TextModel": "gemma",
74+
"Gemma3nForCausalLM": "gemma",
75+
"Gemma3nForConditionalGeneration": "gemma",
76+
"Gemma4ForConditionalGeneration": "gemma",
77+
"GemmaForCausalLM": "gemma",
78+
"Glm4ForCausalLM": "glm",
79+
"Glm4MoeForCausalLM": "glm",
80+
"Glm4MoeLiteForCausalLM": "glm",
81+
"Glm4vForConditionalGeneration": "glm",
82+
"Glm4vMoeForConditionalGeneration": "glm",
83+
"GlmForCausalLM": "chatglm",
84+
"GlmMoeDsaForCausalLM": "glm",
85+
"GlmOcrForConditionalGeneration": "glm",
86+
"GptOssForCausalLM": "gpt_oss",
87+
"GraniteForCausalLM": "granite",
88+
"GraniteMoeForCausalLM": "granite",
89+
"GraniteMoeHybridForCausalLM": "granite",
90+
"GraniteMoeSharedForCausalLM": "granite",
91+
"GraniteSpeechForConditionalGeneration": "granite",
92+
"Grok1ForCausalLM": "grok",
93+
"GrokForCausalLM": "grok",
94+
"GroveMoeForCausalLM": "grovemoe",
95+
"HunYuanDenseV1ForCausalLM": "hunyuan",
96+
"HunYuanMoEV1ForCausalLM": "hunyuan",
97+
"HunYuanVLForConditionalGeneration": "hunyuan",
98+
"IQuestCoderForCausalLM": "llama",
99+
"InternLM2ForCausalLM": "internlm",
100+
"InternLM3ForCausalLM": "internlm",
101+
"JAISLMHeadModel": "jais",
102+
"Jais2ForCausalLM": "jais",
103+
"JambaForCausalLM": "jamba",
104+
"JanusForConditionalGeneration": "januspro",
105+
"JinaBertForMaskedLM": "bert",
106+
"JinaBertModel": "bert",
107+
"JinaEmbeddingsV5Model": "bert",
108+
"KORMoForCausalLM": "qwen",
109+
"KimiK25ForConditionalGeneration": "deepseek",
110+
"KimiLinearForCausalLM": "kimi_linear",
111+
"KimiLinearModel": "kimi_linear",
112+
"KimiVLForConditionalGeneration": "deepseek",
113+
"LFM2ForCausalLM": "lfm2",
114+
"LLaDAMoEModel": "llada",
115+
"LLaDAMoEModelLM": "llada",
116+
"LLaDAModelLM": "llada",
117+
"LLaMAForCausalLM": "llama",
118+
"Lfm25AudioTokenizer": "lfm2",
119+
"Lfm2ForCausalLM": "lfm2",
120+
"Lfm2Model": "lfm2",
121+
"Lfm2MoeForCausalLM": "lfm2",
122+
"Llama4ForCausalLM": "llama",
123+
"Llama4ForConditionalGeneration": "llama",
124+
"LlamaBidirectionalModel": "llama",
125+
"LlamaForCausalLM": "llama",
126+
"LlamaModel": "llama",
127+
"LlavaForConditionalGeneration": "llama",
128+
"LlavaStableLMEpochForCausalLM": "stablelm",
129+
"MPTForCausalLM": "mpt",
130+
"MT5ForConditionalGeneration": "t5",
131+
"MaincoderForCausalLM": "maincoder",
132+
"Mamba2ForCausalLM": "mamba",
133+
"MambaForCausalLM": "mamba",
134+
"MambaLMHeadModel": "mamba",
135+
"MiMoV2FlashForCausalLM": "mimo",
136+
"MiMoV2ForCausalLM": "mimo",
137+
"MiniCPM3ForCausalLM": "minicpm",
138+
"MiniCPMForCausalLM": "minicpm",
139+
"MiniCPMV4_6ForConditionalGeneration": "minicpm",
140+
"MiniMaxM2ForCausalLM": "minimax",
141+
"Ministral3ForCausalLM": "mistral3",
142+
"Mistral3ForConditionalGeneration": "mistral3",
143+
"MistralForCausalLM": "llama",
144+
"MixtralForCausalLM": "llama",
145+
"ModernBertForMaskedLM": "bert",
146+
"ModernBertForSequenceClassification": "bert",
147+
"ModernBertModel": "bert",
148+
"NemotronForCausalLM": "nemotron",
149+
"NemotronHForCausalLM": "nemotron",
150+
"NeoBERT": "bert",
151+
"NeoBERTForSequenceClassification": "bert",
152+
"NeoBERTLMHead": "bert",
153+
"NomicBertModel": "bert",
154+
"OLMoForCausalLM": "olmo",
155+
"Olmo2ForCausalLM": "olmo",
156+
"Olmo3ForCausalLM": "olmo",
157+
"OlmoForCausalLM": "olmo",
158+
"OlmoeForCausalLM": "olmo",
159+
"OpenELMForCausalLM": "openelm",
160+
"OrionForCausalLM": "orion",
161+
"PLMForCausalLM": "plm",
162+
"PLaMo2ForCausalLM": "plamo",
163+
"PLaMo3ForCausalLM": "plamo",
164+
"PaddleOCRVLForConditionalGeneration": "ernie",
165+
"PanguEmbeddedForCausalLM": "pangu",
166+
"Phi3ForCausalLM": "phi",
167+
"Phi4ForCausalLMV": "phi",
168+
"PhiForCausalLM": "phi",
169+
"PhiMoEForCausalLM": "phi",
170+
"Plamo2ForCausalLM": "plamo",
171+
"Plamo3ForCausalLM": "plamo",
172+
"PlamoForCausalLM": "plamo",
173+
"QWenLMHeadModel": "qwen",
174+
"Qwen2AudioForConditionalGeneration": "qwen",
175+
"Qwen2ForCausalLM": "qwen",
176+
"Qwen2Model": "qwen",
177+
"Qwen2MoeForCausalLM": "qwen",
178+
"Qwen2VLForConditionalGeneration": "qwenvl",
179+
"Qwen2VLModel": "qwenvl",
180+
"Qwen2_5OmniModel": "qwenvl",
181+
"Qwen2_5_VLForConditionalGeneration": "qwenvl",
182+
"Qwen3ASRForConditionalGeneration": "qwen3vl",
183+
"Qwen3ForCausalLM": "qwen",
184+
"Qwen3Model": "qwen",
185+
"Qwen3MoeForCausalLM": "qwen",
186+
"Qwen3NextForCausalLM": "qwen",
187+
"Qwen3OmniMoeForConditionalGeneration": "qwen3vl",
188+
"Qwen3VLForConditionalGeneration": "qwen3vl",
189+
"Qwen3VLMoeForConditionalGeneration": "qwen3vl",
190+
"Qwen3_5ForCausalLM": "qwen",
191+
"Qwen3_5ForConditionalGeneration": "qwen",
192+
"Qwen3_5MoeForCausalLM": "qwen",
193+
"Qwen3_5MoeForConditionalGeneration": "qwen",
194+
"RND1": "qwen",
195+
"RWForCausalLM": "falcon",
196+
"RWKV6Qwen2ForCausalLM": "rwkv",
197+
"RWKV7ForCausalLM": "rwkv",
198+
"RobertaForSequenceClassification": "bert",
199+
"RobertaModel": "bert",
200+
"RuGPT3XLForCausalLM": "gpt2",
201+
"Rwkv6ForCausalLM": "rwkv",
202+
"Rwkv7ForCausalLM": "rwkv",
203+
"RwkvHybridForCausalLM": "rwkv",
204+
"Sarashina2VisionForCausalLM": "sarashina2",
205+
"SarvamMoEForCausalLM": "bailingmoe",
206+
"SeedOssForCausalLM": "olmo",
207+
"SmallThinkerForCausalLM": "smallthinker",
208+
"SmolLM3ForCausalLM": "llama",
209+
"SolarOpenForCausalLM": "glm",
210+
"StableLMEpochForCausalLM": "stablelm",
211+
"StableLmForCausalLM": "stablelm",
212+
"Starcoder2ForCausalLM": "starcoder",
213+
"Step3p5ForCausalLM": "step3",
214+
"StepVLForConditionalGeneration": "step3",
215+
"T5EncoderModel": "t5",
216+
"T5ForConditionalGeneration": "t5",
217+
"T5WithLMHeadModel": "t5",
218+
"UMT5ForConditionalGeneration": "t5",
219+
"UMT5Model": "t5",
220+
"UltravoxModel": "ultravox",
221+
"VLlama3ForCausalLM": "llama",
222+
"VoxtralForConditionalGeneration": "llama",
223+
"WavTokenizerDec": "wavtokenizer",
224+
"XLMRobertaForSequenceClassification": "bert",
225+
"XLMRobertaModel": "bert",
226+
"XverseForCausalLM": "xverse",
227+
"YoutuForCausalLM": "deepseek",
228+
"YoutuVLForConditionalGeneration": "deepseek",
229+
"modeling_grove_moe.GroveMoeForCausalLM": "grovemoe",
230+
"modeling_sarvam_moe.SarvamMoEForCausalLM": "bailingmoe",
231+
}
232+
233+
234+
MMPROJ_MODEL_MAP: dict[str, str] = {
235+
"AudioFlamingo3ForConditionalGeneration": "ultravox",
236+
"CogVLMForCausalLM": "cogvlm",
237+
"DeepseekOCRForCausalLM": "deepseek",
238+
"DotsOCRForCausalLM": "dotsocr",
239+
"Gemma3ForConditionalGeneration": "gemma",
240+
"Gemma3nForConditionalGeneration": "gemma",
241+
"Gemma4ForConditionalGeneration": "gemma",
242+
"Glm4vForConditionalGeneration": "qwen3vl",
243+
"Glm4vMoeForConditionalGeneration": "qwen3vl",
244+
"GlmOcrForConditionalGeneration": "qwen3vl",
245+
"GlmasrModel": "ultravox",
246+
"GraniteSpeechForConditionalGeneration": "granite",
247+
"HunYuanVLForConditionalGeneration": "hunyuan",
248+
"Idefics3ForConditionalGeneration": "smolvlm",
249+
"InternVisionModel": "internvl",
250+
"JanusForConditionalGeneration": "januspro",
251+
"KimiK25ForConditionalGeneration": "kimivl",
252+
"KimiVLForConditionalGeneration": "kimivl",
253+
"Lfm2AudioForConditionalGeneration": "lfm2",
254+
"Lfm2VlForConditionalGeneration": "lfm2",
255+
"LightOnOCRForConditionalGeneration": "lighton_ocr",
256+
"Llama4ForConditionalGeneration": "llama4",
257+
"LlavaForConditionalGeneration": "llava",
258+
"MERaLiON2ForConditionalGeneration": "ultravox",
259+
"MiMoV2ForCausalLM": "mimo",
260+
"MiniCPMV4_6ForConditionalGeneration": "minicpm",
261+
"Mistral3ForConditionalGeneration": "llava",
262+
"NemotronH_Nano_VL_V2": "nemotron",
263+
"PaddleOCRVisionModel": "ernie",
264+
"Phi4ForCausalLMV": "phi",
265+
"Qwen2AudioForConditionalGeneration": "ultravox",
266+
"Qwen2VLForConditionalGeneration": "qwenvl",
267+
"Qwen2VLModel": "qwenvl",
268+
"Qwen2_5OmniModel": "qwenvl",
269+
"Qwen2_5_VLForConditionalGeneration": "qwenvl",
270+
"Qwen3ASRForConditionalGeneration": "qwen3vl",
271+
"Qwen3OmniMoeForConditionalGeneration": "qwen3vl",
272+
"Qwen3VLForConditionalGeneration": "qwen3vl",
273+
"Qwen3VLMoeForConditionalGeneration": "qwen3vl",
274+
"Qwen3_5ForConditionalGeneration": "qwen3vl",
275+
"Qwen3_5MoeForConditionalGeneration": "qwen3vl",
276+
"RADIOModel": "nemotron",
277+
"Sarashina2VisionForCausalLM": "sarashina2",
278+
"SmolVLMForConditionalGeneration": "smolvlm",
279+
"StepVLForConditionalGeneration": "step3",
280+
"UltravoxModel": "ultravox",
281+
"VoxtralForConditionalGeneration": "ultravox",
282+
"YoutuVLForConditionalGeneration": "youtuvl",
283+
}
284+
285+
286+
_TEXT_MODEL_MODULES = sorted(set(TEXT_MODEL_MAP.values()))
287+
_MMPROJ_MODEL_MODULES = sorted(set(MMPROJ_MODEL_MAP.values()))
288+
289+
290+
_loaded_text_modules: set[str] = set()
291+
_loaded_mmproj_modules: set[str] = set()
292+
293+
294+
def load_all_models() -> None:
295+
"""Import all model modules to trigger @ModelBase.register() decorators."""
296+
if len(_loaded_text_modules) != len(_TEXT_MODEL_MODULES):
297+
for module_name in _TEXT_MODEL_MODULES:
298+
if module_name not in _loaded_text_modules:
299+
try:
300+
__import__(f"conversion.{module_name}")
301+
_loaded_text_modules.add(module_name)
302+
except Exception as e:
303+
logger.warning(f"Failed to load model module {module_name}: {e}")
304+
305+
if len(_loaded_mmproj_modules) != len(_MMPROJ_MODEL_MODULES):
306+
for module_name in _MMPROJ_MODEL_MODULES:
307+
if module_name not in _loaded_mmproj_modules:
308+
try:
309+
__import__(f"conversion.{module_name}")
310+
_loaded_mmproj_modules.add(module_name)
311+
except Exception as e:
312+
logger.warning(f"Failed to load model module {module_name}: {e}")
313+
314+
315+
def get_model_class(name: str, mmproj: bool = False) -> Type[ModelBase]:
316+
"""Dynamically import and return a model class by its HuggingFace architecture name."""
317+
relevant_map = MMPROJ_MODEL_MAP if mmproj else TEXT_MODEL_MAP
318+
if name not in relevant_map:
319+
raise NotImplementedError(f"Architecture {name!r} not supported!")
320+
module_name = relevant_map[name]
321+
__import__(f"conversion.{module_name}")
322+
model_type = ModelType.MMPROJ if mmproj else ModelType.TEXT
323+
return ModelBase._model_classes[model_type][name]
324+
325+
326+
def print_registered_models() -> None:
327+
load_all_models()
328+
logger.error("TEXT models:")
329+
for name in sorted(TEXT_MODEL_MAP.keys()):
330+
logger.error(f" - {name}")
331+
logger.error("MMPROJ models:")
332+
for name in sorted(MMPROJ_MODEL_MAP.keys()):
333+
logger.error(f" - {name}")

0 commit comments

Comments
 (0)