Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/transformers/conversion_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,8 @@ def get_model_conversion_mapping(
seen to prevent `XForY` / `XModel` pairs from applying the same mapping
twice via different lookup paths.
"""
from .modeling_utils import PreTrainedModel

# note: this function is used in PEFT, so changing the API requires coordination
weight_conversions = []

Expand All @@ -1075,8 +1077,11 @@ def get_model_conversion_mapping(
# prevents a parent's transforms from being duplicated with a scoped copy for the child.
seen_identifiers: defaultdict[str, list[str]] = defaultdict(list)

named_pretrained = model._named_pretrained_submodules
for module_name, submodule in named_pretrained:
for module_name, submodule in model.named_modules():
# Skip if it's not a submodel
if not isinstance(model, PreTrainedModel):
continue

class_name = type(submodule).__name__
model_type = submodule.config.model_type

Expand Down
6 changes: 0 additions & 6 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1394,12 +1394,6 @@ def post_init(self):
# Maybe initialize the weights and tie the keys
self.init_weights()
self._backward_compatibility_gradient_checkpointing()
# Cache the list of (name, submodule) pairs where the submodule is a PreTrainedModel.
# This pattern is used in several places across the codebase; computing it once avoids
# repeated traversal of the full module tree.
self._named_pretrained_submodules: list[tuple[str, PreTrainedModel]] = [
(name, module) for name, module in self.named_modules() if isinstance(module, PreTrainedModel)
]

@property
def tp_plan(self) -> dict[str, str]:
Expand Down
71 changes: 33 additions & 38 deletions tests/utils/test_core_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch
import torch.nn as nn

from transformers import PretrainedConfig
from transformers import PretrainedConfig, PreTrainedModel
from transformers.conversion_mapping import (
get_checkpoint_conversion_mapping,
get_model_conversion_mapping,
Expand Down Expand Up @@ -222,9 +222,6 @@ def __init__(self, add_extra_moe=False, with_mlp=True):
if with_mlp:
self.mlp = DummyMLP()
self.config = PretrainedConfig()
# Mirror what PreTrainedModel.post_init() does so that
# get_model_conversion_mapping() can be called on DummyRoot.
self._named_pretrained_submodules = [("", self)]


class TestConvertAndLoadStateDict(unittest.TestCase):
Expand Down Expand Up @@ -1171,20 +1168,20 @@ def test_class_name_wins_over_model_type(self):
"_tst_mtype", [WeightRenaming(r"^type_key", "type_renamed")], overwrite=True
)

def make_mock(class_name):
m = type(class_name, (), {})()
m.config = SimpleNamespace(model_type="_tst_mtype")
m._named_pretrained_submodules = [("", m)]
return m
class _TstCls(PreTrainedModel): ...

class _TstOther(PreTrainedModel): ...

# A module whose class name has a registry entry → class entry wins.
transforms = get_model_conversion_mapping(make_mock("_TstCls"), add_legacy=False)
transforms = get_model_conversion_mapping(_TstCls(PretrainedConfig(model_type="_tst_mtype")), add_legacy=False)
patterns = [t.source_patterns for t in transforms]
self.assertIn(["^cls_key"], patterns)
self.assertNotIn(["^type_key"], patterns)

# A module with no class entry falls through to the model_type entry.
transforms = get_model_conversion_mapping(make_mock("_TstOther"), add_legacy=False)
transforms = get_model_conversion_mapping(
_TstOther(PretrainedConfig(model_type="_tst_mtype")), add_legacy=False
)
patterns = [t.source_patterns for t in transforms]
self.assertIn(["^type_key"], patterns)
self.assertNotIn(["^cls_key"], patterns)
Expand All @@ -1200,22 +1197,19 @@ def test_sibling_submodels_same_model_type_both_get_transforms(self):
"_tst_shared_type", [WeightRenaming(r"^w", "renamed_w")], overwrite=True
)

def make_root_with_siblings(cls_a, cls_b):
"""Root model with two children of different classes but the same model_type."""
child_a = type(cls_a, (), {})()
child_a.config = SimpleNamespace(model_type="_tst_shared_type")
class _TstEncCls(PreTrainedModel): ...

child_b = type(cls_b, (), {})()
child_b.config = SimpleNamespace(model_type="_tst_shared_type")
class _TstDecCls(PreTrainedModel): ...

root = type("_TstRoot", (), {})()
root.config = SimpleNamespace(model_type="_tst_root_only")
root._named_pretrained_submodules = [("encoder", child_a), ("decoder", child_b)]
return root
class _TstRoot(PreTrainedModel): ...

transforms = get_model_conversion_mapping(
make_root_with_siblings("_TstEncCls", "_TstDecCls"), add_legacy=False
)
child_a = _TstEncCls(PretrainedConfig(model_type="_tst_shared_type"))
child_b = _TstDecCls(PretrainedConfig(model_type="_tst_shared_type"))
root = _TstRoot(PretrainedConfig(model_type="_tst_root_only"))
root.encoder = child_a
root.decoder = child_b

transforms = get_model_conversion_mapping(root, add_legacy=False)
scope_prefixes = [t.scope_prefix for t in transforms]

# Both siblings must be represented with their own scoped transforms.
Expand All @@ -1226,15 +1220,15 @@ def test_sibling_submodels_same_class_both_get_transforms(self):
"""Two sibling sub-models of the *same* class must each get their own scoped transforms."""
register_checkpoint_conversion_mapping("_TstSharedCls", [WeightRenaming(r"^w", "renamed_w")], overwrite=True)

child_a = type("_TstSharedCls", (), {})()
child_a.config = SimpleNamespace(model_type="_tst_shared_cls_mtype")
class _TstSharedCls(PreTrainedModel): ...

child_b = type("_TstSharedCls", (), {})()
child_b.config = SimpleNamespace(model_type="_tst_shared_cls_mtype")
class _TstRootSharedCls(PreTrainedModel): ...

root = type("_TstRootSharedCls", (), {})()
root.config = SimpleNamespace(model_type="_tst_root_only2")
root._named_pretrained_submodules = [("encoder", child_a), ("decoder", child_b)]
child_a = _TstSharedCls(PretrainedConfig(model_type="_tst_shared_cls_mtype"))
child_b = _TstSharedCls(PretrainedConfig(model_type="_tst_shared_cls_mtype"))
root = _TstRootSharedCls(PretrainedConfig(model_type="_tst_root_only2"))
root.encoder = child_a
root.decoder = child_b

transforms = get_model_conversion_mapping(root, add_legacy=False)
scope_prefixes = [t.scope_prefix for t in transforms]
Expand All @@ -1246,17 +1240,18 @@ def test_child_with_same_model_type_as_root_is_skipped(self):
"""When the root model claims a model_type unscoped, a nested child with the
same model_type must NOT produce a second (incorrectly scoped) copy of those
transforms — the root's unscoped transforms already cover all keys."""

class _TstChildSame(PreTrainedModel): ...

class _TstRootSame(PreTrainedModel): ...

register_checkpoint_conversion_mapping(
"_tst_root_child_shared", [WeightRenaming(r"^w", "renamed_w")], overwrite=True
)

child = type("_TstChildSame", (), {})()
child.config = SimpleNamespace(model_type="_tst_root_child_shared")

root = type("_TstRootSame", (), {})()
root.config = SimpleNamespace(model_type="_tst_root_child_shared")
# Root ("") appears before child ("submodel") — mirrors DFS order.
root._named_pretrained_submodules = [("", root), ("submodel", child)]
child = _TstChildSame(PretrainedConfig(model_type="_tst_root_child_shared"))
root = _TstRootSame(PretrainedConfig(model_type="_tst_root_child_shared"))
root.submodel = child

transforms = get_model_conversion_mapping(root, add_legacy=False)
# Only one unscoped transform (from the root); child must be suppressed.
Expand Down
Loading