Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dependencies = [
"aiometer (>=1.0.0,<2.0.0)",
"aiofiles (>=24.1.0,<25.0.0)",
"threadpoolctl (>=3.0.0,<4.0.0)",
"packaging (>=23.2)",
]

[project.optional-dependencies]
Expand Down
201 changes: 201 additions & 0 deletions src/autointent/_deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
"""Validate optional-extra dependencies from installed package metadata.

The :func:`require` guard checks that every dependency of an ``autointent`` extra
is installed and version-satisfied. It reads the metadata that the build baked into
the installed distribution (via :mod:`importlib.metadata`) rather than the source
``pyproject.toml``, which is not shipped in the wheel. Nested extras are resolved
recursively, so e.g. the ``transformers`` extra (``transformers[torch]``)
transitively requires ``accelerate`` and that is checked too.
"""

from __future__ import annotations

from functools import cache
from importlib import metadata
from typing import Literal

from packaging.requirements import Requirement
from packaging.utils import canonicalize_name

_DIST = "autointent"

# Names of the optional-dependency extras autointent declares, mirrored from the
# installed ``Provides-Extra`` metadata (and thus pyproject's
# [project.optional-dependencies]). Typing ``require``'s parameter with this makes
# mypy reject misspelled extra names at call sites; the runtime check in ``require``
# stays the source of truth (mypy isn't run at runtime, and ``dist`` overrides or
# dynamic calls bypass static typing). Kept in sync with the real metadata by
# tests/test_deps.py::test_extra_literal_matches_real_metadata.
Extra = Literal[
"catboost",
"codecarbon",
"dspy",
"fastapi",
"fastmcp",
"openai",
"opensearch",
"peft",
"sentence-transformers",
"transformers",
"vllm",
"wandb",
]


def _check(req: Requirement) -> str | None:
"""Check a single requirement against the installed environment.

Args:
req: The parsed requirement to validate.

Returns:
A human-readable problem description if the distribution is missing or its
installed version does not satisfy ``req.specifier``; ``None`` otherwise.
"""
try:
installed = metadata.version(req.name)
except metadata.PackageNotFoundError:
return f"{req.name}{req.specifier} (not installed)"
if req.specifier and not req.specifier.contains(installed, prereleases=True):
return f"{req.name}{req.specifier} (installed: {installed})"
return None


def _iter_extra_reqs(dist: str, extra: str) -> list[Requirement]:
"""Return the requirements of ``dist`` activated by ``extra``.

Args:
dist: Distribution name whose metadata is read.
extra: Extra name whose dependencies are wanted.

Returns:
The parsed requirements activated by ``extra`` in the current environment,
or an empty list if ``dist`` is not installed (its metadata is unavailable).
"""
target = str(canonicalize_name(extra))
result: list[Requirement] = []
try:
reqs = metadata.requires(dist)
except metadata.PackageNotFoundError:
# `dist` itself isn't installed, so we can't read its nested-extra
# requirements. That's fine: the parent requirement that led us to recurse
# here (e.g. `transformers[torch]`) was already collected by the caller and
# `_check` will flag it as "not installed", producing the proper aggregated
# ImportError with the install hint -- rather than letting a raw
# PackageNotFoundError leak out of the resolver.
return []
for spec in reqs or []:
req = Requirement(spec)
# `req.marker` is the parsed `;` clause of the PEP 508 requirement (a
# packaging Marker), or None when the requirement has no `;` clause. There
# are three cases:
# (1) no marker -> an unconditional base dependency;
# (2) a marker that references `extra` -> belongs to an extra;
# (3) a marker with only environment conditions (e.g. `python_version < "3.9"`)
# -> still a base dependency, just platform-conditional.
# So "has a marker" does NOT mean "belongs to an extra";
marker = req.marker
# Here we cancel out case (1)
if marker is None:
continue
# `marker.evaluate(env)` resolves the whole boolean expression to a bool,
# filling any keys we omit (python_version, sys_platform, ...) from the
# running interpreter. A single `evaluate({"extra": target})` is not enough
# to prove membership: an env-conditional base dep also passes it, because
# its truth comes from the environment and the `extra` key is ignored.
# The discriminator is the second evaluation: a *true* extra dependency
# flips active -> inactive when the extra is removed, whereas a base dep is
# unaffected. So "active with the extra AND inactive with no extra" means
# "active *because of* this extra", which keeps extra members and drops
# base deps. We always pass `extra` explicitly (`""` = base install, no
# extras) since a marker that references `extra` can't be evaluated without it.
# So here we cancel out case (3)
if marker.evaluate({"extra": target}) and not marker.evaluate({"extra": ""}):
result.append(req)
return result


