Skip to content

Commit e4e1c2a

Browse files
hugary1995claude
andcommitted
neml2-stub: pybind11-stubgen as a post-install CLI
Replaces the hand-written ``_aoti.pyi`` with a v2-style ``neml2-stub`` console script that runs ``pybind11-stubgen`` against every discovered extension in the installed package. Walks ``neml2.*`` via ``pkgutil.walk_packages`` + filters on ``importlib.machinery.EXTENSION_SUFFIXES`` so new pybind11 modules (when we add them) are picked up automatically. * Stub output dir is derived per-module from the spec's ``.so`` location, not ``neml2.__path__[0]``. Editable installs carry both the site-packages shim and the source tree in ``__path__``, and the shim's parent is site-packages -- writing there strands the stub next to the wrong (incomplete) copy of the package. * CI typecheck job grows a ``neml2-stub`` step between ``pip install`` and ``pyright``. * scripts/repair_wheel.py installs the freshly-built wheel, runs ``neml2-stub``, then unpacks the wheel + injects every ``.pyi`` preserving its relative path under ``neml2/`` + repacks. Mirrors v2's repair flow so published wheels ship the stubs. * cibuildwheel ``before-build`` grows ``pybind11-stubgen>=2.5.5``. * ``.gitignore`` keeps ``*.pyi`` ignored; the previous hand-written carve-out is gone (stubs are pure build output now). Worked around pybind/pybind11-stubgen#279 by taking ``meta_path`` as ``std::string`` in the pybind binding (stub annotation becomes ``str`` instead of ``os.PathLike``, sidestepping the missing ``import os``). Drop the ``pybind11/stl/filesystem.h`` include; the lambda converts to ``std::filesystem::path`` for the C++ ctor. Drop the lambda + restore ``py::init<const std::filesystem::path &>()`` once pybind/pybind11-stubgen#280 lands in a release. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 4f5254b commit e4e1c2a

7 files changed

Lines changed: 244 additions & 89 deletions

File tree

.github/workflows/python.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,13 @@ jobs:
5555
cpu-torch: true
5656
- name: Install
5757
run: pip install ".[dev]" -v
58+
- name: Regenerate pybind11 stubs
59+
# pybind11 extensions ship as .so; pyright can't introspect them
60+
# without a .pyi sibling. The `neml2-stub` console script (added
61+
# to [project.scripts]) wraps pybind11-stubgen and writes the
62+
# stubs next to each .so. Run AFTER `pip install` so the
63+
# package is fully importable.
64+
run: neml2-stub
5865
- name: Run pyright
5966
run: pyright
6067

.gitignore

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ DartConfiguration.tcl
3333
# Python
3434
.venv
3535
__pycache__
36-
# Most .pyi stubs are pyright/IDE scratch output -- ignore by default.
37-
# Hand-maintained stubs for our pybind modules are committed exceptions.
36+
# Stubs are auto-generated by `neml2-stub` after install (see
37+
# neml2/cli/stub.py); regenerated whenever any pybind11 extension
38+
# changes, never tracked.
3839
*.pyi
39-
!neml2/aoti/_aoti.pyi
4040
neml2.egg-info
4141

4242
# Tests

neml2/aoti/_aoti.pyi

Lines changed: 0 additions & 71 deletions
This file was deleted.

