Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions nemo_automodel/_transformers/model_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,22 @@ def _resolve_custom_model_cls_for_config(config):
return ModelRegistry.resolve_custom_model_cls(arch_name, config)


def _ensure_config_registered_from_config_dict(pretrained_model_name_or_path, **kwargs) -> None:
"""Register a matching local config class before delegating to ``AutoConfig``."""
config_lookup_kwargs = kwargs.copy()
config_lookup_kwargs["_from_auto"] = True
config_lookup_kwargs["name_or_path"] = pretrained_model_name_or_path
try:
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **config_lookup_kwargs)
except Exception:
logger.debug("Could not inspect config metadata for %s", pretrained_model_name_or_path, exc_info=True)
return

model_type = config_dict.get("model_type")
if isinstance(model_type, str):
ModelRegistry.ensure_config_registered(model_type)


def get_hf_config(pretrained_model_name_or_path, attn_implementation, **kwargs):
"""
Get the HF config for the model.
Expand All @@ -233,6 +249,11 @@ def get_hf_config(pretrained_model_name_or_path, attn_implementation, **kwargs):
# with incomplete dicts, losing all other fields. These nested overrides are
# instead handled by _consume_config_overrides which deep-merges them.
nested_kwargs = {k: kwargs.pop(k) for k in list(kwargs) if isinstance(kwargs[k], dict)} # noqa: F841
_ensure_config_registered_from_config_dict(
pretrained_model_name_or_path,
**kwargs,
attn_implementation=attn_implementation,
)
try:
hf_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
Expand Down
Loading
Loading