Skip to content

Commit e345392

Browse files
committed
Fixed ovis incompatibility with transformers v5.
Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai>
1 parent 1a50635 commit e345392

2 files changed

Lines changed: 163 additions & 1 deletion

File tree

gptqmodel/utils/hf.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from functools import lru_cache
1616
from transformers import (
1717
AutoConfig,
18+
AutoModel,
1819
AutoModelForCausalLM,
1920
AutoTokenizer,
2021
GenerationConfig,
@@ -1201,6 +1202,32 @@ def prepare_remote_code_compat(config: Any) -> None:
12011202
normalize_hf_config_compat(config, trust_remote_code=True)
12021203

12031204

1205+
def register_runtime_automodel_config(config, remote_module, config_attr: str, remote_model_name: str) -> None:
1206+
# Obtain the correct config class path to register the config and model.
1207+
# Fix ValueError: Unrecognized configuration class
1208+
# <class 'transformers_modules.Ovis1_dot_6_hyphen_Llama3_dot_2_hyphen_3B.e514127b17008465.configuration_ovis.
1209+
# SiglipVisualTokenizerConfig'> for this kind of AutoModel: AutoModel.
1210+
runtime_config = getattr(config, config_attr, None)
1211+
runtime_model_cls = getattr(remote_module, remote_model_name, None) if remote_module is not None else None
1212+
if runtime_config is None or runtime_model_cls is None:
1213+
return
1214+
1215+
runtime_config_cls = type(runtime_config)
1216+
1217+
try:
1218+
if getattr(runtime_model_cls, "config_class", None) is not runtime_config_cls:
1219+
runtime_model_cls.config_class = runtime_config_cls
1220+
AutoModel.register(runtime_config_cls, runtime_model_cls, exist_ok=True)
1221+
except Exception as exc:
1222+
log.debug(
1223+
"HF: failed to bridge AutoModel registration for `%s` using `%s.%s`: %s",
1224+
config_attr,
1225+
getattr(remote_module, "__name__", "unknown"),
1226+
remote_model_name,
1227+
exc,
1228+
)
1229+
1230+
12041231
def prepare_remote_model_init_compat(model_id_or_path: Optional[str], config: Any) -> None:
12051232
if not model_id_or_path:
12061233
return
@@ -1278,6 +1305,18 @@ def encoder_init_compat(self, encoder_config):
12781305
if vision_model_cls:
12791306
try_patch_legacy_flash_attn_flag(vision_model_cls)
12801307

1308+
if config.model_type == "ovis":
1309+
from transformers import LlamaForCausalLM
1310+
try_patch_legacy_flash_attn_flag(LlamaForCausalLM)
1311+
1312+
vision_model_cls = getattr(
1313+
remote_module,
1314+
"SiglipVisualTokenizer",
1315+
None,
1316+
)
1317+
if vision_model_cls:
1318+
try_patch_legacy_flash_attn_flag(vision_model_cls)
1319+
12811320
if (
12821321
outer_model_cls is not None
12831322
and hasattr(outer_model_cls, "tie_weights")
@@ -1307,6 +1346,8 @@ def tie_weights_compat(self, *args, **kwargs):
13071346
outer_model_cls._gptqmodel_tie_weights_kwargs_patch = True
13081347

13091348
if getattr(config, "model_type", None) == "ovis" and ovis_config_module is not None:
1349+
register_runtime_automodel_config(config, remote_module, "visual_tokenizer_config", "SiglipVisualTokenizer")
1350+
13101351
formatter_cls = getattr(ovis_config_module, "Llama3ConversationFormatter", None)
13111352
if formatter_cls is not None and not getattr(formatter_cls, "_gptqmodel_tokenizer_backend_patch", False):
13121353
support_tokenizer_types = list(getattr(formatter_cls, "support_tokenizer_types", None) or [])
@@ -1318,6 +1359,9 @@ def tie_weights_compat(self, *args, **kwargs):
13181359
formatter_cls.support_tokenizer_types = support_tokenizer_types
13191360
formatter_cls._gptqmodel_tokenizer_backend_patch = True
13201361

1362+
if getattr(config, "model_type", None) == "ovis2_5":
1363+
register_runtime_automodel_config(config, remote_module, "vit_config", "Siglip2NavitModel")
1364+
13211365
if getattr(config, "model_type", None) == "hymba" and remote_module is not None:
13221366
rotary_cls = getattr(remote_module, "LlamaRotaryEmbedding", None)
13231367
attention_cls = getattr(remote_module, "HymbaAttention", None)
@@ -1475,6 +1519,12 @@ def try_patch_legacy_flash_attn_flag(model_cls):
14751519
if model_cls is None or not isinstance(model_cls, type):
14761520
return
14771521

1522+
# The remote modeling code for some models(For example, ovis.) still relies on `_supports_flash_attn_2`
1523+
if hasattr(model_cls, "_supports_flash_attn"):
1524+
if not hasattr(model_cls, "_supports_flash_attn_2"):
1525+
setattr(model_cls, "_supports_flash_attn_2", bool(model_cls._supports_flash_attn))
1526+
return
1527+
14781528
# Find the most specific class that explicitly declares the newer
14791529
# `_supports_flash_attn_2` flag used by newer transformers releases.
14801530
base_with_flag = None

tests/test_hf_config_compat.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
import transformers
88
import transformers.generation.utils as generation_utils
9-
from transformers import GenerationConfig, GPTNeoXConfig, LlamaConfig, cache_utils
9+
from transformers import AutoModel, GenerationConfig, GPTNeoXConfig, LlamaConfig, cache_utils
1010
from transformers.generation.configuration_utils import GenerationMode
1111

1212
from gptqmodel.utils import hf as hf_utils
@@ -451,6 +451,118 @@ class Llama3ConversationFormatter:
451451
assert getattr(Llama3ConversationFormatter, "_gptqmodel_tokenizer_backend_patch", False) is True
452452

453453

454+
def test_prepare_remote_model_init_compat_bridges_ovis_visual_tokenizer_registration(monkeypatch):
455+
captured = {}
456+
457+
class RuntimeVisualConfig:
458+
pass
459+
460+
class RemoteVisualConfig:
461+
pass
462+
463+
class DummyVisualModel:
464+
config_class = RemoteVisualConfig
465+
466+
remote_module = ModuleType("transformers_modules.fake_ovis_bridge.modeling_ovis")
467+
remote_module.SiglipVisualTokenizer = DummyVisualModel
468+
monkeypatch.setitem(sys.modules, remote_module.__name__, remote_module)
469+
470+
config_module = ModuleType("transformers_modules.fake_ovis_bridge.configuration_ovis")
471+
472+
class Llama3ConversationFormatter:
473+
support_tokenizer_types = ["PreTrainedTokenizerFast"]
474+
475+
config_module.Llama3ConversationFormatter = Llama3ConversationFormatter
476+
monkeypatch.setitem(sys.modules, config_module.__name__, config_module)
477+
478+
class DummyRemoteModel:
479+
__module__ = remote_module.__name__
480+
481+
monkeypatch.setattr(
482+
"transformers.dynamic_module_utils.get_class_from_dynamic_module",
483+
lambda class_ref, model_id_or_path, **kwargs: DummyRemoteModel,
484+
)
485+
486+
monkeypatch.setattr(
487+
AutoModel,
488+
"register",
489+
classmethod(
490+
lambda cls, config_class, model_class, exist_ok=False: captured.update(
491+
{
492+
"config_class": config_class,
493+
"model_class": model_class,
494+
"exist_ok": exist_ok,
495+
}
496+
)
497+
),
498+
)
499+
500+
config = SimpleNamespace(
501+
model_type="ovis",
502+
auto_map={"AutoModelForCausalLM": "modeling_ovis.Ovis"},
503+
visual_tokenizer_config=RuntimeVisualConfig(),
504+
)
505+
506+
prepare_remote_model_init_compat("/tmp/ovis", config)
507+
508+
assert captured["config_class"] is RuntimeVisualConfig
509+
assert captured["model_class"] is DummyVisualModel
510+
assert captured["exist_ok"] is True
511+
assert DummyVisualModel.config_class is RuntimeVisualConfig
512+
513+
514+
def test_prepare_remote_model_init_compat_bridges_ovis2_5_vit_registration(monkeypatch):
515+
captured = {}
516+
517+
class RuntimeVitConfig:
518+
pass
519+
520+
class RemoteVitConfig:
521+
pass
522+
523+
class DummyVitModel:
524+
config_class = RemoteVitConfig
525+
526+
remote_module = ModuleType("transformers_modules.fake_ovis2_5_bridge.modeling_ovis2_5")
527+
remote_module.Siglip2NavitModel = DummyVitModel
528+
monkeypatch.setitem(sys.modules, remote_module.__name__, remote_module)
529+
530+
class DummyRemoteModel:
531+
__module__ = remote_module.__name__
532+
533+
monkeypatch.setattr(
534+
"transformers.dynamic_module_utils.get_class_from_dynamic_module",
535+
lambda class_ref, model_id_or_path, **kwargs: DummyRemoteModel,
536+
)
537+
538+
monkeypatch.setattr(
539+
AutoModel,
540+
"register",
541+
classmethod(
542+
lambda cls, config_class, model_class, exist_ok=False: captured.update(
543+
{
544+
"config_class": config_class,
545+
"model_class": model_class,
546+
"exist_ok": exist_ok,
547+
}
548+
)
549+
),
550+
)
551+
552+
config = SimpleNamespace(
553+
model_type="ovis2_5",
554+
auto_map={"AutoModelForCausalLM": "modeling_ovis2_5.Ovis2_5"},
555+
vit_config=RuntimeVitConfig(),
556+
)
557+
558+
prepare_remote_model_init_compat("/tmp/ovis2_5", config)
559+
560+
assert captured["config_class"] is RuntimeVitConfig
561+
assert captured["model_class"] is DummyVitModel
562+
assert captured["exist_ok"] is True
563+
assert DummyVitModel.config_class is RuntimeVitConfig
564+
565+
454566
def test_prepare_remote_model_init_compat_promotes_phi4_positional_seed_to_meta(monkeypatch):
455567
seen_devices = []
456568

0 commit comments

Comments
 (0)