1515from functools import lru_cache
1616from transformers import (
1717 AutoConfig ,
18+ AutoModel ,
1819 AutoModelForCausalLM ,
1920 AutoTokenizer ,
2021 GenerationConfig ,
@@ -1201,6 +1202,32 @@ def prepare_remote_code_compat(config: Any) -> None:
12011202 normalize_hf_config_compat (config , trust_remote_code = True )
12021203
12031204
1205+ def register_runtime_automodel_config (config , remote_module , config_attr : str , remote_model_name : str ) -> None :
1206+ # Obtain the correct config class path to register the config and model.
1207+ # Fix ValueError: Unrecognized configuration class
1208+ # <class 'transformers_modules.Ovis1_dot_6_hyphen_Llama3_dot_2_hyphen_3B.e514127b17008465.configuration_ovis.
1209+ # SiglipVisualTokenizerConfig'> for this kind of AutoModel: AutoModel.
1210+ runtime_config = getattr (config , config_attr , None )
1211+ runtime_model_cls = getattr (remote_module , remote_model_name , None ) if remote_module is not None else None
1212+ if runtime_config is None or runtime_model_cls is None :
1213+ return
1214+
1215+ runtime_config_cls = type (runtime_config )
1216+
1217+ try :
1218+ if getattr (runtime_model_cls , "config_class" , None ) is not runtime_config_cls :
1219+ runtime_model_cls .config_class = runtime_config_cls
1220+ AutoModel .register (runtime_config_cls , runtime_model_cls , exist_ok = True )
1221+ except Exception as exc :
1222+ log .debug (
1223+ "HF: failed to bridge AutoModel registration for `%s` using `%s.%s`: %s" ,
1224+ config_attr ,
1225+ getattr (remote_module , "__name__" , "unknown" ),
1226+ remote_model_name ,
1227+ exc ,
1228+ )
1229+
1230+
12041231def prepare_remote_model_init_compat (model_id_or_path : Optional [str ], config : Any ) -> None :
12051232 if not model_id_or_path :
12061233 return
@@ -1278,6 +1305,18 @@ def encoder_init_compat(self, encoder_config):
12781305 if vision_model_cls :
12791306 try_patch_legacy_flash_attn_flag (vision_model_cls )
12801307
1308+ if config .model_type == "ovis" :
1309+ from transformers import LlamaForCausalLM
1310+ try_patch_legacy_flash_attn_flag (LlamaForCausalLM )
1311+
1312+ vision_model_cls = getattr (
1313+ remote_module ,
1314+ "SiglipVisualTokenizer" ,
1315+ None ,
1316+ )
1317+ if vision_model_cls :
1318+ try_patch_legacy_flash_attn_flag (vision_model_cls )
1319+
12811320 if (
12821321 outer_model_cls is not None
12831322 and hasattr (outer_model_cls , "tie_weights" )
@@ -1307,6 +1346,8 @@ def tie_weights_compat(self, *args, **kwargs):
13071346 outer_model_cls ._gptqmodel_tie_weights_kwargs_patch = True
13081347
13091348 if getattr (config , "model_type" , None ) == "ovis" and ovis_config_module is not None :
1349+ register_runtime_automodel_config (config , remote_module , "visual_tokenizer_config" , "SiglipVisualTokenizer" )
1350+
13101351 formatter_cls = getattr (ovis_config_module , "Llama3ConversationFormatter" , None )
13111352 if formatter_cls is not None and not getattr (formatter_cls , "_gptqmodel_tokenizer_backend_patch" , False ):
13121353 support_tokenizer_types = list (getattr (formatter_cls , "support_tokenizer_types" , None ) or [])
@@ -1318,6 +1359,9 @@ def tie_weights_compat(self, *args, **kwargs):
13181359 formatter_cls .support_tokenizer_types = support_tokenizer_types
13191360 formatter_cls ._gptqmodel_tokenizer_backend_patch = True
13201361
1362+ if getattr (config , "model_type" , None ) == "ovis2_5" :
1363+ register_runtime_automodel_config (config , remote_module , "vit_config" , "Siglip2NavitModel" )
1364+
13211365 if getattr (config , "model_type" , None ) == "hymba" and remote_module is not None :
13221366 rotary_cls = getattr (remote_module , "LlamaRotaryEmbedding" , None )
13231367 attention_cls = getattr (remote_module , "HymbaAttention" , None )
@@ -1475,6 +1519,12 @@ def try_patch_legacy_flash_attn_flag(model_cls):
14751519 if model_cls is None or not isinstance (model_cls , type ):
14761520 return
14771521
1522+ # The remote modeling code for some models(For example, ovis.) still relies on `_supports_flash_attn_2`
1523+ if hasattr (model_cls , "_supports_flash_attn" ):
1524+ if not hasattr (model_cls , "_supports_flash_attn_2" ):
1525+ setattr (model_cls , "_supports_flash_attn_2" , bool (model_cls ._supports_flash_attn ))
1526+ return
1527+
14781528 # Find the most specific class that explicitly declares the newer
14791529 # `_supports_flash_attn_2` flag used by newer transformers releases.
14801530 base_with_flag = None
0 commit comments