def _resolve(dist: str, extra: str, seen: set[tuple[str, str]]) -> list[Requirement]:
"""Recursively collect every leaf requirement activated by ``dist[extra]``.

Each activated requirement is returned for version checking, and any nested
extras it declares (e.g. ``transformers[torch]``) are resolved in turn.

Args:
dist: Distribution name to start from.
extra: Extra name to resolve.
seen: Visited ``(dist, extra)`` pairs, used to break dependency cycles.

Returns:
The flattened list of requirements to validate.
"""
key = (str(canonicalize_name(dist)), str(canonicalize_name(extra)))
if key in seen:
return []
seen.add(key)

leaves: list[Requirement] = []
for req in _iter_extra_reqs(dist, extra):
leaves.append(req)
for nested in req.extras:
leaves.extend(_resolve(req.name, nested, seen))
return leaves


@cache
def _resolve_cached(dist: str, extra: str) -> tuple[Requirement, ...]:
"""Memoized :func:`_resolve`; the metadata graph shape is stable per process.

Args:
dist: Distribution name to start from.
extra: Extra name to resolve.

Returns:
The resolved requirements as an immutable tuple.
"""
return tuple(_resolve(dist, extra, set()))


def _provides_extras(dist: str) -> set[str]:
"""Return the normalized set of extras declared by ``dist``.

Args:
dist: Distribution name whose metadata is read.

Returns:
Normalized extra names from the distribution's ``Provides-Extra`` metadata.
"""
md = metadata.metadata(dist)
return {str(canonicalize_name(e)) for e in (md.get_all("Provides-Extra") or [])}


def require(extra: Extra) -> None:
"""Ensure every dependency of an ``autointent`` extra is installed and current.

Args:
extra: The extra to validate, e.g. ``"transformers"``.

Raises:
ValueError: If ``autointent`` declares no such ``extra`` (typically a typo).
ImportError: If any required dependency is missing or its installed version
does not satisfy the constraint declared in the metadata.
"""
known = _provides_extras(_DIST)
if str(canonicalize_name(extra)) not in known:
msg = f"'{_DIST}' declares no extra '{extra}'. Known extras: {', '.join(sorted(known))}."
raise ValueError(msg)

problems: list[str] = []
for req in _resolve_cached(_DIST, extra):
problem = _check(req)
if problem is not None and problem not in problems:
problems.append(problem)

if problems:
bullets = "\n".join(f" - {p}" for p in problems)
msg = (
f"Feature requires extra '{extra}', but dependencies are missing or outdated:\n"
f"{bullets}\n"
f"Install with: pip install '{_DIST}[{extra}]'"
)
raise ImportError(msg)
20 changes: 10 additions & 10 deletions src/autointent/_dump_tools/unit_dumpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from sklearn.base import BaseEstimator

from autointent import Embedder, Ranker, VectorIndex
from autointent._utils import require
from autointent._deps import require
from autointent._wrappers import BaseTorchModule
from autointent.schemas import TagsList

Expand Down Expand Up @@ -225,8 +225,8 @@ def dump(obj: PeftModel, path: Path, exists_ok: bool) -> None:

@staticmethod
def load(path: Path, **kwargs: Any) -> PeftModel: # noqa: ANN401
require("peft", extra="peft")
require("transformers", extra="transformers")
require("peft")
require("transformers")
import peft
import transformers

Expand All @@ -245,7 +245,7 @@ def load(path: Path, **kwargs: Any) -> PeftModel: # noqa: ANN401
@classmethod
def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401
try:
require("peft", extra="peft")
require("peft")
import peft

return isinstance(obj, peft.PeftModel)
Expand All @@ -263,15 +263,15 @@ def dump(obj: PreTrainedModel, path: Path, exists_ok: bool) -> None:

@staticmethod
def load(path: Path, **kwargs: Any) -> PreTrainedModel: # noqa: ANN401
require("transformers", extra="transformers")
require("transformers")
import transformers

return transformers.AutoModelForSequenceClassification.from_pretrained(path) # type: ignore[no-any-return]

@classmethod
def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401
try:
require("transformers", extra="transformers")
require("transformers")
import transformers

