diff --git a/nemo_automodel/_transformers/model_init.py b/nemo_automodel/_transformers/model_init.py index 09b40393e7..f66b8e2b9e 100644 --- a/nemo_automodel/_transformers/model_init.py +++ b/nemo_automodel/_transformers/model_init.py @@ -220,6 +220,22 @@ def _resolve_custom_model_cls_for_config(config): return ModelRegistry.resolve_custom_model_cls(arch_name, config) +def _ensure_config_registered_from_config_dict(pretrained_model_name_or_path, **kwargs) -> None: + """Register a matching local config class before delegating to ``AutoConfig``.""" + config_lookup_kwargs = kwargs.copy() + config_lookup_kwargs["_from_auto"] = True + config_lookup_kwargs["name_or_path"] = pretrained_model_name_or_path + try: + config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **config_lookup_kwargs) + except Exception: + logger.debug("Could not inspect config metadata for %s", pretrained_model_name_or_path, exc_info=True) + return + + model_type = config_dict.get("model_type") + if isinstance(model_type, str): + ModelRegistry.ensure_config_registered(model_type) + + def get_hf_config(pretrained_model_name_or_path, attn_implementation, **kwargs): """ Get the HF config for the model. @@ -233,6 +249,11 @@ def get_hf_config(pretrained_model_name_or_path, attn_implementation, **kwargs): # with incomplete dicts, losing all other fields. These nested overrides are # instead handled by _consume_config_overrides which deep-merges them. nested_kwargs = {k: kwargs.pop(k) for k in list(kwargs) if isinstance(kwargs[k], dict)} # noqa: F841 + _ensure_config_registered_from_config_dict( + pretrained_model_name_or_path, + **kwargs, + attn_implementation=attn_implementation, + ) try: hf_config = AutoConfig.from_pretrained( pretrained_model_name_or_path, diff --git a/nemo_automodel/_transformers/registry.py b/nemo_automodel/_transformers/registry.py deleted file mode 100644 index 47fd6cecca..0000000000 --- a/nemo_automodel/_transformers/registry.py +++ /dev/null @@ -1,366 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import importlib -import logging -from collections import OrderedDict -from dataclasses import dataclass, field -from functools import lru_cache -from typing import Dict, Tuple, Type, Union - -import torch.nn as nn - -logger = logging.getLogger(__name__) - -# Static mapping: architecture name → (module_path, class_name[, tags]). -# Analogous to HuggingFace transformers' MODEL_FOR_CAUSAL_LM_MAPPING_NAMES. -# Models are loaded lazily on first access rather than imported at startup. -# Optional third element is a set of tags (e.g. {"retrieval"}) used by -# downstream code to classify model archs without importing them. -MODEL_ARCH_MAPPING = OrderedDict( - [ - ( - "BaichuanForCausalLM", - ("nemo_automodel.components.models.baichuan.model", "BaichuanForCausalLM"), - ), - ( - "DeepseekV3ForCausalLM", - ("nemo_automodel.components.models.deepseek_v3.model", "DeepseekV3ForCausalLM"), - ), - ( - "DeepseekV32ForCausalLM", - ("nemo_automodel.components.models.deepseek_v32.model", "DeepseekV32ForCausalLM"), - ), - ( - "DeepseekV4ForCausalLM", - ("nemo_automodel.components.models.deepseek_v4.model", "DeepseekV4ForCausalLM"), - ), - ( - "Ernie4_5_MoeForCausalLM", - ("nemo_automodel.components.models.ernie4_5.model", "Ernie4_5_MoeForCausalLM"), - ), - ( - "Glm4MoeForCausalLM", - ("nemo_automodel.components.models.glm4_moe.model", "Glm4MoeForCausalLM"), - ), - ( - "Glm4MoeLiteForCausalLM", - ("nemo_automodel.components.models.glm4_moe_lite.model", "Glm4MoeLiteForCausalLM"), - ), - ( - "GlmMoeDsaForCausalLM", - ("nemo_automodel.components.models.glm_moe_dsa.model", "GlmMoeDsaForCausalLM"), - ), - ( - "Gemma4ForConditionalGeneration", - ("nemo_automodel.components.models.gemma4_moe.model", "Gemma4ForConditionalGeneration"), - ), - ( - "GptOssForCausalLM", - ("nemo_automodel.components.models.gpt_oss.model", "GptOssForCausalLM"), - ), - ( - "KimiK25ForConditionalGeneration", - ("nemo_automodel.components.models.kimi_k25_vl.model", "KimiK25VLForConditionalGeneration"), - ), - ( - "KimiK25VLForConditionalGeneration", - ("nemo_automodel.components.models.kimi_k25_vl.model", "KimiK25VLForConditionalGeneration"), - ), - ( - "KimiVLForConditionalGeneration", - ("nemo_automodel.components.models.kimivl.model", "KimiVLForConditionalGeneration"), - ), - ( - "LlamaBidirectionalForSequenceClassification", - ( - "nemo_automodel.components.models.llama_bidirectional.model", - "LlamaBidirectionalForSequenceClassification", - {"retrieval"}, - ), - ), - ( - "LlamaBidirectionalModel", - ("nemo_automodel.components.models.llama_bidirectional.model", "LlamaBidirectionalModel", {"retrieval"}), - ), - ( - "LlamaForCausalLM", - ("nemo_automodel.components.models.llama.model", "LlamaForCausalLM"), - ), - ( - "MiniMaxM2ForCausalLM", - ("nemo_automodel.components.models.minimax_m2.model", "MiniMaxM2ForCausalLM"), - ), - ( - "MiMoV2FlashForCausalLM", - ("nemo_automodel.components.models.mimo_v2_flash.model", "MiMoV2FlashForCausalLM"), - ), - ( - "Ministral3ForCausalLM", - ("nemo_automodel.components.models.mistral3.model", "Ministral3ForCausalLM"), - ), - ( - "Ministral3BidirectionalModel", - ( - "nemo_automodel.components.models.ministral_bidirectional.model", - "Ministral3BidirectionalModel", - {"retrieval"}, - ), - ), - ( - "Mistral4ForCausalLM", - ("nemo_automodel.components.models.mistral4.model", "Mistral4ForCausalLM"), - ), - ( - "Mistral3ForConditionalGeneration", - ("nemo_automodel.components.models.mistral4.model", "Mistral3ForConditionalGeneration"), - ), - ( - "Mistral3FP8VLMForConditionalGeneration", - ( - "nemo_automodel.components.models.mistral3_vlm.model", - "Mistral3FP8VLMForConditionalGeneration", - ), - ), - ( - "NemotronHForCausalLM", - ("nemo_automodel.components.models.nemotron_v3.model", "NemotronHForCausalLM"), - ), - ( - "NemotronH_Nano_Omni_Reasoning_V3", - ( - "nemo_automodel.components.models.nemotron_omni.model", - "NemotronOmniForConditionalGeneration", - ), - ), - ( - "NemotronParseForConditionalGeneration", - ("nemo_automodel.components.models.nemotron_parse.model", "NemotronParseForConditionalGeneration"), - ), - ( - "LLaVAOneVision1_5_ForConditionalGeneration", - ( - "nemo_automodel.components.models.llava_onevision.model", - "LLaVAOneVision1_5_ForConditionalGeneration", - ), - ), - ( - "HYV3ForCausalLM", - ("nemo_automodel.components.models.hy_v3.model", "HYV3ForCausalLM"), - ), - ( - "Qwen2ForCausalLM", - ("nemo_automodel.components.models.qwen2.model", "Qwen2ForCausalLM"), - ), - ( - "Qwen3MoeForCausalLM", - ("nemo_automodel.components.models.qwen3_moe.model", "Qwen3MoeForCausalLM"), - ), - ( - "Qwen3NextForCausalLM", - ("nemo_automodel.components.models.qwen3_next.model", "Qwen3NextForCausalLM"), - ), - ( - "Qwen3OmniMoeForConditionalGeneration", - ( - "nemo_automodel.components.models.qwen3_omni_moe.model", - "Qwen3OmniMoeThinkerForConditionalGeneration", - ), - ), - ( - "Qwen3VLMoeForConditionalGeneration", - ("nemo_automodel.components.models.qwen3_vl_moe.model", "Qwen3VLMoeForConditionalGeneration"), - ), - ( - "Qwen3_5MoeForConditionalGeneration", - ("nemo_automodel.components.models.qwen3_5_moe.model", "Qwen3_5MoeForConditionalGeneration"), - ), - ( - "Step3p5ForCausalLM", - ("nemo_automodel.components.models.step3p5.model", "Step3p5ForCausalLM"), - ), - ] -) - - -# Custom model_type → config class for models that have auto_map in their -# checkpoint config.json. Registered eagerly with AutoConfig so that -# AutoConfig.from_pretrained can resolve them without trust_remote_code. -_CUSTOM_CONFIG_REGISTRATIONS: Dict[str, Tuple[str, str]] = { - "baichuan": ("nemo_automodel.components.models.baichuan.configuration", "BaichuanConfig"), - "deepseek_v4": ("nemo_automodel.components.models.deepseek_v4.config", "DeepseekV4Config"), - "hy_v3": ("nemo_automodel.components.models.hy_v3.config", "HYV3Config"), - "kimi_k25": ("nemo_automodel.components.models.kimi_k25_vl.model", "KimiK25VLConfig"), - "kimi_vl": ("nemo_automodel.components.models.kimivl.model", "KimiVLConfig"), - "llavaonevision1_5": ("nemo_automodel.components.models.llava_onevision.model", "Llavaonevision1_5Config"), - "mimo_v2_flash": ("nemo_automodel.components.models.mimo_v2_flash.config", "MiMoV2FlashConfig"), - "mistral4": ("nemo_automodel.components.models.mistral4.configuration", "Mistral4Config"), -} - - -def _register_custom_configs() -> None: - from transformers import AutoConfig - from transformers.models.auto.configuration_auto import CONFIG_MAPPING - - for model_type, (module_path, cls_name) in _CUSTOM_CONFIG_REGISTRATIONS.items(): - if model_type not in CONFIG_MAPPING: - try: - mod = importlib.import_module(module_path) - cfg_cls = getattr(mod, cls_name) - AutoConfig.register(model_type, cfg_cls) - except Exception: - logger.debug("Failed to register config for model_type=%s", model_type, exc_info=True) - - -_register_custom_configs() - - -class _LazyArchMapping: - """Lazy-loading mapping from architecture name to model class. - - Inspired by HuggingFace transformers' ``_LazyAutoMapping``. Entries from the - static ``auto_map`` are imported on first access and cached. Additional entries - can be added at runtime via ``register``. - """ - - def __init__(self, auto_map: Union[OrderedDict, Dict[str, tuple], None] = None): - # Entries may be (module_path, class_name) or (module_path, class_name, tags). - # Strip the optional tags and store them separately. - self._auto_map: Dict[str, Tuple[str, str]] = OrderedDict() - self._tags: Dict[str, set] = {} - for key, value in (auto_map or {}).items(): - self._auto_map[key] = (value[0], value[1]) - if len(value) > 2: - self._tags[key] = value[2] - self._loaded: Dict[str, Type[nn.Module]] = {} - self._extra: Dict[str, Type[nn.Module]] = {} - self._modules: Dict[str, object] = {} - - def _load(self, key: str) -> Type[nn.Module]: - if key in self._loaded: - return self._loaded[key] - module_path, class_name = self._auto_map[key] - if module_path not in self._modules: - self._modules[module_path] = importlib.import_module(module_path) - cls = getattr(self._modules[module_path], class_name) - self._loaded[key] = cls - return cls - - def __contains__(self, key: str) -> bool: - if key in self._extra or key in self._loaded: - return True - if key not in self._auto_map: - return False - try: - self._load(key) - return True - except Exception: - logger.debug("Model %s unavailable (import failed), removing from auto_map", key) - self._auto_map.pop(key, None) - return False - - def __getitem__(self, key: str) -> Type[nn.Module]: - if key in self._extra: - return self._extra[key] - if key in self._auto_map: - return self._load(key) - raise KeyError(key) - - def __setitem__(self, key: str, value: Type[nn.Module]) -> None: - self._extra[key] = value - - def register(self, key: str, value: Type[nn.Module], exist_ok: bool = False) -> None: - """Register a model class under the given architecture name.""" - if not exist_ok and key in self._extra: - raise ValueError(f"Duplicated model implementation for {key}") - self._extra[key] = value - - def has_tag(self, key: str, tag: str) -> bool: - """Return ``True`` if *key* was registered with *tag*.""" - return tag in self._tags.get(key, set()) - - def keys_with_tag(self, tag: str) -> set: - """Return all architecture names that have *tag*.""" - return {k for k, tags in self._tags.items() if tag in tags} - - def keys(self): - return set(self._auto_map.keys()) | set(self._extra.keys()) - - def __len__(self) -> int: - return len(self.keys()) - - def __repr__(self) -> str: - return f"_LazyArchMapping(auto_map={len(self._auto_map)}, extra={len(self._extra)}, loaded={len(self._loaded)})" - - -@dataclass -class _ModelRegistry: - model_arch_name_to_cls: _LazyArchMapping = field(default=None) - _retrieval_archs: set = field(default_factory=set) - - def __post_init__(self): - if self.model_arch_name_to_cls is None: - self.model_arch_name_to_cls = _LazyArchMapping(MODEL_ARCH_MAPPING) - self._retrieval_archs = self.model_arch_name_to_cls.keys_with_tag("retrieval") - - @property - def supported_models(self): - return self.model_arch_name_to_cls.keys() - - def get_model_cls_from_model_arch(self, model_arch: str) -> Type[nn.Module]: - return self.model_arch_name_to_cls[model_arch] - - def has_custom_model(self, arch_name: str) -> bool: - """Return ``True`` if *arch_name* has a custom (non-HF) implementation.""" - return arch_name in self.model_arch_name_to_cls - - def has_retrieval_model(self, arch_name: str) -> bool: - """Return ``True`` if *arch_name* is a registered retrieval/encoder architecture.""" - return arch_name in self._retrieval_archs - - def register_retrieval(self, arch_name: str) -> None: - """Mark *arch_name* as a retrieval/encoder architecture.""" - self._retrieval_archs.add(arch_name) - - def resolve_custom_model_cls(self, architecture: str, config) -> Union[Type[nn.Module], None]: - """Return the custom model class if it exists and supports *config*, else ``None``. - - Custom model classes may define a ``supports_config(config)`` classmethod - to opt out for specific HF configs (e.g. a Mistral3 VLM with a dense - Ministral3 text backbone instead of the expected Mistral4 MoE+MLA). - """ - if architecture not in self.model_arch_name_to_cls: - return None - model_cls = self.model_arch_name_to_cls[architecture] - if hasattr(model_cls, "supports_config") and not model_cls.supports_config(config): - logger.info( - "Custom model %s does not support config %s, falling back to HF", - model_cls.__name__, - type(config).__name__, - ) - return None - return model_cls - - def register(self, arch_name: str, model_cls: Type[nn.Module], exist_ok: bool = False) -> None: - """Register a custom model class for a given architecture name.""" - self.model_arch_name_to_cls.register(arch_name, model_cls, exist_ok=exist_ok) - - -@lru_cache -def get_registry(): - return _ModelRegistry() - - -ModelRegistry = get_registry() diff --git a/nemo_automodel/_transformers/registry/__init__.py b/nemo_automodel/_transformers/registry/__init__.py new file mode 100644 index 0000000000..28fd4666ce --- /dev/null +++ b/nemo_automodel/_transformers/registry/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Public exports for the model registry package.""" + +_MODEL_REGISTRY_EXPORTS = { + "ModelRegistry", + "RetrievalModelRegistry", +} + + +def __getattr__(name: str): + """Lazily expose model-specific registry globals without import cycles.""" + if name not in _MODEL_REGISTRY_EXPORTS: + raise AttributeError(name) + from nemo_automodel._transformers.registry import model_registry + + return getattr(model_registry, name) + + +__all__ = [ + "ModelRegistry", + "RetrievalModelRegistry", +] diff --git a/nemo_automodel/_transformers/registry/base.py b/nemo_automodel/_transformers/registry/base.py new file mode 100644 index 0000000000..ead6c41bc3 --- /dev/null +++ b/nemo_automodel/_transformers/registry/base.py @@ -0,0 +1,254 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import importlib +import inspect +import logging +from collections import OrderedDict +from collections.abc import Iterable, Mapping +from dataclasses import dataclass, field +from types import ModuleType + +import torch.nn as nn + +from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec + +logger = logging.getLogger(__name__) + +_ModelArchMappingInput = Mapping[str, ModelPackageSpec] | Iterable[ModelPackageSpec] | None + + +def _normalize_model_arch_mapping(auto_map: _ModelArchMappingInput = None) -> OrderedDict[str, ModelPackageSpec]: + """Normalize legacy dict mappings or tuple-based specs into an architecture lookup.""" + mapping: OrderedDict[str, ModelPackageSpec] = OrderedDict() + if isinstance(auto_map, Mapping): + spec_items = ((architecture, spec.with_architecture(architecture)) for architecture, spec in auto_map.items()) + else: + spec_items = [] + for spec in auto_map or (): + if spec.class_name is None and not spec.architectures: + continue + if not spec.architectures: + raise ValueError(f"Model architecture spec for {spec.package!r} must declare at least one architecture") + spec_items.extend((architecture, spec) for architecture in spec.architectures) + + for architecture, spec in spec_items: + if spec.class_name is None: + raise ValueError(f"Model architecture entry {architecture!r} must include a class name") + if architecture in mapping: + raise ValueError(f"Duplicated model architecture entry for {architecture!r}") + mapping[architecture] = spec + return mapping + + +@dataclass +class _BaseModelRegistry: + model_specs: _ModelArchMappingInput = None + model_arch_name_to_cls: "_BaseModelRegistry" = field(init=False, repr=False, compare=False) + _model_specs: tuple[ModelPackageSpec, ...] = field(init=False) + _loaded_model_classes: dict[str, type[nn.Module]] = field(default_factory=dict) + _extra_model_classes: dict[str, type[nn.Module]] = field(default_factory=dict) + _discarded_architectures: set[str] = field(default_factory=set) + _architecture_to_specs: dict[str, tuple[ModelPackageSpec, ...]] = field(default_factory=dict) + _model_type_to_specs: dict[str, tuple[ModelPackageSpec, ...]] = field(default_factory=dict) + + def __post_init__(self): + if isinstance(self.model_specs, Mapping): + raw_specs: Iterable[ModelPackageSpec] = _normalize_model_arch_mapping(self.model_specs).values() + else: + raw_specs = tuple(self.model_specs or ()) + _normalize_model_arch_mapping(raw_specs) + self._model_specs = tuple(dict.fromkeys(raw_specs)) + self._rebuild_spec_indexes() + self.model_arch_name_to_cls = self + + def _rebuild_spec_indexes(self) -> None: + architecture_to_specs: dict[str, list[ModelPackageSpec]] = {} + model_type_to_specs: dict[str, list[ModelPackageSpec]] = {} + dynamic_specs = tuple( + ModelPackageSpec.from_model_class(model_cls, architectures=(architecture,)) + for architecture, model_cls in self._extra_model_classes.items() + ) + for spec in (*dynamic_specs, *self._model_specs): + for architecture in spec.architectures: + if architecture in self._discarded_architectures and architecture not in self._extra_model_classes: + continue + architecture_to_specs.setdefault(architecture, []).append(spec) + for model_type in spec.model_types: + model_type_to_specs.setdefault(model_type, []).append(spec) + self._architecture_to_specs = { + architecture: tuple(specs) for architecture, specs in architecture_to_specs.items() + } + self._model_type_to_specs = {model_type: tuple(specs) for model_type, specs in model_type_to_specs.items()} + + def _discard_architecture(self, architecture: str) -> None: + self._discarded_architectures.add(architecture) + self._architecture_to_specs.pop(architecture, None) + self._extra_model_classes.pop(architecture, None) + self._loaded_model_classes.pop(architecture, None) + + def _load_model_class(self, architecture: str) -> type[nn.Module]: + if architecture in self._loaded_model_classes: + return self._loaded_model_classes[architecture] + spec = self.get_model_package_spec(architecture) + if spec is None or spec.class_name is None: + raise KeyError(architecture) + model_cls = getattr(importlib.import_module(spec.module_path), spec.class_name) + self._loaded_model_classes[architecture] = model_cls + return model_cls + + def __contains__(self, architecture: str) -> bool: + if architecture in self._extra_model_classes or architecture in self._loaded_model_classes: + return True + if self.get_model_package_spec(architecture) is None: + return False + try: + self._load_model_class(architecture) + return True + except Exception: + logger.debug("Model %s unavailable (import failed), removing from registry specs", architecture) + self._discard_architecture(architecture) + return False + + def __getitem__(self, architecture: str) -> type[nn.Module]: + if architecture in self._extra_model_classes: + return self._extra_model_classes[architecture] + return self._load_model_class(architecture) + + def __setitem__(self, architecture: str, model_cls: type[nn.Module]) -> None: + self._discarded_architectures.discard(architecture) + self._extra_model_classes[architecture] = model_cls + self._loaded_model_classes.pop(architecture, None) + self._rebuild_spec_indexes() + + def keys(self): + return set(self._architecture_to_specs) | set(self._extra_model_classes) + + def __len__(self) -> int: + return len(self.keys()) + + def __repr__(self) -> str: + return ( + f"_BaseModelRegistry(specs={len(self._architecture_to_specs)}, " + f"extra={len(self._extra_model_classes)}, loaded={len(self._loaded_model_classes)})" + ) + + @property + def supported_models(self): + return self.keys() + + def get_model_cls_from_model_arch(self, model_arch: str) -> type[nn.Module]: + return self[model_arch] + + def has_custom_model(self, arch_name: str) -> bool: + """Return ``True`` if *arch_name* has a custom (non-HF) implementation.""" + return arch_name in self + + def get_model_package_spec(self, architecture: str) -> ModelPackageSpec | None: + """Return package metadata for an architecture without importing ``model.py``.""" + specs = self.get_model_package_specs_for_architecture(architecture) + if specs: + return specs[0] + return None + + def get_model_package_specs_for_architecture(self, architecture: str) -> tuple[ModelPackageSpec, ...]: + """Return all package metadata entries that declare *architecture*.""" + return self._architecture_to_specs.get(architecture, ()) + + def get_model_package_specs_for_model_type(self, model_type: str) -> tuple[ModelPackageSpec, ...]: + """Return package metadata entries that declare *model_type*.""" + return self._model_type_to_specs.get(model_type, ()) + + @staticmethod + def _iter_config_classes(module: ModuleType): + """Yield local ``PretrainedConfig`` subclasses declared by *module*.""" + from transformers import PretrainedConfig + + for _, cls in inspect.getmembers(module, inspect.isclass): + if cls is PretrainedConfig: + continue + if cls.__module__ != module.__name__: + continue + if not issubclass(cls, PretrainedConfig): + continue + if not isinstance(getattr(cls, "model_type", None), str): + continue + yield cls + + def _resolve_config_class(self, spec: ModelPackageSpec, model_type: str): + module_path = spec.config_module_path + if module_path is None: + return None + + try: + module = importlib.import_module(module_path) + except ImportError: + logger.debug("Config module is unavailable: %s", module_path) + return None + + if spec.config_class_name is not None: + return getattr(module, spec.config_class_name, None) + + for config_cls in self._iter_config_classes(module): + if config_cls.model_type == model_type: + return config_cls + return None + + def ensure_config_registered(self, model_type: str) -> bool: + """Register the local config class for *model_type* with HuggingFace, if declared.""" + from transformers import AutoConfig + from transformers.models.auto.configuration_auto import CONFIG_MAPPING + + if model_type in CONFIG_MAPPING: + return True + + for spec in self.get_model_package_specs_for_model_type(model_type): + config_cls = self._resolve_config_class(spec, model_type) + if config_cls is None: + continue + try: + AutoConfig.register(model_type, config_cls) + except ValueError: + if getattr(config_cls, "model_type", None) != model_type: + CONFIG_MAPPING.register(model_type, config_cls) + else: + raise + return True + return False + + def resolve_custom_model_cls(self, architecture: str, config) -> type[nn.Module] | None: + """Return the custom model class if it exists and supports *config*, else ``None``. + + Custom model classes may define a ``supports_config(config)`` classmethod + to opt out for specific HF configs (e.g. a Mistral3 VLM with a dense + Ministral3 text backbone instead of the expected Mistral4 MoE+MLA). + """ + if architecture not in self: + return None + model_cls = self[architecture] + if hasattr(model_cls, "supports_config") and not model_cls.supports_config(config): + logger.info( + "Custom model %s does not support config %s, falling back to HF", + model_cls.__name__, + type(config).__name__, + ) + return None + return model_cls + + def register(self, arch_name: str, model_cls: type[nn.Module], exist_ok: bool = False) -> None: + """Register a custom model class for a given architecture name.""" + if not exist_ok and arch_name in self._extra_model_classes: + raise ValueError(f"Duplicated model implementation for {arch_name}") + self[arch_name] = model_cls diff --git a/nemo_automodel/_transformers/registry/model_package_spec.py b/nemo_automodel/_transformers/registry/model_package_spec.py new file mode 100644 index 0000000000..2f0703fc8a --- /dev/null +++ b/nemo_automodel/_transformers/registry/model_package_spec.py @@ -0,0 +1,97 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, replace + + +@dataclass(frozen=True) +class ModelPackageSpec: + """Registry entry for a model package. + + Args: + package: Python package that owns the model implementation, usually + ``nemo_automodel.components.models.``. + class_name: Model class to lazy-load from ``module_path`` for architectures + declared by this spec. Config-only specs leave this unset. + model_module: Module inside ``package`` that contains ``class_name``. Defaults + to ``model``, producing ``.model``. + config_module: Module inside ``package`` that contains a custom HF config + class. Used only for model types that need on-demand AutoConfig registration. + config_class_name: Explicit config class name to register from + ``config_module_path``. If unset, the registry can discover a local + ``PretrainedConfig`` subclass by ``model_type``. + architectures: HF ``config.architectures`` names that should resolve to this + model class. These are expanded into the architecture-to-spec lookup. + model_types: HF ``config.model_type`` names associated with this package, + primarily for resolving and registering custom config classes. + """ + + package: str + class_name: str | None = None + model_module: str = "model" + config_module: str | None = None + config_class_name: str | None = None + architectures: tuple[str, ...] = () + model_types: tuple[str, ...] = () + + def __post_init__(self) -> None: + object.__setattr__(self, "architectures", tuple(self.architectures)) + object.__setattr__(self, "model_types", tuple(self.model_types)) + + @classmethod + def from_model_class( + cls, + model_cls: type, + *, + config_module: str | None = None, + config_class_name: str | None = None, + architectures: list[str] | tuple[str, ...] = (), + model_types: tuple[str, ...] = (), + ) -> "ModelPackageSpec": + """Create a spec from a model class object.""" + package, sep, model_module = model_cls.__module__.rpartition(".") + if not sep: + package = "" + model_module = model_cls.__module__ + return cls( + package=package, + class_name=model_cls.__name__, + model_module=model_module, + config_module=config_module, + config_class_name=config_class_name, + architectures=architectures, + model_types=model_types, + ) + + @property + def module_path(self) -> str: + """Return the import path for the model implementation module.""" + if not self.package: + return self.model_module + return f"{self.package}.{self.model_module}" + + @property + def config_module_path(self) -> str | None: + """Return the import path for this model package's config module, if declared.""" + if self.config_module is None: + return None + if not self.package: + return self.config_module + return f"{self.package}.{self.config_module}" + + def with_architecture(self, architecture: str) -> "ModelPackageSpec": + """Return a copy that records *architecture* as an alias for this package.""" + if architecture in self.architectures: + return self + return replace(self, architectures=(*self.architectures, architecture)) diff --git a/nemo_automodel/_transformers/registry/model_registry.py b/nemo_automodel/_transformers/registry/model_registry.py new file mode 100644 index 0000000000..b115706ca3 --- /dev/null +++ b/nemo_automodel/_transformers/registry/model_registry.py @@ -0,0 +1,231 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import lru_cache + +from nemo_automodel._transformers.registry.base import _BaseModelRegistry +from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec + +# Static model package specs. The registry derives architecture lookups from each spec's architectures. +# Analogous to HuggingFace transformers' MODEL_FOR_CAUSAL_LM_MAPPING_NAMES. +# Models are loaded lazily on first access rather than imported at startup. +MODEL_PACKAGE_SPECS: tuple[ModelPackageSpec, ...] = ( + ModelPackageSpec( + package="nemo_automodel.components.models.baichuan", + class_name="BaichuanForCausalLM", + config_module="configuration", + config_class_name="BaichuanConfig", + architectures=("BaichuanForCausalLM",), + model_types=("baichuan",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.deepseek_v3", + class_name="DeepseekV3ForCausalLM", + architectures=("DeepseekV3ForCausalLM",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.deepseek_v32", + class_name="DeepseekV32ForCausalLM", + architectures=("DeepseekV32ForCausalLM",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.deepseek_v4", + class_name="DeepseekV4ForCausalLM", + config_module="config", + config_class_name="DeepseekV4Config", + architectures=("DeepseekV4ForCausalLM",), + model_types=("deepseek_v4",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.ernie4_5", + class_name="Ernie4_5_MoeForCausalLM", + architectures=("Ernie4_5_MoeForCausalLM",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.glm4_moe", + class_name="Glm4MoeForCausalLM", + architectures=("Glm4MoeForCausalLM",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.glm4_moe_lite", + class_name="Glm4MoeLiteForCausalLM", + architectures=("Glm4MoeLiteForCausalLM",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.glm_moe_dsa", + class_name="GlmMoeDsaForCausalLM", + architectures=("GlmMoeDsaForCausalLM",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.gemma4_moe", + class_name="Gemma4ForConditionalGeneration", + architectures=("Gemma4ForConditionalGeneration",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.gpt_oss", + class_name="GptOssForCausalLM", + architectures=("GptOssForCausalLM",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.kimi_k25_vl", + class_name="KimiK25VLForConditionalGeneration", + config_module="model", + config_class_name="KimiK25VLConfig", + architectures=("KimiK25ForConditionalGeneration", "KimiK25VLForConditionalGeneration"), + model_types=("kimi_k25", "kimi_k25_vl"), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.kimivl", + class_name="KimiVLForConditionalGeneration", + config_module="model", + config_class_name="KimiVLConfig", + architectures=("KimiVLForConditionalGeneration",), + model_types=("kimi_vl",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.llama", + class_name="LlamaForCausalLM", + architectures=("LlamaForCausalLM",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.minimax_m2", + class_name="MiniMaxM2ForCausalLM", + architectures=("MiniMaxM2ForCausalLM",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.mimo_v2_flash", + class_name="MiMoV2FlashForCausalLM", + config_module="config", + config_class_name="MiMoV2FlashConfig", + architectures=("MiMoV2FlashForCausalLM",), + model_types=("mimo_v2_flash",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.mistral3", + class_name="Ministral3ForCausalLM", + architectures=("Ministral3ForCausalLM",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.mistral4", + class_name="Mistral4ForCausalLM", + config_module="configuration", + config_class_name="Mistral4Config", + architectures=("Mistral4ForCausalLM",), + model_types=("mistral4",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.mistral4", + class_name="Mistral3ForConditionalGeneration", + architectures=("Mistral3ForConditionalGeneration",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.mistral3_vlm", + class_name="Mistral3FP8VLMForConditionalGeneration", + architectures=("Mistral3FP8VLMForConditionalGeneration",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.nemotron_v3", + class_name="NemotronHForCausalLM", + architectures=("NemotronHForCausalLM",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.nemotron_omni", + class_name="NemotronOmniForConditionalGeneration", + architectures=("NemotronH_Nano_Omni_Reasoning_V3",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.nemotron_parse", + class_name="NemotronParseForConditionalGeneration", + architectures=("NemotronParseForConditionalGeneration",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.llava_onevision", + class_name="LLaVAOneVision1_5_ForConditionalGeneration", + config_module="model", + config_class_name="Llavaonevision1_5Config", + architectures=("LLaVAOneVision1_5_ForConditionalGeneration",), + model_types=("llavaonevision1_5",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.hy_v3", + class_name="HYV3ForCausalLM", + config_module="config", + config_class_name="HYV3Config", + architectures=("HYV3ForCausalLM",), + model_types=("hy_v3",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.qwen2", + class_name="Qwen2ForCausalLM", + architectures=("Qwen2ForCausalLM",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.qwen3_moe", + class_name="Qwen3MoeForCausalLM", + architectures=("Qwen3MoeForCausalLM",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.qwen3_next", + class_name="Qwen3NextForCausalLM", + architectures=("Qwen3NextForCausalLM",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.qwen3_omni_moe", + class_name="Qwen3OmniMoeThinkerForConditionalGeneration", + architectures=("Qwen3OmniMoeForConditionalGeneration",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.qwen3_vl_moe", + class_name="Qwen3VLMoeForConditionalGeneration", + architectures=("Qwen3VLMoeForConditionalGeneration",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.qwen3_5_moe", + class_name="Qwen3_5MoeForConditionalGeneration", + architectures=("Qwen3_5MoeForConditionalGeneration",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.step3p5", + class_name="Step3p5ForCausalLM", + architectures=("Step3p5ForCausalLM",), + ), +) + +RETRIEVAL_MODEL_PACKAGE_SPECS: tuple[ModelPackageSpec, ...] = ( + ModelPackageSpec( + package="nemo_automodel.components.models.llama_bidirectional", + class_name="LlamaBidirectionalForSequenceClassification", + architectures=("LlamaBidirectionalForSequenceClassification",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.llama_bidirectional", + class_name="LlamaBidirectionalModel", + architectures=("LlamaBidirectionalModel",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.ministral_bidirectional", + class_name="Ministral3BidirectionalModel", + architectures=("Ministral3BidirectionalModel",), + ), +) + + +@lru_cache +def make_registry(model_specs: tuple[ModelPackageSpec, ...]) -> _BaseModelRegistry: + """Return a process-wide model registry singleton for package specs.""" + return _BaseModelRegistry(model_specs=model_specs) + + +ModelRegistry = make_registry(MODEL_PACKAGE_SPECS) +RetrievalModelRegistry = make_registry(RETRIEVAL_MODEL_PACKAGE_SPECS) diff --git a/nemo_automodel/_transformers/retrieval.py b/nemo_automodel/_transformers/retrieval.py index 7287fed41d..e5dba30446 100644 --- a/nemo_automodel/_transformers/retrieval.py +++ b/nemo_automodel/_transformers/retrieval.py @@ -25,7 +25,7 @@ from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING from transformers.utils import logging -from nemo_automodel._transformers.registry import ModelRegistry +from nemo_automodel._transformers.registry import RetrievalModelRegistry from nemo_automodel.components.models.common.bidirectional import EncoderStateDictAdapter logger = logging.get_logger(__name__) @@ -57,11 +57,11 @@ def _get_supported_backbone_class(model_type: str, task: str) -> type[nn.Module] f"Unsupported task '{task}' for model type '{model_type}'. Available tasks: {', '.join(task_map)}." ) - if arch_name not in ModelRegistry.model_arch_name_to_cls: - raise ValueError(f"Model class '{arch_name}' not found in ModelRegistry.") + if not RetrievalModelRegistry.has_custom_model(arch_name): + raise ValueError(f"Model class '{arch_name}' not found in RetrievalModelRegistry.") logger.info(f"Using {arch_name} from registry") - return ModelRegistry.model_arch_name_to_cls[arch_name] + return RetrievalModelRegistry.get_model_cls_from_model_arch(arch_name) def _move_to_extracted_dtype(model: nn.Module, extracted_model: nn.Module) -> nn.Module: @@ -179,7 +179,7 @@ def configure_encoder_metadata(model: PreTrainedModel, config) -> None: """Configure HuggingFace consolidated checkpoint metadata on a model. Sets ``config.architectures`` unconditionally. For custom retrieval - architectures registered in :class:`ModelRegistry`, also writes + architectures registered in :class:`RetrievalModelRegistry`, also writes ``config.auto_map`` so that the saved checkpoint can be reloaded via HuggingFace Auto classes. Standard HF models already have their own auto-resolution and do not need ``auto_map`` entries. @@ -193,7 +193,7 @@ def configure_encoder_metadata(model: PreTrainedModel, config) -> None: # Only set auto_map for custom retrieval architectures. # Standard HF models don't need auto_map pointing to a local model.py. - if ModelRegistry.has_retrieval_model(encoder_class_name): + if RetrievalModelRegistry.get_model_package_spec(encoder_class_name) is not None: config_class_name = config.__class__.__name__ config_module = config.__class__.__module__.rsplit(".", 1)[-1] model_module = model.__class__.__module__.rsplit(".", 1)[-1] @@ -226,8 +226,8 @@ def build_encoder_backbone( Without ``extract_submodel``, model types listed in :data:`SUPPORTED_BACKBONES` resolve to custom bidirectional classes from - :class:`ModelRegistry`; all other model types fall back to HuggingFace Auto - classes. + :class:`RetrievalModelRegistry`; all other model types fall back to + HuggingFace Auto classes. Args: model_name_or_path: Path or HuggingFace Hub identifier. @@ -247,7 +247,7 @@ def build_encoder_backbone( Raises: ValueError: If the task is unsupported for a known model type, or the - architecture class is missing from :class:`ModelRegistry`. + architecture class is missing from :class:`RetrievalModelRegistry`. """ config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code) model_type = getattr(config, "model_type", "") @@ -320,7 +320,7 @@ def save_encoder_pretrained(model: nn.Module, save_directory: str, **kwargs) -> model.model.save_pretrained(save_directory) -# HuggingFace model_type -> task -> bidirectional architecture class name in ModelRegistry +# HuggingFace model_type -> task -> bidirectional architecture class name in RetrievalModelRegistry _LLAMA_TASKS = { "embedding": "LlamaBidirectionalModel", "score": "LlamaBidirectionalForSequenceClassification", @@ -340,7 +340,7 @@ def _init_encoder_common(encoder: nn.Module, model: PreTrainedModel) -> None: """Shared init for BiEncoderModel and CrossEncoderModel.""" encoder.model = model encoder.config = model.config - if ModelRegistry.has_retrieval_model(model.__class__.__name__): + if RetrievalModelRegistry.get_model_package_spec(model.__class__.__name__) is not None: encoder.name_or_path = os.path.dirname(inspect.getfile(type(model))) else: encoder.name_or_path = getattr(model.config, "name_or_path", "") diff --git a/tests/integration/test_llava_onevision_cpu.py b/tests/integration/test_llava_onevision_cpu.py index 8802734e36..94f7afbc08 100644 --- a/tests/integration/test_llava_onevision_cpu.py +++ b/tests/integration/test_llava_onevision_cpu.py @@ -317,10 +317,10 @@ def test_registry_registration(): arch_name = "LLaVAOneVision1_5_ForConditionalGeneration" - assert arch_name in ModelRegistry.model_arch_name_to_cls, f"{arch_name} not found in MODEL_ARCH_MAPPING" + assert arch_name in ModelRegistry.model_arch_name_to_cls, f"{arch_name} not found in MODEL_PACKAGE_SPECS" model_cls = ModelRegistry.model_arch_name_to_cls[arch_name] - logger.info(" ✓ Model registered in MODEL_ARCH_MAPPING") + logger.info(" ✓ Model registered in MODEL_PACKAGE_SPECS") logger.info(f" - Architecture name: {arch_name}") logger.info(f" - Model class: {model_cls.__name__}") logger.info(f" - Module: {model_cls.__module__}") diff --git a/tests/unit_tests/_transformers/test_doc_coverage.py b/tests/unit_tests/_transformers/test_doc_coverage.py index ab96733721..bedaa81f72 100644 --- a/tests/unit_tests/_transformers/test_doc_coverage.py +++ b/tests/unit_tests/_transformers/test_doc_coverage.py @@ -13,7 +13,7 @@ # limitations under the License. """Guard against the gemma4-style regression where a new model architecture -lands in ``MODEL_ARCH_MAPPING`` without any corresponding page under +lands in ``MODEL_PACKAGE_SPECS`` without any corresponding page under ``docs/model-coverage/``. """ @@ -88,14 +88,14 @@ def _repo_root() -> pathlib.Path: def test_every_registered_arch_has_model_coverage_doc(): - """Every architecture in ``MODEL_ARCH_MAPPING`` must be mentioned in at + """Every architecture in ``MODEL_PACKAGE_SPECS`` must be mentioned in at least one ``docs/model-coverage/*.md`` file, either by its own name or by a mapped alias in ``_DOC_ARCH_ALIASES``. This guards against the regression where a new arch (e.g. gemma4) is registered but never gets a corresponding model card in the docs. """ - from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS docs_dir = _repo_root() / "docs" / "model-coverage" assert docs_dir.is_dir(), f"docs/model-coverage/ not found at {docs_dir}" @@ -104,10 +104,11 @@ def test_every_registered_arch_has_model_coverage_doc(): assert md_contents, "No .md files found under docs/model-coverage/" missing = [] - for arch_name in MODEL_ARCH_MAPPING: - needle = _DOC_ARCH_ALIASES.get(arch_name, arch_name) - if not any(needle in content for content in md_contents): - missing.append((arch_name, needle)) + for spec in MODEL_PACKAGE_SPECS: + for arch_name in spec.architectures: + needle = _DOC_ARCH_ALIASES.get(arch_name, arch_name) + if not any(needle in content for content in md_contents): + missing.append((arch_name, needle)) if missing: details = "\n".join(f" - {arch} (looked for {repr(needle)})" for arch, needle in missing) diff --git a/tests/unit_tests/_transformers/test_model_init.py b/tests/unit_tests/_transformers/test_model_init.py index 84660ea7fb..821bfb5689 100644 --- a/tests/unit_tests/_transformers/test_model_init.py +++ b/tests/unit_tests/_transformers/test_model_init.py @@ -552,6 +552,23 @@ def test_raises_when_config_class_cannot_be_resolved(self, mock_get_dict, mock_m class TestGetHfConfigLayerTypesRetry: """get_hf_config should retry via the layer_types fix helper when AutoConfig raises.""" + @patch("nemo_automodel._transformers.model_init.ModelRegistry.ensure_config_registered") + @patch("nemo_automodel._transformers.model_init.PretrainedConfig.get_config_dict") + @patch("nemo_automodel._transformers.model_init.resolve_trust_remote_code", return_value=False) + @patch("nemo_automodel._transformers.model_init.AutoConfig.from_pretrained") + def test_registers_custom_config_before_auto_config( + self, mock_from_pretrained, _mock_trust, mock_get_dict, mock_register + ): + mock_get_dict.return_value = ({"model_type": "fake_model"}, {}) + built_config = MagicMock() + mock_from_pretrained.return_value = built_config + + result = get_hf_config("fake/model", "sdpa") + + assert result is built_config + mock_register.assert_called_once_with("fake_model") + mock_from_pretrained.assert_called_once() + @patch("nemo_automodel._transformers.model_init._load_config_with_layer_types_fix") @patch("nemo_automodel._transformers.model_init.resolve_trust_remote_code", return_value=True) @patch("nemo_automodel._transformers.model_init.AutoConfig.from_pretrained") diff --git a/tests/unit_tests/_transformers/test_recipe_doc_coverage.py b/tests/unit_tests/_transformers/test_recipe_doc_coverage.py index 2217e2692e..17809bd99b 100644 --- a/tests/unit_tests/_transformers/test_recipe_doc_coverage.py +++ b/tests/unit_tests/_transformers/test_recipe_doc_coverage.py @@ -21,7 +21,7 @@ least one ``docs/model-coverage/*.md`` file. This complements ``test_doc_coverage.py`` which only covers archs registered -in ``MODEL_ARCH_MAPPING``. Many recipes fine-tune HF-native archs (e.g., +in ``MODEL_PACKAGE_SPECS``. Many recipes fine-tune HF-native archs (e.g., ``Olmo2ForCausalLM``) that never get added to the registry, and still need documentation. diff --git a/tests/unit_tests/_transformers/test_registry.py b/tests/unit_tests/_transformers/test_registry.py index 46a1f1b687..a70d92d4cc 100644 --- a/tests/unit_tests/_transformers/test_registry.py +++ b/tests/unit_tests/_transformers/test_registry.py @@ -13,16 +13,23 @@ # limitations under the License. import types +from unittest.mock import patch import pytest +from nemo_automodel._transformers.registry.base import _BaseModelRegistry -def _new_registry_instance(registry_module): + +def _new_registry_instance(_registry_module=None): """Create a fresh registry with an empty auto_map for testing.""" - from nemo_automodel._transformers.registry import _LazyArchMapping + return _BaseModelRegistry(model_specs=()) + - mapping = _LazyArchMapping(auto_map={}) - return registry_module._ModelRegistry(model_arch_name_to_cls=mapping) +def _model_arch_lookup(model_arch_mapping): + """Return the architecture-keyed lookup derived from registry package specs.""" + from nemo_automodel._transformers.registry.base import _normalize_model_arch_mapping + + return _normalize_model_arch_mapping(model_arch_mapping) def test_register_single_class(): @@ -104,38 +111,141 @@ class A: assert inst.get_model_cls_from_model_arch("A") is A -def test_get_registry_is_cached(): - from nemo_automodel._transformers import registry as reg +def test_make_registry_is_cached(): + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS, make_registry - reg.get_registry.cache_clear() - r1 = reg.get_registry() - r2 = reg.get_registry() + make_registry.cache_clear() + r1 = make_registry(MODEL_PACKAGE_SPECS) + r2 = make_registry(MODEL_PACKAGE_SPECS) assert r1 is r2 -def test_lazy_arch_mapping_auto_map(): +def test_make_registry_caches_by_model_specs(): + from nemo_automodel._transformers.registry.model_registry import RETRIEVAL_MODEL_PACKAGE_SPECS, make_registry + + make_registry.cache_clear() + r1 = make_registry(RETRIEVAL_MODEL_PACKAGE_SPECS) + r2 = make_registry(RETRIEVAL_MODEL_PACKAGE_SPECS) + assert r1 is r2 + + +def test_registry_reexports_public_singletons(): + """The registry package root exposes only the runtime registry singletons.""" + from nemo_automodel._transformers import registry as reg + from nemo_automodel._transformers.registry.model_registry import ( + ModelRegistry, + RetrievalModelRegistry, + ) + + assert reg.ModelRegistry is ModelRegistry + assert reg.RetrievalModelRegistry is RetrievalModelRegistry + + +def test_registry_mapping_loads_model_class_on_demand(): """Static auto_map entries are lazily loaded on first access.""" - from nemo_automodel._transformers.registry import _LazyArchMapping + from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec class FakeClass: pass fake_module = types.SimpleNamespace(FakeClass=FakeClass) - mapping = _LazyArchMapping({"FakeArch": ("fake.module", "FakeClass")}) + spec = ModelPackageSpec(package="fake", model_module="module", class_name="FakeClass", architectures=("FakeArch",)) + inst = _BaseModelRegistry(model_specs=(spec,)) - mapping._modules["fake.module"] = fake_module - - assert "FakeArch" in mapping - assert mapping["FakeArch"] is FakeClass - assert "FakeArch" in mapping._loaded + with patch("nemo_automodel._transformers.registry.base.importlib.import_module", return_value=fake_module): + assert "FakeArch" in inst.model_arch_name_to_cls + assert inst.model_arch_name_to_cls["FakeArch"] is FakeClass + assert "FakeArch" in inst._loaded_model_classes with pytest.raises(KeyError): - mapping["NonExistent"] + inst.model_arch_name_to_cls["NonExistent"] + + +def test_registry_accepts_model_package_spec_entries(): + """Static auto_map entries may be declared as ModelPackageSpec metadata.""" + from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec + + class FakeClass: + pass + + fake_module = types.SimpleNamespace(FakeClass=FakeClass) + spec = ModelPackageSpec( + package="fake.package", + model_module="model", + class_name="FakeClass", + architectures=["FakeArch"], + model_types=("fake",), + ) + inst = _BaseModelRegistry(model_specs={"FakeArch": spec}) + + with patch("nemo_automodel._transformers.registry.base.importlib.import_module", return_value=fake_module): + assert inst.model_arch_name_to_cls["FakeArch"] is FakeClass + assert inst.model_arch_name_to_cls.keys() == {"FakeArch"} + + +def test_model_package_spec_from_model_class(): + """ModelPackageSpec can be derived directly from a model class object.""" + from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec + + class FakeClass: + pass + + FakeClass.__module__ = "fake.package.model" + + spec = ModelPackageSpec.from_model_class(FakeClass, architectures=["FakeArch"]) + + assert spec.module_path == "fake.package.model" + assert spec.class_name == "FakeClass" + assert spec.architectures == ("FakeArch",) + + +def test_registry_derives_keys_from_spec_architectures(): + """Tuple-based auto_map entries derive lookup keys from spec.architectures.""" + from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec + + spec = ModelPackageSpec( + package="fake.package", + class_name="FakeClass", + architectures=("FakeArch", "FakeAlias"), + ) + inst = _BaseModelRegistry(model_specs=(spec,)) + + assert inst.model_arch_name_to_cls.keys() == {"FakeArch", "FakeAlias"} + + +def test_base_registry_owns_architecture_metadata(): + """Package metadata indexes live on the registry, not on the lazy class loader.""" + from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec + + spec = ModelPackageSpec( + package="fake.package", + class_name="FakeClass", + architectures=("FakeArch", "FakeAlias"), + model_types=("fake",), + ) + inst = _BaseModelRegistry(model_specs=(spec,)) + + assert inst.get_model_package_spec("FakeArch") is spec + assert inst.get_model_package_spec("FakeAlias") is spec + assert inst.get_model_package_specs_for_model_type("fake") == (spec,) + + +def test_registry_rejects_duplicate_architectures(): + """Duplicate architecture names across package specs fail during registry construction.""" + from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec + + specs = ( + ModelPackageSpec(package="fake.package_a", class_name="FakeClassA", architectures=("FakeArch",)), + ModelPackageSpec(package="fake.package_b", class_name="FakeClassB", architectures=("FakeArch",)), + ) + + with pytest.raises(ValueError, match="Duplicated model architecture entry for 'FakeArch'"): + _BaseModelRegistry(model_specs=specs) -def test_lazy_arch_mapping_extra_overrides_auto_map(): +def test_registry_dynamic_entry_overrides_static_entry(): """Dynamically registered entries take precedence over static entries.""" - from nemo_automodel._transformers.registry import _LazyArchMapping + from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec class StaticClass: pass @@ -144,32 +254,178 @@ class DynamicClass: pass fake_module = types.SimpleNamespace(StaticClass=StaticClass) - mapping = _LazyArchMapping({"MyArch": ("fake.module", "StaticClass")}) - mapping._modules["fake.module"] = fake_module + spec = ModelPackageSpec(package="fake", model_module="module", class_name="StaticClass", architectures=("MyArch",)) + inst = _BaseModelRegistry(model_specs=(spec,)) - assert mapping["MyArch"] is StaticClass + with patch("nemo_automodel._transformers.registry.base.importlib.import_module", return_value=fake_module): + assert inst.model_arch_name_to_cls["MyArch"] is StaticClass - mapping["MyArch"] = DynamicClass - assert mapping["MyArch"] is DynamicClass + inst.model_arch_name_to_cls["MyArch"] = DynamicClass + assert inst.model_arch_name_to_cls["MyArch"] is DynamicClass + assert inst.get_model_package_spec("MyArch").class_name == "DynamicClass" -def test_lazy_arch_mapping_unavailable_model(): +def test_registry_unavailable_model(): """Auto_map entries whose imports fail are removed and excluded from containment.""" - from nemo_automodel._transformers.registry import _LazyArchMapping + from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec + + spec = ModelPackageSpec( + package="nonexistent.module", + model_module="path", + class_name="BadClass", + architectures=("BadArch",), + ) + inst = _BaseModelRegistry(model_specs=(spec,)) + + assert "BadArch" not in inst.model_arch_name_to_cls + assert "BadArch" not in inst.model_arch_name_to_cls.keys() + + +def test_registry_discards_failed_import_from_metadata_index(): + """Failed model imports remove the architecture from the registry-owned metadata index.""" + from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec + + spec = ModelPackageSpec( + package="nonexistent.module", + model_module="path", + class_name="BadClass", + architectures=("BadArch",), + ) + inst = _BaseModelRegistry(model_specs=(spec,)) + + assert "BadArch" not in inst.model_arch_name_to_cls + assert inst.get_model_package_spec("BadArch") is None + assert "BadArch" not in inst.supported_models + + class GoodClass: + pass - mapping = _LazyArchMapping({"BadArch": ("nonexistent.module.path", "BadClass")}) + inst.register("GoodArch", GoodClass) + assert "BadArch" not in inst.model_arch_name_to_cls.keys() - assert "BadArch" not in mapping - assert "BadArch" not in mapping._auto_map + +def test_direct_lazy_registration_updates_registry_metadata(): + """Direct mapping assignment still informs the registry about dynamic model metadata.""" + from nemo_automodel._transformers import registry as reg + + inst = _new_registry_instance(reg) + + class DirectClass: + pass + + DirectClass.__module__ = "fake.direct.model" + inst.model_arch_name_to_cls["DirectArch"] = DirectClass + + spec = inst.get_model_package_spec("DirectArch") + assert spec.module_path == "fake.direct.model" + assert spec.class_name == "DirectClass" + assert inst.model_arch_name_to_cls["DirectArch"] is DirectClass def test_default_registry_has_static_entries(): - """The default registry is populated from MODEL_ARCH_MAPPING.""" - from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING, _ModelRegistry + """The default registry is populated from MODEL_PACKAGE_SPECS.""" + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS, make_registry + + inst = make_registry(MODEL_PACKAGE_SPECS) + for spec in MODEL_PACKAGE_SPECS: + for arch_name in spec.architectures: + assert arch_name in inst.model_arch_name_to_cls.keys() + + +def test_retrieval_registry_has_separate_static_entries(): + """Retrieval architectures live in the retrieval registry, not the default registry.""" + from nemo_automodel._transformers.registry.model_registry import ( + MODEL_PACKAGE_SPECS, + RETRIEVAL_MODEL_PACKAGE_SPECS, + ) + + retrieval_arch = "LlamaBidirectionalModel" + default_registry = _BaseModelRegistry(model_specs=MODEL_PACKAGE_SPECS) + retrieval_registry = _BaseModelRegistry(model_specs=RETRIEVAL_MODEL_PACKAGE_SPECS) + + assert retrieval_arch not in _model_arch_lookup(MODEL_PACKAGE_SPECS) + assert retrieval_arch in _model_arch_lookup(RETRIEVAL_MODEL_PACKAGE_SPECS) + assert retrieval_arch not in default_registry.model_arch_name_to_cls.keys() + assert retrieval_arch in retrieval_registry.model_arch_name_to_cls.keys() + + +def test_registry_discovers_config_class_from_config_module(): + """Config classes are imported from declared convention modules on demand.""" + from transformers import PretrainedConfig + + from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec + + class FakeConfig(PretrainedConfig): + model_type = "fake_model" + + FakeConfig.__module__ = "fake.package.config" + fake_module = types.SimpleNamespace(__name__="fake.package.config", FakeConfig=FakeConfig) + spec = ModelPackageSpec(package="fake.package", model_types=("fake_model",), config_module="config") + inst = _BaseModelRegistry(model_specs=(spec,)) + + with ( + patch("transformers.AutoConfig.register") as mock_register, + patch("nemo_automodel._transformers.registry.base.importlib.import_module", return_value=fake_module), + ): + assert inst.ensure_config_registered("fake_model") is True + + mock_register.assert_called_once_with("fake_model", FakeConfig) + - inst = _ModelRegistry() - for arch_name in MODEL_ARCH_MAPPING: - assert arch_name in inst.model_arch_name_to_cls.keys() +def test_registry_registers_explicit_config_alias(): + """Explicit config metadata supports aliases whose key differs from class.model_type.""" + from transformers import PretrainedConfig + + from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec + + class FakeConfig(PretrainedConfig): + model_type = "canonical_model" + + fake_module = types.SimpleNamespace(__name__="fake.package.model", FakeConfig=FakeConfig) + spec = ModelPackageSpec( + package="fake.package", + model_types=("canonical_model", "alias_model"), + config_module="model", + config_class_name="FakeConfig", + ) + inst = _BaseModelRegistry(model_specs=(spec,)) + + with ( + patch("transformers.AutoConfig.register") as mock_register, + patch("nemo_automodel._transformers.registry.base.importlib.import_module", return_value=fake_module), + ): + assert inst.ensure_config_registered("alias_model") is True + + mock_register.assert_called_once_with("alias_model", FakeConfig) + + +def test_registry_registers_config_alias_when_auto_config_rejects_mismatch(): + """HF validates ``config.model_type``; aliases fall back to CONFIG_MAPPING directly.""" + from transformers import PretrainedConfig + from transformers.models.auto.configuration_auto import CONFIG_MAPPING + + from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec + + class FakeConfig(PretrainedConfig): + model_type = "canonical_model" + + fake_module = types.SimpleNamespace(__name__="fake.package.model", FakeConfig=FakeConfig) + spec = ModelPackageSpec( + package="fake.package", + model_types=("alias_model_with_mismatch",), + config_module="model", + config_class_name="FakeConfig", + ) + inst = _BaseModelRegistry(model_specs=(spec,)) + + with ( + patch("transformers.AutoConfig.register", side_effect=ValueError("model_type mismatch")), + patch.object(CONFIG_MAPPING, "register") as mock_mapping_register, + patch("nemo_automodel._transformers.registry.base.importlib.import_module", return_value=fake_module), + ): + assert inst.ensure_config_registered("alias_model_with_mismatch") is True + + mock_mapping_register.assert_called_once_with("alias_model_with_mismatch", FakeConfig) def test_resolve_custom_model_cls_found(): @@ -242,88 +498,86 @@ def supports_config(cls, config): assert inst.resolve_custom_model_cls("ConfigAwareModel", bad) is None -def test_custom_config_registrations_in_config_mapping(): - """Models in _CUSTOM_CONFIG_REGISTRATIONS must be registered in CONFIG_MAPPING after import. +def test_config_metadata_is_metadata_only_until_requested(): + """Config metadata should not require eager registration at registry import time.""" + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS - This ensures that AutoConfig.from_pretrained can resolve custom model types - (e.g. kimi_k25, kimi_vl) from local checkpoints without trust_remote_code=True. - """ - from transformers.models.auto.configuration_auto import CONFIG_MAPPING - - from nemo_automodel._transformers.registry import _CUSTOM_CONFIG_REGISTRATIONS - - missing = [] - for model_type in _CUSTOM_CONFIG_REGISTRATIONS: - if model_type not in CONFIG_MAPPING: - missing.append(model_type) - - assert not missing, ( - f"Model type(s) {missing} are in _CUSTOM_CONFIG_REGISTRATIONS but not in " - f"CONFIG_MAPPING. The _register_custom_configs() call at module level may " - f"have failed for these entries." - ) + for spec in MODEL_PACKAGE_SPECS: + if spec.config_module_path is None: + continue + assert spec.config_module_path + assert spec.config_class_name -def test_kimi_k25_arch_alias_in_model_arch_mapping(): +def test_kimi_k25_arch_alias_in_model_package_specs(): """KimiK25ForConditionalGeneration (checkpoint arch) must map to KimiK25VLForConditionalGeneration.""" - from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS - assert "KimiK25ForConditionalGeneration" in MODEL_ARCH_MAPPING, ( - "KimiK25ForConditionalGeneration missing from MODEL_ARCH_MAPPING. " + mapping = _model_arch_lookup(MODEL_PACKAGE_SPECS) + assert "KimiK25ForConditionalGeneration" in mapping, ( + "KimiK25ForConditionalGeneration missing from MODEL_PACKAGE_SPECS. " "Kimi-K2.5 checkpoints use this architecture name and need it mapped " "to KimiK25VLForConditionalGeneration." ) - module_path, cls_name = MODEL_ARCH_MAPPING["KimiK25ForConditionalGeneration"] - assert cls_name == "KimiK25VLForConditionalGeneration" + spec = mapping["KimiK25ForConditionalGeneration"] + assert spec.class_name == "KimiK25VLForConditionalGeneration" -def test_deepseek_v4_registered_in_arch_mapping(): - """DeepseekV4ForCausalLM must be registered in MODEL_ARCH_MAPPING.""" - from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING +def test_deepseek_v4_registered_in_model_package_specs(): + """DeepseekV4ForCausalLM must be registered in MODEL_PACKAGE_SPECS.""" + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS - assert "DeepseekV4ForCausalLM" in MODEL_ARCH_MAPPING, ( - "DeepseekV4ForCausalLM missing from MODEL_ARCH_MAPPING. " + mapping = _model_arch_lookup(MODEL_PACKAGE_SPECS) + assert "DeepseekV4ForCausalLM" in mapping, ( + "DeepseekV4ForCausalLM missing from MODEL_PACKAGE_SPECS. " "DSV4 checkpoints declare this architecture and need it routed to the " "in-tree model implementation." ) - module_path, cls_name = MODEL_ARCH_MAPPING["DeepseekV4ForCausalLM"] - assert module_path == "nemo_automodel.components.models.deepseek_v4.model" - assert cls_name == "DeepseekV4ForCausalLM" + spec = mapping["DeepseekV4ForCausalLM"] + assert spec.module_path == "nemo_automodel.components.models.deepseek_v4.model" + assert spec.class_name == "DeepseekV4ForCausalLM" -def test_deepseek_v4_in_custom_config_registrations(): - """deepseek_v4 model_type must be registered in _CUSTOM_CONFIG_REGISTRATIONS.""" - from nemo_automodel._transformers.registry import _CUSTOM_CONFIG_REGISTRATIONS +def test_deepseek_v4_in_model_package_specs(): + """deepseek_v4 model_type must be declared in MODEL_PACKAGE_SPECS.""" + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS - assert "deepseek_v4" in _CUSTOM_CONFIG_REGISTRATIONS, ( - "deepseek_v4 must be in _CUSTOM_CONFIG_REGISTRATIONS so AutoConfig.from_pretrained " - "can resolve DSV4 configs without trust_remote_code=True." + specs_by_model_type = {model_type: spec for spec in MODEL_PACKAGE_SPECS for model_type in spec.model_types} + assert "deepseek_v4" in specs_by_model_type, ( + "deepseek_v4 must be in MODEL_PACKAGE_SPECS so ModelRegistry can register " + "DSV4 configs on demand before AutoConfig.from_pretrained runs." ) - module_path, cls_name = _CUSTOM_CONFIG_REGISTRATIONS["deepseek_v4"] - assert module_path == "nemo_automodel.components.models.deepseek_v4.config" - assert cls_name == "DeepseekV4Config" + spec = specs_by_model_type["deepseek_v4"] + assert spec.config_module_path == "nemo_automodel.components.models.deepseek_v4.config" + assert spec.config_class_name == "DeepseekV4Config" def test_all_model_folders_registered_in_auto_map(): - """Every model folder with a model.py must have at least one entry in MODEL_ARCH_MAPPING. + """Every model folder with a model.py must have at least one registry package-spec entry. This catches the case where a developer adds a new model directory under ``nemo_automodel/components/models/`` but forgets to add it to the static - ``MODEL_ARCH_MAPPING`` in ``registry.py``. + registry package specs in ``registry.py``. """ import pathlib - from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING + from nemo_automodel._transformers.registry.model_registry import ( + MODEL_PACKAGE_SPECS, + RETRIEVAL_MODEL_PACKAGE_SPECS, + ) models_root = pathlib.Path(__file__).resolve().parents[3] / "nemo_automodel" / "components" / "models" - # Collect the set of module paths referenced by the auto_map - registered_module_paths = {v[0] for v in MODEL_ARCH_MAPPING.values()} + # Collect the set of module paths referenced by the registries. + registered_module_paths = {spec.module_path for spec in (*MODEL_PACKAGE_SPECS, *RETRIEVAL_MODEL_PACKAGE_SPECS)} missing = [] + documentation_only_model_dirs = {"blueprint"} for model_dir in sorted(models_root.iterdir()): if not model_dir.is_dir() or model_dir.name.startswith(("_", ".")): continue + if model_dir.name in documentation_only_model_dirs: + continue model_file = model_dir / "model.py" if not model_file.exists(): continue @@ -333,6 +587,6 @@ def test_all_model_folders_registered_in_auto_map(): assert not missing, ( f"Model folder(s) {missing} contain a model.py but are not registered " - f"in MODEL_ARCH_MAPPING (registry.py). Add an entry for each architecture " + f"in the registry package specs (registry.py). Add an entry for each architecture " f"exported by these modules." ) diff --git a/tests/unit_tests/_transformers/test_registry_hy_v3.py b/tests/unit_tests/_transformers/test_registry_hy_v3.py index 9e733ccccb..3f2e9ca509 100644 --- a/tests/unit_tests/_transformers/test_registry_hy_v3.py +++ b/tests/unit_tests/_transformers/test_registry_hy_v3.py @@ -14,45 +14,50 @@ """Verify HYV3 model + config are registered in nemo_automodel._transformers.registry.""" -import pytest - class TestArchMapping: def test_hyv3_arch_registered(self): - from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING + from nemo_automodel._transformers.registry.base import _normalize_model_arch_mapping + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS - assert "HYV3ForCausalLM" in MODEL_ARCH_MAPPING + assert "HYV3ForCausalLM" in _normalize_model_arch_mapping(MODEL_PACKAGE_SPECS) def test_hyv3_arch_points_at_correct_module(self): - from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING + from nemo_automodel._transformers.registry.base import _normalize_model_arch_mapping + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS - entry = MODEL_ARCH_MAPPING["HYV3ForCausalLM"] - assert entry[0] == "nemo_automodel.components.models.hy_v3.model" - assert entry[1] == "HYV3ForCausalLM" + entry = _normalize_model_arch_mapping(MODEL_PACKAGE_SPECS)["HYV3ForCausalLM"] + assert entry.module_path == "nemo_automodel.components.models.hy_v3.model" + assert entry.class_name == "HYV3ForCausalLM" def test_hyv3_arch_resolves_to_class(self): """Walk the mapping path -- importable + the named class exists.""" import importlib - from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING + from nemo_automodel._transformers.registry.base import _normalize_model_arch_mapping + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS - mod_path, cls_name, *_ = MODEL_ARCH_MAPPING["HYV3ForCausalLM"] + spec = _normalize_model_arch_mapping(MODEL_PACKAGE_SPECS)["HYV3ForCausalLM"] + mod_path = spec.module_path + cls_name = spec.class_name mod = importlib.import_module(mod_path) assert hasattr(mod, cls_name) class TestCustomConfigRegistration: def test_hy_v3_config_registered(self): - from nemo_automodel._transformers.registry import _CUSTOM_CONFIG_REGISTRATIONS + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS - assert "hy_v3" in _CUSTOM_CONFIG_REGISTRATIONS + assert any("hy_v3" in spec.model_types for spec in MODEL_PACKAGE_SPECS) def test_hy_v3_config_resolves_to_class(self): import importlib - from nemo_automodel._transformers.registry import _CUSTOM_CONFIG_REGISTRATIONS + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS - mod_path, cls_name = _CUSTOM_CONFIG_REGISTRATIONS["hy_v3"] + spec = next(spec for spec in MODEL_PACKAGE_SPECS if "hy_v3" in spec.model_types) + mod_path = spec.config_module_path + cls_name = spec.config_class_name mod = importlib.import_module(mod_path) cls = getattr(mod, cls_name) assert cls.__name__ == "HYV3Config" diff --git a/tests/unit_tests/models/bi_encoder/test_llama_bidirectional_model.py b/tests/unit_tests/models/bi_encoder/test_llama_bidirectional_model.py index 365856408e..f81f449ca1 100644 --- a/tests/unit_tests/models/bi_encoder/test_llama_bidirectional_model.py +++ b/tests/unit_tests/models/bi_encoder/test_llama_bidirectional_model.py @@ -19,7 +19,7 @@ import torch.nn.functional as F from transformers.modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast -from nemo_automodel._transformers.registry import ModelRegistry +from nemo_automodel._transformers.registry import RetrievalModelRegistry from nemo_automodel._transformers.retrieval import ( BiEncoderModel, CrossEncoderModel, @@ -139,7 +139,6 @@ def test_bidirectional_attention_is_symmetric(): ) - # --- Fakes for classification and encoder tests --- class FakeOutputs: def __init__(self, last_hidden_state=None, hidden_states=None): @@ -234,9 +233,7 @@ def forward(self, input_ids=None, attention_mask=None, return_dict=True, output_ ) lm = NoTTIDLm(hidden=8) - model = BiEncoderModel( - model=lm, pooling="avg", l2_normalize=True - ) + model = BiEncoderModel(model=lm, pooling="avg", l2_normalize=True) # encode removes token_type_ids and normalizes q = { "input_ids": torch.ones(2, 3, dtype=torch.long), @@ -278,9 +275,7 @@ def forward(self, input_ids=None, attention_mask=None, return_dict=True, output_ return OnlyHiddenOutputs(hidden_states) # Test with model using NoLastLM for query encoder - model_no_last = BiEncoderModel( - model=NoLastLM(hidden=8), pooling="avg", l2_normalize=True - ) + model_no_last = BiEncoderModel(model=NoLastLM(hidden=8), pooling="avg", l2_normalize=True) v2 = model_no_last.encode( {"input_ids": torch.ones(2, 3, dtype=torch.long), "attention_mask": torch.ones(2, 3, dtype=torch.long)}, ) @@ -294,9 +289,13 @@ class FakeBidirectionalModel(FakeLM): def from_pretrained(cls, *args, **kwargs): return cls(hidden=16) - # Patch the registry to return our fake model - ModelRegistry.model_arch_name_to_cls["LlamaBidirectionalModel"] = FakeBidirectionalModel - monkeypatch.setattr(ModelRegistry, "model_arch_name_to_cls", ModelRegistry.model_arch_name_to_cls) + # Patch the retrieval registry to return our fake model + RetrievalModelRegistry.model_arch_name_to_cls["LlamaBidirectionalModel"] = FakeBidirectionalModel + monkeypatch.setattr( + RetrievalModelRegistry, + "model_arch_name_to_cls", + RetrievalModelRegistry.model_arch_name_to_cls, + ) # Directory path with config.json to hit config-reading branch model_dir = tmp_path / "model" @@ -380,9 +379,13 @@ class FakeBidirectionalModel(FakeLM): def from_pretrained(cls, *args, **kwargs): return cls(hidden=16) - # Patch the registry to return our fake model - ModelRegistry.model_arch_name_to_cls["LlamaBidirectionalModel"] = FakeBidirectionalModel - monkeypatch.setattr(ModelRegistry, "model_arch_name_to_cls", ModelRegistry.model_arch_name_to_cls) + # Patch the retrieval registry to return our fake model + RetrievalModelRegistry.model_arch_name_to_cls["LlamaBidirectionalModel"] = FakeBidirectionalModel + monkeypatch.setattr( + RetrievalModelRegistry, + "model_arch_name_to_cls", + RetrievalModelRegistry.model_arch_name_to_cls, + ) # Create a model directory whose path has no 'llama' substring model_dir = tmp_path / "scratch" / "job" / "model" @@ -415,9 +418,13 @@ class FakeBidirectionalModel(FakeLM): def from_pretrained(cls, *args, **kwargs): return cls(hidden=16) - # Patch the registry to return our fake model - ModelRegistry.model_arch_name_to_cls["LlamaBidirectionalModel"] = FakeBidirectionalModel - monkeypatch.setattr(ModelRegistry, "model_arch_name_to_cls", ModelRegistry.model_arch_name_to_cls) + # Patch the retrieval registry to return our fake model + RetrievalModelRegistry.model_arch_name_to_cls["LlamaBidirectionalModel"] = FakeBidirectionalModel + monkeypatch.setattr( + RetrievalModelRegistry, + "model_arch_name_to_cls", + RetrievalModelRegistry.model_arch_name_to_cls, + ) # Model type not in SUPPORTED_BACKBONES should fall back to AutoModel import nemo_automodel._transformers.retrieval as encoder_module @@ -535,10 +542,14 @@ def __init__(self): # Use a class name that is NOT a retrieval arch FakeModel.__name__ = "Qwen3Model" - FakeModel = type("Qwen3Model", (nn.Module,), { - "__init__": FakeModel.__init__, - "config": property(lambda self: self._config), - }) + FakeModel = type( + "Qwen3Model", + (nn.Module,), + { + "__init__": FakeModel.__init__, + "config": property(lambda self: self._config), + }, + ) fake = object.__new__(FakeModel) nn.Module.__init__(fake) fake._config = FakeCfg() diff --git a/tests/unit_tests/models/bi_encoder/test_ministral_bidirectional_model.py b/tests/unit_tests/models/bi_encoder/test_ministral_bidirectional_model.py index 92b9ded41d..399f866a5f 100644 --- a/tests/unit_tests/models/bi_encoder/test_ministral_bidirectional_model.py +++ b/tests/unit_tests/models/bi_encoder/test_ministral_bidirectional_model.py @@ -22,8 +22,8 @@ pytest.importorskip("transformers.models.ministral3", reason="Ministral3 not available in this transformers version") -from nemo_automodel._transformers.registry import ModelRegistry -from nemo_automodel._transformers.retrieval import BiEncoderModel, configure_encoder_metadata, _init_encoder_common +from nemo_automodel._transformers.registry import RetrievalModelRegistry +from nemo_automodel._transformers.retrieval import BiEncoderModel, _init_encoder_common, configure_encoder_metadata from nemo_automodel.components.models.ministral_bidirectional.model import ( Ministral3BidirectionalConfig, Ministral3BidirectionalModel, @@ -142,8 +142,12 @@ class FakeBidirectionalModel(FakeLM): def from_pretrained(cls, *args, **kwargs): return cls(hidden=16) - ModelRegistry.model_arch_name_to_cls["Ministral3BidirectionalModel"] = FakeBidirectionalModel - monkeypatch.setattr(ModelRegistry, "model_arch_name_to_cls", ModelRegistry.model_arch_name_to_cls) + RetrievalModelRegistry.model_arch_name_to_cls["Ministral3BidirectionalModel"] = FakeBidirectionalModel + monkeypatch.setattr( + RetrievalModelRegistry, + "model_arch_name_to_cls", + RetrievalModelRegistry.model_arch_name_to_cls, + ) model_dir = tmp_path / "model" model_dir.mkdir() @@ -164,13 +168,18 @@ def from_pretrained(cls, *args, **kwargs): @pytest.mark.parametrize("top_level_model_type", ["ministral3", "ministral3_bidirec"]) def test_encoder_build_ministral_supported_model_types(tmp_path, monkeypatch, top_level_model_type): """Hub / local text configs use ministral3; saved bidirectional checkpoints use ministral3_bidirec.""" + class FakeBidirectionalModel(FakeLM): @classmethod def from_pretrained(cls, *args, **kwargs): return cls(hidden=16) - ModelRegistry.model_arch_name_to_cls["Ministral3BidirectionalModel"] = FakeBidirectionalModel - monkeypatch.setattr(ModelRegistry, "model_arch_name_to_cls", ModelRegistry.model_arch_name_to_cls) + RetrievalModelRegistry.model_arch_name_to_cls["Ministral3BidirectionalModel"] = FakeBidirectionalModel + monkeypatch.setattr( + RetrievalModelRegistry, + "model_arch_name_to_cls", + RetrievalModelRegistry.model_arch_name_to_cls, + ) model_dir = tmp_path / "hub" / "checkpoint" model_dir.mkdir(parents=True) diff --git a/tests/unit_tests/models/mistral4/test_mistral4_model.py b/tests/unit_tests/models/mistral4/test_mistral4_model.py index 2b39879108..922a7072df 100644 --- a/tests/unit_tests/models/mistral4/test_mistral4_model.py +++ b/tests/unit_tests/models/mistral4/test_mistral4_model.py @@ -920,17 +920,22 @@ def test_forward_raises_on_pp_stage_without_embeds(self, text_config, backend, d class TestRegistry: def test_mistral4_in_registry(self): - from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING + from nemo_automodel._transformers.registry.base import _normalize_model_arch_mapping + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS - assert "Mistral4ForCausalLM" in dict(MODEL_ARCH_MAPPING) - assert "Mistral3ForConditionalGeneration" in dict(MODEL_ARCH_MAPPING) + mapping = _normalize_model_arch_mapping(MODEL_PACKAGE_SPECS) + assert "Mistral4ForCausalLM" in mapping + assert "Mistral3ForConditionalGeneration" in mapping def test_registry_module_path(self): - from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING + from nemo_automodel._transformers.registry.base import _normalize_model_arch_mapping + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS - mapping = dict(MODEL_ARCH_MAPPING) - assert mapping["Mistral4ForCausalLM"][0] == "nemo_automodel.components.models.mistral4.model" - assert mapping["Mistral3ForConditionalGeneration"][0] == "nemo_automodel.components.models.mistral4.model" + mapping = _normalize_model_arch_mapping(MODEL_PACKAGE_SPECS) + assert mapping["Mistral4ForCausalLM"].module_path == "nemo_automodel.components.models.mistral4.model" + assert ( + mapping["Mistral3ForConditionalGeneration"].module_path == "nemo_automodel.components.models.mistral4.model" + ) # --------------------------------------------------------------------------- diff --git a/tests/unit_tests/models/nemotron_omni/test_nemotron_omni_layers.py b/tests/unit_tests/models/nemotron_omni/test_nemotron_omni_layers.py index fb2412747b..2832d69ce1 100644 --- a/tests/unit_tests/models/nemotron_omni/test_nemotron_omni_layers.py +++ b/tests/unit_tests/models/nemotron_omni/test_nemotron_omni_layers.py @@ -16,7 +16,7 @@ These tests exercise the small, dependency-free building blocks of the NemotronOmni wrapper (activation, norm, projectors, config) plus the -architecture registration that lets ``MODEL_ARCH_MAPPING`` resolve the v3 dump. +architecture registration that lets ``MODEL_PACKAGE_SPECS`` resolve the v3 dump. The heavy ``NemotronOmniForConditionalGeneration`` end-to-end forward path requires a full HF v3 checkpoint and is covered by the functional/integration suites — not here. @@ -164,18 +164,20 @@ def test_nemotron_omni_config_overrides_propagate(): def test_registry_entry_present(): - """``MODEL_ARCH_MAPPING`` should resolve the v3 architecture name to our wrapper.""" - from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING + """``MODEL_PACKAGE_SPECS`` should resolve the v3 architecture name to our wrapper.""" + from nemo_automodel._transformers.registry.base import _normalize_model_arch_mapping + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS - mapping = dict(MODEL_ARCH_MAPPING) + mapping = _normalize_model_arch_mapping(MODEL_PACKAGE_SPECS) assert "NemotronH_Nano_Omni_Reasoning_V3" in mapping - module_path, class_name, *_ = mapping["NemotronH_Nano_Omni_Reasoning_V3"] - assert module_path == "nemo_automodel.components.models.nemotron_omni.model" - assert class_name == "NemotronOmniForConditionalGeneration" + spec = mapping["NemotronH_Nano_Omni_Reasoning_V3"] + assert spec.module_path == "nemo_automodel.components.models.nemotron_omni.model" + assert spec.class_name == "NemotronOmniForConditionalGeneration" def test_registry_v2_entry_removed(): """V2 dispatch was deleted along with V2 dump support — keep it gone.""" - from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING + from nemo_automodel._transformers.registry.base import _normalize_model_arch_mapping + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS - assert "NemotronH_Nano_VL_V2" not in dict(MODEL_ARCH_MAPPING) + assert "NemotronH_Nano_VL_V2" not in _normalize_model_arch_mapping(MODEL_PACKAGE_SPECS)