Skip to content

Commit 13dbbf3

Browse files
ax3lnavinpvdaltairwalterskarndevsizmailov
committed
Fix class ordering to preserve inheritance hierarchy (#231)
Classes in generated .pyi stubs were sorted alphabetically, causing derived classes to appear before their base classes and breaking type checkers. Three changes fix this: - Parser: use module.__dict__.items() instead of inspect.getmembers() to preserve the pybind11 registration order (definition order) - Printer: replace alphabetical sort with a configurable _order_classes() dispatch supporting "definition" (default), "topological" (Kahn's algorithm ensuring bases precede derived classes), and "alphabetical" - CLI: add --sort-by option to select the class ordering strategy The topological sort ignores external bases (from other modules) and breaks ties by input position for deterministic output. Cyclic cross- references between classes (e.g. aliases, method signatures) are not inheritance cycles and are already handled by `from __future__ import annotations` in the generated stubs. Closes #231 Based on the approaches in PR #275 by @juelg and PR #294 by @daltairwalter, informed by review feedback from @skarndev and @sizmailov. Co-Authored-By: juelg <20750040+juelg@users.noreply.github.com> Co-Authored-By: daltairwalter <daltairwalter@users.noreply.github.com> Co-Authored-By: skarndev <skarndev@users.noreply.github.com> Co-Authored-By: sizmailov <sizmailov@users.noreply.github.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 0b039d0 commit 13dbbf3

File tree

3 files changed

+86
-8
lines changed

3 files changed

+86
-8
lines changed

pybind11_stubgen/__init__.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ class CLIArgs(Namespace):
7777
exit_code: bool
7878
dry_run: bool
7979
stub_extension: str
80+
sort_by: str
8081
module_names: list[str]
8182

8283

@@ -216,6 +217,16 @@ def regex_colon_path(regex_path: str) -> tuple[re.Pattern, str]:
216217
"Must be 'pyi' (default) or 'py'",
217218
)
218219

220+
parser.add_argument(
221+
"--sort-by",
222+
type=str,
223+
default="definition",
224+
choices=["definition", "topological"],
225+
help="Order of classes in generated stubs. "
226+
"'definition' (default) preserves the order from the module. "
227+
"'topological' sorts by inheritance hierarchy.",
228+
)
229+
219230
parser.add_argument(
220231
"module_names",
221232
metavar="MODULE_NAMES",
@@ -310,7 +321,10 @@ def main(argv: Sequence[str] | None = None) -> None:
310321
args = arg_parser().parse_args(argv, namespace=CLIArgs())
311322

312323
parser = stub_parser_from_args(args)
313-
printer = Printer(invalid_expr_as_ellipses=not args.print_invalid_expressions_as_is)
324+
printer = Printer(
325+
invalid_expr_as_ellipses=not args.print_invalid_expressions_as_is,
326+
sort_by=args.sort_by,
327+
)
314328

315329
run(
316330
parser,

pybind11_stubgen/parser/mixins/parse.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def handle_module(
8989
self, path: QualifiedName, module: types.ModuleType
9090
) -> Module | None:
9191
result = Module(name=path[-1])
92-
for name, member in inspect.getmembers(module):
92+
for name, member in module.__dict__.items():
9393
obj = self.handle_module_member(
9494
QualifiedName([*path, Identifier(name)]), module, member
9595
)
@@ -647,9 +647,7 @@ def parse_function_docstring(
647647
# This syntax is not supported before Python 3.12.
648648
return []
649649
type_vars: list[str] = list(
650-
filter(
651-
bool, map(str.strip, (type_vars_group or "").split(","))
652-
)
650+
filter(bool, map(str.strip, (type_vars_group or "").split(",")))
653651
)
654652
args = self.call_with_local_types(
655653
type_vars, lambda: self.parse_args_str(match.group("args"))

pybind11_stubgen/printer.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import dataclasses
4+
import logging
45
import sys
56

67
from pybind11_stubgen.structs import (
@@ -24,14 +25,79 @@
2425
Value,
2526
)
2627

28+
log = logging.getLogger("pybind11_stubgen")
29+
2730

2831
def indent_lines(lines: list[str], by=4) -> list[str]:
2932
return [" " * by + line for line in lines]
3033

3134

35+
def _topological_sort_classes(classes: list[Class]) -> list[Class]:
36+
"""Sort classes so that base classes appear before derived classes.
37+
38+
Uses Kahn's algorithm. Ties are broken by input position for stability.
39+
External bases (not in the current scope) are ignored.
40+
"""
41+
if not classes:
42+
return classes
43+
44+
name_to_index = {c.name: i for i, c in enumerate(classes)}
45+
name_to_class = {c.name: c for c in classes}
46+
47+
# Build adjacency list: base -> [derived, ...]
48+
# and in-degree count for each class
49+
children: dict[str, list[str]] = {c.name: [] for c in classes}
50+
in_degree: dict[str, int] = {c.name: 0 for c in classes}
51+
52+
for c in classes:
53+
for base in c.bases:
54+
base_name = str(base[-1])
55+
if base_name in name_to_class:
56+
children[base_name].append(c.name)
57+
in_degree[c.name] += 1
58+
59+
# Initialize queue with zero in-degree classes, sorted by input position
60+
queue = sorted(
61+
[name for name, deg in in_degree.items() if deg == 0],
62+
key=lambda n: name_to_index[n],
63+
)
64+
65+
result = []
66+
while queue:
67+
name = queue.pop(0)
68+
result.append(name_to_class[name])
69+
# Sort children by input position for stable ordering
70+
for child in sorted(children[name], key=lambda n: name_to_index[n]):
71+
in_degree[child] -= 1
72+
if in_degree[child] == 0:
73+
queue.append(child)
74+
# Re-sort queue to maintain input-position priority
75+
queue.sort(key=lambda n: name_to_index[n])
76+
77+
if len(result) < len(classes):
78+
remaining = [c for c in classes if c.name not in {r.name for r in result}]
79+
log.warning(
80+
"Cycle detected in class inheritance involving: %s. "
81+
"Appending in original order.",
82+
[c.name for c in remaining],
83+
)
84+
result.extend(remaining)
85+
86+
return result
87+
88+
3289
class Printer:
33-
def __init__(self, invalid_expr_as_ellipses: bool):
90+
def __init__(self, invalid_expr_as_ellipses: bool, sort_by: str = "definition"):
3491
self.invalid_expr_as_ellipses = invalid_expr_as_ellipses
92+
self.sort_by = sort_by
93+
94+
def _order_classes(self, classes: list[Class]) -> list[Class]:
95+
if self.sort_by == "alphabetical":
96+
return sorted(classes, key=lambda c: c.name)
97+
elif self.sort_by == "definition":
98+
return classes
99+
else: # "topological"
100+
return _topological_sort_classes(classes)
35101

36102
def print_alias(self, alias: Alias) -> list[str]:
37103
return [f"{alias.name} = {alias.origin}"]
@@ -90,7 +156,7 @@ def print_class_body(self, class_: Class) -> list[str]:
90156
if class_.doc is not None:
91157
result.extend(self.print_docstring(class_.doc))
92158

93-
for sub_class in sorted(class_.classes, key=lambda c: c.name):
159+
for sub_class in self._order_classes(class_.classes):
94160
result.extend(self.print_class(sub_class))
95161

96162
modifier_order: dict[Modifier, int] = {
@@ -232,7 +298,7 @@ def print_module(self, module: Module) -> list[str]:
232298
for type_var in sorted(module.type_vars, key=lambda t: t.name):
233299
result.extend(self.print_type_var(type_var))
234300

235-
for class_ in sorted(module.classes, key=lambda c: c.name):
301+
for class_ in self._order_classes(module.classes):
236302
result.extend(self.print_class(class_))
237303

238304
for func in sorted(module.functions, key=lambda f: f.name):

0 commit comments

Comments
 (0)