From a38a5b4677515d0b5d8b4ea20ce69a125536520d Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Mon, 18 May 2026 18:56:28 -0700 Subject: [PATCH 1/5] refactor(registry): split model registry package --- nemo_automodel/_transformers/model_init.py | 21 + nemo_automodel/_transformers/registry.py | 366 ---------------- .../_transformers/registry/__init__.py | 50 +++ nemo_automodel/_transformers/registry/base.py | 362 ++++++++++++++++ .../registry/model_package_spec.py | 106 +++++ .../_transformers/registry/model_registry.py | 390 ++++++++++++++++++ .../_transformers/test_model_init.py | 17 + .../unit_tests/_transformers/test_registry.py | 251 +++++++++-- .../_transformers/test_registry_hy_v3.py | 20 +- 9 files changed, 1170 insertions(+), 413 deletions(-) delete mode 100644 nemo_automodel/_transformers/registry.py create mode 100644 nemo_automodel/_transformers/registry/__init__.py create mode 100644 nemo_automodel/_transformers/registry/base.py create mode 100644 nemo_automodel/_transformers/registry/model_package_spec.py create mode 100644 nemo_automodel/_transformers/registry/model_registry.py diff --git a/nemo_automodel/_transformers/model_init.py b/nemo_automodel/_transformers/model_init.py index 09b40393e7..f66b8e2b9e 100644 --- a/nemo_automodel/_transformers/model_init.py +++ b/nemo_automodel/_transformers/model_init.py @@ -220,6 +220,22 @@ def _resolve_custom_model_cls_for_config(config): return ModelRegistry.resolve_custom_model_cls(arch_name, config) +def _ensure_config_registered_from_config_dict(pretrained_model_name_or_path, **kwargs) -> None: + """Register a matching local config class before delegating to ``AutoConfig``.""" + config_lookup_kwargs = kwargs.copy() + config_lookup_kwargs["_from_auto"] = True + config_lookup_kwargs["name_or_path"] = pretrained_model_name_or_path + try: + config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **config_lookup_kwargs) + except Exception: + logger.debug("Could not inspect config metadata for %s", pretrained_model_name_or_path, exc_info=True) + return + + model_type = config_dict.get("model_type") + if isinstance(model_type, str): + ModelRegistry.ensure_config_registered(model_type) + + def get_hf_config(pretrained_model_name_or_path, attn_implementation, **kwargs): """ Get the HF config for the model. @@ -233,6 +249,11 @@ def get_hf_config(pretrained_model_name_or_path, attn_implementation, **kwargs): # with incomplete dicts, losing all other fields. These nested overrides are # instead handled by _consume_config_overrides which deep-merges them. nested_kwargs = {k: kwargs.pop(k) for k in list(kwargs) if isinstance(kwargs[k], dict)} # noqa: F841 + _ensure_config_registered_from_config_dict( + pretrained_model_name_or_path, + **kwargs, + attn_implementation=attn_implementation, + ) try: hf_config = AutoConfig.from_pretrained( pretrained_model_name_or_path, diff --git a/nemo_automodel/_transformers/registry.py b/nemo_automodel/_transformers/registry.py deleted file mode 100644 index 47fd6cecca..0000000000 --- a/nemo_automodel/_transformers/registry.py +++ /dev/null @@ -1,366 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import importlib -import logging -from collections import OrderedDict -from dataclasses import dataclass, field -from functools import lru_cache -from typing import Dict, Tuple, Type, Union - -import torch.nn as nn - -logger = logging.getLogger(__name__) - -# Static mapping: architecture name → (module_path, class_name[, tags]). -# Analogous to HuggingFace transformers' MODEL_FOR_CAUSAL_LM_MAPPING_NAMES. -# Models are loaded lazily on first access rather than imported at startup. -# Optional third element is a set of tags (e.g. {"retrieval"}) used by -# downstream code to classify model archs without importing them. -MODEL_ARCH_MAPPING = OrderedDict( - [ - ( - "BaichuanForCausalLM", - ("nemo_automodel.components.models.baichuan.model", "BaichuanForCausalLM"), - ), - ( - "DeepseekV3ForCausalLM", - ("nemo_automodel.components.models.deepseek_v3.model", "DeepseekV3ForCausalLM"), - ), - ( - "DeepseekV32ForCausalLM", - ("nemo_automodel.components.models.deepseek_v32.model", "DeepseekV32ForCausalLM"), - ), - ( - "DeepseekV4ForCausalLM", - ("nemo_automodel.components.models.deepseek_v4.model", "DeepseekV4ForCausalLM"), - ), - ( - "Ernie4_5_MoeForCausalLM", - ("nemo_automodel.components.models.ernie4_5.model", "Ernie4_5_MoeForCausalLM"), - ), - ( - "Glm4MoeForCausalLM", - ("nemo_automodel.components.models.glm4_moe.model", "Glm4MoeForCausalLM"), - ), - ( - "Glm4MoeLiteForCausalLM", - ("nemo_automodel.components.models.glm4_moe_lite.model", "Glm4MoeLiteForCausalLM"), - ), - ( - "GlmMoeDsaForCausalLM", - ("nemo_automodel.components.models.glm_moe_dsa.model", "GlmMoeDsaForCausalLM"), - ), - ( - "Gemma4ForConditionalGeneration", - ("nemo_automodel.components.models.gemma4_moe.model", "Gemma4ForConditionalGeneration"), - ), - ( - "GptOssForCausalLM", - ("nemo_automodel.components.models.gpt_oss.model", "GptOssForCausalLM"), - ), - ( - "KimiK25ForConditionalGeneration", - ("nemo_automodel.components.models.kimi_k25_vl.model", "KimiK25VLForConditionalGeneration"), - ), - ( - "KimiK25VLForConditionalGeneration", - ("nemo_automodel.components.models.kimi_k25_vl.model", "KimiK25VLForConditionalGeneration"), - ), - ( - "KimiVLForConditionalGeneration", - ("nemo_automodel.components.models.kimivl.model", "KimiVLForConditionalGeneration"), - ), - ( - "LlamaBidirectionalForSequenceClassification", - ( - "nemo_automodel.components.models.llama_bidirectional.model", - "LlamaBidirectionalForSequenceClassification", - {"retrieval"}, - ), - ), - ( - "LlamaBidirectionalModel", - ("nemo_automodel.components.models.llama_bidirectional.model", "LlamaBidirectionalModel", {"retrieval"}), - ), - ( - "LlamaForCausalLM", - ("nemo_automodel.components.models.llama.model", "LlamaForCausalLM"), - ), - ( - "MiniMaxM2ForCausalLM", - ("nemo_automodel.components.models.minimax_m2.model", "MiniMaxM2ForCausalLM"), - ), - ( - "MiMoV2FlashForCausalLM", - ("nemo_automodel.components.models.mimo_v2_flash.model", "MiMoV2FlashForCausalLM"), - ), - ( - "Ministral3ForCausalLM", - ("nemo_automodel.components.models.mistral3.model", "Ministral3ForCausalLM"), - ), - ( - "Ministral3BidirectionalModel", - ( - "nemo_automodel.components.models.ministral_bidirectional.model", - "Ministral3BidirectionalModel", - {"retrieval"}, - ), - ), - ( - "Mistral4ForCausalLM", - ("nemo_automodel.components.models.mistral4.model", "Mistral4ForCausalLM"), - ), - ( - "Mistral3ForConditionalGeneration", - ("nemo_automodel.components.models.mistral4.model", "Mistral3ForConditionalGeneration"), - ), - ( - "Mistral3FP8VLMForConditionalGeneration", - ( - "nemo_automodel.components.models.mistral3_vlm.model", - "Mistral3FP8VLMForConditionalGeneration", - ), - ), - ( - "NemotronHForCausalLM", - ("nemo_automodel.components.models.nemotron_v3.model", "NemotronHForCausalLM"), - ), - ( - "NemotronH_Nano_Omni_Reasoning_V3", - ( - "nemo_automodel.components.models.nemotron_omni.model", - "NemotronOmniForConditionalGeneration", - ), - ), - ( - "NemotronParseForConditionalGeneration", - ("nemo_automodel.components.models.nemotron_parse.model", "NemotronParseForConditionalGeneration"), - ), - ( - "LLaVAOneVision1_5_ForConditionalGeneration", - ( - "nemo_automodel.components.models.llava_onevision.model", - "LLaVAOneVision1_5_ForConditionalGeneration", - ), - ), - ( - "HYV3ForCausalLM", - ("nemo_automodel.components.models.hy_v3.model", "HYV3ForCausalLM"), - ), - ( - "Qwen2ForCausalLM", - ("nemo_automodel.components.models.qwen2.model", "Qwen2ForCausalLM"), - ), - ( - "Qwen3MoeForCausalLM", - ("nemo_automodel.components.models.qwen3_moe.model", "Qwen3MoeForCausalLM"), - ), - ( - "Qwen3NextForCausalLM", - ("nemo_automodel.components.models.qwen3_next.model", "Qwen3NextForCausalLM"), - ), - ( - "Qwen3OmniMoeForConditionalGeneration", - ( - "nemo_automodel.components.models.qwen3_omni_moe.model", - "Qwen3OmniMoeThinkerForConditionalGeneration", - ), - ), - ( - "Qwen3VLMoeForConditionalGeneration", - ("nemo_automodel.components.models.qwen3_vl_moe.model", "Qwen3VLMoeForConditionalGeneration"), - ), - ( - "Qwen3_5MoeForConditionalGeneration", - ("nemo_automodel.components.models.qwen3_5_moe.model", "Qwen3_5MoeForConditionalGeneration"), - ), - ( - "Step3p5ForCausalLM", - ("nemo_automodel.components.models.step3p5.model", "Step3p5ForCausalLM"), - ), - ] -) - - -# Custom model_type → config class for models that have auto_map in their -# checkpoint config.json. Registered eagerly with AutoConfig so that -# AutoConfig.from_pretrained can resolve them without trust_remote_code. -_CUSTOM_CONFIG_REGISTRATIONS: Dict[str, Tuple[str, str]] = { - "baichuan": ("nemo_automodel.components.models.baichuan.configuration", "BaichuanConfig"), - "deepseek_v4": ("nemo_automodel.components.models.deepseek_v4.config", "DeepseekV4Config"), - "hy_v3": ("nemo_automodel.components.models.hy_v3.config", "HYV3Config"), - "kimi_k25": ("nemo_automodel.components.models.kimi_k25_vl.model", "KimiK25VLConfig"), - "kimi_vl": ("nemo_automodel.components.models.kimivl.model", "KimiVLConfig"), - "llavaonevision1_5": ("nemo_automodel.components.models.llava_onevision.model", "Llavaonevision1_5Config"), - "mimo_v2_flash": ("nemo_automodel.components.models.mimo_v2_flash.config", "MiMoV2FlashConfig"), - "mistral4": ("nemo_automodel.components.models.mistral4.configuration", "Mistral4Config"), -} - - -def _register_custom_configs() -> None: - from transformers import AutoConfig - from transformers.models.auto.configuration_auto import CONFIG_MAPPING - - for model_type, (module_path, cls_name) in _CUSTOM_CONFIG_REGISTRATIONS.items(): - if model_type not in CONFIG_MAPPING: - try: - mod = importlib.import_module(module_path) - cfg_cls = getattr(mod, cls_name) - AutoConfig.register(model_type, cfg_cls) - except Exception: - logger.debug("Failed to register config for model_type=%s", model_type, exc_info=True) - - -_register_custom_configs() - - -class _LazyArchMapping: - """Lazy-loading mapping from architecture name to model class. - - Inspired by HuggingFace transformers' ``_LazyAutoMapping``. Entries from the - static ``auto_map`` are imported on first access and cached. Additional entries - can be added at runtime via ``register``. - """ - - def __init__(self, auto_map: Union[OrderedDict, Dict[str, tuple], None] = None): - # Entries may be (module_path, class_name) or (module_path, class_name, tags). - # Strip the optional tags and store them separately. - self._auto_map: Dict[str, Tuple[str, str]] = OrderedDict() - self._tags: Dict[str, set] = {} - for key, value in (auto_map or {}).items(): - self._auto_map[key] = (value[0], value[1]) - if len(value) > 2: - self._tags[key] = value[2] - self._loaded: Dict[str, Type[nn.Module]] = {} - self._extra: Dict[str, Type[nn.Module]] = {} - self._modules: Dict[str, object] = {} - - def _load(self, key: str) -> Type[nn.Module]: - if key in self._loaded: - return self._loaded[key] - module_path, class_name = self._auto_map[key] - if module_path not in self._modules: - self._modules[module_path] = importlib.import_module(module_path) - cls = getattr(self._modules[module_path], class_name) - self._loaded[key] = cls - return cls - - def __contains__(self, key: str) -> bool: - if key in self._extra or key in self._loaded: - return True - if key not in self._auto_map: - return False - try: - self._load(key) - return True - except Exception: - logger.debug("Model %s unavailable (import failed), removing from auto_map", key) - self._auto_map.pop(key, None) - return False - - def __getitem__(self, key: str) -> Type[nn.Module]: - if key in self._extra: - return self._extra[key] - if key in self._auto_map: - return self._load(key) - raise KeyError(key) - - def __setitem__(self, key: str, value: Type[nn.Module]) -> None: - self._extra[key] = value - - def register(self, key: str, value: Type[nn.Module], exist_ok: bool = False) -> None: - """Register a model class under the given architecture name.""" - if not exist_ok and key in self._extra: - raise ValueError(f"Duplicated model implementation for {key}") - self._extra[key] = value - - def has_tag(self, key: str, tag: str) -> bool: - """Return ``True`` if *key* was registered with *tag*.""" - return tag in self._tags.get(key, set()) - - def keys_with_tag(self, tag: str) -> set: - """Return all architecture names that have *tag*.""" - return {k for k, tags in self._tags.items() if tag in tags} - - def keys(self): - return set(self._auto_map.keys()) | set(self._extra.keys()) - - def __len__(self) -> int: - return len(self.keys()) - - def __repr__(self) -> str: - return f"_LazyArchMapping(auto_map={len(self._auto_map)}, extra={len(self._extra)}, loaded={len(self._loaded)})" - - -@dataclass -class _ModelRegistry: - model_arch_name_to_cls: _LazyArchMapping = field(default=None) - _retrieval_archs: set = field(default_factory=set) - - def __post_init__(self): - if self.model_arch_name_to_cls is None: - self.model_arch_name_to_cls = _LazyArchMapping(MODEL_ARCH_MAPPING) - self._retrieval_archs = self.model_arch_name_to_cls.keys_with_tag("retrieval") - - @property - def supported_models(self): - return self.model_arch_name_to_cls.keys() - - def get_model_cls_from_model_arch(self, model_arch: str) -> Type[nn.Module]: - return self.model_arch_name_to_cls[model_arch] - - def has_custom_model(self, arch_name: str) -> bool: - """Return ``True`` if *arch_name* has a custom (non-HF) implementation.""" - return arch_name in self.model_arch_name_to_cls - - def has_retrieval_model(self, arch_name: str) -> bool: - """Return ``True`` if *arch_name* is a registered retrieval/encoder architecture.""" - return arch_name in self._retrieval_archs - - def register_retrieval(self, arch_name: str) -> None: - """Mark *arch_name* as a retrieval/encoder architecture.""" - self._retrieval_archs.add(arch_name) - - def resolve_custom_model_cls(self, architecture: str, config) -> Union[Type[nn.Module], None]: - """Return the custom model class if it exists and supports *config*, else ``None``. - - Custom model classes may define a ``supports_config(config)`` classmethod - to opt out for specific HF configs (e.g. a Mistral3 VLM with a dense - Ministral3 text backbone instead of the expected Mistral4 MoE+MLA). - """ - if architecture not in self.model_arch_name_to_cls: - return None - model_cls = self.model_arch_name_to_cls[architecture] - if hasattr(model_cls, "supports_config") and not model_cls.supports_config(config): - logger.info( - "Custom model %s does not support config %s, falling back to HF", - model_cls.__name__, - type(config).__name__, - ) - return None - return model_cls - - def register(self, arch_name: str, model_cls: Type[nn.Module], exist_ok: bool = False) -> None: - """Register a custom model class for a given architecture name.""" - self.model_arch_name_to_cls.register(arch_name, model_cls, exist_ok=exist_ok) - - -@lru_cache -def get_registry(): - return _ModelRegistry() - - -ModelRegistry = get_registry() diff --git a/nemo_automodel/_transformers/registry/__init__.py b/nemo_automodel/_transformers/registry/__init__.py new file mode 100644 index 0000000000..8fe4f60160 --- /dev/null +++ b/nemo_automodel/_transformers/registry/__init__.py @@ -0,0 +1,50 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Public exports for the model registry package.""" + +from nemo_automodel._transformers.registry.base import _LazyArchMapping +from nemo_automodel._transformers.registry.base import _ModelRegistry as _BaseModelRegistry +from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec + +_MODEL_REGISTRY_EXPORTS = { + "MODEL_ARCH_MAPPING", + "_CUSTOM_CONFIG_SPECS", + "MODEL_PACKAGE_SPECS", + "ModelRegistry", + "_ModelRegistry", + "get_registry", +} + + +def __getattr__(name: str): + """Lazily expose model-specific registry globals without import cycles.""" + if name not in _MODEL_REGISTRY_EXPORTS: + raise AttributeError(name) + from nemo_automodel._transformers.registry import model_registry + + return getattr(model_registry, name) + + +__all__ = [ + "MODEL_ARCH_MAPPING", + "_CUSTOM_CONFIG_SPECS", + "MODEL_PACKAGE_SPECS", + "ModelRegistry", + "ModelPackageSpec", + "_BaseModelRegistry", + "_LazyArchMapping", + "_ModelRegistry", + "get_registry", +] diff --git a/nemo_automodel/_transformers/registry/base.py b/nemo_automodel/_transformers/registry/base.py new file mode 100644 index 0000000000..0cea2efe0d --- /dev/null +++ b/nemo_automodel/_transformers/registry/base.py @@ -0,0 +1,362 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import importlib +import inspect +import logging +from collections import OrderedDict +from dataclasses import dataclass, field +from types import ModuleType + +import torch.nn as nn + +from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec + +logger = logging.getLogger(__name__) + + +class _LazyArchMapping: + """Lazy-loading mapping from architecture name to model class. + + Inspired by HuggingFace transformers' ``_LazyAutoMapping``. Entries from the + static ``ModelPackageSpec`` mapping are imported on first access and cached. Additional entries + can be added at runtime via ``register``. + """ + + def __init__(self, auto_map: OrderedDict[str, ModelPackageSpec] | dict[str, ModelPackageSpec] | None = None): + self._specs: dict[str, ModelPackageSpec] = OrderedDict() + self._tags: dict[str, set] = {} + for key, spec in (auto_map or {}).items(): + if spec.class_name is None: + raise ValueError(f"Model architecture entry {key!r} must include a class name") + spec = spec.with_architecture(key) + self._specs[key] = spec + if spec.tags: + self._tags[key] = set(spec.tags) + self._loaded: dict[str, type[nn.Module]] = {} + self._extra: dict[str, type[nn.Module]] = {} + self._extra_specs: dict[str, ModelPackageSpec] = {} + self._modules: dict[str, object] = {} + + def _load(self, key: str) -> type[nn.Module]: + if key in self._loaded: + return self._loaded[key] + spec = self._specs[key] + module_path = spec.module_path + class_name = spec.class_name + if module_path not in self._modules: + self._modules[module_path] = importlib.import_module(module_path) + cls = getattr(self._modules[module_path], class_name) + self._loaded[key] = cls + return cls + + def __contains__(self, key: str) -> bool: + if key in self._extra or key in self._loaded: + return True + if key not in self._specs: + return False + try: + self._load(key) + return True + except Exception: + logger.debug("Model %s unavailable (import failed), removing from registry specs", key) + self._specs.pop(key, None) + self._tags.pop(key, None) + return False + + def __getitem__(self, key: str) -> type[nn.Module]: + if key in self._extra: + return self._extra[key] + if key in self._specs: + return self._load(key) + raise KeyError(key) + + def __setitem__(self, key: str, value: type[nn.Module]) -> None: + self._extra[key] = value + self._extra_specs[key] = ModelPackageSpec.from_module_path(value.__module__, value.__name__).with_architecture( + key + ) + + def register(self, key: str, value: type[nn.Module], exist_ok: bool = False) -> None: + """Register a model class under the given architecture name.""" + if not exist_ok and key in self._extra: + raise ValueError(f"Duplicated model implementation for {key}") + self._extra[key] = value + self._extra_specs[key] = ModelPackageSpec.from_module_path(value.__module__, value.__name__).with_architecture( + key + ) + + def has_tag(self, key: str, tag: str) -> bool: + """Return ``True`` if *key* was registered with *tag*.""" + return tag in self._tags.get(key, set()) + + def keys_with_tag(self, tag: str) -> set: + """Return all architecture names that have *tag*.""" + return {k for k, tags in self._tags.items() if tag in tags} + + def keys(self): + return set(self._specs.keys()) | set(self._extra.keys()) + + def get_spec(self, key: str) -> ModelPackageSpec | None: + """Return package metadata for *key* without importing the model class.""" + if key in self._extra_specs: + return self._extra_specs[key] + return self._specs.get(key) + + def specs(self) -> tuple[ModelPackageSpec, ...]: + """Return metadata for all statically and dynamically registered model classes.""" + return (*self._specs.values(), *self._extra_specs.values()) + + def __len__(self) -> int: + return len(self.keys()) + + def __repr__(self) -> str: + return f"_LazyArchMapping(specs={len(self._specs)}, extra={len(self._extra)}, loaded={len(self._loaded)})" + + +@dataclass +class _ModelRegistry: + model_arch_mapping: OrderedDict[str, ModelPackageSpec] | dict[str, ModelPackageSpec] | None = None + model_arch_name_to_cls: _LazyArchMapping = field(default=None) + package_specs: tuple[ModelPackageSpec, ...] = () + _retrieval_archs: set = field(default_factory=set) + _architecture_to_specs: dict[str, tuple[ModelPackageSpec, ...]] = field(default_factory=dict) + _model_type_to_specs: dict[str, tuple[ModelPackageSpec, ...]] = field(default_factory=dict) + + def __post_init__(self): + if self.model_arch_name_to_cls is None: + self.model_arch_name_to_cls = _LazyArchMapping(self.model_arch_mapping or {}) + self._retrieval_archs = self.model_arch_name_to_cls.keys_with_tag("retrieval") + self._rebuild_spec_indexes() + + def _iter_specs(self) -> tuple[ModelPackageSpec, ...]: + return (*self.model_arch_name_to_cls.specs(), *self.package_specs) + + def _rebuild_spec_indexes(self) -> None: + architecture_to_specs: dict[str, list[ModelPackageSpec]] = {} + model_type_to_specs: dict[str, list[ModelPackageSpec]] = {} + for spec in self._iter_specs(): + for architecture in spec.architectures: + architecture_to_specs.setdefault(architecture, []).append(spec) + for model_type in spec.model_types: + model_type_to_specs.setdefault(model_type, []).append(spec) + self._architecture_to_specs = { + architecture: tuple(specs) for architecture, specs in architecture_to_specs.items() + } + self._model_type_to_specs = { + model_type: tuple(self._dedupe_specs(specs)) for model_type, specs in model_type_to_specs.items() + } + + @staticmethod + def _dedupe_specs(specs: list[ModelPackageSpec] | tuple[ModelPackageSpec, ...]) -> tuple[ModelPackageSpec, ...]: + seen: set[str] = set() + deduped: list[ModelPackageSpec] = [] + for spec in specs: + if spec.package in seen: + continue + seen.add(spec.package) + deduped.append(spec) + return tuple(deduped) + + @staticmethod + def _import_optional_module(spec: ModelPackageSpec, module_name: str) -> ModuleType | None: + if module_name not in spec.optional_modules: + return None + module_path = spec.optional_module_path(module_name) + try: + return importlib.import_module(module_path) + except ImportError: + logger.debug("Optional model module is unavailable: %s", module_path) + return None + + @property + def supported_models(self): + return self.model_arch_name_to_cls.keys() + + def get_model_cls_from_model_arch(self, model_arch: str) -> type[nn.Module]: + return self.model_arch_name_to_cls[model_arch] + + def has_custom_model(self, arch_name: str) -> bool: + """Return ``True`` if *arch_name* has a custom (non-HF) implementation.""" + return arch_name in self.model_arch_name_to_cls + + def has_retrieval_model(self, arch_name: str) -> bool: + """Return ``True`` if *arch_name* is a registered retrieval/encoder architecture.""" + return arch_name in self._retrieval_archs + + def register_retrieval(self, arch_name: str) -> None: + """Mark *arch_name* as a retrieval/encoder architecture.""" + self._retrieval_archs.add(arch_name) + + def get_model_package_spec(self, architecture: str) -> ModelPackageSpec | None: + """Return package metadata for an architecture without importing ``model.py``.""" + specs = self.get_model_package_specs_for_architecture(architecture) + if specs: + return specs[0] + return self.model_arch_name_to_cls.get_spec(architecture) + + def get_model_package_specs_for_architecture(self, architecture: str) -> tuple[ModelPackageSpec, ...]: + """Return all package metadata entries that declare *architecture*.""" + return self._architecture_to_specs.get(architecture, ()) + + def get_model_package_specs_for_model_type(self, model_type: str) -> tuple[ModelPackageSpec, ...]: + """Return package metadata entries that declare *model_type*.""" + return self._model_type_to_specs.get(model_type, ()) + + def get_optional_module_for_architecture(self, architecture: str, module_name: str) -> ModuleType | None: + """Import ``.`` for an architecture if declared.""" + modules = self.iter_optional_modules_for_architectures((architecture,), module_name) + return modules[0] if modules else None + + def iter_optional_modules_for_architectures( + self, architectures: tuple[str, ...] | list[str], module_name: str + ) -> tuple[ModuleType, ...]: + """Import convention modules for the requested architectures only.""" + modules: list[ModuleType] = [] + seen_paths: set[str] = set() + for architecture in architectures: + for spec in self.get_model_package_specs_for_architecture(architecture): + if module_name not in spec.optional_modules: + continue + module_path = spec.optional_module_path(module_name) + if module_path in seen_paths: + continue + module = self._import_optional_module(spec, module_name) + if module is not None: + seen_paths.add(module_path) + modules.append(module) + return tuple(modules) + + def iter_optional_modules_for_model_type(self, model_type: str, module_name: str) -> tuple[ModuleType, ...]: + """Import convention modules for packages that declare *model_type*.""" + modules: list[ModuleType] = [] + for spec in self.get_model_package_specs_for_model_type(model_type): + module = self._import_optional_module(spec, module_name) + if module is not None: + modules.append(module) + return tuple(modules) + + def iter_optional_modules( + self, + module_name: str, + *, + global_patches: bool | None = None, + pre_config_patches: bool | None = None, + post_shard_patches: bool | None = None, + tokenizer_registrations: bool | None = None, + ) -> tuple[ModuleType, ...]: + """Import declared convention modules matching the provided metadata filters.""" + modules: list[ModuleType] = [] + seen_paths: set[str] = set() + for spec in self._iter_specs(): + if module_name not in spec.optional_modules: + continue + if global_patches is not None and spec.global_patches is not global_patches: + continue + if pre_config_patches is not None and spec.pre_config_patches is not pre_config_patches: + continue + if post_shard_patches is not None and spec.post_shard_patches is not post_shard_patches: + continue + if tokenizer_registrations is not None and spec.tokenizer_registrations is not tokenizer_registrations: + continue + module_path = spec.optional_module_path(module_name) + if module_path in seen_paths: + continue + module = self._import_optional_module(spec, module_name) + if module is not None: + seen_paths.add(module_path) + modules.append(module) + return tuple(modules) + + @staticmethod + def _iter_config_classes(module: ModuleType): + """Yield local ``PretrainedConfig`` subclasses declared by *module*.""" + from transformers import PretrainedConfig + + for _, cls in inspect.getmembers(module, inspect.isclass): + if cls is PretrainedConfig: + continue + if cls.__module__ != module.__name__: + continue + if not issubclass(cls, PretrainedConfig): + continue + if not isinstance(getattr(cls, "model_type", None), str): + continue + yield cls + + def _resolve_config_class(self, spec: ModelPackageSpec, model_type: str): + module_path = spec.config_module_path + if module_path is None: + return None + + try: + module = importlib.import_module(module_path) + except ImportError: + logger.debug("Config module is unavailable: %s", module_path) + return None + + if spec.config_class_name is not None: + return getattr(module, spec.config_class_name, None) + + for config_cls in self._iter_config_classes(module): + if config_cls.model_type == model_type: + return config_cls + return None + + def ensure_config_registered(self, model_type: str) -> bool: + """Register the local config class for *model_type* with HuggingFace, if declared.""" + from transformers import AutoConfig + from transformers.models.auto.configuration_auto import CONFIG_MAPPING + + if model_type in CONFIG_MAPPING: + return True + + for spec in self.get_model_package_specs_for_model_type(model_type): + config_cls = self._resolve_config_class(spec, model_type) + if config_cls is None: + continue + try: + AutoConfig.register(model_type, config_cls) + except ValueError: + if getattr(config_cls, "model_type", None) != model_type: + CONFIG_MAPPING.register(model_type, config_cls) + else: + raise + return True + return False + + def resolve_custom_model_cls(self, architecture: str, config) -> type[nn.Module] | None: + """Return the custom model class if it exists and supports *config*, else ``None``. + + Custom model classes may define a ``supports_config(config)`` classmethod + to opt out for specific HF configs (e.g. a Mistral3 VLM with a dense + Ministral3 text backbone instead of the expected Mistral4 MoE+MLA). + """ + if architecture not in self.model_arch_name_to_cls: + return None + model_cls = self.model_arch_name_to_cls[architecture] + if hasattr(model_cls, "supports_config") and not model_cls.supports_config(config): + logger.info( + "Custom model %s does not support config %s, falling back to HF", + model_cls.__name__, + type(config).__name__, + ) + return None + return model_cls + + def register(self, arch_name: str, model_cls: type[nn.Module], exist_ok: bool = False) -> None: + """Register a custom model class for a given architecture name.""" + self.model_arch_name_to_cls.register(arch_name, model_cls, exist_ok=exist_ok) + self._rebuild_spec_indexes() diff --git a/nemo_automodel/_transformers/registry/model_package_spec.py b/nemo_automodel/_transformers/registry/model_package_spec.py new file mode 100644 index 0000000000..a896e0e34a --- /dev/null +++ b/nemo_automodel/_transformers/registry/model_package_spec.py @@ -0,0 +1,106 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field, replace + + +@dataclass(frozen=True) +class ModelPackageSpec: + """Registry metadata for a model package and its optional convention modules.""" + + package: str + class_name: str | None = None + model_module: str = "model" + config_module: str | None = None + config_class_name: str | None = None + tags: frozenset[str] = field(default_factory=frozenset) + architectures: tuple[str, ...] = () + model_types: tuple[str, ...] = () + optional_modules: frozenset[str] = field(default_factory=frozenset) + global_patches: bool = False + pre_config_patches: bool = False + post_shard_patches: bool = False + tokenizer_registrations: bool = False + + def __post_init__(self) -> None: + object.__setattr__(self, "tags", frozenset(self.tags)) + object.__setattr__(self, "architectures", tuple(self.architectures)) + object.__setattr__(self, "model_types", tuple(self.model_types)) + object.__setattr__(self, "optional_modules", frozenset(self.optional_modules)) + + @classmethod + def from_module_path( + cls, + module_path: str, + class_name: str, + *, + config_module: str | None = None, + config_class_name: str | None = None, + tags: set[str] | frozenset[str] | tuple[str, ...] = (), + architectures: tuple[str, ...] = (), + model_types: tuple[str, ...] = (), + optional_modules: set[str] | frozenset[str] | tuple[str, ...] = (), + global_patches: bool = False, + pre_config_patches: bool = False, + post_shard_patches: bool = False, + tokenizer_registrations: bool = False, + ) -> "ModelPackageSpec": + """Create a spec from a fully qualified model module path.""" + package, sep, model_module = module_path.rpartition(".") + if not sep: + package = "" + model_module = module_path + return cls( + package=package, + class_name=class_name, + model_module=model_module, + config_module=config_module, + config_class_name=config_class_name, + tags=frozenset(tags), + architectures=architectures, + model_types=model_types, + optional_modules=frozenset(optional_modules), + global_patches=global_patches, + pre_config_patches=pre_config_patches, + post_shard_patches=post_shard_patches, + tokenizer_registrations=tokenizer_registrations, + ) + + @property + def module_path(self) -> str: + """Return the import path for the model implementation module.""" + if not self.package: + return self.model_module + return f"{self.package}.{self.model_module}" + + def optional_module_path(self, module_name: str) -> str: + """Return the import path for a convention module such as ``patches``.""" + if not self.package: + return f"{self.model_module}.{module_name}" + return f"{self.package}.{module_name}" + + @property + def config_module_path(self) -> str | None: + """Return the import path for this model package's config module, if declared.""" + if self.config_module is None: + return None + if not self.package: + return self.config_module + return f"{self.package}.{self.config_module}" + + def with_architecture(self, architecture: str) -> "ModelPackageSpec": + """Return a copy that records *architecture* as an alias for this package.""" + if architecture in self.architectures: + return self + return replace(self, architectures=(*self.architectures, architecture)) diff --git a/nemo_automodel/_transformers/registry/model_registry.py b/nemo_automodel/_transformers/registry/model_registry.py new file mode 100644 index 0000000000..1038703798 --- /dev/null +++ b/nemo_automodel/_transformers/registry/model_registry.py @@ -0,0 +1,390 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from functools import lru_cache + +from nemo_automodel._transformers.registry.base import _ModelRegistry as _BaseModelRegistry +from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec + +# Static mapping: architecture name → ModelPackageSpec. +# Analogous to HuggingFace transformers' MODEL_FOR_CAUSAL_LM_MAPPING_NAMES. +# Models are loaded lazily on first access rather than imported at startup. +MODEL_ARCH_MAPPING: OrderedDict[str, ModelPackageSpec] = OrderedDict( + [ + ( + "BaichuanForCausalLM", + ModelPackageSpec( + package="nemo_automodel.components.models.baichuan", + class_name="BaichuanForCausalLM", + architectures=("BaichuanForCausalLM",), + ), + ), + ( + "DeepseekV3ForCausalLM", + ModelPackageSpec( + package="nemo_automodel.components.models.deepseek_v3", + class_name="DeepseekV3ForCausalLM", + architectures=("DeepseekV3ForCausalLM",), + ), + ), + ( + "DeepseekV32ForCausalLM", + ModelPackageSpec( + package="nemo_automodel.components.models.deepseek_v32", + class_name="DeepseekV32ForCausalLM", + architectures=("DeepseekV32ForCausalLM",), + ), + ), + ( + "DeepseekV4ForCausalLM", + ModelPackageSpec( + package="nemo_automodel.components.models.deepseek_v4", + class_name="DeepseekV4ForCausalLM", + architectures=("DeepseekV4ForCausalLM",), + ), + ), + ( + "Ernie4_5_MoeForCausalLM", + ModelPackageSpec( + package="nemo_automodel.components.models.ernie4_5", + class_name="Ernie4_5_MoeForCausalLM", + architectures=("Ernie4_5_MoeForCausalLM",), + ), + ), + ( + "Glm4MoeForCausalLM", + ModelPackageSpec( + package="nemo_automodel.components.models.glm4_moe", + class_name="Glm4MoeForCausalLM", + architectures=("Glm4MoeForCausalLM",), + ), + ), + ( + "Glm4MoeLiteForCausalLM", + ModelPackageSpec( + package="nemo_automodel.components.models.glm4_moe_lite", + class_name="Glm4MoeLiteForCausalLM", + architectures=("Glm4MoeLiteForCausalLM",), + ), + ), + ( + "GlmMoeDsaForCausalLM", + ModelPackageSpec( + package="nemo_automodel.components.models.glm_moe_dsa", + class_name="GlmMoeDsaForCausalLM", + architectures=("GlmMoeDsaForCausalLM",), + ), + ), + ( + "Gemma4ForConditionalGeneration", + ModelPackageSpec( + package="nemo_automodel.components.models.gemma4_moe", + class_name="Gemma4ForConditionalGeneration", + architectures=("Gemma4ForConditionalGeneration",), + ), + ), + ( + "GptOssForCausalLM", + ModelPackageSpec( + package="nemo_automodel.components.models.gpt_oss", + class_name="GptOssForCausalLM", + architectures=("GptOssForCausalLM",), + ), + ), + ( + "KimiK25ForConditionalGeneration", + ModelPackageSpec( + package="nemo_automodel.components.models.kimi_k25_vl", + class_name="KimiK25VLForConditionalGeneration", + architectures=("KimiK25ForConditionalGeneration",), + ), + ), + ( + "KimiK25VLForConditionalGeneration", + ModelPackageSpec( + package="nemo_automodel.components.models.kimi_k25_vl", + class_name="KimiK25VLForConditionalGeneration", + architectures=("KimiK25VLForConditionalGeneration",), + ), + ), + ( + "KimiVLForConditionalGeneration", + ModelPackageSpec( + package="nemo_automodel.components.models.kimivl", + class_name="KimiVLForConditionalGeneration", + architectures=("KimiVLForConditionalGeneration",), + ), + ), + ( + "LlamaBidirectionalForSequenceClassification", + ModelPackageSpec( + package="nemo_automodel.components.models.llama_bidirectional", + class_name="LlamaBidirectionalForSequenceClassification", + tags=frozenset({"retrieval"}), + architectures=("LlamaBidirectionalForSequenceClassification",), + ), + ), + ( + "LlamaBidirectionalModel", + ModelPackageSpec( + package="nemo_automodel.components.models.llama_bidirectional", + class_name="LlamaBidirectionalModel", + tags=frozenset({"retrieval"}), + architectures=("LlamaBidirectionalModel",), + ), + ), + ( + "LlamaForCausalLM", + ModelPackageSpec( + package="nemo_automodel.components.models.llama", + class_name="LlamaForCausalLM", + architectures=("LlamaForCausalLM",), + ), + ), + ( + "MiniMaxM2ForCausalLM", + ModelPackageSpec( + package="nemo_automodel.components.models.minimax_m2", + class_name="MiniMaxM2ForCausalLM", + architectures=("MiniMaxM2ForCausalLM",), + ), + ), + ( + "MiMoV2FlashForCausalLM", + ModelPackageSpec( + package="nemo_automodel.components.models.mimo_v2_flash", + class_name="MiMoV2FlashForCausalLM", + architectures=("MiMoV2FlashForCausalLM",), + ), + ), + ( + "Ministral3ForCausalLM", + ModelPackageSpec( + package="nemo_automodel.components.models.mistral3", + class_name="Ministral3ForCausalLM", + architectures=("Ministral3ForCausalLM",), + ), + ), + ( + "Ministral3BidirectionalModel", + ModelPackageSpec( + package="nemo_automodel.components.models.ministral_bidirectional", + class_name="Ministral3BidirectionalModel", + tags=frozenset({"retrieval"}), + architectures=("Ministral3BidirectionalModel",), + ), + ), + ( + "Mistral4ForCausalLM", + ModelPackageSpec( + package="nemo_automodel.components.models.mistral4", + class_name="Mistral4ForCausalLM", + architectures=("Mistral4ForCausalLM",), + ), + ), + ( + "Mistral3ForConditionalGeneration", + ModelPackageSpec( + package="nemo_automodel.components.models.mistral4", + class_name="Mistral3ForConditionalGeneration", + architectures=("Mistral3ForConditionalGeneration",), + ), + ), + ( + "Mistral3FP8VLMForConditionalGeneration", + ModelPackageSpec( + package="nemo_automodel.components.models.mistral3_vlm", + class_name="Mistral3FP8VLMForConditionalGeneration", + architectures=("Mistral3FP8VLMForConditionalGeneration",), + ), + ), + ( + "NemotronHForCausalLM", + ModelPackageSpec( + package="nemo_automodel.components.models.nemotron_v3", + class_name="NemotronHForCausalLM", + architectures=("NemotronHForCausalLM",), + ), + ), + ( + "NemotronH_Nano_Omni_Reasoning_V3", + ModelPackageSpec( + package="nemo_automodel.components.models.nemotron_omni", + class_name="NemotronOmniForConditionalGeneration", + architectures=("NemotronH_Nano_Omni_Reasoning_V3",), + ), + ), + ( + "NemotronParseForConditionalGeneration", + ModelPackageSpec( + package="nemo_automodel.components.models.nemotron_parse", + class_name="NemotronParseForConditionalGeneration", + architectures=("NemotronParseForConditionalGeneration",), + ), + ), + ( + "LLaVAOneVision1_5_ForConditionalGeneration", + ModelPackageSpec( + package="nemo_automodel.components.models.llava_onevision", + class_name="LLaVAOneVision1_5_ForConditionalGeneration", + architectures=("LLaVAOneVision1_5_ForConditionalGeneration",), + ), + ), + ( + "HYV3ForCausalLM", + ModelPackageSpec( + package="nemo_automodel.components.models.hy_v3", + class_name="HYV3ForCausalLM", + architectures=("HYV3ForCausalLM",), + ), + ), + ( + "Qwen2ForCausalLM", + ModelPackageSpec( + package="nemo_automodel.components.models.qwen2", + class_name="Qwen2ForCausalLM", + architectures=("Qwen2ForCausalLM",), + ), + ), + ( + "Qwen3MoeForCausalLM", + ModelPackageSpec( + package="nemo_automodel.components.models.qwen3_moe", + class_name="Qwen3MoeForCausalLM", + architectures=("Qwen3MoeForCausalLM",), + ), + ), + ( + "Qwen3NextForCausalLM", + ModelPackageSpec( + package="nemo_automodel.components.models.qwen3_next", + class_name="Qwen3NextForCausalLM", + architectures=("Qwen3NextForCausalLM",), + ), + ), + ( + "Qwen3OmniMoeForConditionalGeneration", + ModelPackageSpec( + package="nemo_automodel.components.models.qwen3_omni_moe", + class_name="Qwen3OmniMoeThinkerForConditionalGeneration", + architectures=("Qwen3OmniMoeForConditionalGeneration",), + ), + ), + ( + "Qwen3VLMoeForConditionalGeneration", + ModelPackageSpec( + package="nemo_automodel.components.models.qwen3_vl_moe", + class_name="Qwen3VLMoeForConditionalGeneration", + architectures=("Qwen3VLMoeForConditionalGeneration",), + ), + ), + ( + "Qwen3_5MoeForConditionalGeneration", + ModelPackageSpec( + package="nemo_automodel.components.models.qwen3_5_moe", + class_name="Qwen3_5MoeForConditionalGeneration", + architectures=("Qwen3_5MoeForConditionalGeneration",), + ), + ), + ( + "Step3p5ForCausalLM", + ModelPackageSpec( + package="nemo_automodel.components.models.step3p5", + class_name="Step3p5ForCausalLM", + architectures=("Step3p5ForCausalLM",), + ), + ), + ] +) + + +_CUSTOM_CONFIG_SPECS: tuple[ModelPackageSpec, ...] = ( + ModelPackageSpec( + package="nemo_automodel.components.models.baichuan", + model_types=("baichuan",), + config_module="configuration", + config_class_name="BaichuanConfig", + ), + ModelPackageSpec( + package="nemo_automodel.components.models.deepseek_v4", + model_types=("deepseek_v4",), + config_module="config", + config_class_name="DeepseekV4Config", + ), + ModelPackageSpec( + package="nemo_automodel.components.models.hy_v3", + model_types=("hy_v3",), + config_module="config", + config_class_name="HYV3Config", + ), + ModelPackageSpec( + package="nemo_automodel.components.models.kimi_k25_vl", + model_types=("kimi_k25", "kimi_k25_vl"), + config_module="model", + config_class_name="KimiK25VLConfig", + ), + ModelPackageSpec( + package="nemo_automodel.components.models.kimivl", + model_types=("kimi_vl",), + config_module="model", + config_class_name="KimiVLConfig", + ), + ModelPackageSpec( + package="nemo_automodel.components.models.llava_onevision", + model_types=("llavaonevision1_5",), + config_module="model", + config_class_name="Llavaonevision1_5Config", + ), + ModelPackageSpec( + package="nemo_automodel.components.models.mimo_v2_flash", + model_types=("mimo_v2_flash",), + config_module="config", + config_class_name="MiMoV2FlashConfig", + ), + ModelPackageSpec( + package="nemo_automodel.components.models.mistral4", + model_types=("mistral4",), + config_module="configuration", + config_class_name="Mistral4Config", + ), +) + + +MODEL_PACKAGE_SPECS: tuple[ModelPackageSpec, ...] = (*_CUSTOM_CONFIG_SPECS,) + + +class _ModelRegistry(_BaseModelRegistry): + """Model-specific registry initialized from AutoModel's static mappings.""" + + def __init__( + self, + model_arch_name_to_cls=None, + package_specs: tuple[ModelPackageSpec, ...] = MODEL_PACKAGE_SPECS, + model_arch_mapping: OrderedDict[str, ModelPackageSpec] | dict[str, ModelPackageSpec] = MODEL_ARCH_MAPPING, + ) -> None: + super().__init__( + model_arch_mapping=model_arch_mapping, + model_arch_name_to_cls=model_arch_name_to_cls, + package_specs=package_specs, + ) + + +@lru_cache +def get_registry() -> _ModelRegistry: + """Return the process-wide model registry singleton.""" + return _ModelRegistry() + + +ModelRegistry = get_registry() diff --git a/tests/unit_tests/_transformers/test_model_init.py b/tests/unit_tests/_transformers/test_model_init.py index 84660ea7fb..821bfb5689 100644 --- a/tests/unit_tests/_transformers/test_model_init.py +++ b/tests/unit_tests/_transformers/test_model_init.py @@ -552,6 +552,23 @@ def test_raises_when_config_class_cannot_be_resolved(self, mock_get_dict, mock_m class TestGetHfConfigLayerTypesRetry: """get_hf_config should retry via the layer_types fix helper when AutoConfig raises.""" + @patch("nemo_automodel._transformers.model_init.ModelRegistry.ensure_config_registered") + @patch("nemo_automodel._transformers.model_init.PretrainedConfig.get_config_dict") + @patch("nemo_automodel._transformers.model_init.resolve_trust_remote_code", return_value=False) + @patch("nemo_automodel._transformers.model_init.AutoConfig.from_pretrained") + def test_registers_custom_config_before_auto_config( + self, mock_from_pretrained, _mock_trust, mock_get_dict, mock_register + ): + mock_get_dict.return_value = ({"model_type": "fake_model"}, {}) + built_config = MagicMock() + mock_from_pretrained.return_value = built_config + + result = get_hf_config("fake/model", "sdpa") + + assert result is built_config + mock_register.assert_called_once_with("fake_model") + mock_from_pretrained.assert_called_once() + @patch("nemo_automodel._transformers.model_init._load_config_with_layer_types_fix") @patch("nemo_automodel._transformers.model_init.resolve_trust_remote_code", return_value=True) @patch("nemo_automodel._transformers.model_init.AutoConfig.from_pretrained") diff --git a/tests/unit_tests/_transformers/test_registry.py b/tests/unit_tests/_transformers/test_registry.py index 46a1f1b687..5508e61921 100644 --- a/tests/unit_tests/_transformers/test_registry.py +++ b/tests/unit_tests/_transformers/test_registry.py @@ -13,6 +13,7 @@ # limitations under the License. import types +from unittest.mock import patch import pytest @@ -113,15 +114,33 @@ def test_get_registry_is_cached(): assert r1 is r2 +def test_registry_reexports_split_modules(): + """The registry package remains the public import path while definitions live in split modules.""" + from nemo_automodel._transformers import registry as reg + from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec + from nemo_automodel._transformers.registry.model_registry import ( + _CUSTOM_CONFIG_SPECS, + MODEL_ARCH_MAPPING, + MODEL_PACKAGE_SPECS, + ) + + assert reg.ModelPackageSpec is ModelPackageSpec + assert reg.MODEL_ARCH_MAPPING is MODEL_ARCH_MAPPING + assert reg.MODEL_PACKAGE_SPECS is MODEL_PACKAGE_SPECS + assert reg._CUSTOM_CONFIG_SPECS is _CUSTOM_CONFIG_SPECS + + def test_lazy_arch_mapping_auto_map(): """Static auto_map entries are lazily loaded on first access.""" from nemo_automodel._transformers.registry import _LazyArchMapping + from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec class FakeClass: pass fake_module = types.SimpleNamespace(FakeClass=FakeClass) - mapping = _LazyArchMapping({"FakeArch": ("fake.module", "FakeClass")}) + spec = ModelPackageSpec(package="fake", model_module="module", class_name="FakeClass") + mapping = _LazyArchMapping({"FakeArch": spec}) mapping._modules["fake.module"] = fake_module @@ -133,9 +152,33 @@ class FakeClass: mapping["NonExistent"] +def test_lazy_arch_mapping_accepts_model_package_spec_entries(): + """Static auto_map entries may be declared as ModelPackageSpec metadata.""" + from nemo_automodel._transformers.registry import _LazyArchMapping + from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec + + class FakeClass: + pass + + fake_module = types.SimpleNamespace(FakeClass=FakeClass) + spec = ModelPackageSpec.from_module_path( + "fake.package.model", + "FakeClass", + model_types=("fake",), + optional_modules=frozenset({"patches"}), + ) + mapping = _LazyArchMapping({"FakeArch": spec}) + mapping._modules["fake.package.model"] = fake_module + + assert mapping["FakeArch"] is FakeClass + assert mapping.get_spec("FakeArch").package == "fake.package" + assert mapping.get_spec("FakeArch").model_types == ("fake",) + + def test_lazy_arch_mapping_extra_overrides_auto_map(): """Dynamically registered entries take precedence over static entries.""" from nemo_automodel._transformers.registry import _LazyArchMapping + from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec class StaticClass: pass @@ -144,7 +187,8 @@ class DynamicClass: pass fake_module = types.SimpleNamespace(StaticClass=StaticClass) - mapping = _LazyArchMapping({"MyArch": ("fake.module", "StaticClass")}) + spec = ModelPackageSpec(package="fake", model_module="module", class_name="StaticClass") + mapping = _LazyArchMapping({"MyArch": spec}) mapping._modules["fake.module"] = fake_module assert mapping["MyArch"] is StaticClass @@ -156,11 +200,13 @@ class DynamicClass: def test_lazy_arch_mapping_unavailable_model(): """Auto_map entries whose imports fail are removed and excluded from containment.""" from nemo_automodel._transformers.registry import _LazyArchMapping + from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec - mapping = _LazyArchMapping({"BadArch": ("nonexistent.module.path", "BadClass")}) + spec = ModelPackageSpec(package="nonexistent.module", model_module="path", class_name="BadClass") + mapping = _LazyArchMapping({"BadArch": spec}) assert "BadArch" not in mapping - assert "BadArch" not in mapping._auto_map + assert mapping.get_spec("BadArch") is None def test_default_registry_has_static_entries(): @@ -172,6 +218,144 @@ def test_default_registry_has_static_entries(): assert arch_name in inst.model_arch_name_to_cls.keys() +def test_registry_imports_optional_module_for_architecture(): + """Convention modules are imported through ModelPackageSpec package metadata.""" + from nemo_automodel._transformers import registry as reg + from nemo_automodel._transformers.registry import _LazyArchMapping + from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec + + fake_module = types.SimpleNamespace(__name__="fake.package.patches") + spec = ModelPackageSpec( + package="fake.package", + architectures=("FakeArch",), + optional_modules=frozenset({"patches"}), + ) + inst = reg._ModelRegistry(model_arch_name_to_cls=_LazyArchMapping(auto_map={}), package_specs=(spec,)) + + with patch( + "nemo_automodel._transformers.registry.base.importlib.import_module", return_value=fake_module + ) as mock_import: + modules = inst.iter_optional_modules_for_architectures(("FakeArch",), "patches") + + mock_import.assert_called_once_with("fake.package.patches") + assert modules == (fake_module,) + + +def test_registry_filters_optional_modules_by_scope(): + """Global patch lookup imports only specs that advertise global patches.""" + from nemo_automodel._transformers import registry as reg + from nemo_automodel._transformers.registry import _LazyArchMapping + from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec + + global_module = types.SimpleNamespace(__name__="fake.global_model.patches") + specs = ( + ModelPackageSpec( + package="fake.global_model", + optional_modules=frozenset({"patches"}), + global_patches=True, + ), + ModelPackageSpec( + package="fake.runtime_model", + optional_modules=frozenset({"patches"}), + global_patches=False, + ), + ) + inst = reg._ModelRegistry(model_arch_name_to_cls=_LazyArchMapping(auto_map={}), package_specs=specs) + + with patch( + "nemo_automodel._transformers.registry.base.importlib.import_module", return_value=global_module + ) as mock_import: + modules = inst.iter_optional_modules("patches", global_patches=True) + + mock_import.assert_called_once_with("fake.global_model.patches") + assert modules == (global_module,) + + +def test_registry_discovers_config_class_from_config_module(): + """Config classes are imported from declared convention modules on demand.""" + from transformers import PretrainedConfig + + from nemo_automodel._transformers import registry as reg + from nemo_automodel._transformers.registry import _LazyArchMapping + from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec + + class FakeConfig(PretrainedConfig): + model_type = "fake_model" + + FakeConfig.__module__ = "fake.package.config" + fake_module = types.SimpleNamespace(__name__="fake.package.config", FakeConfig=FakeConfig) + spec = ModelPackageSpec(package="fake.package", model_types=("fake_model",), config_module="config") + inst = reg._ModelRegistry(model_arch_name_to_cls=_LazyArchMapping(auto_map={}), package_specs=(spec,)) + + with ( + patch("transformers.AutoConfig.register") as mock_register, + patch("nemo_automodel._transformers.registry.base.importlib.import_module", return_value=fake_module), + ): + assert inst.ensure_config_registered("fake_model") is True + + mock_register.assert_called_once_with("fake_model", FakeConfig) + + +def test_registry_registers_explicit_config_alias(): + """Explicit config metadata supports aliases whose key differs from class.model_type.""" + from transformers import PretrainedConfig + + from nemo_automodel._transformers import registry as reg + from nemo_automodel._transformers.registry import _LazyArchMapping + from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec + + class FakeConfig(PretrainedConfig): + model_type = "canonical_model" + + fake_module = types.SimpleNamespace(__name__="fake.package.model", FakeConfig=FakeConfig) + spec = ModelPackageSpec( + package="fake.package", + model_types=("canonical_model", "alias_model"), + config_module="model", + config_class_name="FakeConfig", + ) + inst = reg._ModelRegistry(model_arch_name_to_cls=_LazyArchMapping(auto_map={}), package_specs=(spec,)) + + with ( + patch("transformers.AutoConfig.register") as mock_register, + patch("nemo_automodel._transformers.registry.base.importlib.import_module", return_value=fake_module), + ): + assert inst.ensure_config_registered("alias_model") is True + + mock_register.assert_called_once_with("alias_model", FakeConfig) + + +def test_registry_registers_config_alias_when_auto_config_rejects_mismatch(): + """HF validates ``config.model_type``; aliases fall back to CONFIG_MAPPING directly.""" + from transformers import PretrainedConfig + from transformers.models.auto.configuration_auto import CONFIG_MAPPING + + from nemo_automodel._transformers import registry as reg + from nemo_automodel._transformers.registry import _LazyArchMapping + from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec + + class FakeConfig(PretrainedConfig): + model_type = "canonical_model" + + fake_module = types.SimpleNamespace(__name__="fake.package.model", FakeConfig=FakeConfig) + spec = ModelPackageSpec( + package="fake.package", + model_types=("alias_model_with_mismatch",), + config_module="model", + config_class_name="FakeConfig", + ) + inst = reg._ModelRegistry(model_arch_name_to_cls=_LazyArchMapping(auto_map={}), package_specs=(spec,)) + + with ( + patch("transformers.AutoConfig.register", side_effect=ValueError("model_type mismatch")), + patch.object(CONFIG_MAPPING, "register") as mock_mapping_register, + patch("nemo_automodel._transformers.registry.base.importlib.import_module", return_value=fake_module), + ): + assert inst.ensure_config_registered("alias_model_with_mismatch") is True + + mock_mapping_register.assert_called_once_with("alias_model_with_mismatch", FakeConfig) + + def test_resolve_custom_model_cls_found(): """resolve_custom_model_cls returns the class when it exists and has no supports_config.""" from nemo_automodel._transformers import registry as reg @@ -242,26 +426,13 @@ def supports_config(cls, config): assert inst.resolve_custom_model_cls("ConfigAwareModel", bad) is None -def test_custom_config_registrations_in_config_mapping(): - """Models in _CUSTOM_CONFIG_REGISTRATIONS must be registered in CONFIG_MAPPING after import. +def test_custom_config_specs_are_metadata_only_until_requested(): + """Custom config entries should not require eager registration at registry import time.""" + from nemo_automodel._transformers.registry import _CUSTOM_CONFIG_SPECS - This ensures that AutoConfig.from_pretrained can resolve custom model types - (e.g. kimi_k25, kimi_vl) from local checkpoints without trust_remote_code=True. - """ - from transformers.models.auto.configuration_auto import CONFIG_MAPPING - - from nemo_automodel._transformers.registry import _CUSTOM_CONFIG_REGISTRATIONS - - missing = [] - for model_type in _CUSTOM_CONFIG_REGISTRATIONS: - if model_type not in CONFIG_MAPPING: - missing.append(model_type) - - assert not missing, ( - f"Model type(s) {missing} are in _CUSTOM_CONFIG_REGISTRATIONS but not in " - f"CONFIG_MAPPING. The _register_custom_configs() call at module level may " - f"have failed for these entries." - ) + for spec in _CUSTOM_CONFIG_SPECS: + assert spec.config_module_path + assert spec.config_class_name def test_kimi_k25_arch_alias_in_model_arch_mapping(): @@ -273,8 +444,8 @@ def test_kimi_k25_arch_alias_in_model_arch_mapping(): "Kimi-K2.5 checkpoints use this architecture name and need it mapped " "to KimiK25VLForConditionalGeneration." ) - module_path, cls_name = MODEL_ARCH_MAPPING["KimiK25ForConditionalGeneration"] - assert cls_name == "KimiK25VLForConditionalGeneration" + spec = MODEL_ARCH_MAPPING["KimiK25ForConditionalGeneration"] + assert spec.class_name == "KimiK25VLForConditionalGeneration" def test_deepseek_v4_registered_in_arch_mapping(): @@ -286,22 +457,23 @@ def test_deepseek_v4_registered_in_arch_mapping(): "DSV4 checkpoints declare this architecture and need it routed to the " "in-tree model implementation." ) - module_path, cls_name = MODEL_ARCH_MAPPING["DeepseekV4ForCausalLM"] - assert module_path == "nemo_automodel.components.models.deepseek_v4.model" - assert cls_name == "DeepseekV4ForCausalLM" + spec = MODEL_ARCH_MAPPING["DeepseekV4ForCausalLM"] + assert spec.module_path == "nemo_automodel.components.models.deepseek_v4.model" + assert spec.class_name == "DeepseekV4ForCausalLM" -def test_deepseek_v4_in_custom_config_registrations(): - """deepseek_v4 model_type must be registered in _CUSTOM_CONFIG_REGISTRATIONS.""" - from nemo_automodel._transformers.registry import _CUSTOM_CONFIG_REGISTRATIONS +def test_deepseek_v4_in_custom_config_specs(): + """deepseek_v4 model_type must be declared in _CUSTOM_CONFIG_SPECS.""" + from nemo_automodel._transformers.registry import _CUSTOM_CONFIG_SPECS - assert "deepseek_v4" in _CUSTOM_CONFIG_REGISTRATIONS, ( - "deepseek_v4 must be in _CUSTOM_CONFIG_REGISTRATIONS so AutoConfig.from_pretrained " - "can resolve DSV4 configs without trust_remote_code=True." + specs_by_model_type = {model_type: spec for spec in _CUSTOM_CONFIG_SPECS for model_type in spec.model_types} + assert "deepseek_v4" in specs_by_model_type, ( + "deepseek_v4 must be in _CUSTOM_CONFIG_SPECS so ModelRegistry can register " + "DSV4 configs on demand before AutoConfig.from_pretrained runs." ) - module_path, cls_name = _CUSTOM_CONFIG_REGISTRATIONS["deepseek_v4"] - assert module_path == "nemo_automodel.components.models.deepseek_v4.config" - assert cls_name == "DeepseekV4Config" + spec = specs_by_model_type["deepseek_v4"] + assert spec.config_module_path == "nemo_automodel.components.models.deepseek_v4.config" + assert spec.config_class_name == "DeepseekV4Config" def test_all_model_folders_registered_in_auto_map(): @@ -318,12 +490,15 @@ def test_all_model_folders_registered_in_auto_map(): models_root = pathlib.Path(__file__).resolve().parents[3] / "nemo_automodel" / "components" / "models" # Collect the set of module paths referenced by the auto_map - registered_module_paths = {v[0] for v in MODEL_ARCH_MAPPING.values()} + registered_module_paths = {spec.module_path for spec in MODEL_ARCH_MAPPING.values()} missing = [] + documentation_only_model_dirs = {"blueprint"} for model_dir in sorted(models_root.iterdir()): if not model_dir.is_dir() or model_dir.name.startswith(("_", ".")): continue + if model_dir.name in documentation_only_model_dirs: + continue model_file = model_dir / "model.py" if not model_file.exists(): continue diff --git a/tests/unit_tests/_transformers/test_registry_hy_v3.py b/tests/unit_tests/_transformers/test_registry_hy_v3.py index 9e733ccccb..d4d0988157 100644 --- a/tests/unit_tests/_transformers/test_registry_hy_v3.py +++ b/tests/unit_tests/_transformers/test_registry_hy_v3.py @@ -14,8 +14,6 @@ """Verify HYV3 model + config are registered in nemo_automodel._transformers.registry.""" -import pytest - class TestArchMapping: def test_hyv3_arch_registered(self): @@ -27,8 +25,8 @@ def test_hyv3_arch_points_at_correct_module(self): from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING entry = MODEL_ARCH_MAPPING["HYV3ForCausalLM"] - assert entry[0] == "nemo_automodel.components.models.hy_v3.model" - assert entry[1] == "HYV3ForCausalLM" + assert entry.module_path == "nemo_automodel.components.models.hy_v3.model" + assert entry.class_name == "HYV3ForCausalLM" def test_hyv3_arch_resolves_to_class(self): """Walk the mapping path -- importable + the named class exists.""" @@ -36,23 +34,27 @@ def test_hyv3_arch_resolves_to_class(self): from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING - mod_path, cls_name, *_ = MODEL_ARCH_MAPPING["HYV3ForCausalLM"] + spec = MODEL_ARCH_MAPPING["HYV3ForCausalLM"] + mod_path = spec.module_path + cls_name = spec.class_name mod = importlib.import_module(mod_path) assert hasattr(mod, cls_name) class TestCustomConfigRegistration: def test_hy_v3_config_registered(self): - from nemo_automodel._transformers.registry import _CUSTOM_CONFIG_REGISTRATIONS + from nemo_automodel._transformers.registry import _CUSTOM_CONFIG_SPECS - assert "hy_v3" in _CUSTOM_CONFIG_REGISTRATIONS + assert any("hy_v3" in spec.model_types for spec in _CUSTOM_CONFIG_SPECS) def test_hy_v3_config_resolves_to_class(self): import importlib - from nemo_automodel._transformers.registry import _CUSTOM_CONFIG_REGISTRATIONS + from nemo_automodel._transformers.registry import _CUSTOM_CONFIG_SPECS - mod_path, cls_name = _CUSTOM_CONFIG_REGISTRATIONS["hy_v3"] + spec = next(spec for spec in _CUSTOM_CONFIG_SPECS if "hy_v3" in spec.model_types) + mod_path = spec.config_module_path + cls_name = spec.config_class_name mod = importlib.import_module(mod_path) cls = getattr(mod, cls_name) assert cls.__name__ == "HYV3Config" From 3824d05a113d5e3ea027847a01fd602c2681db91 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Mon, 18 May 2026 21:52:10 -0700 Subject: [PATCH 2/5] update registry Signed-off-by: Alexandros Koumparoulis --- .../_transformers/registry/__init__.py | 20 +- nemo_automodel/_transformers/registry/base.py | 332 +++++------- .../registry/model_package_spec.py | 68 ++- .../_transformers/registry/model_registry.py | 503 ++++++------------ nemo_automodel/_transformers/retrieval.py | 22 +- tests/integration/test_llava_onevision_cpu.py | 4 +- .../_transformers/test_doc_coverage.py | 15 +- .../_transformers/test_recipe_doc_coverage.py | 2 +- .../unit_tests/_transformers/test_registry.py | 355 +++++++----- .../_transformers/test_registry_hy_v3.py | 23 +- .../test_llama_bidirectional_model.py | 53 +- .../test_ministral_bidirectional_model.py | 21 +- .../models/mistral4/test_mistral4_model.py | 19 +- .../test_nemotron_omni_layers.py | 20 +- 14 files changed, 694 insertions(+), 763 deletions(-) diff --git a/nemo_automodel/_transformers/registry/__init__.py b/nemo_automodel/_transformers/registry/__init__.py index 8fe4f60160..0d9005e6b6 100644 --- a/nemo_automodel/_transformers/registry/__init__.py +++ b/nemo_automodel/_transformers/registry/__init__.py @@ -14,17 +14,16 @@ """Public exports for the model registry package.""" -from nemo_automodel._transformers.registry.base import _LazyArchMapping -from nemo_automodel._transformers.registry.base import _ModelRegistry as _BaseModelRegistry +from nemo_automodel._transformers.registry.base import _BaseModelRegistry from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec _MODEL_REGISTRY_EXPORTS = { - "MODEL_ARCH_MAPPING", - "_CUSTOM_CONFIG_SPECS", "MODEL_PACKAGE_SPECS", + "RETRIEVAL_MODEL_PACKAGE_SPECS", "ModelRegistry", - "_ModelRegistry", - "get_registry", + "RetrievalModelRegistry", + "make_registry", + "make_retrieval_registry", } @@ -38,13 +37,12 @@ def __getattr__(name: str): __all__ = [ - "MODEL_ARCH_MAPPING", - "_CUSTOM_CONFIG_SPECS", "MODEL_PACKAGE_SPECS", + "RETRIEVAL_MODEL_PACKAGE_SPECS", "ModelRegistry", + "RetrievalModelRegistry", "ModelPackageSpec", "_BaseModelRegistry", - "_LazyArchMapping", - "_ModelRegistry", - "get_registry", + "make_registry", + "make_retrieval_registry", ] diff --git a/nemo_automodel/_transformers/registry/base.py b/nemo_automodel/_transformers/registry/base.py index 0cea2efe0d..8838c85138 100644 --- a/nemo_automodel/_transformers/registry/base.py +++ b/nemo_automodel/_transformers/registry/base.py @@ -16,7 +16,9 @@ import importlib import inspect import logging +import warnings from collections import OrderedDict +from collections.abc import Iterable, Mapping from dataclasses import dataclass, field from types import ModuleType @@ -26,186 +28,170 @@ logger = logging.getLogger(__name__) +_ModelArchMappingInput = Mapping[str, ModelPackageSpec] | Iterable[ModelPackageSpec] | None -class _LazyArchMapping: - """Lazy-loading mapping from architecture name to model class. - - Inspired by HuggingFace transformers' ``_LazyAutoMapping``. Entries from the - static ``ModelPackageSpec`` mapping are imported on first access and cached. Additional entries - can be added at runtime via ``register``. - """ - - def __init__(self, auto_map: OrderedDict[str, ModelPackageSpec] | dict[str, ModelPackageSpec] | None = None): - self._specs: dict[str, ModelPackageSpec] = OrderedDict() - self._tags: dict[str, set] = {} - for key, spec in (auto_map or {}).items(): - if spec.class_name is None: - raise ValueError(f"Model architecture entry {key!r} must include a class name") - spec = spec.with_architecture(key) - self._specs[key] = spec - if spec.tags: - self._tags[key] = set(spec.tags) - self._loaded: dict[str, type[nn.Module]] = {} - self._extra: dict[str, type[nn.Module]] = {} - self._extra_specs: dict[str, ModelPackageSpec] = {} - self._modules: dict[str, object] = {} - - def _load(self, key: str) -> type[nn.Module]: - if key in self._loaded: - return self._loaded[key] - spec = self._specs[key] - module_path = spec.module_path - class_name = spec.class_name - if module_path not in self._modules: - self._modules[module_path] = importlib.import_module(module_path) - cls = getattr(self._modules[module_path], class_name) - self._loaded[key] = cls - return cls - - def __contains__(self, key: str) -> bool: - if key in self._extra or key in self._loaded: - return True - if key not in self._specs: - return False - try: - self._load(key) - return True - except Exception: - logger.debug("Model %s unavailable (import failed), removing from registry specs", key) - self._specs.pop(key, None) - self._tags.pop(key, None) - return False - - def __getitem__(self, key: str) -> type[nn.Module]: - if key in self._extra: - return self._extra[key] - if key in self._specs: - return self._load(key) - raise KeyError(key) - - def __setitem__(self, key: str, value: type[nn.Module]) -> None: - self._extra[key] = value - self._extra_specs[key] = ModelPackageSpec.from_module_path(value.__module__, value.__name__).with_architecture( - key - ) - def register(self, key: str, value: type[nn.Module], exist_ok: bool = False) -> None: - """Register a model class under the given architecture name.""" - if not exist_ok and key in self._extra: - raise ValueError(f"Duplicated model implementation for {key}") - self._extra[key] = value - self._extra_specs[key] = ModelPackageSpec.from_module_path(value.__module__, value.__name__).with_architecture( - key - ) - - def has_tag(self, key: str, tag: str) -> bool: - """Return ``True`` if *key* was registered with *tag*.""" - return tag in self._tags.get(key, set()) - - def keys_with_tag(self, tag: str) -> set: - """Return all architecture names that have *tag*.""" - return {k for k, tags in self._tags.items() if tag in tags} - - def keys(self): - return set(self._specs.keys()) | set(self._extra.keys()) - - def get_spec(self, key: str) -> ModelPackageSpec | None: - """Return package metadata for *key* without importing the model class.""" - if key in self._extra_specs: - return self._extra_specs[key] - return self._specs.get(key) - - def specs(self) -> tuple[ModelPackageSpec, ...]: - """Return metadata for all statically and dynamically registered model classes.""" - return (*self._specs.values(), *self._extra_specs.values()) - - def __len__(self) -> int: - return len(self.keys()) +def _normalize_model_arch_mapping(auto_map: _ModelArchMappingInput = None) -> OrderedDict[str, ModelPackageSpec]: + """Normalize legacy dict mappings or tuple-based specs into an architecture lookup.""" + mapping: OrderedDict[str, ModelPackageSpec] = OrderedDict() + if isinstance(auto_map, Mapping): + spec_items = ((architecture, spec.with_architecture(architecture)) for architecture, spec in auto_map.items()) + else: + spec_items = [] + for spec in auto_map or (): + if spec.class_name is None and not spec.architectures: + continue + if not spec.architectures: + raise ValueError(f"Model architecture spec for {spec.package!r} must declare at least one architecture") + spec_items.extend((architecture, spec) for architecture in spec.architectures) - def __repr__(self) -> str: - return f"_LazyArchMapping(specs={len(self._specs)}, extra={len(self._extra)}, loaded={len(self._loaded)})" + for architecture, spec in spec_items: + if spec.class_name is None: + raise ValueError(f"Model architecture entry {architecture!r} must include a class name") + if architecture in mapping: + raise ValueError(f"Duplicated model architecture entry for {architecture!r}") + mapping[architecture] = spec + return mapping @dataclass -class _ModelRegistry: - model_arch_mapping: OrderedDict[str, ModelPackageSpec] | dict[str, ModelPackageSpec] | None = None - model_arch_name_to_cls: _LazyArchMapping = field(default=None) - package_specs: tuple[ModelPackageSpec, ...] = () - _retrieval_archs: set = field(default_factory=set) +class _BaseModelRegistry: + model_specs: _ModelArchMappingInput = None + model_arch_name_to_cls: "_BaseModelRegistry" = field(init=False, repr=False, compare=False) + _model_specs: tuple[ModelPackageSpec, ...] = field(init=False) + _loaded_model_classes: dict[str, type[nn.Module]] = field(default_factory=dict) + _extra_model_classes: dict[str, type[nn.Module]] = field(default_factory=dict) + _discarded_architectures: set[str] = field(default_factory=set) + _manual_architecture_tags: dict[str, set[str]] = field(default_factory=dict) _architecture_to_specs: dict[str, tuple[ModelPackageSpec, ...]] = field(default_factory=dict) _model_type_to_specs: dict[str, tuple[ModelPackageSpec, ...]] = field(default_factory=dict) def __post_init__(self): - if self.model_arch_name_to_cls is None: - self.model_arch_name_to_cls = _LazyArchMapping(self.model_arch_mapping or {}) - self._retrieval_archs = self.model_arch_name_to_cls.keys_with_tag("retrieval") + if isinstance(self.model_specs, Mapping): + raw_specs: Iterable[ModelPackageSpec] = _normalize_model_arch_mapping(self.model_specs).values() + else: + raw_specs = tuple(self.model_specs or ()) + _normalize_model_arch_mapping(raw_specs) + self._model_specs = tuple(dict.fromkeys(raw_specs)) self._rebuild_spec_indexes() - - def _iter_specs(self) -> tuple[ModelPackageSpec, ...]: - return (*self.model_arch_name_to_cls.specs(), *self.package_specs) + self.model_arch_name_to_cls = self def _rebuild_spec_indexes(self) -> None: architecture_to_specs: dict[str, list[ModelPackageSpec]] = {} model_type_to_specs: dict[str, list[ModelPackageSpec]] = {} - for spec in self._iter_specs(): + dynamic_specs = tuple( + ModelPackageSpec.from_model_class(model_cls, architectures=(architecture,)) + for architecture, model_cls in self._extra_model_classes.items() + ) + for spec in (*dynamic_specs, *self._model_specs): for architecture in spec.architectures: + if architecture in self._discarded_architectures and architecture not in self._extra_model_classes: + continue architecture_to_specs.setdefault(architecture, []).append(spec) for model_type in spec.model_types: model_type_to_specs.setdefault(model_type, []).append(spec) self._architecture_to_specs = { architecture: tuple(specs) for architecture, specs in architecture_to_specs.items() } - self._model_type_to_specs = { - model_type: tuple(self._dedupe_specs(specs)) for model_type, specs in model_type_to_specs.items() - } - - @staticmethod - def _dedupe_specs(specs: list[ModelPackageSpec] | tuple[ModelPackageSpec, ...]) -> tuple[ModelPackageSpec, ...]: - seen: set[str] = set() - deduped: list[ModelPackageSpec] = [] - for spec in specs: - if spec.package in seen: - continue - seen.add(spec.package) - deduped.append(spec) - return tuple(deduped) + self._model_type_to_specs = {model_type: tuple(specs) for model_type, specs in model_type_to_specs.items()} + + def _discard_architecture(self, architecture: str) -> None: + self._discarded_architectures.add(architecture) + self._architecture_to_specs.pop(architecture, None) + self._extra_model_classes.pop(architecture, None) + self._loaded_model_classes.pop(architecture, None) + self._manual_architecture_tags.pop(architecture, None) + + def _load_model_class(self, architecture: str) -> type[nn.Module]: + if architecture in self._loaded_model_classes: + return self._loaded_model_classes[architecture] + spec = self.get_model_package_spec(architecture) + if spec is None or spec.class_name is None: + raise KeyError(architecture) + model_cls = getattr(importlib.import_module(spec.module_path), spec.class_name) + self._loaded_model_classes[architecture] = model_cls + return model_cls - @staticmethod - def _import_optional_module(spec: ModelPackageSpec, module_name: str) -> ModuleType | None: - if module_name not in spec.optional_modules: - return None - module_path = spec.optional_module_path(module_name) + def __contains__(self, architecture: str) -> bool: + if architecture in self._extra_model_classes or architecture in self._loaded_model_classes: + return True + if self.get_model_package_spec(architecture) is None: + return False try: - return importlib.import_module(module_path) - except ImportError: - logger.debug("Optional model module is unavailable: %s", module_path) - return None + self._load_model_class(architecture) + return True + except Exception: + logger.debug("Model %s unavailable (import failed), removing from registry specs", architecture) + self._discard_architecture(architecture) + return False + + def __getitem__(self, architecture: str) -> type[nn.Module]: + if architecture in self._extra_model_classes: + return self._extra_model_classes[architecture] + return self._load_model_class(architecture) + + def __setitem__(self, architecture: str, model_cls: type[nn.Module]) -> None: + self._discarded_architectures.discard(architecture) + self._extra_model_classes[architecture] = model_cls + self._loaded_model_classes.pop(architecture, None) + self._rebuild_spec_indexes() + + def keys(self): + return set(self._architecture_to_specs) | set(self._extra_model_classes) + + def __len__(self) -> int: + return len(self.keys()) + + def __repr__(self) -> str: + return ( + f"_BaseModelRegistry(specs={len(self._architecture_to_specs)}, " + f"extra={len(self._extra_model_classes)}, loaded={len(self._loaded_model_classes)})" + ) + + def _architecture_has_tag(self, architecture: str, tag: str) -> bool: + if tag in self._manual_architecture_tags.get(architecture, set()): + return True + return any(tag in spec.tags for spec in self.get_model_package_specs_for_architecture(architecture)) @property def supported_models(self): - return self.model_arch_name_to_cls.keys() + return self.keys() def get_model_cls_from_model_arch(self, model_arch: str) -> type[nn.Module]: - return self.model_arch_name_to_cls[model_arch] + return self[model_arch] def has_custom_model(self, arch_name: str) -> bool: """Return ``True`` if *arch_name* has a custom (non-HF) implementation.""" - return arch_name in self.model_arch_name_to_cls + return arch_name in self def has_retrieval_model(self, arch_name: str) -> bool: """Return ``True`` if *arch_name* is a registered retrieval/encoder architecture.""" - return arch_name in self._retrieval_archs + warnings.warn( + "has_retrieval_model() is deprecated; use RetrievalModelRegistry.get_model_package_spec() instead.", + DeprecationWarning, + stacklevel=2, + ) + if self._architecture_has_tag(arch_name, "retrieval"): + return True + from nemo_automodel._transformers.registry import RetrievalModelRegistry + + return RetrievalModelRegistry.get_model_package_spec(arch_name) is not None def register_retrieval(self, arch_name: str) -> None: """Mark *arch_name* as a retrieval/encoder architecture.""" - self._retrieval_archs.add(arch_name) + warnings.warn( + "register_retrieval() is deprecated; register retrieval classes with RetrievalModelRegistry.register().", + DeprecationWarning, + stacklevel=2, + ) + self._manual_architecture_tags.setdefault(arch_name, set()).add("retrieval") def get_model_package_spec(self, architecture: str) -> ModelPackageSpec | None: """Return package metadata for an architecture without importing ``model.py``.""" specs = self.get_model_package_specs_for_architecture(architecture) if specs: return specs[0] - return self.model_arch_name_to_cls.get_spec(architecture) + return None def get_model_package_specs_for_architecture(self, architecture: str) -> tuple[ModelPackageSpec, ...]: """Return all package metadata entries that declare *architecture*.""" @@ -215,71 +201,6 @@ def get_model_package_specs_for_model_type(self, model_type: str) -> tuple[Model """Return package metadata entries that declare *model_type*.""" return self._model_type_to_specs.get(model_type, ()) - def get_optional_module_for_architecture(self, architecture: str, module_name: str) -> ModuleType | None: - """Import ``.`` for an architecture if declared.""" - modules = self.iter_optional_modules_for_architectures((architecture,), module_name) - return modules[0] if modules else None - - def iter_optional_modules_for_architectures( - self, architectures: tuple[str, ...] | list[str], module_name: str - ) -> tuple[ModuleType, ...]: - """Import convention modules for the requested architectures only.""" - modules: list[ModuleType] = [] - seen_paths: set[str] = set() - for architecture in architectures: - for spec in self.get_model_package_specs_for_architecture(architecture): - if module_name not in spec.optional_modules: - continue - module_path = spec.optional_module_path(module_name) - if module_path in seen_paths: - continue - module = self._import_optional_module(spec, module_name) - if module is not None: - seen_paths.add(module_path) - modules.append(module) - return tuple(modules) - - def iter_optional_modules_for_model_type(self, model_type: str, module_name: str) -> tuple[ModuleType, ...]: - """Import convention modules for packages that declare *model_type*.""" - modules: list[ModuleType] = [] - for spec in self.get_model_package_specs_for_model_type(model_type): - module = self._import_optional_module(spec, module_name) - if module is not None: - modules.append(module) - return tuple(modules) - - def iter_optional_modules( - self, - module_name: str, - *, - global_patches: bool | None = None, - pre_config_patches: bool | None = None, - post_shard_patches: bool | None = None, - tokenizer_registrations: bool | None = None, - ) -> tuple[ModuleType, ...]: - """Import declared convention modules matching the provided metadata filters.""" - modules: list[ModuleType] = [] - seen_paths: set[str] = set() - for spec in self._iter_specs(): - if module_name not in spec.optional_modules: - continue - if global_patches is not None and spec.global_patches is not global_patches: - continue - if pre_config_patches is not None and spec.pre_config_patches is not pre_config_patches: - continue - if post_shard_patches is not None and spec.post_shard_patches is not post_shard_patches: - continue - if tokenizer_registrations is not None and spec.tokenizer_registrations is not tokenizer_registrations: - continue - module_path = spec.optional_module_path(module_name) - if module_path in seen_paths: - continue - module = self._import_optional_module(spec, module_name) - if module is not None: - seen_paths.add(module_path) - modules.append(module) - return tuple(modules) - @staticmethod def _iter_config_classes(module: ModuleType): """Yield local ``PretrainedConfig`` subclasses declared by *module*.""" @@ -344,9 +265,9 @@ def resolve_custom_model_cls(self, architecture: str, config) -> type[nn.Module] to opt out for specific HF configs (e.g. a Mistral3 VLM with a dense Ministral3 text backbone instead of the expected Mistral4 MoE+MLA). """ - if architecture not in self.model_arch_name_to_cls: + if architecture not in self: return None - model_cls = self.model_arch_name_to_cls[architecture] + model_cls = self[architecture] if hasattr(model_cls, "supports_config") and not model_cls.supports_config(config): logger.info( "Custom model %s does not support config %s, falling back to HF", @@ -358,5 +279,6 @@ def resolve_custom_model_cls(self, architecture: str, config) -> type[nn.Module] def register(self, arch_name: str, model_cls: type[nn.Module], exist_ok: bool = False) -> None: """Register a custom model class for a given architecture name.""" - self.model_arch_name_to_cls.register(arch_name, model_cls, exist_ok=exist_ok) - self._rebuild_spec_indexes() + if not exist_ok and arch_name in self._extra_model_classes: + raise ValueError(f"Duplicated model implementation for {arch_name}") + self[arch_name] = model_cls diff --git a/nemo_automodel/_transformers/registry/model_package_spec.py b/nemo_automodel/_transformers/registry/model_package_spec.py index a896e0e34a..29e41bd631 100644 --- a/nemo_automodel/_transformers/registry/model_package_spec.py +++ b/nemo_automodel/_transformers/registry/model_package_spec.py @@ -17,7 +17,27 @@ @dataclass(frozen=True) class ModelPackageSpec: - """Registry metadata for a model package and its optional convention modules.""" + """Registry entry for a model package. + + Args: + package: Python package that owns the model implementation, usually + ``nemo_automodel.components.models.``. + class_name: Model class to lazy-load from ``module_path`` for architectures + declared by this spec. Config-only specs leave this unset. + model_module: Module inside ``package`` that contains ``class_name``. Defaults + to ``model``, producing ``.model``. + config_module: Module inside ``package`` that contains a custom HF config + class. Used only for model types that need on-demand AutoConfig registration. + config_class_name: Explicit config class name to register from + ``config_module_path``. If unset, the registry can discover a local + ``PretrainedConfig`` subclass by ``model_type``. + tags: Registry metadata flags for the model class. For example, + ``retrieval`` marks bidirectional encoder architectures. + architectures: HF ``config.architectures`` names that should resolve to this + model class. These are expanded into the architecture-to-spec lookup. + model_types: HF ``config.model_type`` names associated with this package, + primarily for resolving and registering custom config classes. + """ package: str class_name: str | None = None @@ -27,17 +47,11 @@ class ModelPackageSpec: tags: frozenset[str] = field(default_factory=frozenset) architectures: tuple[str, ...] = () model_types: tuple[str, ...] = () - optional_modules: frozenset[str] = field(default_factory=frozenset) - global_patches: bool = False - pre_config_patches: bool = False - post_shard_patches: bool = False - tokenizer_registrations: bool = False def __post_init__(self) -> None: object.__setattr__(self, "tags", frozenset(self.tags)) object.__setattr__(self, "architectures", tuple(self.architectures)) object.__setattr__(self, "model_types", tuple(self.model_types)) - object.__setattr__(self, "optional_modules", frozenset(self.optional_modules)) @classmethod def from_module_path( @@ -48,13 +62,8 @@ def from_module_path( config_module: str | None = None, config_class_name: str | None = None, tags: set[str] | frozenset[str] | tuple[str, ...] = (), - architectures: tuple[str, ...] = (), + architectures: list[str] | tuple[str, ...] = (), model_types: tuple[str, ...] = (), - optional_modules: set[str] | frozenset[str] | tuple[str, ...] = (), - global_patches: bool = False, - pre_config_patches: bool = False, - post_shard_patches: bool = False, - tokenizer_registrations: bool = False, ) -> "ModelPackageSpec": """Create a spec from a fully qualified model module path.""" package, sep, model_module = module_path.rpartition(".") @@ -70,11 +79,28 @@ def from_module_path( tags=frozenset(tags), architectures=architectures, model_types=model_types, - optional_modules=frozenset(optional_modules), - global_patches=global_patches, - pre_config_patches=pre_config_patches, - post_shard_patches=post_shard_patches, - tokenizer_registrations=tokenizer_registrations, + ) + + @classmethod + def from_model_class( + cls, + model_cls: type, + *, + config_module: str | None = None, + config_class_name: str | None = None, + tags: set[str] | frozenset[str] | tuple[str, ...] = (), + architectures: list[str] | tuple[str, ...] = (), + model_types: tuple[str, ...] = (), + ) -> "ModelPackageSpec": + """Create a spec from a model class object.""" + return cls.from_module_path( + model_cls.__module__, + model_cls.__name__, + config_module=config_module, + config_class_name=config_class_name, + tags=tags, + architectures=architectures, + model_types=model_types, ) @property @@ -84,12 +110,6 @@ def module_path(self) -> str: return self.model_module return f"{self.package}.{self.model_module}" - def optional_module_path(self, module_name: str) -> str: - """Return the import path for a convention module such as ``patches``.""" - if not self.package: - return f"{self.model_module}.{module_name}" - return f"{self.package}.{module_name}" - @property def config_module_path(self) -> str | None: """Return the import path for this model package's config module, if declared.""" diff --git a/nemo_automodel/_transformers/registry/model_registry.py b/nemo_automodel/_transformers/registry/model_registry.py index 1038703798..0f339e2105 100644 --- a/nemo_automodel/_transformers/registry/model_registry.py +++ b/nemo_automodel/_transformers/registry/model_registry.py @@ -12,379 +12,228 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import OrderedDict from functools import lru_cache -from nemo_automodel._transformers.registry.base import _ModelRegistry as _BaseModelRegistry +from nemo_automodel._transformers.registry.base import _BaseModelRegistry from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec -# Static mapping: architecture name → ModelPackageSpec. +# Static model package specs. The registry derives architecture lookups from each spec's architectures. # Analogous to HuggingFace transformers' MODEL_FOR_CAUSAL_LM_MAPPING_NAMES. # Models are loaded lazily on first access rather than imported at startup. -MODEL_ARCH_MAPPING: OrderedDict[str, ModelPackageSpec] = OrderedDict( - [ - ( - "BaichuanForCausalLM", - ModelPackageSpec( - package="nemo_automodel.components.models.baichuan", - class_name="BaichuanForCausalLM", - architectures=("BaichuanForCausalLM",), - ), - ), - ( - "DeepseekV3ForCausalLM", - ModelPackageSpec( - package="nemo_automodel.components.models.deepseek_v3", - class_name="DeepseekV3ForCausalLM", - architectures=("DeepseekV3ForCausalLM",), - ), - ), - ( - "DeepseekV32ForCausalLM", - ModelPackageSpec( - package="nemo_automodel.components.models.deepseek_v32", - class_name="DeepseekV32ForCausalLM", - architectures=("DeepseekV32ForCausalLM",), - ), - ), - ( - "DeepseekV4ForCausalLM", - ModelPackageSpec( - package="nemo_automodel.components.models.deepseek_v4", - class_name="DeepseekV4ForCausalLM", - architectures=("DeepseekV4ForCausalLM",), - ), - ), - ( - "Ernie4_5_MoeForCausalLM", - ModelPackageSpec( - package="nemo_automodel.components.models.ernie4_5", - class_name="Ernie4_5_MoeForCausalLM", - architectures=("Ernie4_5_MoeForCausalLM",), - ), - ), - ( - "Glm4MoeForCausalLM", - ModelPackageSpec( - package="nemo_automodel.components.models.glm4_moe", - class_name="Glm4MoeForCausalLM", - architectures=("Glm4MoeForCausalLM",), - ), - ), - ( - "Glm4MoeLiteForCausalLM", - ModelPackageSpec( - package="nemo_automodel.components.models.glm4_moe_lite", - class_name="Glm4MoeLiteForCausalLM", - architectures=("Glm4MoeLiteForCausalLM",), - ), - ), - ( - "GlmMoeDsaForCausalLM", - ModelPackageSpec( - package="nemo_automodel.components.models.glm_moe_dsa", - class_name="GlmMoeDsaForCausalLM", - architectures=("GlmMoeDsaForCausalLM",), - ), - ), - ( - "Gemma4ForConditionalGeneration", - ModelPackageSpec( - package="nemo_automodel.components.models.gemma4_moe", - class_name="Gemma4ForConditionalGeneration", - architectures=("Gemma4ForConditionalGeneration",), - ), - ), - ( - "GptOssForCausalLM", - ModelPackageSpec( - package="nemo_automodel.components.models.gpt_oss", - class_name="GptOssForCausalLM", - architectures=("GptOssForCausalLM",), - ), - ), - ( - "KimiK25ForConditionalGeneration", - ModelPackageSpec( - package="nemo_automodel.components.models.kimi_k25_vl", - class_name="KimiK25VLForConditionalGeneration", - architectures=("KimiK25ForConditionalGeneration",), - ), - ), - ( - "KimiK25VLForConditionalGeneration", - ModelPackageSpec( - package="nemo_automodel.components.models.kimi_k25_vl", - class_name="KimiK25VLForConditionalGeneration", - architectures=("KimiK25VLForConditionalGeneration",), - ), - ), - ( - "KimiVLForConditionalGeneration", - ModelPackageSpec( - package="nemo_automodel.components.models.kimivl", - class_name="KimiVLForConditionalGeneration", - architectures=("KimiVLForConditionalGeneration",), - ), - ), - ( - "LlamaBidirectionalForSequenceClassification", - ModelPackageSpec( - package="nemo_automodel.components.models.llama_bidirectional", - class_name="LlamaBidirectionalForSequenceClassification", - tags=frozenset({"retrieval"}), - architectures=("LlamaBidirectionalForSequenceClassification",), - ), - ), - ( - "LlamaBidirectionalModel", - ModelPackageSpec( - package="nemo_automodel.components.models.llama_bidirectional", - class_name="LlamaBidirectionalModel", - tags=frozenset({"retrieval"}), - architectures=("LlamaBidirectionalModel",), - ), - ), - ( - "LlamaForCausalLM", - ModelPackageSpec( - package="nemo_automodel.components.models.llama", - class_name="LlamaForCausalLM", - architectures=("LlamaForCausalLM",), - ), - ), - ( - "MiniMaxM2ForCausalLM", - ModelPackageSpec( - package="nemo_automodel.components.models.minimax_m2", - class_name="MiniMaxM2ForCausalLM", - architectures=("MiniMaxM2ForCausalLM",), - ), - ), - ( - "MiMoV2FlashForCausalLM", - ModelPackageSpec( - package="nemo_automodel.components.models.mimo_v2_flash", - class_name="MiMoV2FlashForCausalLM", - architectures=("MiMoV2FlashForCausalLM",), - ), - ), - ( - "Ministral3ForCausalLM", - ModelPackageSpec( - package="nemo_automodel.components.models.mistral3", - class_name="Ministral3ForCausalLM", - architectures=("Ministral3ForCausalLM",), - ), - ), - ( - "Ministral3BidirectionalModel", - ModelPackageSpec( - package="nemo_automodel.components.models.ministral_bidirectional", - class_name="Ministral3BidirectionalModel", - tags=frozenset({"retrieval"}), - architectures=("Ministral3BidirectionalModel",), - ), - ), - ( - "Mistral4ForCausalLM", - ModelPackageSpec( - package="nemo_automodel.components.models.mistral4", - class_name="Mistral4ForCausalLM", - architectures=("Mistral4ForCausalLM",), - ), - ), - ( - "Mistral3ForConditionalGeneration", - ModelPackageSpec( - package="nemo_automodel.components.models.mistral4", - class_name="Mistral3ForConditionalGeneration", - architectures=("Mistral3ForConditionalGeneration",), - ), - ), - ( - "Mistral3FP8VLMForConditionalGeneration", - ModelPackageSpec( - package="nemo_automodel.components.models.mistral3_vlm", - class_name="Mistral3FP8VLMForConditionalGeneration", - architectures=("Mistral3FP8VLMForConditionalGeneration",), - ), - ), - ( - "NemotronHForCausalLM", - ModelPackageSpec( - package="nemo_automodel.components.models.nemotron_v3", - class_name="NemotronHForCausalLM", - architectures=("NemotronHForCausalLM",), - ), - ), - ( - "NemotronH_Nano_Omni_Reasoning_V3", - ModelPackageSpec( - package="nemo_automodel.components.models.nemotron_omni", - class_name="NemotronOmniForConditionalGeneration", - architectures=("NemotronH_Nano_Omni_Reasoning_V3",), - ), - ), - ( - "NemotronParseForConditionalGeneration", - ModelPackageSpec( - package="nemo_automodel.components.models.nemotron_parse", - class_name="NemotronParseForConditionalGeneration", - architectures=("NemotronParseForConditionalGeneration",), - ), - ), - ( - "LLaVAOneVision1_5_ForConditionalGeneration", - ModelPackageSpec( - package="nemo_automodel.components.models.llava_onevision", - class_name="LLaVAOneVision1_5_ForConditionalGeneration", - architectures=("LLaVAOneVision1_5_ForConditionalGeneration",), - ), - ), - ( - "HYV3ForCausalLM", - ModelPackageSpec( - package="nemo_automodel.components.models.hy_v3", - class_name="HYV3ForCausalLM", - architectures=("HYV3ForCausalLM",), - ), - ), - ( - "Qwen2ForCausalLM", - ModelPackageSpec( - package="nemo_automodel.components.models.qwen2", - class_name="Qwen2ForCausalLM", - architectures=("Qwen2ForCausalLM",), - ), - ), - ( - "Qwen3MoeForCausalLM", - ModelPackageSpec( - package="nemo_automodel.components.models.qwen3_moe", - class_name="Qwen3MoeForCausalLM", - architectures=("Qwen3MoeForCausalLM",), - ), - ), - ( - "Qwen3NextForCausalLM", - ModelPackageSpec( - package="nemo_automodel.components.models.qwen3_next", - class_name="Qwen3NextForCausalLM", - architectures=("Qwen3NextForCausalLM",), - ), - ), - ( - "Qwen3OmniMoeForConditionalGeneration", - ModelPackageSpec( - package="nemo_automodel.components.models.qwen3_omni_moe", - class_name="Qwen3OmniMoeThinkerForConditionalGeneration", - architectures=("Qwen3OmniMoeForConditionalGeneration",), - ), - ), - ( - "Qwen3VLMoeForConditionalGeneration", - ModelPackageSpec( - package="nemo_automodel.components.models.qwen3_vl_moe", - class_name="Qwen3VLMoeForConditionalGeneration", - architectures=("Qwen3VLMoeForConditionalGeneration",), - ), - ), - ( - "Qwen3_5MoeForConditionalGeneration", - ModelPackageSpec( - package="nemo_automodel.components.models.qwen3_5_moe", - class_name="Qwen3_5MoeForConditionalGeneration", - architectures=("Qwen3_5MoeForConditionalGeneration",), - ), - ), - ( - "Step3p5ForCausalLM", - ModelPackageSpec( - package="nemo_automodel.components.models.step3p5", - class_name="Step3p5ForCausalLM", - architectures=("Step3p5ForCausalLM",), - ), - ), - ] -) - - -_CUSTOM_CONFIG_SPECS: tuple[ModelPackageSpec, ...] = ( +MODEL_PACKAGE_SPECS: tuple[ModelPackageSpec, ...] = ( ModelPackageSpec( package="nemo_automodel.components.models.baichuan", - model_types=("baichuan",), + class_name="BaichuanForCausalLM", config_module="configuration", config_class_name="BaichuanConfig", + architectures=("BaichuanForCausalLM",), + model_types=("baichuan",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.deepseek_v3", + class_name="DeepseekV3ForCausalLM", + architectures=("DeepseekV3ForCausalLM",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.deepseek_v32", + class_name="DeepseekV32ForCausalLM", + architectures=("DeepseekV32ForCausalLM",), ), ModelPackageSpec( package="nemo_automodel.components.models.deepseek_v4", - model_types=("deepseek_v4",), + class_name="DeepseekV4ForCausalLM", config_module="config", config_class_name="DeepseekV4Config", + architectures=("DeepseekV4ForCausalLM",), + model_types=("deepseek_v4",), ), ModelPackageSpec( - package="nemo_automodel.components.models.hy_v3", - model_types=("hy_v3",), - config_module="config", - config_class_name="HYV3Config", + package="nemo_automodel.components.models.ernie4_5", + class_name="Ernie4_5_MoeForCausalLM", + architectures=("Ernie4_5_MoeForCausalLM",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.glm4_moe", + class_name="Glm4MoeForCausalLM", + architectures=("Glm4MoeForCausalLM",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.glm4_moe_lite", + class_name="Glm4MoeLiteForCausalLM", + architectures=("Glm4MoeLiteForCausalLM",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.glm_moe_dsa", + class_name="GlmMoeDsaForCausalLM", + architectures=("GlmMoeDsaForCausalLM",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.gemma4_moe", + class_name="Gemma4ForConditionalGeneration", + architectures=("Gemma4ForConditionalGeneration",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.gpt_oss", + class_name="GptOssForCausalLM", + architectures=("GptOssForCausalLM",), ), ModelPackageSpec( package="nemo_automodel.components.models.kimi_k25_vl", - model_types=("kimi_k25", "kimi_k25_vl"), + class_name="KimiK25VLForConditionalGeneration", config_module="model", config_class_name="KimiK25VLConfig", + architectures=("KimiK25ForConditionalGeneration", "KimiK25VLForConditionalGeneration"), + model_types=("kimi_k25", "kimi_k25_vl"), ), ModelPackageSpec( package="nemo_automodel.components.models.kimivl", - model_types=("kimi_vl",), + class_name="KimiVLForConditionalGeneration", config_module="model", config_class_name="KimiVLConfig", + architectures=("KimiVLForConditionalGeneration",), + model_types=("kimi_vl",), ), ModelPackageSpec( - package="nemo_automodel.components.models.llava_onevision", - model_types=("llavaonevision1_5",), - config_module="model", - config_class_name="Llavaonevision1_5Config", + package="nemo_automodel.components.models.llama", + class_name="LlamaForCausalLM", + architectures=("LlamaForCausalLM",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.minimax_m2", + class_name="MiniMaxM2ForCausalLM", + architectures=("MiniMaxM2ForCausalLM",), ), ModelPackageSpec( package="nemo_automodel.components.models.mimo_v2_flash", - model_types=("mimo_v2_flash",), + class_name="MiMoV2FlashForCausalLM", config_module="config", config_class_name="MiMoV2FlashConfig", + architectures=("MiMoV2FlashForCausalLM",), + model_types=("mimo_v2_flash",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.mistral3", + class_name="Ministral3ForCausalLM", + architectures=("Ministral3ForCausalLM",), ), ModelPackageSpec( package="nemo_automodel.components.models.mistral4", - model_types=("mistral4",), + class_name="Mistral4ForCausalLM", config_module="configuration", config_class_name="Mistral4Config", + architectures=("Mistral4ForCausalLM",), + model_types=("mistral4",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.mistral4", + class_name="Mistral3ForConditionalGeneration", + architectures=("Mistral3ForConditionalGeneration",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.mistral3_vlm", + class_name="Mistral3FP8VLMForConditionalGeneration", + architectures=("Mistral3FP8VLMForConditionalGeneration",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.nemotron_v3", + class_name="NemotronHForCausalLM", + architectures=("NemotronHForCausalLM",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.nemotron_omni", + class_name="NemotronOmniForConditionalGeneration", + architectures=("NemotronH_Nano_Omni_Reasoning_V3",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.nemotron_parse", + class_name="NemotronParseForConditionalGeneration", + architectures=("NemotronParseForConditionalGeneration",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.llava_onevision", + class_name="LLaVAOneVision1_5_ForConditionalGeneration", + config_module="model", + config_class_name="Llavaonevision1_5Config", + architectures=("LLaVAOneVision1_5_ForConditionalGeneration",), + model_types=("llavaonevision1_5",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.hy_v3", + class_name="HYV3ForCausalLM", + config_module="config", + config_class_name="HYV3Config", + architectures=("HYV3ForCausalLM",), + model_types=("hy_v3",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.qwen2", + class_name="Qwen2ForCausalLM", + architectures=("Qwen2ForCausalLM",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.qwen3_moe", + class_name="Qwen3MoeForCausalLM", + architectures=("Qwen3MoeForCausalLM",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.qwen3_next", + class_name="Qwen3NextForCausalLM", + architectures=("Qwen3NextForCausalLM",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.qwen3_omni_moe", + class_name="Qwen3OmniMoeThinkerForConditionalGeneration", + architectures=("Qwen3OmniMoeForConditionalGeneration",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.qwen3_vl_moe", + class_name="Qwen3VLMoeForConditionalGeneration", + architectures=("Qwen3VLMoeForConditionalGeneration",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.qwen3_5_moe", + class_name="Qwen3_5MoeForConditionalGeneration", + architectures=("Qwen3_5MoeForConditionalGeneration",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.step3p5", + class_name="Step3p5ForCausalLM", + architectures=("Step3p5ForCausalLM",), ), ) +RETRIEVAL_MODEL_PACKAGE_SPECS: tuple[ModelPackageSpec, ...] = ( + ModelPackageSpec( + package="nemo_automodel.components.models.llama_bidirectional", + class_name="LlamaBidirectionalForSequenceClassification", + architectures=("LlamaBidirectionalForSequenceClassification",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.llama_bidirectional", + class_name="LlamaBidirectionalModel", + architectures=("LlamaBidirectionalModel",), + ), + ModelPackageSpec( + package="nemo_automodel.components.models.ministral_bidirectional", + class_name="Ministral3BidirectionalModel", + architectures=("Ministral3BidirectionalModel",), + ), +) -MODEL_PACKAGE_SPECS: tuple[ModelPackageSpec, ...] = (*_CUSTOM_CONFIG_SPECS,) - - -class _ModelRegistry(_BaseModelRegistry): - """Model-specific registry initialized from AutoModel's static mappings.""" - def __init__( - self, - model_arch_name_to_cls=None, - package_specs: tuple[ModelPackageSpec, ...] = MODEL_PACKAGE_SPECS, - model_arch_mapping: OrderedDict[str, ModelPackageSpec] | dict[str, ModelPackageSpec] = MODEL_ARCH_MAPPING, - ) -> None: - super().__init__( - model_arch_mapping=model_arch_mapping, - model_arch_name_to_cls=model_arch_name_to_cls, - package_specs=package_specs, - ) +@lru_cache +def make_registry(model_specs: tuple[ModelPackageSpec, ...] = MODEL_PACKAGE_SPECS) -> _BaseModelRegistry: + """Return the process-wide model registry singleton.""" + return _BaseModelRegistry(model_specs=model_specs) @lru_cache -def get_registry() -> _ModelRegistry: - """Return the process-wide model registry singleton.""" - return _ModelRegistry() +def make_retrieval_registry( + model_specs: tuple[ModelPackageSpec, ...] = RETRIEVAL_MODEL_PACKAGE_SPECS, +) -> _BaseModelRegistry: + """Return the process-wide retrieval model registry singleton.""" + return _BaseModelRegistry(model_specs=model_specs) -ModelRegistry = get_registry() +ModelRegistry = make_registry() +RetrievalModelRegistry = make_retrieval_registry() diff --git a/nemo_automodel/_transformers/retrieval.py b/nemo_automodel/_transformers/retrieval.py index 7287fed41d..e5dba30446 100644 --- a/nemo_automodel/_transformers/retrieval.py +++ b/nemo_automodel/_transformers/retrieval.py @@ -25,7 +25,7 @@ from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING from transformers.utils import logging -from nemo_automodel._transformers.registry import ModelRegistry +from nemo_automodel._transformers.registry import RetrievalModelRegistry from nemo_automodel.components.models.common.bidirectional import EncoderStateDictAdapter logger = logging.get_logger(__name__) @@ -57,11 +57,11 @@ def _get_supported_backbone_class(model_type: str, task: str) -> type[nn.Module] f"Unsupported task '{task}' for model type '{model_type}'. Available tasks: {', '.join(task_map)}." ) - if arch_name not in ModelRegistry.model_arch_name_to_cls: - raise ValueError(f"Model class '{arch_name}' not found in ModelRegistry.") + if not RetrievalModelRegistry.has_custom_model(arch_name): + raise ValueError(f"Model class '{arch_name}' not found in RetrievalModelRegistry.") logger.info(f"Using {arch_name} from registry") - return ModelRegistry.model_arch_name_to_cls[arch_name] + return RetrievalModelRegistry.get_model_cls_from_model_arch(arch_name) def _move_to_extracted_dtype(model: nn.Module, extracted_model: nn.Module) -> nn.Module: @@ -179,7 +179,7 @@ def configure_encoder_metadata(model: PreTrainedModel, config) -> None: """Configure HuggingFace consolidated checkpoint metadata on a model. Sets ``config.architectures`` unconditionally. For custom retrieval - architectures registered in :class:`ModelRegistry`, also writes + architectures registered in :class:`RetrievalModelRegistry`, also writes ``config.auto_map`` so that the saved checkpoint can be reloaded via HuggingFace Auto classes. Standard HF models already have their own auto-resolution and do not need ``auto_map`` entries. @@ -193,7 +193,7 @@ def configure_encoder_metadata(model: PreTrainedModel, config) -> None: # Only set auto_map for custom retrieval architectures. # Standard HF models don't need auto_map pointing to a local model.py. - if ModelRegistry.has_retrieval_model(encoder_class_name): + if RetrievalModelRegistry.get_model_package_spec(encoder_class_name) is not None: config_class_name = config.__class__.__name__ config_module = config.__class__.__module__.rsplit(".", 1)[-1] model_module = model.__class__.__module__.rsplit(".", 1)[-1] @@ -226,8 +226,8 @@ def build_encoder_backbone( Without ``extract_submodel``, model types listed in :data:`SUPPORTED_BACKBONES` resolve to custom bidirectional classes from - :class:`ModelRegistry`; all other model types fall back to HuggingFace Auto - classes. + :class:`RetrievalModelRegistry`; all other model types fall back to + HuggingFace Auto classes. Args: model_name_or_path: Path or HuggingFace Hub identifier. @@ -247,7 +247,7 @@ def build_encoder_backbone( Raises: ValueError: If the task is unsupported for a known model type, or the - architecture class is missing from :class:`ModelRegistry`. + architecture class is missing from :class:`RetrievalModelRegistry`. """ config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code) model_type = getattr(config, "model_type", "") @@ -320,7 +320,7 @@ def save_encoder_pretrained(model: nn.Module, save_directory: str, **kwargs) -> model.model.save_pretrained(save_directory) -# HuggingFace model_type -> task -> bidirectional architecture class name in ModelRegistry +# HuggingFace model_type -> task -> bidirectional architecture class name in RetrievalModelRegistry _LLAMA_TASKS = { "embedding": "LlamaBidirectionalModel", "score": "LlamaBidirectionalForSequenceClassification", @@ -340,7 +340,7 @@ def _init_encoder_common(encoder: nn.Module, model: PreTrainedModel) -> None: """Shared init for BiEncoderModel and CrossEncoderModel.""" encoder.model = model encoder.config = model.config - if ModelRegistry.has_retrieval_model(model.__class__.__name__): + if RetrievalModelRegistry.get_model_package_spec(model.__class__.__name__) is not None: encoder.name_or_path = os.path.dirname(inspect.getfile(type(model))) else: encoder.name_or_path = getattr(model.config, "name_or_path", "") diff --git a/tests/integration/test_llava_onevision_cpu.py b/tests/integration/test_llava_onevision_cpu.py index 8802734e36..94f7afbc08 100644 --- a/tests/integration/test_llava_onevision_cpu.py +++ b/tests/integration/test_llava_onevision_cpu.py @@ -317,10 +317,10 @@ def test_registry_registration(): arch_name = "LLaVAOneVision1_5_ForConditionalGeneration" - assert arch_name in ModelRegistry.model_arch_name_to_cls, f"{arch_name} not found in MODEL_ARCH_MAPPING" + assert arch_name in ModelRegistry.model_arch_name_to_cls, f"{arch_name} not found in MODEL_PACKAGE_SPECS" model_cls = ModelRegistry.model_arch_name_to_cls[arch_name] - logger.info(" ✓ Model registered in MODEL_ARCH_MAPPING") + logger.info(" ✓ Model registered in MODEL_PACKAGE_SPECS") logger.info(f" - Architecture name: {arch_name}") logger.info(f" - Model class: {model_cls.__name__}") logger.info(f" - Module: {model_cls.__module__}") diff --git a/tests/unit_tests/_transformers/test_doc_coverage.py b/tests/unit_tests/_transformers/test_doc_coverage.py index ab96733721..31cc333147 100644 --- a/tests/unit_tests/_transformers/test_doc_coverage.py +++ b/tests/unit_tests/_transformers/test_doc_coverage.py @@ -13,7 +13,7 @@ # limitations under the License. """Guard against the gemma4-style regression where a new model architecture -lands in ``MODEL_ARCH_MAPPING`` without any corresponding page under +lands in ``MODEL_PACKAGE_SPECS`` without any corresponding page under ``docs/model-coverage/``. """ @@ -88,14 +88,14 @@ def _repo_root() -> pathlib.Path: def test_every_registered_arch_has_model_coverage_doc(): - """Every architecture in ``MODEL_ARCH_MAPPING`` must be mentioned in at + """Every architecture in ``MODEL_PACKAGE_SPECS`` must be mentioned in at least one ``docs/model-coverage/*.md`` file, either by its own name or by a mapped alias in ``_DOC_ARCH_ALIASES``. This guards against the regression where a new arch (e.g. gemma4) is registered but never gets a corresponding model card in the docs. """ - from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING + from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS docs_dir = _repo_root() / "docs" / "model-coverage" assert docs_dir.is_dir(), f"docs/model-coverage/ not found at {docs_dir}" @@ -104,10 +104,11 @@ def test_every_registered_arch_has_model_coverage_doc(): assert md_contents, "No .md files found under docs/model-coverage/" missing = [] - for arch_name in MODEL_ARCH_MAPPING: - needle = _DOC_ARCH_ALIASES.get(arch_name, arch_name) - if not any(needle in content for content in md_contents): - missing.append((arch_name, needle)) + for spec in MODEL_PACKAGE_SPECS: + for arch_name in spec.architectures: + needle = _DOC_ARCH_ALIASES.get(arch_name, arch_name) + if not any(needle in content for content in md_contents): + missing.append((arch_name, needle)) if missing: details = "\n".join(f" - {arch} (looked for {repr(needle)})" for arch, needle in missing) diff --git a/tests/unit_tests/_transformers/test_recipe_doc_coverage.py b/tests/unit_tests/_transformers/test_recipe_doc_coverage.py index 2217e2692e..17809bd99b 100644 --- a/tests/unit_tests/_transformers/test_recipe_doc_coverage.py +++ b/tests/unit_tests/_transformers/test_recipe_doc_coverage.py @@ -21,7 +21,7 @@ least one ``docs/model-coverage/*.md`` file. This complements ``test_doc_coverage.py`` which only covers archs registered -in ``MODEL_ARCH_MAPPING``. Many recipes fine-tune HF-native archs (e.g., +in ``MODEL_PACKAGE_SPECS``. Many recipes fine-tune HF-native archs (e.g., ``Olmo2ForCausalLM``) that never get added to the registry, and still need documentation. diff --git a/tests/unit_tests/_transformers/test_registry.py b/tests/unit_tests/_transformers/test_registry.py index 5508e61921..8ebf34ce88 100644 --- a/tests/unit_tests/_transformers/test_registry.py +++ b/tests/unit_tests/_transformers/test_registry.py @@ -20,10 +20,14 @@ def _new_registry_instance(registry_module): """Create a fresh registry with an empty auto_map for testing.""" - from nemo_automodel._transformers.registry import _LazyArchMapping + return registry_module._BaseModelRegistry(model_specs=()) - mapping = _LazyArchMapping(auto_map={}) - return registry_module._ModelRegistry(model_arch_name_to_cls=mapping) + +def _model_arch_lookup(model_arch_mapping): + """Return the architecture-keyed lookup derived from registry package specs.""" + from nemo_automodel._transformers.registry.base import _normalize_model_arch_mapping + + return _normalize_model_arch_mapping(model_arch_mapping) def test_register_single_class(): @@ -91,6 +95,27 @@ class ReplacementClass: assert inst.model_arch_name_to_cls["MyArch"] is ReplacementClass +def test_deprecated_retrieval_registration_methods_warn(): + from nemo_automodel._transformers import registry as reg + + inst = _new_registry_instance(reg) + + with pytest.warns(DeprecationWarning, match="register_retrieval"): + inst.register_retrieval("ManualRetrieval") + + with pytest.warns(DeprecationWarning, match="has_retrieval_model"): + assert inst.has_retrieval_model("ManualRetrieval") is True + + +def test_deprecated_retrieval_lookup_delegates_to_retrieval_registry(): + from nemo_automodel._transformers import registry as reg + + inst = _new_registry_instance(reg) + + with pytest.warns(DeprecationWarning, match="has_retrieval_model"): + assert inst.has_retrieval_model("LlamaBidirectionalModel") is True + + def test_supported_models_and_getter(): from nemo_automodel._transformers import registry as reg @@ -105,12 +130,21 @@ class A: assert inst.get_model_cls_from_model_arch("A") is A -def test_get_registry_is_cached(): +def test_make_registry_is_cached(): from nemo_automodel._transformers import registry as reg - reg.get_registry.cache_clear() - r1 = reg.get_registry() - r2 = reg.get_registry() + reg.make_registry.cache_clear() + r1 = reg.make_registry() + r2 = reg.make_registry() + assert r1 is r2 + + +def test_make_retrieval_registry_is_cached(): + from nemo_automodel._transformers import registry as reg + + reg.make_retrieval_registry.cache_clear() + r1 = reg.make_retrieval_registry() + r2 = reg.make_retrieval_registry() assert r1 is r2 @@ -119,42 +153,45 @@ def test_registry_reexports_split_modules(): from nemo_automodel._transformers import registry as reg from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec from nemo_automodel._transformers.registry.model_registry import ( - _CUSTOM_CONFIG_SPECS, - MODEL_ARCH_MAPPING, MODEL_PACKAGE_SPECS, + RETRIEVAL_MODEL_PACKAGE_SPECS, + RetrievalModelRegistry, + make_registry, + make_retrieval_registry, ) assert reg.ModelPackageSpec is ModelPackageSpec - assert reg.MODEL_ARCH_MAPPING is MODEL_ARCH_MAPPING assert reg.MODEL_PACKAGE_SPECS is MODEL_PACKAGE_SPECS - assert reg._CUSTOM_CONFIG_SPECS is _CUSTOM_CONFIG_SPECS + assert reg.RETRIEVAL_MODEL_PACKAGE_SPECS is RETRIEVAL_MODEL_PACKAGE_SPECS + assert reg.RetrievalModelRegistry is RetrievalModelRegistry + assert reg.make_registry is make_registry + assert reg.make_retrieval_registry is make_retrieval_registry -def test_lazy_arch_mapping_auto_map(): +def test_registry_mapping_loads_model_class_on_demand(): """Static auto_map entries are lazily loaded on first access.""" - from nemo_automodel._transformers.registry import _LazyArchMapping + from nemo_automodel._transformers import registry as reg from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec class FakeClass: pass fake_module = types.SimpleNamespace(FakeClass=FakeClass) - spec = ModelPackageSpec(package="fake", model_module="module", class_name="FakeClass") - mapping = _LazyArchMapping({"FakeArch": spec}) - - mapping._modules["fake.module"] = fake_module + spec = ModelPackageSpec(package="fake", model_module="module", class_name="FakeClass", architectures=("FakeArch",)) + inst = reg._BaseModelRegistry(model_specs=(spec,)) - assert "FakeArch" in mapping - assert mapping["FakeArch"] is FakeClass - assert "FakeArch" in mapping._loaded + with patch("nemo_automodel._transformers.registry.base.importlib.import_module", return_value=fake_module): + assert "FakeArch" in inst.model_arch_name_to_cls + assert inst.model_arch_name_to_cls["FakeArch"] is FakeClass + assert "FakeArch" in inst._loaded_model_classes with pytest.raises(KeyError): - mapping["NonExistent"] + inst.model_arch_name_to_cls["NonExistent"] -def test_lazy_arch_mapping_accepts_model_package_spec_entries(): +def test_registry_accepts_model_package_spec_entries(): """Static auto_map entries may be declared as ModelPackageSpec metadata.""" - from nemo_automodel._transformers.registry import _LazyArchMapping + from nemo_automodel._transformers import registry as reg from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec class FakeClass: @@ -164,20 +201,82 @@ class FakeClass: spec = ModelPackageSpec.from_module_path( "fake.package.model", "FakeClass", + architectures=["FakeArch"], + model_types=("fake",), + ) + inst = reg._BaseModelRegistry(model_specs={"FakeArch": spec}) + + with patch("nemo_automodel._transformers.registry.base.importlib.import_module", return_value=fake_module): + assert inst.model_arch_name_to_cls["FakeArch"] is FakeClass + assert inst.model_arch_name_to_cls.keys() == {"FakeArch"} + + +def test_model_package_spec_from_model_class(): + """ModelPackageSpec can be derived directly from a model class object.""" + from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec + + class FakeClass: + pass + + FakeClass.__module__ = "fake.package.model" + + spec = ModelPackageSpec.from_model_class(FakeClass, architectures=["FakeArch"]) + + assert spec.module_path == "fake.package.model" + assert spec.class_name == "FakeClass" + assert spec.architectures == ("FakeArch",) + + +def test_registry_derives_keys_from_spec_architectures(): + """Tuple-based auto_map entries derive lookup keys from spec.architectures.""" + from nemo_automodel._transformers import registry as reg + from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec + + spec = ModelPackageSpec( + package="fake.package", + class_name="FakeClass", + architectures=("FakeArch", "FakeAlias"), + ) + inst = reg._BaseModelRegistry(model_specs=(spec,)) + + assert inst.model_arch_name_to_cls.keys() == {"FakeArch", "FakeAlias"} + + +def test_base_registry_owns_architecture_metadata(): + """Package metadata indexes live on the registry, not on the lazy class loader.""" + from nemo_automodel._transformers import registry as reg + from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec + + spec = ModelPackageSpec( + package="fake.package", + class_name="FakeClass", + architectures=("FakeArch", "FakeAlias"), model_types=("fake",), - optional_modules=frozenset({"patches"}), ) - mapping = _LazyArchMapping({"FakeArch": spec}) - mapping._modules["fake.package.model"] = fake_module + inst = reg._BaseModelRegistry(model_specs=(spec,)) + + assert inst.get_model_package_spec("FakeArch") is spec + assert inst.get_model_package_spec("FakeAlias") is spec + assert inst.get_model_package_specs_for_model_type("fake") == (spec,) + + +def test_registry_rejects_duplicate_architectures(): + """Duplicate architecture names across package specs fail during registry construction.""" + from nemo_automodel._transformers import registry as reg + from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec + + specs = ( + ModelPackageSpec(package="fake.package_a", class_name="FakeClassA", architectures=("FakeArch",)), + ModelPackageSpec(package="fake.package_b", class_name="FakeClassB", architectures=("FakeArch",)), + ) - assert mapping["FakeArch"] is FakeClass - assert mapping.get_spec("FakeArch").package == "fake.package" - assert mapping.get_spec("FakeArch").model_types == ("fake",) + with pytest.raises(ValueError, match="Duplicated model architecture entry for 'FakeArch'"): + reg._BaseModelRegistry(model_specs=specs) -def test_lazy_arch_mapping_extra_overrides_auto_map(): +def test_registry_dynamic_entry_overrides_static_entry(): """Dynamically registered entries take precedence over static entries.""" - from nemo_automodel._transformers.registry import _LazyArchMapping + from nemo_automodel._transformers import registry as reg from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec class StaticClass: @@ -187,88 +286,99 @@ class DynamicClass: pass fake_module = types.SimpleNamespace(StaticClass=StaticClass) - spec = ModelPackageSpec(package="fake", model_module="module", class_name="StaticClass") - mapping = _LazyArchMapping({"MyArch": spec}) - mapping._modules["fake.module"] = fake_module + spec = ModelPackageSpec(package="fake", model_module="module", class_name="StaticClass", architectures=("MyArch",)) + inst = reg._BaseModelRegistry(model_specs=(spec,)) - assert mapping["MyArch"] is StaticClass + with patch("nemo_automodel._transformers.registry.base.importlib.import_module", return_value=fake_module): + assert inst.model_arch_name_to_cls["MyArch"] is StaticClass - mapping["MyArch"] = DynamicClass - assert mapping["MyArch"] is DynamicClass + inst.model_arch_name_to_cls["MyArch"] = DynamicClass + assert inst.model_arch_name_to_cls["MyArch"] is DynamicClass + assert inst.get_model_package_spec("MyArch").class_name == "DynamicClass" -def test_lazy_arch_mapping_unavailable_model(): +def test_registry_unavailable_model(): """Auto_map entries whose imports fail are removed and excluded from containment.""" - from nemo_automodel._transformers.registry import _LazyArchMapping + from nemo_automodel._transformers import registry as reg from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec - spec = ModelPackageSpec(package="nonexistent.module", model_module="path", class_name="BadClass") - mapping = _LazyArchMapping({"BadArch": spec}) - - assert "BadArch" not in mapping - assert mapping.get_spec("BadArch") is None - - -def test_default_registry_has_static_entries(): - """The default registry is populated from MODEL_ARCH_MAPPING.""" - from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING, _ModelRegistry + spec = ModelPackageSpec( + package="nonexistent.module", + model_module="path", + class_name="BadClass", + architectures=("BadArch",), + ) + inst = reg._BaseModelRegistry(model_specs=(spec,)) - inst = _ModelRegistry() - for arch_name in MODEL_ARCH_MAPPING: - assert arch_name in inst.model_arch_name_to_cls.keys() + assert "BadArch" not in inst.model_arch_name_to_cls + assert "BadArch" not in inst.model_arch_name_to_cls.keys() -def test_registry_imports_optional_module_for_architecture(): - """Convention modules are imported through ModelPackageSpec package metadata.""" +def test_registry_discards_failed_import_from_metadata_index(): + """Failed model imports remove the architecture from the registry-owned metadata index.""" from nemo_automodel._transformers import registry as reg - from nemo_automodel._transformers.registry import _LazyArchMapping from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec - fake_module = types.SimpleNamespace(__name__="fake.package.patches") spec = ModelPackageSpec( - package="fake.package", - architectures=("FakeArch",), - optional_modules=frozenset({"patches"}), + package="nonexistent.module", + model_module="path", + class_name="BadClass", + architectures=("BadArch",), ) - inst = reg._ModelRegistry(model_arch_name_to_cls=_LazyArchMapping(auto_map={}), package_specs=(spec,)) + inst = reg._BaseModelRegistry(model_specs=(spec,)) - with patch( - "nemo_automodel._transformers.registry.base.importlib.import_module", return_value=fake_module - ) as mock_import: - modules = inst.iter_optional_modules_for_architectures(("FakeArch",), "patches") + assert "BadArch" not in inst.model_arch_name_to_cls + assert inst.get_model_package_spec("BadArch") is None + assert "BadArch" not in inst.supported_models - mock_import.assert_called_once_with("fake.package.patches") - assert modules == (fake_module,) + class GoodClass: + pass + + inst.register("GoodArch", GoodClass) + assert "BadArch" not in inst.model_arch_name_to_cls.keys() -def test_registry_filters_optional_modules_by_scope(): - """Global patch lookup imports only specs that advertise global patches.""" +def test_direct_lazy_registration_updates_registry_metadata(): + """Direct mapping assignment still informs the registry about dynamic model metadata.""" from nemo_automodel._transformers import registry as reg - from nemo_automodel._transformers.registry import _LazyArchMapping - from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec - global_module = types.SimpleNamespace(__name__="fake.global_model.patches") - specs = ( - ModelPackageSpec( - package="fake.global_model", - optional_modules=frozenset({"patches"}), - global_patches=True, - ), - ModelPackageSpec( - package="fake.runtime_model", - optional_modules=frozenset({"patches"}), - global_patches=False, - ), - ) - inst = reg._ModelRegistry(model_arch_name_to_cls=_LazyArchMapping(auto_map={}), package_specs=specs) + inst = _new_registry_instance(reg) + + class DirectClass: + pass - with patch( - "nemo_automodel._transformers.registry.base.importlib.import_module", return_value=global_module - ) as mock_import: - modules = inst.iter_optional_modules("patches", global_patches=True) + DirectClass.__module__ = "fake.direct.model" + inst.model_arch_name_to_cls["DirectArch"] = DirectClass - mock_import.assert_called_once_with("fake.global_model.patches") - assert modules == (global_module,) + spec = inst.get_model_package_spec("DirectArch") + assert spec.module_path == "fake.direct.model" + assert spec.class_name == "DirectClass" + assert inst.model_arch_name_to_cls["DirectArch"] is DirectClass + + +def test_default_registry_has_static_entries(): + """The default registry is populated from MODEL_PACKAGE_SPECS.""" + from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS, make_registry + + inst = make_registry() + for spec in MODEL_PACKAGE_SPECS: + for arch_name in spec.architectures: + assert arch_name in inst.model_arch_name_to_cls.keys() + + +def test_retrieval_registry_has_separate_static_entries(): + """Retrieval architectures live in the retrieval registry, not the default registry.""" + from nemo_automodel._transformers import registry as reg + from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS, RETRIEVAL_MODEL_PACKAGE_SPECS + + retrieval_arch = "LlamaBidirectionalModel" + default_registry = reg._BaseModelRegistry(model_specs=MODEL_PACKAGE_SPECS) + retrieval_registry = reg._BaseModelRegistry(model_specs=RETRIEVAL_MODEL_PACKAGE_SPECS) + + assert retrieval_arch not in _model_arch_lookup(MODEL_PACKAGE_SPECS) + assert retrieval_arch in _model_arch_lookup(RETRIEVAL_MODEL_PACKAGE_SPECS) + assert retrieval_arch not in default_registry.model_arch_name_to_cls.keys() + assert retrieval_arch in retrieval_registry.model_arch_name_to_cls.keys() def test_registry_discovers_config_class_from_config_module(): @@ -276,7 +386,6 @@ def test_registry_discovers_config_class_from_config_module(): from transformers import PretrainedConfig from nemo_automodel._transformers import registry as reg - from nemo_automodel._transformers.registry import _LazyArchMapping from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec class FakeConfig(PretrainedConfig): @@ -285,7 +394,7 @@ class FakeConfig(PretrainedConfig): FakeConfig.__module__ = "fake.package.config" fake_module = types.SimpleNamespace(__name__="fake.package.config", FakeConfig=FakeConfig) spec = ModelPackageSpec(package="fake.package", model_types=("fake_model",), config_module="config") - inst = reg._ModelRegistry(model_arch_name_to_cls=_LazyArchMapping(auto_map={}), package_specs=(spec,)) + inst = reg._BaseModelRegistry(model_specs=(spec,)) with ( patch("transformers.AutoConfig.register") as mock_register, @@ -301,7 +410,6 @@ def test_registry_registers_explicit_config_alias(): from transformers import PretrainedConfig from nemo_automodel._transformers import registry as reg - from nemo_automodel._transformers.registry import _LazyArchMapping from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec class FakeConfig(PretrainedConfig): @@ -314,7 +422,7 @@ class FakeConfig(PretrainedConfig): config_module="model", config_class_name="FakeConfig", ) - inst = reg._ModelRegistry(model_arch_name_to_cls=_LazyArchMapping(auto_map={}), package_specs=(spec,)) + inst = reg._BaseModelRegistry(model_specs=(spec,)) with ( patch("transformers.AutoConfig.register") as mock_register, @@ -331,7 +439,6 @@ def test_registry_registers_config_alias_when_auto_config_rejects_mismatch(): from transformers.models.auto.configuration_auto import CONFIG_MAPPING from nemo_automodel._transformers import registry as reg - from nemo_automodel._transformers.registry import _LazyArchMapping from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec class FakeConfig(PretrainedConfig): @@ -344,7 +451,7 @@ class FakeConfig(PretrainedConfig): config_module="model", config_class_name="FakeConfig", ) - inst = reg._ModelRegistry(model_arch_name_to_cls=_LazyArchMapping(auto_map={}), package_specs=(spec,)) + inst = reg._BaseModelRegistry(model_specs=(spec,)) with ( patch("transformers.AutoConfig.register", side_effect=ValueError("model_type mismatch")), @@ -426,49 +533,53 @@ def supports_config(cls, config): assert inst.resolve_custom_model_cls("ConfigAwareModel", bad) is None -def test_custom_config_specs_are_metadata_only_until_requested(): - """Custom config entries should not require eager registration at registry import time.""" - from nemo_automodel._transformers.registry import _CUSTOM_CONFIG_SPECS +def test_config_metadata_is_metadata_only_until_requested(): + """Config metadata should not require eager registration at registry import time.""" + from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS - for spec in _CUSTOM_CONFIG_SPECS: + for spec in MODEL_PACKAGE_SPECS: + if spec.config_module_path is None: + continue assert spec.config_module_path assert spec.config_class_name -def test_kimi_k25_arch_alias_in_model_arch_mapping(): +def test_kimi_k25_arch_alias_in_model_package_specs(): """KimiK25ForConditionalGeneration (checkpoint arch) must map to KimiK25VLForConditionalGeneration.""" - from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING + from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS - assert "KimiK25ForConditionalGeneration" in MODEL_ARCH_MAPPING, ( - "KimiK25ForConditionalGeneration missing from MODEL_ARCH_MAPPING. " + mapping = _model_arch_lookup(MODEL_PACKAGE_SPECS) + assert "KimiK25ForConditionalGeneration" in mapping, ( + "KimiK25ForConditionalGeneration missing from MODEL_PACKAGE_SPECS. " "Kimi-K2.5 checkpoints use this architecture name and need it mapped " "to KimiK25VLForConditionalGeneration." ) - spec = MODEL_ARCH_MAPPING["KimiK25ForConditionalGeneration"] + spec = mapping["KimiK25ForConditionalGeneration"] assert spec.class_name == "KimiK25VLForConditionalGeneration" -def test_deepseek_v4_registered_in_arch_mapping(): - """DeepseekV4ForCausalLM must be registered in MODEL_ARCH_MAPPING.""" - from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING +def test_deepseek_v4_registered_in_model_package_specs(): + """DeepseekV4ForCausalLM must be registered in MODEL_PACKAGE_SPECS.""" + from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS - assert "DeepseekV4ForCausalLM" in MODEL_ARCH_MAPPING, ( - "DeepseekV4ForCausalLM missing from MODEL_ARCH_MAPPING. " + mapping = _model_arch_lookup(MODEL_PACKAGE_SPECS) + assert "DeepseekV4ForCausalLM" in mapping, ( + "DeepseekV4ForCausalLM missing from MODEL_PACKAGE_SPECS. " "DSV4 checkpoints declare this architecture and need it routed to the " "in-tree model implementation." ) - spec = MODEL_ARCH_MAPPING["DeepseekV4ForCausalLM"] + spec = mapping["DeepseekV4ForCausalLM"] assert spec.module_path == "nemo_automodel.components.models.deepseek_v4.model" assert spec.class_name == "DeepseekV4ForCausalLM" -def test_deepseek_v4_in_custom_config_specs(): - """deepseek_v4 model_type must be declared in _CUSTOM_CONFIG_SPECS.""" - from nemo_automodel._transformers.registry import _CUSTOM_CONFIG_SPECS +def test_deepseek_v4_in_model_package_specs(): + """deepseek_v4 model_type must be declared in MODEL_PACKAGE_SPECS.""" + from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS - specs_by_model_type = {model_type: spec for spec in _CUSTOM_CONFIG_SPECS for model_type in spec.model_types} + specs_by_model_type = {model_type: spec for spec in MODEL_PACKAGE_SPECS for model_type in spec.model_types} assert "deepseek_v4" in specs_by_model_type, ( - "deepseek_v4 must be in _CUSTOM_CONFIG_SPECS so ModelRegistry can register " + "deepseek_v4 must be in MODEL_PACKAGE_SPECS so ModelRegistry can register " "DSV4 configs on demand before AutoConfig.from_pretrained runs." ) spec = specs_by_model_type["deepseek_v4"] @@ -477,20 +588,20 @@ def test_deepseek_v4_in_custom_config_specs(): def test_all_model_folders_registered_in_auto_map(): - """Every model folder with a model.py must have at least one entry in MODEL_ARCH_MAPPING. + """Every model folder with a model.py must have at least one registry package-spec entry. This catches the case where a developer adds a new model directory under ``nemo_automodel/components/models/`` but forgets to add it to the static - ``MODEL_ARCH_MAPPING`` in ``registry.py``. + registry package specs in ``registry.py``. """ import pathlib - from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING + from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS, RETRIEVAL_MODEL_PACKAGE_SPECS models_root = pathlib.Path(__file__).resolve().parents[3] / "nemo_automodel" / "components" / "models" - # Collect the set of module paths referenced by the auto_map - registered_module_paths = {spec.module_path for spec in MODEL_ARCH_MAPPING.values()} + # Collect the set of module paths referenced by the registries. + registered_module_paths = {spec.module_path for spec in (*MODEL_PACKAGE_SPECS, *RETRIEVAL_MODEL_PACKAGE_SPECS)} missing = [] documentation_only_model_dirs = {"blueprint"} @@ -508,6 +619,6 @@ def test_all_model_folders_registered_in_auto_map(): assert not missing, ( f"Model folder(s) {missing} contain a model.py but are not registered " - f"in MODEL_ARCH_MAPPING (registry.py). Add an entry for each architecture " + f"in the registry package specs (registry.py). Add an entry for each architecture " f"exported by these modules." ) diff --git a/tests/unit_tests/_transformers/test_registry_hy_v3.py b/tests/unit_tests/_transformers/test_registry_hy_v3.py index d4d0988157..f375689b74 100644 --- a/tests/unit_tests/_transformers/test_registry_hy_v3.py +++ b/tests/unit_tests/_transformers/test_registry_hy_v3.py @@ -17,14 +17,16 @@ class TestArchMapping: def test_hyv3_arch_registered(self): - from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING + from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS + from nemo_automodel._transformers.registry.base import _normalize_model_arch_mapping - assert "HYV3ForCausalLM" in MODEL_ARCH_MAPPING + assert "HYV3ForCausalLM" in _normalize_model_arch_mapping(MODEL_PACKAGE_SPECS) def test_hyv3_arch_points_at_correct_module(self): - from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING + from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS + from nemo_automodel._transformers.registry.base import _normalize_model_arch_mapping - entry = MODEL_ARCH_MAPPING["HYV3ForCausalLM"] + entry = _normalize_model_arch_mapping(MODEL_PACKAGE_SPECS)["HYV3ForCausalLM"] assert entry.module_path == "nemo_automodel.components.models.hy_v3.model" assert entry.class_name == "HYV3ForCausalLM" @@ -32,9 +34,10 @@ def test_hyv3_arch_resolves_to_class(self): """Walk the mapping path -- importable + the named class exists.""" import importlib - from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING + from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS + from nemo_automodel._transformers.registry.base import _normalize_model_arch_mapping - spec = MODEL_ARCH_MAPPING["HYV3ForCausalLM"] + spec = _normalize_model_arch_mapping(MODEL_PACKAGE_SPECS)["HYV3ForCausalLM"] mod_path = spec.module_path cls_name = spec.class_name mod = importlib.import_module(mod_path) @@ -43,16 +46,16 @@ def test_hyv3_arch_resolves_to_class(self): class TestCustomConfigRegistration: def test_hy_v3_config_registered(self): - from nemo_automodel._transformers.registry import _CUSTOM_CONFIG_SPECS + from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS - assert any("hy_v3" in spec.model_types for spec in _CUSTOM_CONFIG_SPECS) + assert any("hy_v3" in spec.model_types for spec in MODEL_PACKAGE_SPECS) def test_hy_v3_config_resolves_to_class(self): import importlib - from nemo_automodel._transformers.registry import _CUSTOM_CONFIG_SPECS + from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS - spec = next(spec for spec in _CUSTOM_CONFIG_SPECS if "hy_v3" in spec.model_types) + spec = next(spec for spec in MODEL_PACKAGE_SPECS if "hy_v3" in spec.model_types) mod_path = spec.config_module_path cls_name = spec.config_class_name mod = importlib.import_module(mod_path) diff --git a/tests/unit_tests/models/bi_encoder/test_llama_bidirectional_model.py b/tests/unit_tests/models/bi_encoder/test_llama_bidirectional_model.py index 365856408e..f81f449ca1 100644 --- a/tests/unit_tests/models/bi_encoder/test_llama_bidirectional_model.py +++ b/tests/unit_tests/models/bi_encoder/test_llama_bidirectional_model.py @@ -19,7 +19,7 @@ import torch.nn.functional as F from transformers.modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast -from nemo_automodel._transformers.registry import ModelRegistry +from nemo_automodel._transformers.registry import RetrievalModelRegistry from nemo_automodel._transformers.retrieval import ( BiEncoderModel, CrossEncoderModel, @@ -139,7 +139,6 @@ def test_bidirectional_attention_is_symmetric(): ) - # --- Fakes for classification and encoder tests --- class FakeOutputs: def __init__(self, last_hidden_state=None, hidden_states=None): @@ -234,9 +233,7 @@ def forward(self, input_ids=None, attention_mask=None, return_dict=True, output_ ) lm = NoTTIDLm(hidden=8) - model = BiEncoderModel( - model=lm, pooling="avg", l2_normalize=True - ) + model = BiEncoderModel(model=lm, pooling="avg", l2_normalize=True) # encode removes token_type_ids and normalizes q = { "input_ids": torch.ones(2, 3, dtype=torch.long), @@ -278,9 +275,7 @@ def forward(self, input_ids=None, attention_mask=None, return_dict=True, output_ return OnlyHiddenOutputs(hidden_states) # Test with model using NoLastLM for query encoder - model_no_last = BiEncoderModel( - model=NoLastLM(hidden=8), pooling="avg", l2_normalize=True - ) + model_no_last = BiEncoderModel(model=NoLastLM(hidden=8), pooling="avg", l2_normalize=True) v2 = model_no_last.encode( {"input_ids": torch.ones(2, 3, dtype=torch.long), "attention_mask": torch.ones(2, 3, dtype=torch.long)}, ) @@ -294,9 +289,13 @@ class FakeBidirectionalModel(FakeLM): def from_pretrained(cls, *args, **kwargs): return cls(hidden=16) - # Patch the registry to return our fake model - ModelRegistry.model_arch_name_to_cls["LlamaBidirectionalModel"] = FakeBidirectionalModel - monkeypatch.setattr(ModelRegistry, "model_arch_name_to_cls", ModelRegistry.model_arch_name_to_cls) + # Patch the retrieval registry to return our fake model + RetrievalModelRegistry.model_arch_name_to_cls["LlamaBidirectionalModel"] = FakeBidirectionalModel + monkeypatch.setattr( + RetrievalModelRegistry, + "model_arch_name_to_cls", + RetrievalModelRegistry.model_arch_name_to_cls, + ) # Directory path with config.json to hit config-reading branch model_dir = tmp_path / "model" @@ -380,9 +379,13 @@ class FakeBidirectionalModel(FakeLM): def from_pretrained(cls, *args, **kwargs): return cls(hidden=16) - # Patch the registry to return our fake model - ModelRegistry.model_arch_name_to_cls["LlamaBidirectionalModel"] = FakeBidirectionalModel - monkeypatch.setattr(ModelRegistry, "model_arch_name_to_cls", ModelRegistry.model_arch_name_to_cls) + # Patch the retrieval registry to return our fake model + RetrievalModelRegistry.model_arch_name_to_cls["LlamaBidirectionalModel"] = FakeBidirectionalModel + monkeypatch.setattr( + RetrievalModelRegistry, + "model_arch_name_to_cls", + RetrievalModelRegistry.model_arch_name_to_cls, + ) # Create a model directory whose path has no 'llama' substring model_dir = tmp_path / "scratch" / "job" / "model" @@ -415,9 +418,13 @@ class FakeBidirectionalModel(FakeLM): def from_pretrained(cls, *args, **kwargs): return cls(hidden=16) - # Patch the registry to return our fake model - ModelRegistry.model_arch_name_to_cls["LlamaBidirectionalModel"] = FakeBidirectionalModel - monkeypatch.setattr(ModelRegistry, "model_arch_name_to_cls", ModelRegistry.model_arch_name_to_cls) + # Patch the retrieval registry to return our fake model + RetrievalModelRegistry.model_arch_name_to_cls["LlamaBidirectionalModel"] = FakeBidirectionalModel + monkeypatch.setattr( + RetrievalModelRegistry, + "model_arch_name_to_cls", + RetrievalModelRegistry.model_arch_name_to_cls, + ) # Model type not in SUPPORTED_BACKBONES should fall back to AutoModel import nemo_automodel._transformers.retrieval as encoder_module @@ -535,10 +542,14 @@ def __init__(self): # Use a class name that is NOT a retrieval arch FakeModel.__name__ = "Qwen3Model" - FakeModel = type("Qwen3Model", (nn.Module,), { - "__init__": FakeModel.__init__, - "config": property(lambda self: self._config), - }) + FakeModel = type( + "Qwen3Model", + (nn.Module,), + { + "__init__": FakeModel.__init__, + "config": property(lambda self: self._config), + }, + ) fake = object.__new__(FakeModel) nn.Module.__init__(fake) fake._config = FakeCfg() diff --git a/tests/unit_tests/models/bi_encoder/test_ministral_bidirectional_model.py b/tests/unit_tests/models/bi_encoder/test_ministral_bidirectional_model.py index 92b9ded41d..399f866a5f 100644 --- a/tests/unit_tests/models/bi_encoder/test_ministral_bidirectional_model.py +++ b/tests/unit_tests/models/bi_encoder/test_ministral_bidirectional_model.py @@ -22,8 +22,8 @@ pytest.importorskip("transformers.models.ministral3", reason="Ministral3 not available in this transformers version") -from nemo_automodel._transformers.registry import ModelRegistry -from nemo_automodel._transformers.retrieval import BiEncoderModel, configure_encoder_metadata, _init_encoder_common +from nemo_automodel._transformers.registry import RetrievalModelRegistry +from nemo_automodel._transformers.retrieval import BiEncoderModel, _init_encoder_common, configure_encoder_metadata from nemo_automodel.components.models.ministral_bidirectional.model import ( Ministral3BidirectionalConfig, Ministral3BidirectionalModel, @@ -142,8 +142,12 @@ class FakeBidirectionalModel(FakeLM): def from_pretrained(cls, *args, **kwargs): return cls(hidden=16) - ModelRegistry.model_arch_name_to_cls["Ministral3BidirectionalModel"] = FakeBidirectionalModel - monkeypatch.setattr(ModelRegistry, "model_arch_name_to_cls", ModelRegistry.model_arch_name_to_cls) + RetrievalModelRegistry.model_arch_name_to_cls["Ministral3BidirectionalModel"] = FakeBidirectionalModel + monkeypatch.setattr( + RetrievalModelRegistry, + "model_arch_name_to_cls", + RetrievalModelRegistry.model_arch_name_to_cls, + ) model_dir = tmp_path / "model" model_dir.mkdir() @@ -164,13 +168,18 @@ def from_pretrained(cls, *args, **kwargs): @pytest.mark.parametrize("top_level_model_type", ["ministral3", "ministral3_bidirec"]) def test_encoder_build_ministral_supported_model_types(tmp_path, monkeypatch, top_level_model_type): """Hub / local text configs use ministral3; saved bidirectional checkpoints use ministral3_bidirec.""" + class FakeBidirectionalModel(FakeLM): @classmethod def from_pretrained(cls, *args, **kwargs): return cls(hidden=16) - ModelRegistry.model_arch_name_to_cls["Ministral3BidirectionalModel"] = FakeBidirectionalModel - monkeypatch.setattr(ModelRegistry, "model_arch_name_to_cls", ModelRegistry.model_arch_name_to_cls) + RetrievalModelRegistry.model_arch_name_to_cls["Ministral3BidirectionalModel"] = FakeBidirectionalModel + monkeypatch.setattr( + RetrievalModelRegistry, + "model_arch_name_to_cls", + RetrievalModelRegistry.model_arch_name_to_cls, + ) model_dir = tmp_path / "hub" / "checkpoint" model_dir.mkdir(parents=True) diff --git a/tests/unit_tests/models/mistral4/test_mistral4_model.py b/tests/unit_tests/models/mistral4/test_mistral4_model.py index 2b39879108..6e00c98a08 100644 --- a/tests/unit_tests/models/mistral4/test_mistral4_model.py +++ b/tests/unit_tests/models/mistral4/test_mistral4_model.py @@ -920,17 +920,22 @@ def test_forward_raises_on_pp_stage_without_embeds(self, text_config, backend, d class TestRegistry: def test_mistral4_in_registry(self): - from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING + from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS + from nemo_automodel._transformers.registry.base import _normalize_model_arch_mapping - assert "Mistral4ForCausalLM" in dict(MODEL_ARCH_MAPPING) - assert "Mistral3ForConditionalGeneration" in dict(MODEL_ARCH_MAPPING) + mapping = _normalize_model_arch_mapping(MODEL_PACKAGE_SPECS) + assert "Mistral4ForCausalLM" in mapping + assert "Mistral3ForConditionalGeneration" in mapping def test_registry_module_path(self): - from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING + from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS + from nemo_automodel._transformers.registry.base import _normalize_model_arch_mapping - mapping = dict(MODEL_ARCH_MAPPING) - assert mapping["Mistral4ForCausalLM"][0] == "nemo_automodel.components.models.mistral4.model" - assert mapping["Mistral3ForConditionalGeneration"][0] == "nemo_automodel.components.models.mistral4.model" + mapping = _normalize_model_arch_mapping(MODEL_PACKAGE_SPECS) + assert mapping["Mistral4ForCausalLM"].module_path == "nemo_automodel.components.models.mistral4.model" + assert ( + mapping["Mistral3ForConditionalGeneration"].module_path == "nemo_automodel.components.models.mistral4.model" + ) # --------------------------------------------------------------------------- diff --git a/tests/unit_tests/models/nemotron_omni/test_nemotron_omni_layers.py b/tests/unit_tests/models/nemotron_omni/test_nemotron_omni_layers.py index fb2412747b..463bc8d741 100644 --- a/tests/unit_tests/models/nemotron_omni/test_nemotron_omni_layers.py +++ b/tests/unit_tests/models/nemotron_omni/test_nemotron_omni_layers.py @@ -16,7 +16,7 @@ These tests exercise the small, dependency-free building blocks of the NemotronOmni wrapper (activation, norm, projectors, config) plus the -architecture registration that lets ``MODEL_ARCH_MAPPING`` resolve the v3 dump. +architecture registration that lets ``MODEL_PACKAGE_SPECS`` resolve the v3 dump. The heavy ``NemotronOmniForConditionalGeneration`` end-to-end forward path requires a full HF v3 checkpoint and is covered by the functional/integration suites — not here. @@ -164,18 +164,20 @@ def test_nemotron_omni_config_overrides_propagate(): def test_registry_entry_present(): - """``MODEL_ARCH_MAPPING`` should resolve the v3 architecture name to our wrapper.""" - from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING + """``MODEL_PACKAGE_SPECS`` should resolve the v3 architecture name to our wrapper.""" + from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS + from nemo_automodel._transformers.registry.base import _normalize_model_arch_mapping - mapping = dict(MODEL_ARCH_MAPPING) + mapping = _normalize_model_arch_mapping(MODEL_PACKAGE_SPECS) assert "NemotronH_Nano_Omni_Reasoning_V3" in mapping - module_path, class_name, *_ = mapping["NemotronH_Nano_Omni_Reasoning_V3"] - assert module_path == "nemo_automodel.components.models.nemotron_omni.model" - assert class_name == "NemotronOmniForConditionalGeneration" + spec = mapping["NemotronH_Nano_Omni_Reasoning_V3"] + assert spec.module_path == "nemo_automodel.components.models.nemotron_omni.model" + assert spec.class_name == "NemotronOmniForConditionalGeneration" def test_registry_v2_entry_removed(): """V2 dispatch was deleted along with V2 dump support — keep it gone.""" - from nemo_automodel._transformers.registry import MODEL_ARCH_MAPPING + from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS + from nemo_automodel._transformers.registry.base import _normalize_model_arch_mapping - assert "NemotronH_Nano_VL_V2" not in dict(MODEL_ARCH_MAPPING) + assert "NemotronH_Nano_VL_V2" not in _normalize_model_arch_mapping(MODEL_PACKAGE_SPECS) From e003c166fdd33b263dc2dfe4d662f61c32685937 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Mon, 18 May 2026 21:59:26 -0700 Subject: [PATCH 3/5] remove register_retrieval() and has_retrieval_model() Signed-off-by: Alexandros Koumparoulis --- nemo_automodel/_transformers/registry/base.py | 30 ------------------- .../registry/model_package_spec.py | 10 +------ .../unit_tests/_transformers/test_registry.py | 21 ------------- 3 files changed, 1 insertion(+), 60 deletions(-) diff --git a/nemo_automodel/_transformers/registry/base.py b/nemo_automodel/_transformers/registry/base.py index 8838c85138..ead6c41bc3 100644 --- a/nemo_automodel/_transformers/registry/base.py +++ b/nemo_automodel/_transformers/registry/base.py @@ -16,7 +16,6 @@ import importlib import inspect import logging -import warnings from collections import OrderedDict from collections.abc import Iterable, Mapping from dataclasses import dataclass, field @@ -62,7 +61,6 @@ class _BaseModelRegistry: _loaded_model_classes: dict[str, type[nn.Module]] = field(default_factory=dict) _extra_model_classes: dict[str, type[nn.Module]] = field(default_factory=dict) _discarded_architectures: set[str] = field(default_factory=set) - _manual_architecture_tags: dict[str, set[str]] = field(default_factory=dict) _architecture_to_specs: dict[str, tuple[ModelPackageSpec, ...]] = field(default_factory=dict) _model_type_to_specs: dict[str, tuple[ModelPackageSpec, ...]] = field(default_factory=dict) @@ -100,7 +98,6 @@ def _discard_architecture(self, architecture: str) -> None: self._architecture_to_specs.pop(architecture, None) self._extra_model_classes.pop(architecture, None) self._loaded_model_classes.pop(architecture, None) - self._manual_architecture_tags.pop(architecture, None) def _load_model_class(self, architecture: str) -> type[nn.Module]: if architecture in self._loaded_model_classes: @@ -148,11 +145,6 @@ def __repr__(self) -> str: f"extra={len(self._extra_model_classes)}, loaded={len(self._loaded_model_classes)})" ) - def _architecture_has_tag(self, architecture: str, tag: str) -> bool: - if tag in self._manual_architecture_tags.get(architecture, set()): - return True - return any(tag in spec.tags for spec in self.get_model_package_specs_for_architecture(architecture)) - @property def supported_models(self): return self.keys() @@ -164,28 +156,6 @@ def has_custom_model(self, arch_name: str) -> bool: """Return ``True`` if *arch_name* has a custom (non-HF) implementation.""" return arch_name in self - def has_retrieval_model(self, arch_name: str) -> bool: - """Return ``True`` if *arch_name* is a registered retrieval/encoder architecture.""" - warnings.warn( - "has_retrieval_model() is deprecated; use RetrievalModelRegistry.get_model_package_spec() instead.", - DeprecationWarning, - stacklevel=2, - ) - if self._architecture_has_tag(arch_name, "retrieval"): - return True - from nemo_automodel._transformers.registry import RetrievalModelRegistry - - return RetrievalModelRegistry.get_model_package_spec(arch_name) is not None - - def register_retrieval(self, arch_name: str) -> None: - """Mark *arch_name* as a retrieval/encoder architecture.""" - warnings.warn( - "register_retrieval() is deprecated; register retrieval classes with RetrievalModelRegistry.register().", - DeprecationWarning, - stacklevel=2, - ) - self._manual_architecture_tags.setdefault(arch_name, set()).add("retrieval") - def get_model_package_spec(self, architecture: str) -> ModelPackageSpec | None: """Return package metadata for an architecture without importing ``model.py``.""" specs = self.get_model_package_specs_for_architecture(architecture) diff --git a/nemo_automodel/_transformers/registry/model_package_spec.py b/nemo_automodel/_transformers/registry/model_package_spec.py index 29e41bd631..2c4aa39e7e 100644 --- a/nemo_automodel/_transformers/registry/model_package_spec.py +++ b/nemo_automodel/_transformers/registry/model_package_spec.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass, field, replace +from dataclasses import dataclass, replace @dataclass(frozen=True) @@ -31,8 +31,6 @@ class ModelPackageSpec: config_class_name: Explicit config class name to register from ``config_module_path``. If unset, the registry can discover a local ``PretrainedConfig`` subclass by ``model_type``. - tags: Registry metadata flags for the model class. For example, - ``retrieval`` marks bidirectional encoder architectures. architectures: HF ``config.architectures`` names that should resolve to this model class. These are expanded into the architecture-to-spec lookup. model_types: HF ``config.model_type`` names associated with this package, @@ -44,12 +42,10 @@ class ModelPackageSpec: model_module: str = "model" config_module: str | None = None config_class_name: str | None = None - tags: frozenset[str] = field(default_factory=frozenset) architectures: tuple[str, ...] = () model_types: tuple[str, ...] = () def __post_init__(self) -> None: - object.__setattr__(self, "tags", frozenset(self.tags)) object.__setattr__(self, "architectures", tuple(self.architectures)) object.__setattr__(self, "model_types", tuple(self.model_types)) @@ -61,7 +57,6 @@ def from_module_path( *, config_module: str | None = None, config_class_name: str | None = None, - tags: set[str] | frozenset[str] | tuple[str, ...] = (), architectures: list[str] | tuple[str, ...] = (), model_types: tuple[str, ...] = (), ) -> "ModelPackageSpec": @@ -76,7 +71,6 @@ def from_module_path( model_module=model_module, config_module=config_module, config_class_name=config_class_name, - tags=frozenset(tags), architectures=architectures, model_types=model_types, ) @@ -88,7 +82,6 @@ def from_model_class( *, config_module: str | None = None, config_class_name: str | None = None, - tags: set[str] | frozenset[str] | tuple[str, ...] = (), architectures: list[str] | tuple[str, ...] = (), model_types: tuple[str, ...] = (), ) -> "ModelPackageSpec": @@ -98,7 +91,6 @@ def from_model_class( model_cls.__name__, config_module=config_module, config_class_name=config_class_name, - tags=tags, architectures=architectures, model_types=model_types, ) diff --git a/tests/unit_tests/_transformers/test_registry.py b/tests/unit_tests/_transformers/test_registry.py index 8ebf34ce88..88c48f8f8d 100644 --- a/tests/unit_tests/_transformers/test_registry.py +++ b/tests/unit_tests/_transformers/test_registry.py @@ -95,27 +95,6 @@ class ReplacementClass: assert inst.model_arch_name_to_cls["MyArch"] is ReplacementClass -def test_deprecated_retrieval_registration_methods_warn(): - from nemo_automodel._transformers import registry as reg - - inst = _new_registry_instance(reg) - - with pytest.warns(DeprecationWarning, match="register_retrieval"): - inst.register_retrieval("ManualRetrieval") - - with pytest.warns(DeprecationWarning, match="has_retrieval_model"): - assert inst.has_retrieval_model("ManualRetrieval") is True - - -def test_deprecated_retrieval_lookup_delegates_to_retrieval_registry(): - from nemo_automodel._transformers import registry as reg - - inst = _new_registry_instance(reg) - - with pytest.warns(DeprecationWarning, match="has_retrieval_model"): - assert inst.has_retrieval_model("LlamaBidirectionalModel") is True - - def test_supported_models_and_getter(): from nemo_automodel._transformers import registry as reg From b5bd2e08195b090ab86542c490e92b2ee863dd19 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Mon, 18 May 2026 22:03:01 -0700 Subject: [PATCH 4/5] trim Signed-off-by: Alexandros Koumparoulis --- .../registry/model_package_spec.py | 26 ------------------- 1 file changed, 26 deletions(-) diff --git a/nemo_automodel/_transformers/registry/model_package_spec.py b/nemo_automodel/_transformers/registry/model_package_spec.py index 2c4aa39e7e..014c530483 100644 --- a/nemo_automodel/_transformers/registry/model_package_spec.py +++ b/nemo_automodel/_transformers/registry/model_package_spec.py @@ -49,32 +49,6 @@ def __post_init__(self) -> None: object.__setattr__(self, "architectures", tuple(self.architectures)) object.__setattr__(self, "model_types", tuple(self.model_types)) - @classmethod - def from_module_path( - cls, - module_path: str, - class_name: str, - *, - config_module: str | None = None, - config_class_name: str | None = None, - architectures: list[str] | tuple[str, ...] = (), - model_types: tuple[str, ...] = (), - ) -> "ModelPackageSpec": - """Create a spec from a fully qualified model module path.""" - package, sep, model_module = module_path.rpartition(".") - if not sep: - package = "" - model_module = module_path - return cls( - package=package, - class_name=class_name, - model_module=model_module, - config_module=config_module, - config_class_name=config_class_name, - architectures=architectures, - model_types=model_types, - ) - @classmethod def from_model_class( cls, From c2e70b79521d072f32a2a18abb8d8621f93b5e25 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Mon, 18 May 2026 22:12:06 -0700 Subject: [PATCH 5/5] trim Signed-off-by: Alexandros Koumparoulis --- .../_transformers/registry/__init__.py | 13 --- .../registry/model_package_spec.py | 11 +- .../_transformers/registry/model_registry.py | 16 +-- .../_transformers/test_doc_coverage.py | 2 +- .../unit_tests/_transformers/test_registry.py | 107 ++++++++---------- .../_transformers/test_registry_hy_v3.py | 10 +- .../models/mistral4/test_mistral4_model.py | 4 +- .../test_nemotron_omni_layers.py | 4 +- 8 files changed, 70 insertions(+), 97 deletions(-) diff --git a/nemo_automodel/_transformers/registry/__init__.py b/nemo_automodel/_transformers/registry/__init__.py index 0d9005e6b6..28fd4666ce 100644 --- a/nemo_automodel/_transformers/registry/__init__.py +++ b/nemo_automodel/_transformers/registry/__init__.py @@ -14,16 +14,9 @@ """Public exports for the model registry package.""" -from nemo_automodel._transformers.registry.base import _BaseModelRegistry -from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec - _MODEL_REGISTRY_EXPORTS = { - "MODEL_PACKAGE_SPECS", - "RETRIEVAL_MODEL_PACKAGE_SPECS", "ModelRegistry", "RetrievalModelRegistry", - "make_registry", - "make_retrieval_registry", } @@ -37,12 +30,6 @@ def __getattr__(name: str): __all__ = [ - "MODEL_PACKAGE_SPECS", - "RETRIEVAL_MODEL_PACKAGE_SPECS", "ModelRegistry", "RetrievalModelRegistry", - "ModelPackageSpec", - "_BaseModelRegistry", - "make_registry", - "make_retrieval_registry", ] diff --git a/nemo_automodel/_transformers/registry/model_package_spec.py b/nemo_automodel/_transformers/registry/model_package_spec.py index 014c530483..2f0703fc8a 100644 --- a/nemo_automodel/_transformers/registry/model_package_spec.py +++ b/nemo_automodel/_transformers/registry/model_package_spec.py @@ -60,9 +60,14 @@ def from_model_class( model_types: tuple[str, ...] = (), ) -> "ModelPackageSpec": """Create a spec from a model class object.""" - return cls.from_module_path( - model_cls.__module__, - model_cls.__name__, + package, sep, model_module = model_cls.__module__.rpartition(".") + if not sep: + package = "" + model_module = model_cls.__module__ + return cls( + package=package, + class_name=model_cls.__name__, + model_module=model_module, config_module=config_module, config_class_name=config_class_name, architectures=architectures, diff --git a/nemo_automodel/_transformers/registry/model_registry.py b/nemo_automodel/_transformers/registry/model_registry.py index 0f339e2105..b115706ca3 100644 --- a/nemo_automodel/_transformers/registry/model_registry.py +++ b/nemo_automodel/_transformers/registry/model_registry.py @@ -222,18 +222,10 @@ @lru_cache -def make_registry(model_specs: tuple[ModelPackageSpec, ...] = MODEL_PACKAGE_SPECS) -> _BaseModelRegistry: - """Return the process-wide model registry singleton.""" +def make_registry(model_specs: tuple[ModelPackageSpec, ...]) -> _BaseModelRegistry: + """Return a process-wide model registry singleton for package specs.""" return _BaseModelRegistry(model_specs=model_specs) -@lru_cache -def make_retrieval_registry( - model_specs: tuple[ModelPackageSpec, ...] = RETRIEVAL_MODEL_PACKAGE_SPECS, -) -> _BaseModelRegistry: - """Return the process-wide retrieval model registry singleton.""" - return _BaseModelRegistry(model_specs=model_specs) - - -ModelRegistry = make_registry() -RetrievalModelRegistry = make_retrieval_registry() +ModelRegistry = make_registry(MODEL_PACKAGE_SPECS) +RetrievalModelRegistry = make_registry(RETRIEVAL_MODEL_PACKAGE_SPECS) diff --git a/tests/unit_tests/_transformers/test_doc_coverage.py b/tests/unit_tests/_transformers/test_doc_coverage.py index 31cc333147..bedaa81f72 100644 --- a/tests/unit_tests/_transformers/test_doc_coverage.py +++ b/tests/unit_tests/_transformers/test_doc_coverage.py @@ -95,7 +95,7 @@ def test_every_registered_arch_has_model_coverage_doc(): This guards against the regression where a new arch (e.g. gemma4) is registered but never gets a corresponding model card in the docs. """ - from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS docs_dir = _repo_root() / "docs" / "model-coverage" assert docs_dir.is_dir(), f"docs/model-coverage/ not found at {docs_dir}" diff --git a/tests/unit_tests/_transformers/test_registry.py b/tests/unit_tests/_transformers/test_registry.py index 88c48f8f8d..a70d92d4cc 100644 --- a/tests/unit_tests/_transformers/test_registry.py +++ b/tests/unit_tests/_transformers/test_registry.py @@ -17,10 +17,12 @@ import pytest +from nemo_automodel._transformers.registry.base import _BaseModelRegistry -def _new_registry_instance(registry_module): + +def _new_registry_instance(_registry_module=None): """Create a fresh registry with an empty auto_map for testing.""" - return registry_module._BaseModelRegistry(model_specs=()) + return _BaseModelRegistry(model_specs=()) def _model_arch_lookup(model_arch_mapping): @@ -110,46 +112,37 @@ class A: def test_make_registry_is_cached(): - from nemo_automodel._transformers import registry as reg + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS, make_registry - reg.make_registry.cache_clear() - r1 = reg.make_registry() - r2 = reg.make_registry() + make_registry.cache_clear() + r1 = make_registry(MODEL_PACKAGE_SPECS) + r2 = make_registry(MODEL_PACKAGE_SPECS) assert r1 is r2 -def test_make_retrieval_registry_is_cached(): - from nemo_automodel._transformers import registry as reg +def test_make_registry_caches_by_model_specs(): + from nemo_automodel._transformers.registry.model_registry import RETRIEVAL_MODEL_PACKAGE_SPECS, make_registry - reg.make_retrieval_registry.cache_clear() - r1 = reg.make_retrieval_registry() - r2 = reg.make_retrieval_registry() + make_registry.cache_clear() + r1 = make_registry(RETRIEVAL_MODEL_PACKAGE_SPECS) + r2 = make_registry(RETRIEVAL_MODEL_PACKAGE_SPECS) assert r1 is r2 -def test_registry_reexports_split_modules(): - """The registry package remains the public import path while definitions live in split modules.""" +def test_registry_reexports_public_singletons(): + """The registry package root exposes only the runtime registry singletons.""" from nemo_automodel._transformers import registry as reg - from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec from nemo_automodel._transformers.registry.model_registry import ( - MODEL_PACKAGE_SPECS, - RETRIEVAL_MODEL_PACKAGE_SPECS, + ModelRegistry, RetrievalModelRegistry, - make_registry, - make_retrieval_registry, ) - assert reg.ModelPackageSpec is ModelPackageSpec - assert reg.MODEL_PACKAGE_SPECS is MODEL_PACKAGE_SPECS - assert reg.RETRIEVAL_MODEL_PACKAGE_SPECS is RETRIEVAL_MODEL_PACKAGE_SPECS + assert reg.ModelRegistry is ModelRegistry assert reg.RetrievalModelRegistry is RetrievalModelRegistry - assert reg.make_registry is make_registry - assert reg.make_retrieval_registry is make_retrieval_registry def test_registry_mapping_loads_model_class_on_demand(): """Static auto_map entries are lazily loaded on first access.""" - from nemo_automodel._transformers import registry as reg from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec class FakeClass: @@ -157,7 +150,7 @@ class FakeClass: fake_module = types.SimpleNamespace(FakeClass=FakeClass) spec = ModelPackageSpec(package="fake", model_module="module", class_name="FakeClass", architectures=("FakeArch",)) - inst = reg._BaseModelRegistry(model_specs=(spec,)) + inst = _BaseModelRegistry(model_specs=(spec,)) with patch("nemo_automodel._transformers.registry.base.importlib.import_module", return_value=fake_module): assert "FakeArch" in inst.model_arch_name_to_cls @@ -170,20 +163,20 @@ class FakeClass: def test_registry_accepts_model_package_spec_entries(): """Static auto_map entries may be declared as ModelPackageSpec metadata.""" - from nemo_automodel._transformers import registry as reg from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec class FakeClass: pass fake_module = types.SimpleNamespace(FakeClass=FakeClass) - spec = ModelPackageSpec.from_module_path( - "fake.package.model", - "FakeClass", + spec = ModelPackageSpec( + package="fake.package", + model_module="model", + class_name="FakeClass", architectures=["FakeArch"], model_types=("fake",), ) - inst = reg._BaseModelRegistry(model_specs={"FakeArch": spec}) + inst = _BaseModelRegistry(model_specs={"FakeArch": spec}) with patch("nemo_automodel._transformers.registry.base.importlib.import_module", return_value=fake_module): assert inst.model_arch_name_to_cls["FakeArch"] is FakeClass @@ -208,7 +201,6 @@ class FakeClass: def test_registry_derives_keys_from_spec_architectures(): """Tuple-based auto_map entries derive lookup keys from spec.architectures.""" - from nemo_automodel._transformers import registry as reg from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec spec = ModelPackageSpec( @@ -216,14 +208,13 @@ def test_registry_derives_keys_from_spec_architectures(): class_name="FakeClass", architectures=("FakeArch", "FakeAlias"), ) - inst = reg._BaseModelRegistry(model_specs=(spec,)) + inst = _BaseModelRegistry(model_specs=(spec,)) assert inst.model_arch_name_to_cls.keys() == {"FakeArch", "FakeAlias"} def test_base_registry_owns_architecture_metadata(): """Package metadata indexes live on the registry, not on the lazy class loader.""" - from nemo_automodel._transformers import registry as reg from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec spec = ModelPackageSpec( @@ -232,7 +223,7 @@ def test_base_registry_owns_architecture_metadata(): architectures=("FakeArch", "FakeAlias"), model_types=("fake",), ) - inst = reg._BaseModelRegistry(model_specs=(spec,)) + inst = _BaseModelRegistry(model_specs=(spec,)) assert inst.get_model_package_spec("FakeArch") is spec assert inst.get_model_package_spec("FakeAlias") is spec @@ -241,7 +232,6 @@ def test_base_registry_owns_architecture_metadata(): def test_registry_rejects_duplicate_architectures(): """Duplicate architecture names across package specs fail during registry construction.""" - from nemo_automodel._transformers import registry as reg from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec specs = ( @@ -250,12 +240,11 @@ def test_registry_rejects_duplicate_architectures(): ) with pytest.raises(ValueError, match="Duplicated model architecture entry for 'FakeArch'"): - reg._BaseModelRegistry(model_specs=specs) + _BaseModelRegistry(model_specs=specs) def test_registry_dynamic_entry_overrides_static_entry(): """Dynamically registered entries take precedence over static entries.""" - from nemo_automodel._transformers import registry as reg from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec class StaticClass: @@ -266,7 +255,7 @@ class DynamicClass: fake_module = types.SimpleNamespace(StaticClass=StaticClass) spec = ModelPackageSpec(package="fake", model_module="module", class_name="StaticClass", architectures=("MyArch",)) - inst = reg._BaseModelRegistry(model_specs=(spec,)) + inst = _BaseModelRegistry(model_specs=(spec,)) with patch("nemo_automodel._transformers.registry.base.importlib.import_module", return_value=fake_module): assert inst.model_arch_name_to_cls["MyArch"] is StaticClass @@ -278,7 +267,6 @@ class DynamicClass: def test_registry_unavailable_model(): """Auto_map entries whose imports fail are removed and excluded from containment.""" - from nemo_automodel._transformers import registry as reg from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec spec = ModelPackageSpec( @@ -287,7 +275,7 @@ def test_registry_unavailable_model(): class_name="BadClass", architectures=("BadArch",), ) - inst = reg._BaseModelRegistry(model_specs=(spec,)) + inst = _BaseModelRegistry(model_specs=(spec,)) assert "BadArch" not in inst.model_arch_name_to_cls assert "BadArch" not in inst.model_arch_name_to_cls.keys() @@ -295,7 +283,6 @@ def test_registry_unavailable_model(): def test_registry_discards_failed_import_from_metadata_index(): """Failed model imports remove the architecture from the registry-owned metadata index.""" - from nemo_automodel._transformers import registry as reg from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec spec = ModelPackageSpec( @@ -304,7 +291,7 @@ def test_registry_discards_failed_import_from_metadata_index(): class_name="BadClass", architectures=("BadArch",), ) - inst = reg._BaseModelRegistry(model_specs=(spec,)) + inst = _BaseModelRegistry(model_specs=(spec,)) assert "BadArch" not in inst.model_arch_name_to_cls assert inst.get_model_package_spec("BadArch") is None @@ -337,9 +324,9 @@ class DirectClass: def test_default_registry_has_static_entries(): """The default registry is populated from MODEL_PACKAGE_SPECS.""" - from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS, make_registry + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS, make_registry - inst = make_registry() + inst = make_registry(MODEL_PACKAGE_SPECS) for spec in MODEL_PACKAGE_SPECS: for arch_name in spec.architectures: assert arch_name in inst.model_arch_name_to_cls.keys() @@ -347,12 +334,14 @@ def test_default_registry_has_static_entries(): def test_retrieval_registry_has_separate_static_entries(): """Retrieval architectures live in the retrieval registry, not the default registry.""" - from nemo_automodel._transformers import registry as reg - from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS, RETRIEVAL_MODEL_PACKAGE_SPECS + from nemo_automodel._transformers.registry.model_registry import ( + MODEL_PACKAGE_SPECS, + RETRIEVAL_MODEL_PACKAGE_SPECS, + ) retrieval_arch = "LlamaBidirectionalModel" - default_registry = reg._BaseModelRegistry(model_specs=MODEL_PACKAGE_SPECS) - retrieval_registry = reg._BaseModelRegistry(model_specs=RETRIEVAL_MODEL_PACKAGE_SPECS) + default_registry = _BaseModelRegistry(model_specs=MODEL_PACKAGE_SPECS) + retrieval_registry = _BaseModelRegistry(model_specs=RETRIEVAL_MODEL_PACKAGE_SPECS) assert retrieval_arch not in _model_arch_lookup(MODEL_PACKAGE_SPECS) assert retrieval_arch in _model_arch_lookup(RETRIEVAL_MODEL_PACKAGE_SPECS) @@ -364,7 +353,6 @@ def test_registry_discovers_config_class_from_config_module(): """Config classes are imported from declared convention modules on demand.""" from transformers import PretrainedConfig - from nemo_automodel._transformers import registry as reg from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec class FakeConfig(PretrainedConfig): @@ -373,7 +361,7 @@ class FakeConfig(PretrainedConfig): FakeConfig.__module__ = "fake.package.config" fake_module = types.SimpleNamespace(__name__="fake.package.config", FakeConfig=FakeConfig) spec = ModelPackageSpec(package="fake.package", model_types=("fake_model",), config_module="config") - inst = reg._BaseModelRegistry(model_specs=(spec,)) + inst = _BaseModelRegistry(model_specs=(spec,)) with ( patch("transformers.AutoConfig.register") as mock_register, @@ -388,7 +376,6 @@ def test_registry_registers_explicit_config_alias(): """Explicit config metadata supports aliases whose key differs from class.model_type.""" from transformers import PretrainedConfig - from nemo_automodel._transformers import registry as reg from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec class FakeConfig(PretrainedConfig): @@ -401,7 +388,7 @@ class FakeConfig(PretrainedConfig): config_module="model", config_class_name="FakeConfig", ) - inst = reg._BaseModelRegistry(model_specs=(spec,)) + inst = _BaseModelRegistry(model_specs=(spec,)) with ( patch("transformers.AutoConfig.register") as mock_register, @@ -417,7 +404,6 @@ def test_registry_registers_config_alias_when_auto_config_rejects_mismatch(): from transformers import PretrainedConfig from transformers.models.auto.configuration_auto import CONFIG_MAPPING - from nemo_automodel._transformers import registry as reg from nemo_automodel._transformers.registry.model_package_spec import ModelPackageSpec class FakeConfig(PretrainedConfig): @@ -430,7 +416,7 @@ class FakeConfig(PretrainedConfig): config_module="model", config_class_name="FakeConfig", ) - inst = reg._BaseModelRegistry(model_specs=(spec,)) + inst = _BaseModelRegistry(model_specs=(spec,)) with ( patch("transformers.AutoConfig.register", side_effect=ValueError("model_type mismatch")), @@ -514,7 +500,7 @@ def supports_config(cls, config): def test_config_metadata_is_metadata_only_until_requested(): """Config metadata should not require eager registration at registry import time.""" - from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS for spec in MODEL_PACKAGE_SPECS: if spec.config_module_path is None: @@ -525,7 +511,7 @@ def test_config_metadata_is_metadata_only_until_requested(): def test_kimi_k25_arch_alias_in_model_package_specs(): """KimiK25ForConditionalGeneration (checkpoint arch) must map to KimiK25VLForConditionalGeneration.""" - from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS mapping = _model_arch_lookup(MODEL_PACKAGE_SPECS) assert "KimiK25ForConditionalGeneration" in mapping, ( @@ -539,7 +525,7 @@ def test_kimi_k25_arch_alias_in_model_package_specs(): def test_deepseek_v4_registered_in_model_package_specs(): """DeepseekV4ForCausalLM must be registered in MODEL_PACKAGE_SPECS.""" - from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS mapping = _model_arch_lookup(MODEL_PACKAGE_SPECS) assert "DeepseekV4ForCausalLM" in mapping, ( @@ -554,7 +540,7 @@ def test_deepseek_v4_registered_in_model_package_specs(): def test_deepseek_v4_in_model_package_specs(): """deepseek_v4 model_type must be declared in MODEL_PACKAGE_SPECS.""" - from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS specs_by_model_type = {model_type: spec for spec in MODEL_PACKAGE_SPECS for model_type in spec.model_types} assert "deepseek_v4" in specs_by_model_type, ( @@ -575,7 +561,10 @@ def test_all_model_folders_registered_in_auto_map(): """ import pathlib - from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS, RETRIEVAL_MODEL_PACKAGE_SPECS + from nemo_automodel._transformers.registry.model_registry import ( + MODEL_PACKAGE_SPECS, + RETRIEVAL_MODEL_PACKAGE_SPECS, + ) models_root = pathlib.Path(__file__).resolve().parents[3] / "nemo_automodel" / "components" / "models" diff --git a/tests/unit_tests/_transformers/test_registry_hy_v3.py b/tests/unit_tests/_transformers/test_registry_hy_v3.py index f375689b74..3f2e9ca509 100644 --- a/tests/unit_tests/_transformers/test_registry_hy_v3.py +++ b/tests/unit_tests/_transformers/test_registry_hy_v3.py @@ -17,14 +17,14 @@ class TestArchMapping: def test_hyv3_arch_registered(self): - from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS from nemo_automodel._transformers.registry.base import _normalize_model_arch_mapping + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS assert "HYV3ForCausalLM" in _normalize_model_arch_mapping(MODEL_PACKAGE_SPECS) def test_hyv3_arch_points_at_correct_module(self): - from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS from nemo_automodel._transformers.registry.base import _normalize_model_arch_mapping + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS entry = _normalize_model_arch_mapping(MODEL_PACKAGE_SPECS)["HYV3ForCausalLM"] assert entry.module_path == "nemo_automodel.components.models.hy_v3.model" @@ -34,8 +34,8 @@ def test_hyv3_arch_resolves_to_class(self): """Walk the mapping path -- importable + the named class exists.""" import importlib - from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS from nemo_automodel._transformers.registry.base import _normalize_model_arch_mapping + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS spec = _normalize_model_arch_mapping(MODEL_PACKAGE_SPECS)["HYV3ForCausalLM"] mod_path = spec.module_path @@ -46,14 +46,14 @@ def test_hyv3_arch_resolves_to_class(self): class TestCustomConfigRegistration: def test_hy_v3_config_registered(self): - from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS assert any("hy_v3" in spec.model_types for spec in MODEL_PACKAGE_SPECS) def test_hy_v3_config_resolves_to_class(self): import importlib - from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS spec = next(spec for spec in MODEL_PACKAGE_SPECS if "hy_v3" in spec.model_types) mod_path = spec.config_module_path diff --git a/tests/unit_tests/models/mistral4/test_mistral4_model.py b/tests/unit_tests/models/mistral4/test_mistral4_model.py index 6e00c98a08..922a7072df 100644 --- a/tests/unit_tests/models/mistral4/test_mistral4_model.py +++ b/tests/unit_tests/models/mistral4/test_mistral4_model.py @@ -920,16 +920,16 @@ def test_forward_raises_on_pp_stage_without_embeds(self, text_config, backend, d class TestRegistry: def test_mistral4_in_registry(self): - from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS from nemo_automodel._transformers.registry.base import _normalize_model_arch_mapping + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS mapping = _normalize_model_arch_mapping(MODEL_PACKAGE_SPECS) assert "Mistral4ForCausalLM" in mapping assert "Mistral3ForConditionalGeneration" in mapping def test_registry_module_path(self): - from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS from nemo_automodel._transformers.registry.base import _normalize_model_arch_mapping + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS mapping = _normalize_model_arch_mapping(MODEL_PACKAGE_SPECS) assert mapping["Mistral4ForCausalLM"].module_path == "nemo_automodel.components.models.mistral4.model" diff --git a/tests/unit_tests/models/nemotron_omni/test_nemotron_omni_layers.py b/tests/unit_tests/models/nemotron_omni/test_nemotron_omni_layers.py index 463bc8d741..2832d69ce1 100644 --- a/tests/unit_tests/models/nemotron_omni/test_nemotron_omni_layers.py +++ b/tests/unit_tests/models/nemotron_omni/test_nemotron_omni_layers.py @@ -165,8 +165,8 @@ def test_nemotron_omni_config_overrides_propagate(): def test_registry_entry_present(): """``MODEL_PACKAGE_SPECS`` should resolve the v3 architecture name to our wrapper.""" - from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS from nemo_automodel._transformers.registry.base import _normalize_model_arch_mapping + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS mapping = _normalize_model_arch_mapping(MODEL_PACKAGE_SPECS) assert "NemotronH_Nano_Omni_Reasoning_V3" in mapping @@ -177,7 +177,7 @@ def test_registry_entry_present(): def test_registry_v2_entry_removed(): """V2 dispatch was deleted along with V2 dump support — keep it gone.""" - from nemo_automodel._transformers.registry import MODEL_PACKAGE_SPECS from nemo_automodel._transformers.registry.base import _normalize_model_arch_mapping + from nemo_automodel._transformers.registry.model_registry import MODEL_PACKAGE_SPECS assert "NemotronH_Nano_VL_V2" not in _normalize_model_arch_mapping(MODEL_PACKAGE_SPECS)