neml2/cli/stub.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Copyright 2024, UChicago Argonne, LLC
2+
# All Rights Reserved
3+
# Software Name: NEML2 -- the New Engineering material Model Library, version 2
4+
# By: Argonne National Laboratory
5+
# OPEN SOURCE LICENSE (MIT)
6+
#
7+
# Permission is hereby granted, free of charge, to any person obtaining a copy
8+
# of this software and associated documentation files (the "Software"), to deal
9+
# in the Software without restriction, including without limitation the rights
10+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11+
# copies of the Software, and to permit persons to whom the Software is
12+
# furnished to do so, subject to the following conditions:
13+
#
14+
# The above copyright notice and this permission notice shall be included in
15+
# all copies or substantial portions of the Software.
16+
#
17+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23+
# THE SOFTWARE.
24+
25+
"""``neml2-stub`` -- regenerate ``.pyi`` stubs for every pybind11 extension.
26+
27+
Walks the installed ``neml2`` package, finds every binary extension module
28+
(currently just ``neml2.aoti._aoti``), and runs ``pybind11_stubgen`` against
29+
each one. The generated stubs land next to their ``.so`` so pyright /
30+
IDE autocompletion resolve ``from ._aoti import Model`` cleanly.
31+
32+
Invocation order matters: the script ``import``s every extension to
33+
introspect it, so the package must be **fully installed** before this runs.
34+
That means after ``pip install`` (CI runs ``neml2-stub`` between install
35+
and pyright) or after ``cibuildwheel``'s repair step has unpacked the
36+
wheel into the test env (``scripts/repair_wheel.py`` calls the helper
37+
in-process, then re-packs the wheel with the stubs injected).
38+
39+
Any extra arguments are forwarded verbatim to ``pybind11_stubgen``
40+
(e.g. ``neml2-stub --exit-code`` makes the script return non-zero when
41+
stubgen reports unresolved names).
42+
"""
43+
44+
from __future__ import annotations
45+
46+
import importlib
47+
import importlib.machinery
48+
import importlib.util
49+
import pkgutil
50+
import sys
51+
from pathlib import Path
52+
53+
54+
def discover_extension_modules(package_name: str = "neml2") -> list[str]:
55+
"""Return the dotted names of every pybind11 extension under ``package_name``.
56+
57+
Uses :func:`pkgutil.walk_packages` to enumerate submodules and filters
58+
by spec origin -- a module is an extension iff its file extension is
59+
in :data:`importlib.machinery.EXTENSION_SUFFIXES` (e.g.
60+
``.cpython-314-x86_64-linux-gnu.so`` on CPython 3.14 Linux).
61+
"""
62+
package = importlib.import_module(package_name)
63+
prefix = f"{package_name}."
64+
found: list[str] = []
65+
for module in pkgutil.walk_packages(package.__path__, prefix):
66+
spec = importlib.util.find_spec(module.name)
67+
if spec is None or spec.origin is None:
68+
continue
69+
if any(spec.origin.endswith(suffix) for suffix in importlib.machinery.EXTENSION_SUFFIXES):
70+
found.append(module.name)
71+
return sorted(found)
72+
73+
74+
def _output_dir_for(module_name: str) -> Path:
75+
"""Return the dir to pass to ``pybind11-stubgen -o`` for ``module_name``.
76+
77+
pybind11-stubgen writes to ``<output_dir>/<dotted_path>.pyi``, so for
78+
the stub to land next to the ``.so`` we need ``output_dir`` =
79+
``.so`` location minus the dotted-path's worth of directories.
80+
81+
Computing this from the spec instead of ``package.__path__[0]`` matters
82+
for editable installs: ``neml2.__path__`` carries two entries (the
83+
minimal site-packages shim and the real source tree), and the shim's
84+
parent is site-packages -- writing the stub there strands it next to
85+
a different (incomplete) copy of ``neml2/`` instead of next to the
86+
extension. The spec's ``origin`` always points at the actual ``.so``,
87+
so this is robust to whatever the path-list ordering happens to be.
88+
"""
89+
spec = importlib.util.find_spec(module_name)
90+
if spec is None or spec.origin is None:
91+
raise RuntimeError(f"can't resolve spec for {module_name}")
92+
depth = len(module_name.split("."))
93+
return Path(spec.origin).resolve().parents[depth - 1]
94+
95+
96+
def generate_stubs(extra_args: list[str] | None = None) -> int:
97+
"""Run ``pybind11_stubgen`` on every discovered extension.
98+
99+
Each stub is written next to its ``.so`` (see
100+
:func:`_output_dir_for`). Returns the last non-zero exit code from
101+
``pybind11_stubgen``, or 0 if every module succeeded. Missing
102+
``pybind11_stubgen`` is a hard error -- this is a dev/CI tool and
103+
silent skipping would mask the pyright-fails-in-CI symptom that
104+
motivated it.
105+
"""
106+
if importlib.util.find_spec("pybind11_stubgen") is None:
107+
print(
108+
"neml2-stub: pybind11-stubgen is not installed. Install it via "
109+
"`pip install pybind11-stubgen` (or the [dev] extras).",
110+
file=sys.stderr,
111+
)
112+
return 1
113+
114+
import pybind11_stubgen # noqa: PLC0415
115+
116+
modules = discover_extension_modules("neml2")
117+
if not modules:
118+
print("neml2-stub: no pybind11 extension modules discovered under neml2/.")
119+
return 0
120+
121+
last_failure = 0
122+
for module in modules:
123+
output_dir = _output_dir_for(module)
124+
argv = ["-o", str(output_dir), *(extra_args or []), module]
125+
print(f"neml2-stub: generating stubs for {module} -> {output_dir}")
126+
try:
127+
pybind11_stubgen.main(argv)
128+
except SystemExit as exc:
129+
if exc.code:
130+
print(
131+
f"neml2-stub: pybind11-stubgen failed for {module} (exit {exc.code})",
132+
file=sys.stderr,
133+
)
134+
last_failure = int(exc.code) if isinstance(exc.code, int) else 1
135+
return last_failure
136+
137+
138+
def main(argv: list[str] | None = None) -> int:
139+
args = list(sys.argv[1:]) if argv is None else list(argv)
140+
return generate_stubs(args)
141+
142+
143+
if __name__ == "__main__":
144+
raise SystemExit(main())