return isinstance(obj, transformers.PreTrainedModel)
Expand All @@ -289,15 +289,15 @@ def dump(obj: PreTrainedTokenizer | PreTrainedTokenizerFast, path: Path, exists_

@staticmethod
def load(path: Path, **kwargs: Any) -> PreTrainedTokenizer | PreTrainedTokenizerFast: # noqa: ANN401
require("transformers", extra="transformers")
require("transformers")
import transformers

return transformers.AutoTokenizer.from_pretrained(path) # type: ignore[no-any-return,no-untyped-call]

@classmethod
def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401
try:
require("transformers", extra="transformers")
require("transformers")
import transformers

return isinstance(obj, transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerFast)
Expand Down Expand Up @@ -342,7 +342,7 @@ def dump(obj: CatBoostClassifier, path: Path, exists_ok: bool) -> None: # noqa:

@staticmethod
def load(path: Path, **kwargs: Any) -> CatBoostClassifier: # noqa: ANN401
require("catboost", extra="catboost")
require("catboost")
from catboost import CatBoostClassifier

model = CatBoostClassifier()
Expand All @@ -352,7 +352,7 @@ def load(path: Path, **kwargs: Any) -> CatBoostClassifier: # noqa: ANN401
@classmethod
def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401
try:
require("catboost", extra="catboost")
require("catboost")
from catboost import CatBoostClassifier

return isinstance(obj, CatBoostClassifier)
Expand Down
20 changes: 0 additions & 20 deletions src/autointent/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Utils."""

import importlib
from typing import TypeVar

import torch
Expand Down Expand Up @@ -28,22 +27,3 @@ def detect_device() -> str:
return "cpu"


def require(dependency: str, extra: str | None = None) -> None:
"""Try to import dependency, raise informative ImportError if missing.

Args:
dependency: The name of the module to import
extra: Optional extra package name for pip install instructions

Returns:
The imported module

Raises:
ImportError: If the dependency is not installed
"""
try:
importlib.import_module(dependency)
except ImportError as e:
extra_info = f" Install with `pip install autointent[{extra}]`." if extra else ""
msg = f"Missing dependency '{dependency}' required for this feature.{extra_info}"
raise ImportError(msg) from e
6 changes: 3 additions & 3 deletions src/autointent/_wrappers/embedder/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import numpy.typing as npt
import torch

from autointent._deps import require
from autointent._hash import Hasher
from autointent._utils import require
from autointent.configs._embedder import OpenaiEmbeddingConfig

from .base import BaseEmbeddingBackend
Expand Down Expand Up @@ -77,7 +77,7 @@ def _openai_api_error_message(exc: BaseException, *, batch_size: int) -> str:

def _tiktoken_encoding_for_embedding_model(model_name: str) -> Encoding:
"""Resolve tiktoken encoding for batch sizing; fallback for unknown provider model ids."""
require("tiktoken", "openai")
require("openai")
import tiktoken

try:
Expand Down Expand Up @@ -110,7 +110,7 @@ def __init__(self, config: OpenaiEmbeddingConfig) -> None:
Args:
config: Configuration for OpenAI embeddings.
"""
require("openai", "openai")
require("openai")
self.config = config
self._event_loop: asyncio.AbstractEventLoop | None = None

Expand Down
11 changes: 5 additions & 6 deletions src/autointent/_wrappers/embedder/sentence_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from datasets import Dataset
from sklearn.model_selection import train_test_split

from autointent._deps import require
from autointent._hash import Hasher
from autointent._utils import require
from autointent.configs._embedder import SentenceTransformerEmbeddingConfig

from .base import BaseEmbeddingBackend
Expand All @@ -42,7 +42,7 @@ def _set_training_seed(seed: int) -> None:
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)

require("transformers", extra="transformers")
require("transformers")
from transformers import set_seed

set_seed(seed)
Expand Down Expand Up @@ -130,7 +130,7 @@ def _load_model(self) -> SentenceTransformer:
"""Load sentence transformers model to device."""
if self._model is None:
# Lazy import sentence-transformers
require("sentence_transformers", extra="sentence-transformers")
require("sentence-transformers")
from sentence_transformers import SentenceTransformer

res = SentenceTransformer(
Expand Down Expand Up @@ -294,9 +294,8 @@ def train(self, utterances: list[str], labels: ListOfLabels, config: EmbedderFin
_set_training_seed(config.seed)

# Lazy import sentence-transformers training components (only needed for fine-tuning)
require("sentence_transformers", extra="sentence-transformers")
require("transformers", extra="transformers")
require("accelerate", extra="transformers")
require("sentence-transformers")
require("transformers")
from sentence_transformers import (
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
Expand Down
Loading