Skip to content

Commit 483ead2

Browse files
authored
feat(deps): metadata-driven require(extra) dependency guard (#339)
1 parent 413d0ff commit 483ead2

15 files changed

Lines changed: 491 additions & 54 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ dependencies = [
4545
"aiometer (>=1.0.0,<2.0.0)",
4646
"aiofiles (>=24.1.0,<25.0.0)",
4747
"threadpoolctl (>=3.0.0,<4.0.0)",
48+
"packaging (>=23.2)",
4849
]
4950

5051
[project.optional-dependencies]

src/autointent/_deps.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
"""Validate optional-extra dependencies from installed package metadata.
2+
3+
The :func:`require` guard checks that every dependency of an ``autointent`` extra
4+
is installed and version-satisfied. It reads the metadata that the build baked into
5+
the installed distribution (via :mod:`importlib.metadata`) rather than the source
6+
``pyproject.toml``, which is not shipped in the wheel. Nested extras are resolved
7+
recursively, so e.g. the ``transformers`` extra (``transformers[torch]``)
8+
transitively requires ``accelerate`` and that is checked too.
9+
"""
10+
11+
from __future__ import annotations
12+
13+
from functools import cache
14+
from importlib import metadata
15+
from typing import Literal
16+
17+
from packaging.requirements import Requirement
18+
from packaging.utils import canonicalize_name
19+
20+
_DIST = "autointent"
21+
22+
# Names of the optional-dependency extras autointent declares, mirrored from the
23+
# installed ``Provides-Extra`` metadata (and thus pyproject's
24+
# [project.optional-dependencies]). Typing ``require``'s parameter with this makes
25+
# mypy reject misspelled extra names at call sites; the runtime check in ``require``
26+
# stays the source of truth (mypy isn't run at runtime, and ``dist`` overrides or
27+
# dynamic calls bypass static typing). Kept in sync with the real metadata by
28+
# tests/test_deps.py::test_extra_literal_matches_real_metadata.
29+
Extra = Literal[
30+
"catboost",
31+
"codecarbon",
32+
"dspy",
33+
"fastapi",
34+
"fastmcp",
35+
"openai",
36+
"opensearch",
37+
"peft",
38+
"sentence-transformers",
39+
"transformers",
40+
"vllm",
41+
"wandb",
42+
]
43+
44+
45+
def _check(req: Requirement) -> str | None:
46+
"""Check a single requirement against the installed environment.
47+
48+
Args:
49+
req: The parsed requirement to validate.
50+
51+
Returns:
52+
A human-readable problem description if the distribution is missing or its
53+
installed version does not satisfy ``req.specifier``; ``None`` otherwise.
54+
"""
55+
try:
56+
installed = metadata.version(req.name)
57+
except metadata.PackageNotFoundError:
58+
return f"{req.name}{req.specifier} (not installed)"
59+
if req.specifier and not req.specifier.contains(installed, prereleases=True):
60+
return f"{req.name}{req.specifier} (installed: {installed})"
61+
return None
62+
63+
64+
def _iter_extra_reqs(dist: str, extra: str) -> list[Requirement]:
65+
"""Return the requirements of ``dist`` activated by ``extra``.
66+
67+
Args:
68+
dist: Distribution name whose metadata is read.
69+
extra: Extra name whose dependencies are wanted.
70+
71+
Returns:
72+
The parsed requirements activated by ``extra`` in the current environment,
73+
or an empty list if ``dist`` is not installed (its metadata is unavailable).
74+
"""
75+
target = str(canonicalize_name(extra))
76+
result: list[Requirement] = []
77+
try:
78+
reqs = metadata.requires(dist)
79+
except metadata.PackageNotFoundError:
80+
# `dist` itself isn't installed, so we can't read its nested-extra
81+
# requirements. That's fine: the parent requirement that led us to recurse
82+
# here (e.g. `transformers[torch]`) was already collected by the caller and
83+
# `_check` will flag it as "not installed", producing the proper aggregated
84+
# ImportError with the install hint -- rather than letting a raw
85+
# PackageNotFoundError leak out of the resolver.
86+
return []
87+
for spec in reqs or []:
88+
req = Requirement(spec)
89+
# `req.marker` is the parsed `;` clause of the PEP 508 requirement (a
90+
# packaging Marker), or None when the requirement has no `;` clause. There
91+
# are three cases:
92+
# (1) no marker -> an unconditional base dependency;
93+
# (2) a marker that references `extra` -> belongs to an extra;
94+
# (3) a marker with only environment conditions (e.g. `python_version < "3.9"`)
95+
# -> still a base dependency, just platform-conditional.
96+
# So "has a marker" does NOT mean "belongs to an extra";
97+
marker = req.marker
98+
# Here we cancel out case (1)
99+
if marker is None:
100+
continue
101+
# `marker.evaluate(env)` resolves the whole boolean expression to a bool,
102+
# filling any keys we omit (python_version, sys_platform, ...) from the
103+
# running interpreter. A single `evaluate({"extra": target})` is not enough
104+
# to prove membership: an env-conditional base dep also passes it, because
105+
# its truth comes from the environment and the `extra` key is ignored.
106+
# The discriminator is the second evaluation: a *true* extra dependency
107+
# flips active -> inactive when the extra is removed, whereas a base dep is
108+
# unaffected. So "active with the extra AND inactive with no extra" means
109+
# "active *because of* this extra", which keeps extra members and drops
110+
# base deps. We always pass `extra` explicitly (`""` = base install, no
111+
# extras) since a marker that references `extra` can't be evaluated without it.
112+
# So here we cancel out case (3)
113+
if marker.evaluate({"extra": target}) and not marker.evaluate({"extra": ""}):
114+
result.append(req)
115+
return result
116+
117+
118+
def _resolve(dist: str, extra: str, seen: set[tuple[str, str]]) -> list[Requirement]:
119+
"""Recursively collect every leaf requirement activated by ``dist[extra]``.
120+
121+
Each activated requirement is returned for version checking, and any nested
122+
extras it declares (e.g. ``transformers[torch]``) are resolved in turn.
123+
124+
Args:
125+
dist: Distribution name to start from.
126+
extra: Extra name to resolve.
127+
seen: Visited ``(dist, extra)`` pairs, used to break dependency cycles.
128+
129+
Returns:
130+
The flattened list of requirements to validate.
131+
"""
132+
key = (str(canonicalize_name(dist)), str(canonicalize_name(extra)))
133+
if key in seen:
134+
return []
135+
seen.add(key)
136+
137+
leaves: list[Requirement] = []
138+
for req in _iter_extra_reqs(dist, extra):
139+
leaves.append(req)
140+
for nested in req.extras:
141+
leaves.extend(_resolve(req.name, nested, seen))
142+
return leaves
143+
144+
145+
@cache
146+
def _resolve_cached(dist: str, extra: str) -> tuple[Requirement, ...]:
147+
"""Memoized :func:`_resolve`; the metadata graph shape is stable per process.
148+
149+
Args:
150+
dist: Distribution name to start from.
151+
extra: Extra name to resolve.
152+
153+
Returns:
154+
The resolved requirements as an immutable tuple.
155+
"""
156+
return tuple(_resolve(dist, extra, set()))
157+
158+
159+
def _provides_extras(dist: str) -> set[str]:
160+
"""Return the normalized set of extras declared by ``dist``.
161+
162+
Args:
163+
dist: Distribution name whose metadata is read.
164+
165+
Returns:
166+
Normalized extra names from the distribution's ``Provides-Extra`` metadata.
167+
"""
168+
md = metadata.metadata(dist)
169+
return {str(canonicalize_name(e)) for e in (md.get_all("Provides-Extra") or [])}
170+
171+
172+
def require(extra: Extra) -> None:
173+
"""Ensure every dependency of an ``autointent`` extra is installed and current.
174+
175+
Args:
176+
extra: The extra to validate, e.g. ``"transformers"``.
177+
178+
Raises:
179+
ValueError: If ``autointent`` declares no such ``extra`` (typically a typo).
180+
ImportError: If any required dependency is missing or its installed version
181+
does not satisfy the constraint declared in the metadata.
182+
"""
183+
known = _provides_extras(_DIST)
184+
if str(canonicalize_name(extra)) not in known:
185+
msg = f"'{_DIST}' declares no extra '{extra}'. Known extras: {', '.join(sorted(known))}."
186+
raise ValueError(msg)
187+
188+
problems: list[str] = []
189+
for req in _resolve_cached(_DIST, extra):
190+
problem = _check(req)
191+
if problem is not None and problem not in problems:
192+
problems.append(problem)
193+
194+
if problems:
195+
bullets = "\n".join(f" - {p}" for p in problems)
196+
msg = (
197+
f"Feature requires extra '{extra}', but dependencies are missing or outdated:\n"
198+
f"{bullets}\n"
199+
f"Install with: pip install '{_DIST}[{extra}]'"
200+
)
201+
raise ImportError(msg)

src/autointent/_dump_tools/unit_dumpers.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from sklearn.base import BaseEstimator
1414

1515
from autointent import Embedder, Ranker, VectorIndex
16-
from autointent._utils import require
16+
from autointent._deps import require
1717
from autointent._wrappers import BaseTorchModule
1818
from autointent.schemas import TagsList
1919

@@ -225,8 +225,8 @@ def dump(obj: PeftModel, path: Path, exists_ok: bool) -> None:
225225

226226
@staticmethod
227227
def load(path: Path, **kwargs: Any) -> PeftModel: # noqa: ANN401
228-
require("peft", extra="peft")
229-
require("transformers", extra="transformers")
228+
require("peft")
229+
require("transformers")
230230
import peft
231231
import transformers
232232

@@ -245,7 +245,7 @@ def load(path: Path, **kwargs: Any) -> PeftModel: # noqa: ANN401
245245
@classmethod
246246
def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401
247247
try:
248-
require("peft", extra="peft")
248+
require("peft")
249249
import peft
250250

251251
return isinstance(obj, peft.PeftModel)
@@ -263,15 +263,15 @@ def dump(obj: PreTrainedModel, path: Path, exists_ok: bool) -> None:
263263

264264
@staticmethod
265265
def load(path: Path, **kwargs: Any) -> PreTrainedModel: # noqa: ANN401
266-
require("transformers", extra="transformers")
266+
require("transformers")
267267
import transformers
268268

269269
return transformers.AutoModelForSequenceClassification.from_pretrained(path) # type: ignore[no-any-return]
270270

271271
@classmethod
272272
def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401
273273
try:
274-
require("transformers", extra="transformers")
274+
require("transformers")
275275
import transformers
276276

277277
return isinstance(obj, transformers.PreTrainedModel)
@@ -289,15 +289,15 @@ def dump(obj: PreTrainedTokenizer | PreTrainedTokenizerFast, path: Path, exists_
289289

290290
@staticmethod
291291
def load(path: Path, **kwargs: Any) -> PreTrainedTokenizer | PreTrainedTokenizerFast: # noqa: ANN401
292-
require("transformers", extra="transformers")
292+
require("transformers")
293293
import transformers
294294

295295
return transformers.AutoTokenizer.from_pretrained(path) # type: ignore[no-any-return,no-untyped-call]
296296

297297
@classmethod
298298
def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401
299299
try:
300-
require("transformers", extra="transformers")
300+
require("transformers")
301301
import transformers
302302

303303
return isinstance(obj, transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerFast)
@@ -342,7 +342,7 @@ def dump(obj: CatBoostClassifier, path: Path, exists_ok: bool) -> None: # noqa:
342342

343343
@staticmethod
344344
def load(path: Path, **kwargs: Any) -> CatBoostClassifier: # noqa: ANN401
345-
require("catboost", extra="catboost")
345+
require("catboost")
346346
from catboost import CatBoostClassifier
347347

348348
model = CatBoostClassifier()
@@ -352,7 +352,7 @@ def load(path: Path, **kwargs: Any) -> CatBoostClassifier: # noqa: ANN401
352352
@classmethod
353353
def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401
354354
try:
355-
require("catboost", extra="catboost")
355+
require("catboost")
356356
from catboost import CatBoostClassifier
357357

358358
return isinstance(obj, CatBoostClassifier)

src/autointent/_utils.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Utils."""
22

3-
import importlib
43
from typing import TypeVar
54

65
import torch
@@ -28,22 +27,3 @@ def detect_device() -> str:
2827
return "cpu"
2928

3029

31-
def require(dependency: str, extra: str | None = None) -> None:
32-
"""Try to import dependency, raise informative ImportError if missing.
33-
34-
Args:
35-
dependency: The name of the module to import
36-
extra: Optional extra package name for pip install instructions
37-
38-
Returns:
39-
The imported module
40-
41-
Raises:
42-
ImportError: If the dependency is not installed
43-
"""
44-
try:
45-
importlib.import_module(dependency)
46-
except ImportError as e:
47-
extra_info = f" Install with `pip install autointent[{extra}]`." if extra else ""
48-
msg = f"Missing dependency '{dependency}' required for this feature.{extra_info}"
49-
raise ImportError(msg) from e

src/autointent/_wrappers/embedder/openai.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
import numpy.typing as npt
1313
import torch
1414

15+
from autointent._deps import require
1516
from autointent._hash import Hasher
16-
from autointent._utils import require
1717
from autointent.configs._embedder import OpenaiEmbeddingConfig
1818

1919
from .base import BaseEmbeddingBackend
@@ -77,7 +77,7 @@ def _openai_api_error_message(exc: BaseException, *, batch_size: int) -> str:
7777

7878
def _tiktoken_encoding_for_embedding_model(model_name: str) -> Encoding:
7979
"""Resolve tiktoken encoding for batch sizing; fallback for unknown provider model ids."""
80-
require("tiktoken", "openai")
80+
require("openai")
8181
import tiktoken
8282

8383
try:
@@ -110,7 +110,7 @@ def __init__(self, config: OpenaiEmbeddingConfig) -> None:
110110
Args:
111111
config: Configuration for OpenAI embeddings.
112112
"""
113-
require("openai", "openai")
113+
require("openai")
114114
self.config = config
115115
self._event_loop: asyncio.AbstractEventLoop | None = None
116116

src/autointent/_wrappers/embedder/sentence_transformers.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from datasets import Dataset
1515
from sklearn.model_selection import train_test_split
1616

17+
from autointent._deps import require
1718
from autointent._hash import Hasher
18-
from autointent._utils import require
1919
from autointent.configs._embedder import SentenceTransformerEmbeddingConfig
2020

2121
from .base import BaseEmbeddingBackend
@@ -42,7 +42,7 @@ def _set_training_seed(seed: int) -> None:
4242
if torch.cuda.is_available():
4343
torch.cuda.manual_seed_all(seed)
4444

45-
require("transformers", extra="transformers")
45+
require("transformers")
4646
from transformers import set_seed
4747

4848
set_seed(seed)
@@ -130,7 +130,7 @@ def _load_model(self) -> SentenceTransformer:
130130
"""Load sentence transformers model to device."""
131131
if self._model is None:
132132
# Lazy import sentence-transformers
133-
require("sentence_transformers", extra="sentence-transformers")
133+
require("sentence-transformers")
134134
from sentence_transformers import SentenceTransformer
135135

136136
res = SentenceTransformer(
@@ -294,9 +294,8 @@ def train(self, utterances: list[str], labels: ListOfLabels, config: EmbedderFin
294294
_set_training_seed(config.seed)
295295

296296
# Lazy import sentence-transformers training components (only needed for fine-tuning)
297-
require("sentence_transformers", extra="sentence-transformers")
298-
require("transformers", extra="transformers")
299-
require("accelerate", extra="transformers")
297+
require("sentence-transformers")
298+
require("transformers")
300299
from sentence_transformers import (
301300
SentenceTransformerTrainer,
302301
SentenceTransformerTrainingArguments,

0 commit comments

Comments
 (0)