|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from .base import ( |
| 4 | + ModelBase, TextModel, MmprojModel, ModelType, SentencePieceTokenTypes, |
| 5 | + logger, _mistral_common_installed, _mistral_import_error_msg, |
| 6 | + get_model_architecture, LazyTorchTensor, |
| 7 | +) |
| 8 | +from typing import Type |
| 9 | + |
| 10 | + |
| 11 | +__all__ = [ |
| 12 | + "ModelBase", "TextModel", "MmprojModel", "ModelType", "SentencePieceTokenTypes", |
| 13 | + "get_model_architecture", "LazyTorchTensor", "logger", |
| 14 | + "_mistral_common_installed", "_mistral_import_error_msg", |
| 15 | + "get_model_class", "print_registered_models", "load_all_models", |
| 16 | +] |
| 17 | + |
| 18 | + |
| 19 | +TEXT_MODEL_MAP: dict[str, str] = { |
| 20 | + "AfmoeForCausalLM": "afmoe", |
| 21 | + "ApertusForCausalLM": "llama", |
| 22 | + "ArceeForCausalLM": "llama", |
| 23 | + "ArcticForCausalLM": "arctic", |
| 24 | + "AudioFlamingo3ForConditionalGeneration": "qwen", |
| 25 | + "BaiChuanForCausalLM": "baichuan", |
| 26 | + "BaichuanForCausalLM": "baichuan", |
| 27 | + "BailingMoeForCausalLM": "bailingmoe", |
| 28 | + "BailingMoeV2ForCausalLM": "bailingmoe", |
| 29 | + "BambaForCausalLM": "granite", |
| 30 | + "BertForMaskedLM": "bert", |
| 31 | + "BertForSequenceClassification": "bert", |
| 32 | + "BertModel": "bert", |
| 33 | + "BitnetForCausalLM": "bitnet", |
| 34 | + "BloomForCausalLM": "bloom", |
| 35 | + "BloomModel": "bloom", |
| 36 | + "CamembertModel": "bert", |
| 37 | + "ChameleonForCausalLM": "chameleon", |
| 38 | + "ChameleonForConditionalGeneration": "chameleon", |
| 39 | + "ChatGLMForConditionalGeneration": "chatglm", |
| 40 | + "ChatGLMModel": "chatglm", |
| 41 | + "CodeShellForCausalLM": "codeshell", |
| 42 | + "CogVLMForCausalLM": "cogvlm", |
| 43 | + "Cohere2ForCausalLM": "command_r", |
| 44 | + "CohereForCausalLM": "command_r", |
| 45 | + "DbrxForCausalLM": "dbrx", |
| 46 | + "DeciLMForCausalLM": "deci", |
| 47 | + "DeepseekForCausalLM": "deepseek", |
| 48 | + "DeepseekV2ForCausalLM": "deepseek", |
| 49 | + "DeepseekV3ForCausalLM": "deepseek", |
| 50 | + "DistilBertForMaskedLM": "bert", |
| 51 | + "DistilBertForSequenceClassification": "bert", |
| 52 | + "DistilBertModel": "bert", |
| 53 | + "Dots1ForCausalLM": "dots1", |
| 54 | + "DotsOCRForCausalLM": "qwen", |
| 55 | + "DreamModel": "dream", |
| 56 | + "Ernie4_5ForCausalLM": "ernie", |
| 57 | + "Ernie4_5_ForCausalLM": "ernie", |
| 58 | + "Ernie4_5_MoeForCausalLM": "ernie", |
| 59 | + "EuroBertModel": "bert", |
| 60 | + "Exaone4ForCausalLM": "exaone", |
| 61 | + "ExaoneForCausalLM": "exaone", |
| 62 | + "ExaoneMoEForCausalLM": "exaone", |
| 63 | + "FalconForCausalLM": "falcon", |
| 64 | + "FalconH1ForCausalLM": "falcon_h1", |
| 65 | + "FalconMambaForCausalLM": "mamba", |
| 66 | + "GPT2LMHeadModel": "gpt2", |
| 67 | + "GPTBigCodeForCausalLM": "starcoder", |
| 68 | + "GPTNeoXForCausalLM": "gptneox", |
| 69 | + "GPTRefactForCausalLM": "refact", |
| 70 | + "Gemma2ForCausalLM": "gemma", |
| 71 | + "Gemma3ForCausalLM": "gemma", |
| 72 | + "Gemma3ForConditionalGeneration": "gemma", |
| 73 | + "Gemma3TextModel": "gemma", |
| 74 | + "Gemma3nForCausalLM": "gemma", |
| 75 | + "Gemma3nForConditionalGeneration": "gemma", |
| 76 | + "Gemma4ForConditionalGeneration": "gemma", |
| 77 | + "GemmaForCausalLM": "gemma", |
| 78 | + "Glm4ForCausalLM": "glm", |
| 79 | + "Glm4MoeForCausalLM": "glm", |
| 80 | + "Glm4MoeLiteForCausalLM": "glm", |
| 81 | + "Glm4vForConditionalGeneration": "glm", |
| 82 | + "Glm4vMoeForConditionalGeneration": "glm", |
| 83 | + "GlmForCausalLM": "chatglm", |
| 84 | + "GlmMoeDsaForCausalLM": "glm", |
| 85 | + "GlmOcrForConditionalGeneration": "glm", |
| 86 | + "GptOssForCausalLM": "gpt_oss", |
| 87 | + "GraniteForCausalLM": "granite", |
| 88 | + "GraniteMoeForCausalLM": "granite", |
| 89 | + "GraniteMoeHybridForCausalLM": "granite", |
| 90 | + "GraniteMoeSharedForCausalLM": "granite", |
| 91 | + "GraniteSpeechForConditionalGeneration": "granite", |
| 92 | + "Grok1ForCausalLM": "grok", |
| 93 | + "GrokForCausalLM": "grok", |
| 94 | + "GroveMoeForCausalLM": "grovemoe", |
| 95 | + "HunYuanDenseV1ForCausalLM": "hunyuan", |
| 96 | + "HunYuanMoEV1ForCausalLM": "hunyuan", |
| 97 | + "HunYuanVLForConditionalGeneration": "hunyuan", |
| 98 | + "IQuestCoderForCausalLM": "llama", |
| 99 | + "InternLM2ForCausalLM": "internlm", |
| 100 | + "InternLM3ForCausalLM": "internlm", |
| 101 | + "JAISLMHeadModel": "jais", |
| 102 | + "Jais2ForCausalLM": "jais", |
| 103 | + "JambaForCausalLM": "jamba", |
| 104 | + "JanusForConditionalGeneration": "januspro", |
| 105 | + "JinaBertForMaskedLM": "bert", |
| 106 | + "JinaBertModel": "bert", |
| 107 | + "JinaEmbeddingsV5Model": "bert", |
| 108 | + "KORMoForCausalLM": "qwen", |
| 109 | + "KimiK25ForConditionalGeneration": "deepseek", |
| 110 | + "KimiLinearForCausalLM": "kimi_linear", |
| 111 | + "KimiLinearModel": "kimi_linear", |
| 112 | + "KimiVLForConditionalGeneration": "deepseek", |
| 113 | + "LFM2ForCausalLM": "lfm2", |
| 114 | + "LLaDAMoEModel": "llada", |
| 115 | + "LLaDAMoEModelLM": "llada", |
| 116 | + "LLaDAModelLM": "llada", |
| 117 | + "LLaMAForCausalLM": "llama", |
| 118 | + "Lfm25AudioTokenizer": "lfm2", |
| 119 | + "Lfm2ForCausalLM": "lfm2", |
| 120 | + "Lfm2Model": "lfm2", |
| 121 | + "Lfm2MoeForCausalLM": "lfm2", |
| 122 | + "Llama4ForCausalLM": "llama", |
| 123 | + "Llama4ForConditionalGeneration": "llama", |
| 124 | + "LlamaBidirectionalModel": "llama", |
| 125 | + "LlamaForCausalLM": "llama", |
| 126 | + "LlamaModel": "llama", |
| 127 | + "LlavaForConditionalGeneration": "llama", |
| 128 | + "LlavaStableLMEpochForCausalLM": "stablelm", |
| 129 | + "MPTForCausalLM": "mpt", |
| 130 | + "MT5ForConditionalGeneration": "t5", |
| 131 | + "MaincoderForCausalLM": "maincoder", |
| 132 | + "Mamba2ForCausalLM": "mamba", |
| 133 | + "MambaForCausalLM": "mamba", |
| 134 | + "MambaLMHeadModel": "mamba", |
| 135 | + "MiMoV2FlashForCausalLM": "mimo", |
| 136 | + "MiMoV2ForCausalLM": "mimo", |
| 137 | + "MiniCPM3ForCausalLM": "minicpm", |
| 138 | + "MiniCPMForCausalLM": "minicpm", |
| 139 | + "MiniCPMV4_6ForConditionalGeneration": "minicpm", |
| 140 | + "MiniMaxM2ForCausalLM": "minimax", |
| 141 | + "Ministral3ForCausalLM": "mistral3", |
| 142 | + "Mistral3ForConditionalGeneration": "mistral3", |
| 143 | + "MistralForCausalLM": "llama", |
| 144 | + "MixtralForCausalLM": "llama", |
| 145 | + "ModernBertForMaskedLM": "bert", |
| 146 | + "ModernBertForSequenceClassification": "bert", |
| 147 | + "ModernBertModel": "bert", |
| 148 | + "NemotronForCausalLM": "nemotron", |
| 149 | + "NemotronHForCausalLM": "nemotron", |
| 150 | + "NeoBERT": "bert", |
| 151 | + "NeoBERTForSequenceClassification": "bert", |
| 152 | + "NeoBERTLMHead": "bert", |
| 153 | + "NomicBertModel": "bert", |
| 154 | + "OLMoForCausalLM": "olmo", |
| 155 | + "Olmo2ForCausalLM": "olmo", |
| 156 | + "Olmo3ForCausalLM": "olmo", |
| 157 | + "OlmoForCausalLM": "olmo", |
| 158 | + "OlmoeForCausalLM": "olmo", |
| 159 | + "OpenELMForCausalLM": "openelm", |
| 160 | + "OrionForCausalLM": "orion", |
| 161 | + "PLMForCausalLM": "plm", |
| 162 | + "PLaMo2ForCausalLM": "plamo", |
| 163 | + "PLaMo3ForCausalLM": "plamo", |
| 164 | + "PaddleOCRVLForConditionalGeneration": "ernie", |
| 165 | + "PanguEmbeddedForCausalLM": "pangu", |
| 166 | + "Phi3ForCausalLM": "phi", |
| 167 | + "Phi4ForCausalLMV": "phi", |
| 168 | + "PhiForCausalLM": "phi", |
| 169 | + "PhiMoEForCausalLM": "phi", |
| 170 | + "Plamo2ForCausalLM": "plamo", |
| 171 | + "Plamo3ForCausalLM": "plamo", |
| 172 | + "PlamoForCausalLM": "plamo", |
| 173 | + "QWenLMHeadModel": "qwen", |
| 174 | + "Qwen2AudioForConditionalGeneration": "qwen", |
| 175 | + "Qwen2ForCausalLM": "qwen", |
| 176 | + "Qwen2Model": "qwen", |
| 177 | + "Qwen2MoeForCausalLM": "qwen", |
| 178 | + "Qwen2VLForConditionalGeneration": "qwenvl", |
| 179 | + "Qwen2VLModel": "qwenvl", |
| 180 | + "Qwen2_5OmniModel": "qwenvl", |
| 181 | + "Qwen2_5_VLForConditionalGeneration": "qwenvl", |
| 182 | + "Qwen3ASRForConditionalGeneration": "qwen3vl", |
| 183 | + "Qwen3ForCausalLM": "qwen", |
| 184 | + "Qwen3Model": "qwen", |
| 185 | + "Qwen3MoeForCausalLM": "qwen", |
| 186 | + "Qwen3NextForCausalLM": "qwen", |
| 187 | + "Qwen3OmniMoeForConditionalGeneration": "qwen3vl", |
| 188 | + "Qwen3VLForConditionalGeneration": "qwen3vl", |
| 189 | + "Qwen3VLMoeForConditionalGeneration": "qwen3vl", |
| 190 | + "Qwen3_5ForCausalLM": "qwen", |
| 191 | + "Qwen3_5ForConditionalGeneration": "qwen", |
| 192 | + "Qwen3_5MoeForCausalLM": "qwen", |
| 193 | + "Qwen3_5MoeForConditionalGeneration": "qwen", |
| 194 | + "RND1": "qwen", |
| 195 | + "RWForCausalLM": "falcon", |
| 196 | + "RWKV6Qwen2ForCausalLM": "rwkv", |
| 197 | + "RWKV7ForCausalLM": "rwkv", |
| 198 | + "RobertaForSequenceClassification": "bert", |
| 199 | + "RobertaModel": "bert", |
| 200 | + "RuGPT3XLForCausalLM": "gpt2", |
| 201 | + "Rwkv6ForCausalLM": "rwkv", |
| 202 | + "Rwkv7ForCausalLM": "rwkv", |
| 203 | + "RwkvHybridForCausalLM": "rwkv", |
| 204 | + "Sarashina2VisionForCausalLM": "sarashina2", |
| 205 | + "SarvamMoEForCausalLM": "bailingmoe", |
| 206 | + "SeedOssForCausalLM": "olmo", |
| 207 | + "SmallThinkerForCausalLM": "smallthinker", |
| 208 | + "SmolLM3ForCausalLM": "llama", |
| 209 | + "SolarOpenForCausalLM": "glm", |
| 210 | + "StableLMEpochForCausalLM": "stablelm", |
| 211 | + "StableLmForCausalLM": "stablelm", |
| 212 | + "Starcoder2ForCausalLM": "starcoder", |
| 213 | + "Step3p5ForCausalLM": "step3", |
| 214 | + "StepVLForConditionalGeneration": "step3", |
| 215 | + "T5EncoderModel": "t5", |
| 216 | + "T5ForConditionalGeneration": "t5", |
| 217 | + "T5WithLMHeadModel": "t5", |
| 218 | + "UMT5ForConditionalGeneration": "t5", |
| 219 | + "UMT5Model": "t5", |
| 220 | + "UltravoxModel": "ultravox", |
| 221 | + "VLlama3ForCausalLM": "llama", |
| 222 | + "VoxtralForConditionalGeneration": "llama", |
| 223 | + "WavTokenizerDec": "wavtokenizer", |
| 224 | + "XLMRobertaForSequenceClassification": "bert", |
| 225 | + "XLMRobertaModel": "bert", |
| 226 | + "XverseForCausalLM": "xverse", |
| 227 | + "YoutuForCausalLM": "deepseek", |
| 228 | + "YoutuVLForConditionalGeneration": "deepseek", |
| 229 | + "modeling_grove_moe.GroveMoeForCausalLM": "grovemoe", |
| 230 | + "modeling_sarvam_moe.SarvamMoEForCausalLM": "bailingmoe", |
| 231 | +} |
| 232 | + |
| 233 | + |
| 234 | +MMPROJ_MODEL_MAP: dict[str, str] = { |
| 235 | + "AudioFlamingo3ForConditionalGeneration": "ultravox", |
| 236 | + "CogVLMForCausalLM": "cogvlm", |
| 237 | + "DeepseekOCRForCausalLM": "deepseek", |
| 238 | + "DotsOCRForCausalLM": "dotsocr", |
| 239 | + "Gemma3ForConditionalGeneration": "gemma", |
| 240 | + "Gemma3nForConditionalGeneration": "gemma", |
| 241 | + "Gemma4ForConditionalGeneration": "gemma", |
| 242 | + "Glm4vForConditionalGeneration": "qwen3vl", |
| 243 | + "Glm4vMoeForConditionalGeneration": "qwen3vl", |
| 244 | + "GlmOcrForConditionalGeneration": "qwen3vl", |
| 245 | + "GlmasrModel": "ultravox", |
| 246 | + "GraniteSpeechForConditionalGeneration": "granite", |
| 247 | + "HunYuanVLForConditionalGeneration": "hunyuan", |
| 248 | + "Idefics3ForConditionalGeneration": "smolvlm", |
| 249 | + "InternVisionModel": "internvl", |
| 250 | + "JanusForConditionalGeneration": "januspro", |
| 251 | + "KimiK25ForConditionalGeneration": "kimivl", |
| 252 | + "KimiVLForConditionalGeneration": "kimivl", |
| 253 | + "Lfm2AudioForConditionalGeneration": "lfm2", |
| 254 | + "Lfm2VlForConditionalGeneration": "lfm2", |
| 255 | + "LightOnOCRForConditionalGeneration": "lighton_ocr", |
| 256 | + "Llama4ForConditionalGeneration": "llama4", |
| 257 | + "LlavaForConditionalGeneration": "llava", |
| 258 | + "MERaLiON2ForConditionalGeneration": "ultravox", |
| 259 | + "MiMoV2ForCausalLM": "mimo", |
| 260 | + "MiniCPMV4_6ForConditionalGeneration": "minicpm", |
| 261 | + "Mistral3ForConditionalGeneration": "llava", |
| 262 | + "NemotronH_Nano_VL_V2": "nemotron", |
| 263 | + "PaddleOCRVisionModel": "ernie", |
| 264 | + "Phi4ForCausalLMV": "phi", |
| 265 | + "Qwen2AudioForConditionalGeneration": "ultravox", |
| 266 | + "Qwen2VLForConditionalGeneration": "qwenvl", |
| 267 | + "Qwen2VLModel": "qwenvl", |
| 268 | + "Qwen2_5OmniModel": "qwenvl", |
| 269 | + "Qwen2_5_VLForConditionalGeneration": "qwenvl", |
| 270 | + "Qwen3ASRForConditionalGeneration": "qwen3vl", |
| 271 | + "Qwen3OmniMoeForConditionalGeneration": "qwen3vl", |
| 272 | + "Qwen3VLForConditionalGeneration": "qwen3vl", |
| 273 | + "Qwen3VLMoeForConditionalGeneration": "qwen3vl", |
| 274 | + "Qwen3_5ForConditionalGeneration": "qwen3vl", |
| 275 | + "Qwen3_5MoeForConditionalGeneration": "qwen3vl", |
| 276 | + "RADIOModel": "nemotron", |
| 277 | + "Sarashina2VisionForCausalLM": "sarashina2", |
| 278 | + "SmolVLMForConditionalGeneration": "smolvlm", |
| 279 | + "StepVLForConditionalGeneration": "step3", |
| 280 | + "UltravoxModel": "ultravox", |
| 281 | + "VoxtralForConditionalGeneration": "ultravox", |
| 282 | + "YoutuVLForConditionalGeneration": "youtuvl", |
| 283 | +} |
| 284 | + |
| 285 | + |
| 286 | +_TEXT_MODEL_MODULES = sorted(set(TEXT_MODEL_MAP.values())) |
| 287 | +_MMPROJ_MODEL_MODULES = sorted(set(MMPROJ_MODEL_MAP.values())) |
| 288 | + |
| 289 | + |
| 290 | +_loaded_text_modules: set[str] = set() |
| 291 | +_loaded_mmproj_modules: set[str] = set() |
| 292 | + |
| 293 | + |
| 294 | +def load_all_models() -> None: |
| 295 | + """Import all model modules to trigger @ModelBase.register() decorators.""" |
| 296 | + if len(_loaded_text_modules) != len(_TEXT_MODEL_MODULES): |
| 297 | + for module_name in _TEXT_MODEL_MODULES: |
| 298 | + if module_name not in _loaded_text_modules: |
| 299 | + try: |
| 300 | + __import__(f"conversion.{module_name}") |
| 301 | + _loaded_text_modules.add(module_name) |
| 302 | + except Exception as e: |
| 303 | + logger.warning(f"Failed to load model module {module_name}: {e}") |
| 304 | + |
| 305 | + if len(_loaded_mmproj_modules) != len(_MMPROJ_MODEL_MODULES): |
| 306 | + for module_name in _MMPROJ_MODEL_MODULES: |
| 307 | + if module_name not in _loaded_mmproj_modules: |
| 308 | + try: |
| 309 | + __import__(f"conversion.{module_name}") |
| 310 | + _loaded_mmproj_modules.add(module_name) |
| 311 | + except Exception as e: |
| 312 | + logger.warning(f"Failed to load model module {module_name}: {e}") |
| 313 | + |
| 314 | + |
| 315 | +def get_model_class(name: str, mmproj: bool = False) -> Type[ModelBase]: |
| 316 | + """Dynamically import and return a model class by its HuggingFace architecture name.""" |
| 317 | + relevant_map = MMPROJ_MODEL_MAP if mmproj else TEXT_MODEL_MAP |
| 318 | + if name not in relevant_map: |
| 319 | + raise NotImplementedError(f"Architecture {name!r} not supported!") |
| 320 | + module_name = relevant_map[name] |
| 321 | + __import__(f"conversion.{module_name}") |
| 322 | + model_type = ModelType.MMPROJ if mmproj else ModelType.TEXT |
| 323 | + return ModelBase._model_classes[model_type][name] |
| 324 | + |
| 325 | + |
| 326 | +def print_registered_models() -> None: |
| 327 | + load_all_models() |
| 328 | + logger.error("TEXT models:") |
| 329 | + for name in sorted(TEXT_MODEL_MAP.keys()): |
| 330 | + logger.error(f" - {name}") |
| 331 | + logger.error("MMPROJ models:") |
| 332 | + for name in sorted(MMPROJ_MODEL_MAP.keys()): |
| 333 | + logger.error(f" - {name}") |
0 commit comments