neml2/csrc/aoti/_aoti.cxx

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535

3636
#include <pybind11/pybind11.h>
3737
#include <pybind11/stl.h>
38-
#include <pybind11/stl/filesystem.h>
3938
#include <torch/csrc/utils/pybind.h>
4039

4140
#include "neml2/csrc/aoti/Model.h"
@@ -68,7 +67,17 @@ compile time are reachable through ``named_parameters()`` and may be
6867
mutated in-place (e.g. ``model.named_parameters()['E'].fill_(210000.0)``).
6968
Everything else is baked into the graph as a constant.
7069
)")
71-
.def(py::init<const std::filesystem::path &>(),
70+
// Take ``meta_path`` as ``std::string`` (rather than
71+
// ``std::filesystem::path`` via the stl/filesystem caster) so the
72+
// pybind11-stubgen-generated annotation comes out as ``str``
73+
// instead of ``os.PathLike``. The current stubgen release
74+
// (≤2.5.5) emits ``os.PathLike`` without an accompanying
75+
// ``import os``, which trips pyright; pybind/pybind11-stubgen#280
76+
// fixes this upstream, drop this lambda + restore
77+
// ``py::init<const std::filesystem::path &>()`` once a release
78+
// with that PR lands.
79+
.def(py::init([](const std::string & meta_path)
80+
{ return std::make_unique<Model>(std::filesystem::path{meta_path}); }),
7281
py::arg("meta_path"),
7382
"Load all .pt2 segments + metadata from `meta_path`. Throws on "
7483
"any missing file or schema mismatch.")

pyproject.toml

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ dev = [
8989
# Chromium binary itself is fetched on first use via
9090
# `playwright install chromium` (CI runs that step explicitly).
9191
"playwright",
92+
# Auto-generates the .pyi stubs for the pybind11 extensions; invoked
93+
# via the `neml2-stub` console script after pip install completes.
94+
# CI's typecheck workflow + scripts/repair_wheel.py both call it.
95+
"pybind11-stubgen>=2.5.5",
9296
# git
9397
"pre-commit",
9498
]
@@ -97,6 +101,7 @@ dev = [
97101
neml2-compile = "neml2.cli.aoti_compile:main"
98102
neml2-inspect = "neml2.cli.inspect:main"
99103
neml2-run = "neml2.cli.run:main"
104+
neml2-stub = "neml2.cli.stub:main"
100105
neml2-syntax = "neml2.cli.syntax:main"
101106

102107
[tool.ruff]
@@ -128,13 +133,16 @@ ignore = ["E741"]
128133

129134
[tool.pyright]
130135
# The neml2 package source, the benchmark harness, scripts, and docs are
131-
# all in scope. The compiled ``neml2.aoti._aoti`` pybind module is
132-
# resolved via the hand-written stub at ``neml2/aoti/_aoti.pyi`` (the
133-
# .so is built at wheel time and is not necessarily present in the
134-
# source tree pyright analyses -- CI installs the wheel non-editably,
135-
# so the binary lives in site-packages). Cross-module ``import neml2.x``
136-
# resolves through whatever environment pyright sees; a source file
137-
# missing from ``wheel.packages`` surfaces as an unresolved import.
136+
# all in scope. pybind11 extensions (currently ``neml2.aoti._aoti``) are
137+
# resolved via auto-generated ``.pyi`` stubs co-located with their
138+
# ``.so``. The ``neml2-stub`` console script (see
139+
# ``neml2/cli/stub.py``) runs ``pybind11-stubgen`` against the fully
140+
# installed package; CI invokes it between ``pip install`` and
141+
# ``pyright``, and ``scripts/repair_wheel.py`` invokes it during
142+
# cibuildwheel so the published wheels ship the stubs. Cross-module
143+
# ``import neml2.x`` resolves through whatever environment pyright sees;
144+
# a source file missing from ``wheel.packages`` surfaces as an
145+
# unresolved import.
138146
include = [
139147
"neml2",
140148
"benchmark",
@@ -191,7 +199,10 @@ build-verbosity = 1
191199
build-frontend = { name = "build", args = ["--no-isolation"] }
192200

193201
before-all = ""
194-
before-build = "python -m pip install scikit-build-core ninja torch wheel"
202+
# pybind11-stubgen is needed by scripts/repair_wheel.py to generate
203+
# .pyi stubs for the published wheel; `wheel` is needed for the
204+
# unpack/repack dance there.
205+
before-build = "python -m pip install scikit-build-core ninja torch wheel pybind11-stubgen>=2.5.5"
195206

196207
# dependencies: pyzag.version
197208
test-requires = "pytest torch pyzag==1.1.1"

0 commit comments

Comments
 (0)