|
| 1 | +# Copyright 2024, UChicago Argonne, LLC |
| 2 | +# All Rights Reserved |
| 3 | +# Software Name: NEML2 -- the New Engineering material Model Library, version 2 |
| 4 | +# By: Argonne National Laboratory |
| 5 | +# OPEN SOURCE LICENSE (MIT) |
| 6 | +# |
| 7 | +# Permission is hereby granted, free of charge, to any person obtaining a copy |
| 8 | +# of this software and associated documentation files (the "Software"), to deal |
| 9 | +# in the Software without restriction, including without limitation the rights |
| 10 | +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
| 11 | +# copies of the Software, and to permit persons to whom the Software is |
| 12 | +# furnished to do so, subject to the following conditions: |
| 13 | +# |
| 14 | +# The above copyright notice and this permission notice shall be included in |
| 15 | +# all copies or substantial portions of the Software. |
| 16 | +# |
| 17 | +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 18 | +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 19 | +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 20 | +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 21 | +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 22 | +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN |
| 23 | +# THE SOFTWARE. |
| 24 | + |
| 25 | +"""``neml2-stub`` -- regenerate ``.pyi`` stubs for every pybind11 extension. |
| 26 | +
|
| 27 | +Walks the installed ``neml2`` package, finds every binary extension module |
| 28 | +(currently just ``neml2.aoti._aoti``), and runs ``pybind11_stubgen`` against |
| 29 | +each one. The generated stubs land next to their ``.so`` so pyright / |
| 30 | +IDE autocompletion resolve ``from ._aoti import Model`` cleanly. |
| 31 | +
|
| 32 | +Invocation order matters: the script ``import``s every extension to |
| 33 | +introspect it, so the package must be **fully installed** before this runs. |
| 34 | +That means after ``pip install`` (CI runs ``neml2-stub`` between install |
| 35 | +and pyright) or after ``cibuildwheel``'s repair step has unpacked the |
| 36 | +wheel into the test env (``scripts/repair_wheel.py`` calls the helper |
| 37 | +in-process, then re-packs the wheel with the stubs injected). |
| 38 | +
|
| 39 | +Any extra arguments are forwarded verbatim to ``pybind11_stubgen`` |
| 40 | +(e.g. ``neml2-stub --exit-code`` makes the script return non-zero when |
| 41 | +stubgen reports unresolved names). |
| 42 | +""" |
| 43 | + |
| 44 | +from __future__ import annotations |
| 45 | + |
| 46 | +import importlib |
| 47 | +import importlib.machinery |
| 48 | +import importlib.util |
| 49 | +import pkgutil |
| 50 | +import sys |
| 51 | +from pathlib import Path |
| 52 | + |
| 53 | + |
| 54 | +def discover_extension_modules(package_name: str = "neml2") -> list[str]: |
| 55 | + """Return the dotted names of every pybind11 extension under ``package_name``. |
| 56 | +
|
| 57 | + Uses :func:`pkgutil.walk_packages` to enumerate submodules and filters |
| 58 | + by spec origin -- a module is an extension iff its file extension is |
| 59 | + in :data:`importlib.machinery.EXTENSION_SUFFIXES` (e.g. |
| 60 | + ``.cpython-314-x86_64-linux-gnu.so`` on CPython 3.14 Linux). |
| 61 | + """ |
| 62 | + package = importlib.import_module(package_name) |
| 63 | + prefix = f"{package_name}." |
| 64 | + found: list[str] = [] |
| 65 | + for module in pkgutil.walk_packages(package.__path__, prefix): |
| 66 | + spec = importlib.util.find_spec(module.name) |
| 67 | + if spec is None or spec.origin is None: |
| 68 | + continue |
| 69 | + if any(spec.origin.endswith(suffix) for suffix in importlib.machinery.EXTENSION_SUFFIXES): |
| 70 | + found.append(module.name) |
| 71 | + return sorted(found) |
| 72 | + |
| 73 | + |
| 74 | +def _output_dir_for(module_name: str) -> Path: |
| 75 | + """Return the dir to pass to ``pybind11-stubgen -o`` for ``module_name``. |
| 76 | +
|
| 77 | + pybind11-stubgen writes to ``<output_dir>/<dotted_path>.pyi``, so for |
| 78 | + the stub to land next to the ``.so`` we need ``output_dir`` = |
| 79 | + ``.so`` location minus the dotted-path's worth of directories. |
| 80 | +
|
| 81 | + Computing this from the spec instead of ``package.__path__[0]`` matters |
| 82 | + for editable installs: ``neml2.__path__`` carries two entries (the |
| 83 | + minimal site-packages shim and the real source tree), and the shim's |
| 84 | + parent is site-packages -- writing the stub there strands it next to |
| 85 | + a different (incomplete) copy of ``neml2/`` instead of next to the |
| 86 | + extension. The spec's ``origin`` always points at the actual ``.so``, |
| 87 | + so this is robust to whatever the path-list ordering happens to be. |
| 88 | + """ |
| 89 | + spec = importlib.util.find_spec(module_name) |
| 90 | + if spec is None or spec.origin is None: |
| 91 | + raise RuntimeError(f"can't resolve spec for {module_name}") |
| 92 | + depth = len(module_name.split(".")) |
| 93 | + return Path(spec.origin).resolve().parents[depth - 1] |
| 94 | + |
| 95 | + |
| 96 | +def generate_stubs(extra_args: list[str] | None = None) -> int: |
| 97 | + """Run ``pybind11_stubgen`` on every discovered extension. |
| 98 | +
|
| 99 | + Each stub is written next to its ``.so`` (see |
| 100 | + :func:`_output_dir_for`). Returns the last non-zero exit code from |
| 101 | + ``pybind11_stubgen``, or 0 if every module succeeded. Missing |
| 102 | + ``pybind11_stubgen`` is a hard error -- this is a dev/CI tool and |
| 103 | + silent skipping would mask the pyright-fails-in-CI symptom that |
| 104 | + motivated it. |
| 105 | + """ |
| 106 | + if importlib.util.find_spec("pybind11_stubgen") is None: |
| 107 | + print( |
| 108 | + "neml2-stub: pybind11-stubgen is not installed. Install it via " |
| 109 | + "`pip install pybind11-stubgen` (or the [dev] extras).", |
| 110 | + file=sys.stderr, |
| 111 | + ) |
| 112 | + return 1 |
| 113 | + |
| 114 | + import pybind11_stubgen # noqa: PLC0415 |
| 115 | + |
| 116 | + modules = discover_extension_modules("neml2") |
| 117 | + if not modules: |
| 118 | + print("neml2-stub: no pybind11 extension modules discovered under neml2/.") |
| 119 | + return 0 |
| 120 | + |
| 121 | + last_failure = 0 |
| 122 | + for module in modules: |
| 123 | + output_dir = _output_dir_for(module) |
| 124 | + argv = ["-o", str(output_dir), *(extra_args or []), module] |
| 125 | + print(f"neml2-stub: generating stubs for {module} -> {output_dir}") |
| 126 | + try: |
| 127 | + pybind11_stubgen.main(argv) |
| 128 | + except SystemExit as exc: |
| 129 | + if exc.code: |
| 130 | + print( |
| 131 | + f"neml2-stub: pybind11-stubgen failed for {module} (exit {exc.code})", |
| 132 | + file=sys.stderr, |
| 133 | + ) |
| 134 | + last_failure = int(exc.code) if isinstance(exc.code, int) else 1 |
| 135 | + return last_failure |
| 136 | + |
| 137 | + |
| 138 | +def main(argv: list[str] | None = None) -> int: |
| 139 | + args = list(sys.argv[1:]) if argv is None else list(argv) |
| 140 | + return generate_stubs(args) |
| 141 | + |
| 142 | + |
| 143 | +if __name__ == "__main__": |
| 144 | + raise SystemExit(main()) |
0 commit comments