Skip to content

Commit 5c7fd20

Browse files
authored
Fix Ling compat (#2786)
* Fix LazyTurtle checkpoint path aliases * Patch remote flash attention compat * Restore legacy default rope init compat
1 parent 5b3b036 commit 5c7fd20

6 files changed

Lines changed: 132 additions & 5 deletions

File tree

gptqmodel/models/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,9 @@ class BaseQModel(nn.Module):
247247
server = None
248248

249249
support_offload_to_disk = True
250+
# Optional runtime->checkpoint prefix overrides used by LazyTurtle when the
251+
# execution shell inserts wrapper modules that are absent from the checkpoint.
252+
checkpoint_path_aliases: Optional[tuple[tuple[str, str], ...]] = None
250253
# Optional runtime->checkpoint prefix overrides for LazyTurtle. When unset,
251254
# the loader derives them from Hugging Face conversion mappings.
252255
HF_CONVERSION_MAP_REVERSED: Optional[Dict[str, str]] = None

gptqmodel/models/loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,7 @@ def skip(*args, **kwargs):
657657
config=model.config,
658658
model_init_kwargs=shell_model_init_kwargs,
659659
module_tree=copy.deepcopy(getattr(cls, "module_tree", None)),
660+
checkpoint_path_aliases=copy.deepcopy(getattr(cls, "checkpoint_path_aliases", None)),
660661
hf_conversion_map_reversed=copy.deepcopy(
661662
cls.resolve_hf_conversion_map_reversed(target_model=model)
662663
),

gptqmodel/utils/hf.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,11 @@ def _patch_transformers_remote_code_compat() -> None:
576576
except Exception:
577577
utils = None
578578

579+
try:
580+
import transformers.modeling_rope_utils as rope_utils
581+
except Exception:
582+
rope_utils = None
583+
579584
import transformers.utils.generic as generic
580585
with _MONKEY_PATCH_LOCK:
581586
if not hasattr(import_utils, "is_torch_fx_available"):
@@ -599,6 +604,51 @@ def is_flash_attn_greater_or_equal_2_10() -> bool:
599604

600605
utils.is_flash_attn_greater_or_equal_2_10 = is_flash_attn_greater_or_equal_2_10
601606

607+
if rope_utils is not None and "default" not in getattr(rope_utils, "ROPE_INIT_FUNCTIONS", {}):
608+
# transformers 5.x removed the legacy `"default"` RoPE entrypoint,
609+
# but older trust_remote_code model files still resolve it directly.
610+
# Recreate the unscaled/base initializer instead of aliasing to
611+
# `"linear"` so configs do not need an artificial `factor=1.0`.
612+
def _compute_default_rope_parameters_compat(
613+
config: Optional["PreTrainedConfig"] = None,
614+
device: Optional["torch.device"] = None,
615+
seq_len: int | None = None,
616+
layer_type: str | None = None,
617+
) -> tuple["torch.Tensor", float]:
618+
del seq_len
619+
if config is None:
620+
raise ValueError("`config` is required to compute default RoPE parameters.")
621+
622+
standardize_rope_params = getattr(config, "standardize_rope_params", None)
623+
if callable(standardize_rope_params):
624+
standardize_rope_params()
625+
626+
rope_parameters = getattr(config, "rope_parameters", None)
627+
if layer_type is not None and isinstance(rope_parameters, dict):
628+
rope_parameters = rope_parameters.get(layer_type, rope_parameters)
629+
630+
rope_theta = None
631+
partial_rotary_factor = 1.0
632+
if isinstance(rope_parameters, dict):
633+
rope_theta = rope_parameters.get("rope_theta")
634+
partial_rotary_factor = rope_parameters.get("partial_rotary_factor", partial_rotary_factor)
635+
636+
if rope_theta is None:
637+
rope_theta = getattr(config, "rope_theta", None)
638+
if rope_theta is None:
639+
rope_theta = getattr(config, "default_theta", 10_000.0)
640+
641+
head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
642+
dim = int(head_dim * partial_rotary_factor)
643+
attention_factor = 1.0
644+
inv_freq = 1.0 / (
645+
rope_theta
646+
** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
647+
)
648+
return inv_freq, attention_factor
649+
650+
rope_utils.ROPE_INIT_FUNCTIONS["default"] = _compute_default_rope_parameters_compat
651+
602652
if cache_utils is not None and not hasattr(cache_utils, "SlidingWindowCache") and hasattr(cache_utils, "StaticCache"):
603653
# transformers 5.x folds sliding-window behavior into StaticCache
604654
# layers, but older remote code still imports the legacy symbol.
@@ -1028,7 +1078,7 @@ def prepare_remote_model_init_compat(model_id_or_path: Optional[str], config: An
10281078
input_mode_enum = getattr(remote_module, "InputMode", None) if remote_module is not None else None
10291079

10301080
with _MONKEY_PATCH_LOCK:
1031-
if config.model_type == "minicpm" or config.model_type == "instella":
1081+
if outer_model_cls is not None:
10321082
try_patch_legacy_flash_attn_flag(outer_model_cls)
10331083

10341084
if config.model_type == "minicpmv" or config.model_type == "minicpmo":
@@ -1212,7 +1262,8 @@ def try_patch_legacy_flash_attn_flag(model_cls):
12121262
if model_cls is None or not isinstance(model_cls, type):
12131263
return
12141264

1215-
# Find the "source class" that defines _supports_flash_attn_2.
1265+
# Find the most specific class that explicitly declares the newer
1266+
# `_supports_flash_attn_2` flag used by newer transformers releases.
12161267
base_with_flag = None
12171268
for cls in model_cls.__mro__:
12181269
if "_supports_flash_attn_2" in cls.__dict__:
@@ -1222,8 +1273,15 @@ def try_patch_legacy_flash_attn_flag(model_cls):
12221273
if base_with_flag is None:
12231274
return
12241275

1276+
# Respect remote models that already define the legacy flag themselves.
1277+
for cls in model_cls.__mro__:
1278+
if cls is base_with_flag:
1279+
break
1280+
if "_supports_flash_attn" in cls.__dict__:
1281+
return
1282+
12251283
flash_attn_2_val = base_with_flag.__dict__["_supports_flash_attn_2"]
1226-
setattr(cls, "_supports_flash_attn", bool(flash_attn_2_val))
1284+
setattr(base_with_flag, "_supports_flash_attn", bool(flash_attn_2_val))
12271285

12281286

12291287
def load_tokenizer(tokenizer_or_path, *, model_config: Any = None, **kwargs):

gptqmodel/utils/structure.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,7 @@ def __init__(
590590
config: Any,
591591
model_init_kwargs: Optional[Dict[str, Any]] = None,
592592
module_tree: Optional[Any] = None,
593+
checkpoint_path_aliases: Optional[Any] = None,
593594
hf_conversion_map_reversed: Optional[Any] = None,
594595
target_model: Optional[nn.Module] = None,
595596
) -> None:
@@ -602,10 +603,11 @@ def __init__(
602603
self._module_tree_layer_prefix, self._moe_alias_specs = self._build_moe_alias_specs(self._module_tree)
603604
# Resolve runtime->checkpoint aliases once up front so per-tensor lookups
604605
# only need cheap prefix rewrites.
605-
alias_items = self._normalize_runtime_to_checkpoint_aliases(
606+
alias_items = self._merge_runtime_to_checkpoint_aliases(
607+
checkpoint_path_aliases,
606608
hf_conversion_map_reversed
607609
if hf_conversion_map_reversed is not None
608-
else self.infer_hf_conversion_map_reversed(target_model=target_model)
610+
else self.infer_hf_conversion_map_reversed(target_model=target_model),
609611
)
610612
self._runtime_to_checkpoint_aliases = tuple(alias_items)
611613
self._lock = threading.RLock()
@@ -618,6 +620,7 @@ def maybe_create(
618620
config: Any,
619621
model_init_kwargs: Optional[Dict[str, Any]] = None,
620622
module_tree: Optional[Any] = None,
623+
checkpoint_path_aliases: Optional[Any] = None,
621624
hf_conversion_map_reversed: Optional[Any] = None,
622625
target_model: Optional[nn.Module] = None,
623626
) -> Optional["LazyTurtle"]:
@@ -630,6 +633,7 @@ def maybe_create(
630633
config=config,
631634
model_init_kwargs=model_init_kwargs,
632635
module_tree=module_tree,
636+
checkpoint_path_aliases=checkpoint_path_aliases,
633637
hf_conversion_map_reversed=hf_conversion_map_reversed,
634638
target_model=target_model,
635639
)
@@ -765,6 +769,19 @@ def _normalize_runtime_to_checkpoint_aliases(raw_aliases: Optional[Any]) -> tupl
765769
alias_items.sort(key=lambda item: len(item[0]), reverse=True)
766770
return tuple(alias_items)
767771

772+
@classmethod
773+
def _merge_runtime_to_checkpoint_aliases(cls, *raw_alias_sets: Optional[Any]) -> tuple[tuple[str, str], ...]:
774+
merged: list[tuple[str, str]] = []
775+
seen: set[tuple[str, str]] = set()
776+
for raw_aliases in raw_alias_sets:
777+
for alias_item in cls._normalize_runtime_to_checkpoint_aliases(raw_aliases):
778+
if alias_item in seen:
779+
continue
780+
seen.add(alias_item)
781+
merged.append(alias_item)
782+
merged.sort(key=lambda item: len(item[0]), reverse=True)
783+
return tuple(merged)
784+
768785
@classmethod
769786
def _extract_hf_source_prefix(cls, pattern: Any) -> Optional[str]:
770787
if not isinstance(pattern, str):

tests/test_hf_config_compat.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,27 @@ def test_normalize_hf_config_compat_restores_flash_attn_legacy_version_probe(mon
185185
assert transformers.utils.is_flash_attn_greater_or_equal_2_10() is True
186186

187187

188+
def test_normalize_hf_config_compat_restores_default_rope_init_alias(monkeypatch):
189+
import transformers.modeling_rope_utils as rope_utils
190+
191+
monkeypatch.delitem(rope_utils.ROPE_INIT_FUNCTIONS, "default", raising=False)
192+
193+
config = SimpleNamespace(
194+
rope_parameters={"rope_type": "default", "rope_theta": 10000.0},
195+
rope_theta=10000.0,
196+
hidden_size=64,
197+
num_attention_heads=8,
198+
head_dim=8,
199+
)
200+
201+
normalize_hf_config_compat(config, trust_remote_code=True)
202+
203+
inv_freq, attention_factor = rope_utils.ROPE_INIT_FUNCTIONS["default"](config, None)
204+
205+
torch.testing.assert_close(inv_freq, torch.tensor([1.0, 0.1, 0.01, 0.001]))
206+
assert attention_factor == 1.0
207+
208+
188209
def test_normalize_hf_config_compat_restores_legacy_cache_length_helpers(monkeypatch):
189210
monkeypatch.delattr(cache_utils.Cache, "get_max_length", raising=False)
190211
monkeypatch.delattr(cache_utils.Cache, "get_usable_length", raising=False)
@@ -377,6 +398,31 @@ def tie_weights(self):
377398
assert getattr(DummyRemoteModel, "_gptqmodel_tie_weights_kwargs_patch", False) is True
378399

379400

401+
def test_prepare_remote_model_init_compat_backfills_legacy_flash_attn_flag(monkeypatch):
402+
class DummyRemoteBase:
403+
_supports_flash_attn_2 = True
404+
405+
class DummyRemoteModel(DummyRemoteBase):
406+
__module__ = "transformers_modules.fake_bailing.modeling_bailing_moe_v2"
407+
408+
monkeypatch.setattr(
409+
"transformers.dynamic_module_utils.get_class_from_dynamic_module",
410+
lambda class_ref, model_id_or_path, **kwargs: DummyRemoteModel,
411+
)
412+
413+
config = SimpleNamespace(
414+
model_type="bailing_moe",
415+
auto_map={"AutoModelForCausalLM": "modeling_bailing_moe_v2.BailingMoeV2ForCausalLM"},
416+
)
417+
418+
assert "_supports_flash_attn" not in DummyRemoteBase.__dict__
419+
420+
prepare_remote_model_init_compat("/tmp/ling", config)
421+
422+
assert DummyRemoteBase._supports_flash_attn is True
423+
assert DummyRemoteModel._supports_flash_attn is True
424+
425+
380426
def test_prepare_remote_model_init_compat_accepts_tokenizers_backend_for_ovis(monkeypatch):
381427
class DummyRemoteModel:
382428
__module__ = "transformers_modules.fake_ovis.modeling_ovis"

tests/test_local_model_paths.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,7 @@ class DummyQModel:
496496
require_pkgs = []
497497
supports_desc_act = [True, False]
498498
support_offload_to_disk = True
499+
checkpoint_path_aliases = (("shell_model", "model"),)
499500
config_class = None
500501

501502
@staticmethod
@@ -530,6 +531,7 @@ def __init__(self, model, **kwargs):
530531
assert shell_configs
531532
assert load_calls == []
532533
assert isinstance(instance.turtle_model, LazyTurtle)
534+
assert instance.turtle_model._runtime_to_checkpoint_aliases == DummyQModel.checkpoint_path_aliases
533535
assert instance.turtle_model.config._experts_implementation == "linear_loop"
534536
assert instance.turtle_model.config is not instance.model.config
535537

0 commit comments

Comments
 (0)