|
| 1 | +"""Validate optional-extra dependencies from installed package metadata. |
| 2 | +
|
| 3 | +The :func:`require` guard checks that every dependency of an ``autointent`` extra |
| 4 | +is installed and version-satisfied. It reads the metadata that the build baked into |
| 5 | +the installed distribution (via :mod:`importlib.metadata`) rather than the source |
| 6 | +``pyproject.toml``, which is not shipped in the wheel. Nested extras are resolved |
| 7 | +recursively, so e.g. the ``transformers`` extra (``transformers[torch]``) |
| 8 | +transitively requires ``accelerate`` and that is checked too. |
| 9 | +""" |
| 10 | + |
| 11 | +from __future__ import annotations |
| 12 | + |
| 13 | +from functools import cache |
| 14 | +from importlib import metadata |
| 15 | +from typing import Literal |
| 16 | + |
| 17 | +from packaging.requirements import Requirement |
| 18 | +from packaging.utils import canonicalize_name |
| 19 | + |
| 20 | +_DIST = "autointent" |
| 21 | + |
| 22 | +# Names of the optional-dependency extras autointent declares, mirrored from the |
| 23 | +# installed ``Provides-Extra`` metadata (and thus pyproject's |
| 24 | +# [project.optional-dependencies]). Typing ``require``'s parameter with this makes |
| 25 | +# mypy reject misspelled extra names at call sites; the runtime check in ``require`` |
| 26 | +# stays the source of truth (mypy isn't run at runtime, and ``dist`` overrides or |
| 27 | +# dynamic calls bypass static typing). Kept in sync with the real metadata by |
| 28 | +# tests/test_deps.py::test_extra_literal_matches_real_metadata. |
| 29 | +Extra = Literal[ |
| 30 | + "catboost", |
| 31 | + "codecarbon", |
| 32 | + "dspy", |
| 33 | + "fastapi", |
| 34 | + "fastmcp", |
| 35 | + "openai", |
| 36 | + "opensearch", |
| 37 | + "peft", |
| 38 | + "sentence-transformers", |
| 39 | + "transformers", |
| 40 | + "vllm", |
| 41 | + "wandb", |
| 42 | +] |
| 43 | + |
| 44 | + |
| 45 | +def _check(req: Requirement) -> str | None: |
| 46 | + """Check a single requirement against the installed environment. |
| 47 | +
|
| 48 | + Args: |
| 49 | + req: The parsed requirement to validate. |
| 50 | +
|
| 51 | + Returns: |
| 52 | + A human-readable problem description if the distribution is missing or its |
| 53 | + installed version does not satisfy ``req.specifier``; ``None`` otherwise. |
| 54 | + """ |
| 55 | + try: |
| 56 | + installed = metadata.version(req.name) |
| 57 | + except metadata.PackageNotFoundError: |
| 58 | + return f"{req.name}{req.specifier} (not installed)" |
| 59 | + if req.specifier and not req.specifier.contains(installed, prereleases=True): |
| 60 | + return f"{req.name}{req.specifier} (installed: {installed})" |
| 61 | + return None |
| 62 | + |
| 63 | + |
| 64 | +def _iter_extra_reqs(dist: str, extra: str) -> list[Requirement]: |
| 65 | + """Return the requirements of ``dist`` activated by ``extra``. |
| 66 | +
|
| 67 | + Args: |
| 68 | + dist: Distribution name whose metadata is read. |
| 69 | + extra: Extra name whose dependencies are wanted. |
| 70 | +
|
| 71 | + Returns: |
| 72 | + The parsed requirements activated by ``extra`` in the current environment, |
| 73 | + or an empty list if ``dist`` is not installed (its metadata is unavailable). |
| 74 | + """ |
| 75 | + target = str(canonicalize_name(extra)) |
| 76 | + result: list[Requirement] = [] |
| 77 | + try: |
| 78 | + reqs = metadata.requires(dist) |
| 79 | + except metadata.PackageNotFoundError: |
| 80 | + # `dist` itself isn't installed, so we can't read its nested-extra |
| 81 | + # requirements. That's fine: the parent requirement that led us to recurse |
| 82 | + # here (e.g. `transformers[torch]`) was already collected by the caller and |
| 83 | + # `_check` will flag it as "not installed", producing the proper aggregated |
| 84 | + # ImportError with the install hint -- rather than letting a raw |
| 85 | + # PackageNotFoundError leak out of the resolver. |
| 86 | + return [] |
| 87 | + for spec in reqs or []: |
| 88 | + req = Requirement(spec) |
| 89 | + # `req.marker` is the parsed `;` clause of the PEP 508 requirement (a |
| 90 | + # packaging Marker), or None when the requirement has no `;` clause. There |
| 91 | + # are three cases: |
| 92 | + # (1) no marker -> an unconditional base dependency; |
| 93 | + # (2) a marker that references `extra` -> belongs to an extra; |
| 94 | + # (3) a marker with only environment conditions (e.g. `python_version < "3.9"`) |
| 95 | + # -> still a base dependency, just platform-conditional. |
| 96 | + # So "has a marker" does NOT mean "belongs to an extra"; |
| 97 | + marker = req.marker |
| 98 | + # Here we cancel out case (1) |
| 99 | + if marker is None: |
| 100 | + continue |
| 101 | + # `marker.evaluate(env)` resolves the whole boolean expression to a bool, |
| 102 | + # filling any keys we omit (python_version, sys_platform, ...) from the |
| 103 | + # running interpreter. A single `evaluate({"extra": target})` is not enough |
| 104 | + # to prove membership: an env-conditional base dep also passes it, because |
| 105 | + # its truth comes from the environment and the `extra` key is ignored. |
| 106 | + # The discriminator is the second evaluation: a *true* extra dependency |
| 107 | + # flips active -> inactive when the extra is removed, whereas a base dep is |
| 108 | + # unaffected. So "active with the extra AND inactive with no extra" means |
| 109 | + # "active *because of* this extra", which keeps extra members and drops |
| 110 | + # base deps. We always pass `extra` explicitly (`""` = base install, no |
| 111 | + # extras) since a marker that references `extra` can't be evaluated without it. |
| 112 | + # So here we cancel out case (3) |
| 113 | + if marker.evaluate({"extra": target}) and not marker.evaluate({"extra": ""}): |
| 114 | + result.append(req) |
| 115 | + return result |
| 116 | + |
| 117 | + |
| 118 | +def _resolve(dist: str, extra: str, seen: set[tuple[str, str]]) -> list[Requirement]: |
| 119 | + """Recursively collect every leaf requirement activated by ``dist[extra]``. |
| 120 | +
|
| 121 | + Each activated requirement is returned for version checking, and any nested |
| 122 | + extras it declares (e.g. ``transformers[torch]``) are resolved in turn. |
| 123 | +
|
| 124 | + Args: |
| 125 | + dist: Distribution name to start from. |
| 126 | + extra: Extra name to resolve. |
| 127 | + seen: Visited ``(dist, extra)`` pairs, used to break dependency cycles. |
| 128 | +
|
| 129 | + Returns: |
| 130 | + The flattened list of requirements to validate. |
| 131 | + """ |
| 132 | + key = (str(canonicalize_name(dist)), str(canonicalize_name(extra))) |
| 133 | + if key in seen: |
| 134 | + return [] |
| 135 | + seen.add(key) |
| 136 | + |
| 137 | + leaves: list[Requirement] = [] |
| 138 | + for req in _iter_extra_reqs(dist, extra): |
| 139 | + leaves.append(req) |
| 140 | + for nested in req.extras: |
| 141 | + leaves.extend(_resolve(req.name, nested, seen)) |
| 142 | + return leaves |
| 143 | + |
| 144 | + |
| 145 | +@cache |
| 146 | +def _resolve_cached(dist: str, extra: str) -> tuple[Requirement, ...]: |
| 147 | + """Memoized :func:`_resolve`; the metadata graph shape is stable per process. |
| 148 | +
|
| 149 | + Args: |
| 150 | + dist: Distribution name to start from. |
| 151 | + extra: Extra name to resolve. |
| 152 | +
|
| 153 | + Returns: |
| 154 | + The resolved requirements as an immutable tuple. |
| 155 | + """ |
| 156 | + return tuple(_resolve(dist, extra, set())) |
| 157 | + |
| 158 | + |
| 159 | +def _provides_extras(dist: str) -> set[str]: |
| 160 | + """Return the normalized set of extras declared by ``dist``. |
| 161 | +
|
| 162 | + Args: |
| 163 | + dist: Distribution name whose metadata is read. |
| 164 | +
|
| 165 | + Returns: |
| 166 | + Normalized extra names from the distribution's ``Provides-Extra`` metadata. |
| 167 | + """ |
| 168 | + md = metadata.metadata(dist) |
| 169 | + return {str(canonicalize_name(e)) for e in (md.get_all("Provides-Extra") or [])} |
| 170 | + |
| 171 | + |
| 172 | +def require(extra: Extra) -> None: |
| 173 | + """Ensure every dependency of an ``autointent`` extra is installed and current. |
| 174 | +
|
| 175 | + Args: |
| 176 | + extra: The extra to validate, e.g. ``"transformers"``. |
| 177 | +
|
| 178 | + Raises: |
| 179 | + ValueError: If ``autointent`` declares no such ``extra`` (typically a typo). |
| 180 | + ImportError: If any required dependency is missing or its installed version |
| 181 | + does not satisfy the constraint declared in the metadata. |
| 182 | + """ |
| 183 | + known = _provides_extras(_DIST) |
| 184 | + if str(canonicalize_name(extra)) not in known: |
| 185 | + msg = f"'{_DIST}' declares no extra '{extra}'. Known extras: {', '.join(sorted(known))}." |
| 186 | + raise ValueError(msg) |
| 187 | + |
| 188 | + problems: list[str] = [] |
| 189 | + for req in _resolve_cached(_DIST, extra): |
| 190 | + problem = _check(req) |
| 191 | + if problem is not None and problem not in problems: |
| 192 | + problems.append(problem) |
| 193 | + |
| 194 | + if problems: |
| 195 | + bullets = "\n".join(f" - {p}" for p in problems) |
| 196 | + msg = ( |
| 197 | + f"Feature requires extra '{extra}', but dependencies are missing or outdated:\n" |
| 198 | + f"{bullets}\n" |
| 199 | + f"Install with: pip install '{_DIST}[{extra}]'" |
| 200 | + ) |
| 201 | + raise ImportError(msg) |
0 commit comments