diff --git a/pyproject.toml b/pyproject.toml index 5f65407ae..5fbabaf86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ "aiometer (>=1.0.0,<2.0.0)", "aiofiles (>=24.1.0,<25.0.0)", "threadpoolctl (>=3.0.0,<4.0.0)", + "packaging (>=23.2)", ] [project.optional-dependencies] diff --git a/src/autointent/_deps.py b/src/autointent/_deps.py new file mode 100644 index 000000000..f56b2a536 --- /dev/null +++ b/src/autointent/_deps.py @@ -0,0 +1,201 @@ +"""Validate optional-extra dependencies from installed package metadata. + +The :func:`require` guard checks that every dependency of an ``autointent`` extra +is installed and version-satisfied. It reads the metadata that the build baked into +the installed distribution (via :mod:`importlib.metadata`) rather than the source +``pyproject.toml``, which is not shipped in the wheel. Nested extras are resolved +recursively, so e.g. the ``transformers`` extra (``transformers[torch]``) +transitively requires ``accelerate`` and that is checked too. +""" + +from __future__ import annotations + +from functools import cache +from importlib import metadata +from typing import Literal + +from packaging.requirements import Requirement +from packaging.utils import canonicalize_name + +_DIST = "autointent" + +# Names of the optional-dependency extras autointent declares, mirrored from the +# installed ``Provides-Extra`` metadata (and thus pyproject's +# [project.optional-dependencies]). Typing ``require``'s parameter with this makes +# mypy reject misspelled extra names at call sites; the runtime check in ``require`` +# stays the source of truth (mypy isn't run at runtime, and ``dist`` overrides or +# dynamic calls bypass static typing). Kept in sync with the real metadata by +# tests/test_deps.py::test_extra_literal_matches_real_metadata. +Extra = Literal[ + "catboost", + "codecarbon", + "dspy", + "fastapi", + "fastmcp", + "openai", + "opensearch", + "peft", + "sentence-transformers", + "transformers", + "vllm", + "wandb", +] + + +def _check(req: Requirement) -> str | None: + """Check a single requirement against the installed environment. + + Args: + req: The parsed requirement to validate. + + Returns: + A human-readable problem description if the distribution is missing or its + installed version does not satisfy ``req.specifier``; ``None`` otherwise. + """ + try: + installed = metadata.version(req.name) + except metadata.PackageNotFoundError: + return f"{req.name}{req.specifier} (not installed)" + if req.specifier and not req.specifier.contains(installed, prereleases=True): + return f"{req.name}{req.specifier} (installed: {installed})" + return None + + +def _iter_extra_reqs(dist: str, extra: str) -> list[Requirement]: + """Return the requirements of ``dist`` activated by ``extra``. + + Args: + dist: Distribution name whose metadata is read. + extra: Extra name whose dependencies are wanted. + + Returns: + The parsed requirements activated by ``extra`` in the current environment, + or an empty list if ``dist`` is not installed (its metadata is unavailable). + """ + target = str(canonicalize_name(extra)) + result: list[Requirement] = [] + try: + reqs = metadata.requires(dist) + except metadata.PackageNotFoundError: + # `dist` itself isn't installed, so we can't read its nested-extra + # requirements. That's fine: the parent requirement that led us to recurse + # here (e.g. `transformers[torch]`) was already collected by the caller and + # `_check` will flag it as "not installed", producing the proper aggregated + # ImportError with the install hint -- rather than letting a raw + # PackageNotFoundError leak out of the resolver. + return [] + for spec in reqs or []: + req = Requirement(spec) + # `req.marker` is the parsed `;` clause of the PEP 508 requirement (a + # packaging Marker), or None when the requirement has no `;` clause. There + # are three cases: + # (1) no marker -> an unconditional base dependency; + # (2) a marker that references `extra` -> belongs to an extra; + # (3) a marker with only environment conditions (e.g. `python_version < "3.9"`) + # -> still a base dependency, just platform-conditional. + # So "has a marker" does NOT mean "belongs to an extra"; + marker = req.marker + # Here we cancel out case (1) + if marker is None: + continue + # `marker.evaluate(env)` resolves the whole boolean expression to a bool, + # filling any keys we omit (python_version, sys_platform, ...) from the + # running interpreter. A single `evaluate({"extra": target})` is not enough + # to prove membership: an env-conditional base dep also passes it, because + # its truth comes from the environment and the `extra` key is ignored. + # The discriminator is the second evaluation: a *true* extra dependency + # flips active -> inactive when the extra is removed, whereas a base dep is + # unaffected. So "active with the extra AND inactive with no extra" means + # "active *because of* this extra", which keeps extra members and drops + # base deps. We always pass `extra` explicitly (`""` = base install, no + # extras) since a marker that references `extra` can't be evaluated without it. + # So here we cancel out case (3) + if marker.evaluate({"extra": target}) and not marker.evaluate({"extra": ""}): + result.append(req) + return result + + +def _resolve(dist: str, extra: str, seen: set[tuple[str, str]]) -> list[Requirement]: + """Recursively collect every leaf requirement activated by ``dist[extra]``. + + Each activated requirement is returned for version checking, and any nested + extras it declares (e.g. ``transformers[torch]``) are resolved in turn. + + Args: + dist: Distribution name to start from. + extra: Extra name to resolve. + seen: Visited ``(dist, extra)`` pairs, used to break dependency cycles. + + Returns: + The flattened list of requirements to validate. + """ + key = (str(canonicalize_name(dist)), str(canonicalize_name(extra))) + if key in seen: + return [] + seen.add(key) + + leaves: list[Requirement] = [] + for req in _iter_extra_reqs(dist, extra): + leaves.append(req) + for nested in req.extras: + leaves.extend(_resolve(req.name, nested, seen)) + return leaves + + +@cache +def _resolve_cached(dist: str, extra: str) -> tuple[Requirement, ...]: + """Memoized :func:`_resolve`; the metadata graph shape is stable per process. + + Args: + dist: Distribution name to start from. + extra: Extra name to resolve. + + Returns: + The resolved requirements as an immutable tuple. + """ + return tuple(_resolve(dist, extra, set())) + + +def _provides_extras(dist: str) -> set[str]: + """Return the normalized set of extras declared by ``dist``. + + Args: + dist: Distribution name whose metadata is read. + + Returns: + Normalized extra names from the distribution's ``Provides-Extra`` metadata. + """ + md = metadata.metadata(dist) + return {str(canonicalize_name(e)) for e in (md.get_all("Provides-Extra") or [])} + + +def require(extra: Extra) -> None: + """Ensure every dependency of an ``autointent`` extra is installed and current. + + Args: + extra: The extra to validate, e.g. ``"transformers"``. + + Raises: + ValueError: If ``autointent`` declares no such ``extra`` (typically a typo). + ImportError: If any required dependency is missing or its installed version + does not satisfy the constraint declared in the metadata. + """ + known = _provides_extras(_DIST) + if str(canonicalize_name(extra)) not in known: + msg = f"'{_DIST}' declares no extra '{extra}'. Known extras: {', '.join(sorted(known))}." + raise ValueError(msg) + + problems: list[str] = [] + for req in _resolve_cached(_DIST, extra): + problem = _check(req) + if problem is not None and problem not in problems: + problems.append(problem) + + if problems: + bullets = "\n".join(f" - {p}" for p in problems) + msg = ( + f"Feature requires extra '{extra}', but dependencies are missing or outdated:\n" + f"{bullets}\n" + f"Install with: pip install '{_DIST}[{extra}]'" + ) + raise ImportError(msg) diff --git a/src/autointent/_dump_tools/unit_dumpers.py b/src/autointent/_dump_tools/unit_dumpers.py index 69b5a276d..2b1c201bb 100644 --- a/src/autointent/_dump_tools/unit_dumpers.py +++ b/src/autointent/_dump_tools/unit_dumpers.py @@ -13,7 +13,7 @@ from sklearn.base import BaseEstimator from autointent import Embedder, Ranker, VectorIndex -from autointent._utils import require +from autointent._deps import require from autointent._wrappers import BaseTorchModule from autointent.schemas import TagsList @@ -225,8 +225,8 @@ def dump(obj: PeftModel, path: Path, exists_ok: bool) -> None: @staticmethod def load(path: Path, **kwargs: Any) -> PeftModel: # noqa: ANN401 - require("peft", extra="peft") - require("transformers", extra="transformers") + require("peft") + require("transformers") import peft import transformers @@ -245,7 +245,7 @@ def load(path: Path, **kwargs: Any) -> PeftModel: # noqa: ANN401 @classmethod def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401 try: - require("peft", extra="peft") + require("peft") import peft return isinstance(obj, peft.PeftModel) @@ -263,7 +263,7 @@ def dump(obj: PreTrainedModel, path: Path, exists_ok: bool) -> None: @staticmethod def load(path: Path, **kwargs: Any) -> PreTrainedModel: # noqa: ANN401 - require("transformers", extra="transformers") + require("transformers") import transformers return transformers.AutoModelForSequenceClassification.from_pretrained(path) # type: ignore[no-any-return] @@ -271,7 +271,7 @@ def load(path: Path, **kwargs: Any) -> PreTrainedModel: # noqa: ANN401 @classmethod def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401 try: - require("transformers", extra="transformers") + require("transformers") import transformers return isinstance(obj, transformers.PreTrainedModel) @@ -289,7 +289,7 @@ def dump(obj: PreTrainedTokenizer | PreTrainedTokenizerFast, path: Path, exists_ @staticmethod def load(path: Path, **kwargs: Any) -> PreTrainedTokenizer | PreTrainedTokenizerFast: # noqa: ANN401 - require("transformers", extra="transformers") + require("transformers") import transformers return transformers.AutoTokenizer.from_pretrained(path) # type: ignore[no-any-return,no-untyped-call] @@ -297,7 +297,7 @@ def load(path: Path, **kwargs: Any) -> PreTrainedTokenizer | PreTrainedTokenizer @classmethod def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401 try: - require("transformers", extra="transformers") + require("transformers") import transformers return isinstance(obj, transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerFast) @@ -342,7 +342,7 @@ def dump(obj: CatBoostClassifier, path: Path, exists_ok: bool) -> None: # noqa: @staticmethod def load(path: Path, **kwargs: Any) -> CatBoostClassifier: # noqa: ANN401 - require("catboost", extra="catboost") + require("catboost") from catboost import CatBoostClassifier model = CatBoostClassifier() @@ -352,7 +352,7 @@ def load(path: Path, **kwargs: Any) -> CatBoostClassifier: # noqa: ANN401 @classmethod def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401 try: - require("catboost", extra="catboost") + require("catboost") from catboost import CatBoostClassifier return isinstance(obj, CatBoostClassifier) diff --git a/src/autointent/_utils.py b/src/autointent/_utils.py index 4ecbb6e4d..c8a8614b7 100644 --- a/src/autointent/_utils.py +++ b/src/autointent/_utils.py @@ -1,6 +1,5 @@ """Utils.""" -import importlib from typing import TypeVar import torch @@ -28,22 +27,3 @@ def detect_device() -> str: return "cpu" -def require(dependency: str, extra: str | None = None) -> None: - """Try to import dependency, raise informative ImportError if missing. - - Args: - dependency: The name of the module to import - extra: Optional extra package name for pip install instructions - - Returns: - The imported module - - Raises: - ImportError: If the dependency is not installed - """ - try: - importlib.import_module(dependency) - except ImportError as e: - extra_info = f" Install with `pip install autointent[{extra}]`." if extra else "" - msg = f"Missing dependency '{dependency}' required for this feature.{extra_info}" - raise ImportError(msg) from e diff --git a/src/autointent/_wrappers/embedder/openai.py b/src/autointent/_wrappers/embedder/openai.py index 9eada4f05..c1ae6ee3f 100644 --- a/src/autointent/_wrappers/embedder/openai.py +++ b/src/autointent/_wrappers/embedder/openai.py @@ -12,8 +12,8 @@ import numpy.typing as npt import torch +from autointent._deps import require from autointent._hash import Hasher -from autointent._utils import require from autointent.configs._embedder import OpenaiEmbeddingConfig from .base import BaseEmbeddingBackend @@ -77,7 +77,7 @@ def _openai_api_error_message(exc: BaseException, *, batch_size: int) -> str: def _tiktoken_encoding_for_embedding_model(model_name: str) -> Encoding: """Resolve tiktoken encoding for batch sizing; fallback for unknown provider model ids.""" - require("tiktoken", "openai") + require("openai") import tiktoken try: @@ -110,7 +110,7 @@ def __init__(self, config: OpenaiEmbeddingConfig) -> None: Args: config: Configuration for OpenAI embeddings. """ - require("openai", "openai") + require("openai") self.config = config self._event_loop: asyncio.AbstractEventLoop | None = None diff --git a/src/autointent/_wrappers/embedder/sentence_transformers.py b/src/autointent/_wrappers/embedder/sentence_transformers.py index cbb921409..772737fed 100644 --- a/src/autointent/_wrappers/embedder/sentence_transformers.py +++ b/src/autointent/_wrappers/embedder/sentence_transformers.py @@ -14,8 +14,8 @@ from datasets import Dataset from sklearn.model_selection import train_test_split +from autointent._deps import require from autointent._hash import Hasher -from autointent._utils import require from autointent.configs._embedder import SentenceTransformerEmbeddingConfig from .base import BaseEmbeddingBackend @@ -42,7 +42,7 @@ def _set_training_seed(seed: int) -> None: if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) - require("transformers", extra="transformers") + require("transformers") from transformers import set_seed set_seed(seed) @@ -130,7 +130,7 @@ def _load_model(self) -> SentenceTransformer: """Load sentence transformers model to device.""" if self._model is None: # Lazy import sentence-transformers - require("sentence_transformers", extra="sentence-transformers") + require("sentence-transformers") from sentence_transformers import SentenceTransformer res = SentenceTransformer( @@ -294,9 +294,8 @@ def train(self, utterances: list[str], labels: ListOfLabels, config: EmbedderFin _set_training_seed(config.seed) # Lazy import sentence-transformers training components (only needed for fine-tuning) - require("sentence_transformers", extra="sentence-transformers") - require("transformers", extra="transformers") - require("accelerate", extra="transformers") + require("sentence-transformers") + require("transformers") from sentence_transformers import ( SentenceTransformerTrainer, SentenceTransformerTrainingArguments, diff --git a/src/autointent/_wrappers/embedder/vllm.py b/src/autointent/_wrappers/embedder/vllm.py index 170c73aa5..691686ca3 100644 --- a/src/autointent/_wrappers/embedder/vllm.py +++ b/src/autointent/_wrappers/embedder/vllm.py @@ -9,8 +9,8 @@ import numpy as np import torch +from autointent._deps import require from autointent._hash import Hasher -from autointent._utils import require from autointent.configs._embedder import VllmEmbeddingConfig from .base import BaseEmbeddingBackend @@ -44,7 +44,7 @@ def __init__(self, config: VllmEmbeddingConfig) -> None: def _load_model(self) -> LLM: """Lazy-load the vLLM LLM engine on first use.""" if self._model is None: - require("vllm", extra="vllm") + require("vllm") from vllm import LLM kwargs = { diff --git a/src/autointent/_wrappers/ranker.py b/src/autointent/_wrappers/ranker.py index 5e3826693..8e73f0be5 100644 --- a/src/autointent/_wrappers/ranker.py +++ b/src/autointent/_wrappers/ranker.py @@ -20,7 +20,7 @@ from sklearn.linear_model import LogisticRegressionCV from torch import nn -from autointent._utils import require +from autointent._deps import require from autointent.configs import CrossEncoderConfig from autointent.custom_types import RerankedItem @@ -118,7 +118,7 @@ def __init__( output_range: Range of the output probabilities ([0, 1] for sigmoid, [-1, 1] for tanh) """ # Lazy import sentence-transformers - require("sentence_transformers", extra="sentence-transformers") + require("sentence-transformers") from sentence_transformers import CrossEncoder self.config = CrossEncoderConfig.from_search_config(cross_encoder_config) diff --git a/src/autointent/generation/_generator.py b/src/autointent/generation/_generator.py index 3a919c656..44d1dcdee 100644 --- a/src/autointent/generation/_generator.py +++ b/src/autointent/generation/_generator.py @@ -11,7 +11,7 @@ from dotenv import load_dotenv from pydantic import BaseModel, ValidationError -from autointent._utils import require +from autointent._deps import require from autointent.generation.chat_templates import Message, Role from ._cache import StructuredOutputCache @@ -139,7 +139,7 @@ def __init__( client_params: Additional parameters for client. **generation_params: Additional generation parameters to override defaults passed to OpenAI completions API. """ - require("openai", "openai") + require("openai") import openai base_url = base_url or os.getenv("OPENAI_BASE_URL") diff --git a/src/autointent/modules/scoring/_bert.py b/src/autointent/modules/scoring/_bert.py index d5da50c79..816fd863e 100644 --- a/src/autointent/modules/scoring/_bert.py +++ b/src/autointent/modules/scoring/_bert.py @@ -13,7 +13,7 @@ from autointent import Context from autointent._callbacks import REPORTERS_NAMES -from autointent._utils import require +from autointent._deps import require from autointent.configs import EarlyStoppingConfig, HFModelConfig from autointent.metrics import SCORING_METRICS_MULTICLASS, SCORING_METRICS_MULTILABEL from autointent.modules.base import BaseScorer @@ -88,7 +88,7 @@ def __init__( early_stopping_config: EarlyStoppingConfig | dict[str, Any] | None = None, print_progress: bool = False, ) -> None: - require("transformers", "transformers") + require("transformers") self.classification_model_config = HFModelConfig.from_search_config(classification_model_config) self.num_train_epochs = num_train_epochs self.batch_size = batch_size diff --git a/src/autointent/modules/scoring/_catboost/catboost_scorer.py b/src/autointent/modules/scoring/_catboost/catboost_scorer.py index cb1770374..82eba5a40 100644 --- a/src/autointent/modules/scoring/_catboost/catboost_scorer.py +++ b/src/autointent/modules/scoring/_catboost/catboost_scorer.py @@ -11,7 +11,7 @@ from pydantic import PositiveInt from autointent import Context, Embedder -from autointent._utils import require +from autointent._deps import require from autointent.configs import EmbedderConfig, TaskTypeEnum, initialize_embedder_config from autointent.custom_types import FloatFromZeroToOne, ListOfLabels from autointent.modules.base import BaseScorer @@ -110,7 +110,7 @@ def __init__( **catboost_kwargs: Any, # noqa: ANN401 ) -> None: # Lazy import catboost - require("catboost", extra="catboost") + require("catboost") self.val_fraction = val_fraction self.early_stopping_rounds = early_stopping_rounds diff --git a/src/autointent/modules/scoring/_lora/lora.py b/src/autointent/modules/scoring/_lora/lora.py index 2b0030815..f310bd3eb 100644 --- a/src/autointent/modules/scoring/_lora/lora.py +++ b/src/autointent/modules/scoring/_lora/lora.py @@ -6,8 +6,8 @@ from typing import TYPE_CHECKING, Any, Literal from autointent import Context +from autointent._deps import require from autointent._dump_tools import Dumper -from autointent._utils import require from autointent.configs import EarlyStoppingConfig, HFModelConfig from autointent.modules.scoring._bert import BertScorer @@ -75,7 +75,7 @@ def __init__( **lora_kwargs: Any, # noqa: ANN401 ) -> None: # Lazy import peft - require("peft", extra="peft") + require("peft") from peft import LoraConfig # early stopping doesn't work with lora for now https://github.com/huggingface/transformers/issues/38130 diff --git a/src/autointent/modules/scoring/_ptuning/ptuning.py b/src/autointent/modules/scoring/_ptuning/ptuning.py index 4edb86749..162ee502d 100644 --- a/src/autointent/modules/scoring/_ptuning/ptuning.py +++ b/src/autointent/modules/scoring/_ptuning/ptuning.py @@ -8,8 +8,8 @@ from pydantic import PositiveInt from autointent import Context +from autointent._deps import require from autointent._dump_tools import Dumper -from autointent._utils import require from autointent.configs import EarlyStoppingConfig, HFModelConfig from autointent.modules.scoring._bert import BertScorer @@ -73,7 +73,7 @@ def __init__( # noqa: PLR0913 **ptuning_kwargs: Any, # noqa: ANN401 ) -> None: # Lazy import peft - require("peft", extra="peft") + require("peft") from peft import PromptEncoderConfig, PromptEncoderReparameterizationType, TaskType diff --git a/tests/test_deps.py b/tests/test_deps.py new file mode 100644 index 000000000..a5c74a716 --- /dev/null +++ b/tests/test_deps.py @@ -0,0 +1,256 @@ +from __future__ import annotations + +import re +from importlib import metadata +from typing import TYPE_CHECKING, get_args + +import pytest +from packaging.requirements import Requirement + +import autointent._deps as deps + +if TYPE_CHECKING: + from collections.abc import Iterator + +_EXTRA_RE = re.compile(r"""extra\s*==\s*['"]([^'"]+)['"]""") + + +class _FakeMeta: + def __init__(self, extras: list[str]) -> None: + self._extras = extras + + def get_all(self, name: str, failobj: list[str] | None = None) -> list[str] | None: + if name == "Provides-Extra": + return list(self._extras) + return failobj + + +def _patch_metadata( + monkeypatch: pytest.MonkeyPatch, + requires_map: dict[str, list[str]], + versions: dict[str, str], +) -> None: + """Patch importlib.metadata so deps.* sees a synthetic dependency graph. + + requires_map: {dist_name: [PEP 508 requirement string, ...]} + versions: {dist_name: installed_version_string} (absent key => not installed) + """ + def fake_requires(dist: str) -> list[str]: + # Mirror the real importlib.metadata.requires: a dist with no metadata + # (i.e. not installed) raises PackageNotFoundError rather than returning []. + # A dist that is installed but has no requirements is modelled by an empty + # list in requires_map. + if dist not in requires_map: + raise metadata.PackageNotFoundError(dist) + return requires_map[dist] + + def fake_version(name: str) -> str: + if name not in versions: + raise metadata.PackageNotFoundError(name) + return versions[name] + + def fake_metadata(dist: str) -> _FakeMeta: + extras = sorted({e for s in requires_map.get(dist, []) for e in _EXTRA_RE.findall(s)}) + return _FakeMeta(extras) + + monkeypatch.setattr(metadata, "requires", fake_requires) + monkeypatch.setattr(metadata, "version", fake_version) + monkeypatch.setattr(metadata, "metadata", fake_metadata) + + +def test_check_returns_none_when_satisfied(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_metadata(monkeypatch, {}, {"catboost": "1.5.0"}) + assert deps._check(Requirement("catboost>=1.2.8,<2.0.0")) is None + + +def test_check_reports_missing(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_metadata(monkeypatch, {}, {}) + problem = deps._check(Requirement("catboost>=1.2.8")) + assert problem is not None + assert "catboost" in problem + assert "not installed" in problem + + +def test_check_reports_outdated(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_metadata(monkeypatch, {}, {"catboost": "1.0.0"}) + problem = deps._check(Requirement("catboost>=1.2.8,<2.0.0")) + assert problem is not None + assert "installed: 1.0.0" in problem + + +def test_iter_extra_reqs_selects_only_extra_members(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_metadata( + monkeypatch, + {"autointent": [ + "numpy>=1.0 ; python_version >= '3.0'", # base dep w/ env marker -> excluded + "torch>=2.0", # base dep, no marker -> excluded + "catboost>=1.2.8,<2.0.0 ; extra == 'catboost'", # extra member -> included + "peft>=0.10.0 ; extra == 'peft'", # different extra -> excluded + ]}, + {}, + ) + reqs = deps._iter_extra_reqs("autointent", "catboost") + assert {r.name for r in reqs} == {"catboost"} + + +@pytest.fixture(autouse=True) +def _clear_resolve_cache() -> Iterator[None]: + deps._resolve_cached.cache_clear() + yield + deps._resolve_cached.cache_clear() + + +def test_resolve_recurses_into_nested_extra(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_metadata( + monkeypatch, + { + "autointent": ["transformers[torch]>=4.49.0,<5.0.0 ; extra == 'transformers'"], + "transformers": [ + "torch>=2.2 ; extra == 'torch'", + "accelerate>=0.26.0 ; extra == 'torch'", + ], + }, + {}, + ) + reqs = deps._resolve("autointent", "transformers", set()) + assert {r.name for r in reqs} == {"transformers", "torch", "accelerate"} + + +def test_resolve_terminates_on_cycle(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_metadata( + monkeypatch, + {"pkg": [ + "pkg[b]>=1.0 ; extra == 'a'", + "pkg[a]>=1.0 ; extra == 'b'", + ]}, + {}, + ) + reqs = deps._resolve("pkg", "a", set()) + assert {r.name for r in reqs} == {"pkg"} + + +def test_resolve_cached_returns_tuple(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_metadata( + monkeypatch, + {"autointent": ["catboost>=1.2.8 ; extra == 'catboost'"]}, + {}, + ) + result = deps._resolve_cached("autointent", "catboost") + assert isinstance(result, tuple) + assert {r.name for r in result} == {"catboost"} + + +def test_require_passes_when_all_present(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_metadata( + monkeypatch, + {"autointent": ["catboost>=1.2.8,<2.0.0 ; extra == 'catboost'"]}, + {"catboost": "1.5.0"}, + ) + deps.require("catboost") # must not raise + + +def test_require_raises_for_missing_leaf(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_metadata( + monkeypatch, + {"autointent": ["catboost>=1.2.8,<2.0.0 ; extra == 'catboost'"]}, + {}, + ) + with pytest.raises(ImportError) as exc: + deps.require("catboost") + text = str(exc.value) + assert "catboost" in text + assert "not installed" in text + assert "pip install 'autointent[catboost]'" in text + + +def test_require_raises_for_outdated_version(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_metadata( + monkeypatch, + {"autointent": ["catboost>=1.2.8,<2.0.0 ; extra == 'catboost'"]}, + {"catboost": "1.0.0"}, + ) + with pytest.raises(ImportError) as exc: + deps.require("catboost") + assert "installed: 1.0.0" in str(exc.value) + + +def test_require_detects_missing_nested_accelerate(monkeypatch: pytest.MonkeyPatch) -> None: + # Regression for #322: accelerate lives in transformers' own [torch] extra, + # so a transformers-present-but-accelerate-absent env must still be flagged. + _patch_metadata( + monkeypatch, + { + "autointent": ["transformers[torch]>=4.49.0,<5.0.0 ; extra == 'transformers'"], + "transformers": [ + "torch>=2.2 ; extra == 'torch'", + "accelerate>=0.26.0 ; extra == 'torch'", + ], + }, + {"transformers": "4.49.0", "torch": "2.2.0"}, # accelerate absent + ) + with pytest.raises(ImportError) as exc: + deps.require("transformers") + assert "accelerate" in str(exc.value) + + +def test_require_reports_extra_package_entirely_missing(monkeypatch: pytest.MonkeyPatch) -> None: + # Bare-install path: the extra's top-level package isn't installed at all, so + # recursing into its nested [torch] extra would read missing metadata. The + # resolver must NOT leak a raw PackageNotFoundError; instead the parent + # `transformers[torch]` requirement is flagged as missing with the install hint. + _patch_metadata( + monkeypatch, + {"autointent": ["transformers[torch]>=4.49.0,<5.0.0 ; extra == 'transformers'"]}, + {}, # transformers (and everything else) absent + ) + with pytest.raises(ImportError) as exc: + deps.require("transformers") + text = str(exc.value) + assert "transformers" in text + assert "not installed" in text + assert "pip install 'autointent[transformers]'" in text + + +def test_iter_extra_reqs_returns_empty_for_uninstalled_dist(monkeypatch: pytest.MonkeyPatch) -> None: + # The metadata read for a not-installed dist must be swallowed and yield []. + _patch_metadata(monkeypatch, {}, {}) + assert deps._iter_extra_reqs("not-installed", "torch") == [] + + +def test_require_rejects_unknown_extra(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_metadata( + monkeypatch, + {"autointent": ["catboost>=1.2.8 ; extra == 'catboost'"]}, + {"catboost": "1.5.0"}, + ) + # "transfomers" is intentionally invalid (typo) to exercise the runtime guard; + # the type: ignore is required because `Extra` now rejects it at type-check time. + with pytest.raises(ValueError, match="no extra 'transfomers'"): + deps.require("transfomers") # type: ignore[arg-type] + + +def test_resolve_reads_real_autointent_metadata() -> None: + # catboost is deliberately chosen: a flat, recursion-free extra, so this real- + # metadata wiring check is deterministic regardless of what CI installs. + reqs = deps._resolve_cached("autointent", "catboost") + assert any(r.name == "catboost" for r in reqs) + assert all(not r.extras for r in reqs) # documents the "no nested extra" premise + + +def test_extra_literal_matches_real_metadata() -> None: + # The `Extra` Literal is a hand-maintained mirror of the real Provides-Extra + # metadata. This fails if pyproject gains/loses an extra without the Literal being + # updated (or vice versa), keeping the static type honest and preventing the + # manual-sync drift the metadata-driven design otherwise removes. + assert set(get_args(deps.Extra)) == deps._provides_extras("autointent") + + +def test_resolve_every_real_extra_without_raising() -> None: + # Walk every extra autointent actually declares (incl. transformers[torch], + # which recurses into a nested extra) against real metadata. The resolver must + # never raise, regardless of which optional packages CI installed -- this is the + # real-metadata guard for the not-installed-nested-dist fix. + extras = deps._provides_extras("autointent") + assert extras # sanity: metadata wiring returns *something* + for extra in extras: + deps._resolve_cached("autointent", extra) # must not raise diff --git a/user_guides/advanced/02_embedder_configuration.py b/user_guides/advanced/02_embedder_configuration.py index f788e3334..1dd436ecf 100644 --- a/user_guides/advanced/02_embedder_configuration.py +++ b/user_guides/advanced/02_embedder_configuration.py @@ -19,7 +19,7 @@ pip install "autointent[sentence-transformers]" ``` -Other backends need their own extras, for example `autointent[openai]` or `autointent[vllm]`, as shown in the sections below. When a backend package is missing, code paths that need it typically call `autointent._utils.require`, which raises an `ImportError` that includes the matching `pip install autointent[]` hint. +Other backends need their own extras, for example `autointent[openai]` or `autointent[vllm]`, as shown in the sections below. When a backend package is missing, code paths that need it typically call `autointent._deps.require`, which raises an `ImportError` that includes the matching `pip install autointent[]` hint. ## Configuration Approaches