Skip to content

Commit 2bf269d

Browse files
codeSamuraiiCopilot
andcommitted
Fix __init_subclass__
Co-authored-by: Copilot <copilot@github.com>
1 parent 3108028 commit 2bf269d

4 files changed

Lines changed: 279 additions & 21 deletions

File tree

docs/TECHNICAL_OVERVIEW.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -237,9 +237,11 @@ When a traced function calls an untraced user-defined function, pyfuse automatic
237237

238238
Cross-module imports (e.g., `from utils import helper`) are converted from import statements to inline dependency edges, so the reconstructed code is self-contained.
239239

240-
Class constructors (`MyClass()`) are auto-discovered: pyfuse registers all user-defined methods of the class. `@staticmethod` and `@classmethod` descriptors are unwrapped and registered correctly. When a method uses `super()`, base classes and their methods are discovered recursively, and `class Foo(Base):` headers are emitted in reconstructed source.
240+
Class constructors (`MyClass()`) are auto-discovered: pyfuse registers all user-defined methods of the class, even when the class relies on the implicit `object.__init__`. `@staticmethod` and `@classmethod` descriptors are unwrapped and registered correctly. Base classes and their methods are pulled in via `__mro__` walking (independent of `super()` usage), and `class Foo(Base):` headers are emitted in reconstructed source.
241241

242-
Class-level attributes (assignments, annotated assignments, docstrings) are extracted from the class source AST and emitted in reconstructed class blocks. Class decorators (e.g., `@dataclass`) are captured and emitted above the class header. Metaclass keywords (e.g., `metaclass=ABCMeta`) and other class keywords are extracted from the class definition and included in the reconstructed header.
242+
Classes that hook `__init_subclass__` are treated as registries: every user-defined subclass is auto-registered so its definition fires the parent hook on the worker. The same path catches subclasses looked up indirectly (e.g., by name from a registry dict) without requiring an explicit reference in the traced body.
243+
244+
Class-level attributes (assignments, annotated assignments, docstrings) are extracted from the class source AST and emitted in reconstructed class blocks. User classes referenced from class-body RHS expressions (e.g., descriptors installed via `field = Doubler()`) are auto-registered transitively. Class decorators (e.g., `@dataclass`) are captured and emitted above the class header. Metaclass keywords (e.g., `metaclass=ABCMeta`) and other class keywords are extracted from the class definition and included in the reconstructed header.
243245

244246
Module-level constants and variables referenced by traced functions (e.g., `MAX_RETRIES = 5`) are captured and emitted in reconstructed source.
245247

@@ -670,8 +672,7 @@ Environment variables `PYFUSE_SANDBOX_DOCKER_IMAGE` and `PYFUSE_SANDBOX_DOCKER_C
670672
- Aliased cross-module imports (`from utils import helper as h`) are skipped to avoid name mismatches.
671673

672674
### Classes
673-
- Metaclasses and `__init_subclass__` hooks are replayed when the parent class is in the dependency tree (i.e., referenced via `super()` or constructor call).
674-
- Class-level attributes defined via complex descriptors or external decorators (beyond simple assignments) may not be captured.
675+
- Class-level attributes are captured verbatim from the class source AST. User classes referenced from class-body expressions (e.g., descriptor instances assigned to class attributes) are auto-registered along with the owning class. External decorators applied to attributes via more dynamic patterns may still be missed.
675676

676677
## CLI reference
677678

pyfuse/graph/analyzer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,16 @@ def _resolve_bare_calls(
515515
chosen = _prefer_same_module(init_matches, func_module)
516516
logger.debug("Constructor call %s() -> %s", called, chosen.qualified_name)
517517
deps.append(chosen.qualified_name)
518+
continue
519+
# Class with no user-defined ``__init__`` (inherits ``object.__init__``):
520+
# link to every registered method so the whole class block is emitted.
521+
method_matches = [
522+
node for node in registry.values() if node.owner_class == called
523+
]
524+
if method_matches:
525+
for node in method_matches:
526+
logger.debug("Bare class call %s() -> %s", called, node.qualified_name)
527+
deps.append(node.qualified_name)
518528
return deps
519529

520530

pyfuse/graph/graph.py

Lines changed: 154 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,8 @@ def __init__(self) -> None:
291291
)
292292
self._runtime_deps: dict[str, set[str]] = {}
293293
self._lock: threading.Lock = threading.Lock()
294+
self._classes_in_progress: set[str] = set()
295+
self._inclusion_deps: dict[str, set[str]] = {}
294296

295297
@classmethod
296298
def default(cls) -> "Graph":
@@ -499,23 +501,39 @@ def _auto_register_class(self, cls: type) -> None:
499501
"""Auto-register all user-defined methods of a class into the graph."""
500502
class_name = cls.__name__
501503
module_name = cls.__module__
502-
503-
for attr_name, raw in cls.__dict__.items():
504-
if isinstance(raw, (classmethod, staticmethod)):
505-
func = raw.__func__
506-
elif inspect.isfunction(raw):
507-
func = raw
508-
else:
509-
continue
510-
if not _is_user_function(func):
511-
continue
512-
qname = f"{module_name}.{class_name}.{attr_name}"
513-
if qname in self._nodes:
514-
continue
515-
self._auto_register(func)
516-
517-
self._set_class_metadata(cls)
518-
self._resolve_class_bases(cls)
504+
cls_key = f"{module_name}.{cls.__qualname__}"
505+
if cls_key in self._classes_in_progress:
506+
return
507+
self._classes_in_progress.add(cls_key)
508+
try:
509+
for attr_name, raw in cls.__dict__.items():
510+
if isinstance(raw, (classmethod, staticmethod)):
511+
func = raw.__func__
512+
elif inspect.isfunction(raw):
513+
func = raw
514+
else:
515+
continue
516+
if not _is_user_function(func):
517+
continue
518+
qname = f"{module_name}.{class_name}.{attr_name}"
519+
if qname in self._nodes:
520+
continue
521+
self._auto_register(func)
522+
523+
self._set_class_metadata(cls)
524+
self._resolve_class_bases(cls)
525+
526+
# Subclass registry pattern: classes that hook ``__init_subclass__``
527+
# populate registries from subclass definitions. The traced source
528+
# may look subclasses up indirectly (e.g. by name); to make that
529+
# work on the worker, pull every user-defined subclass into the
530+
# graph so its definition fires the parent hook on reconstruct.
531+
if "__init_subclass__" in cls.__dict__:
532+
for sub in cls.__subclasses__():
533+
if _is_user_class(sub):
534+
self._auto_register_class(sub)
535+
finally:
536+
self._classes_in_progress.discard(cls_key)
519537

520538
def _set_class_metadata(self, cls: type) -> None:
521539
"""Capture class-level attributes and decorators onto method nodes."""
@@ -532,10 +550,18 @@ def _set_class_metadata(self, cls: type) -> None:
532550
for deco_src in decorators:
533551
extra_names |= get_used_names(deco_src)
534552

553+
# User classes referenced from class-body RHS (e.g. descriptors like
554+
# ``field = Doubler()``) are not visible to bare-call discovery on
555+
# function bodies; register them here so they survive reconstruction.
556+
ref_method_qnames = self._register_class_attr_refs(cls, extra_names)
557+
535558
for node in self._nodes.values():
536559
if node.owner_class == class_name and node.module == module_name:
537560
node.class_attrs = attrs
538561
node.class_decorators = decorators
562+
for ref_qname in ref_method_qnames:
563+
if ref_qname != node.qualified_name:
564+
self._inclusion_deps.setdefault(node.qualified_name, set()).add(ref_qname)
539565
if extra_names:
540566
existing_names = {imp.bound_name for imp in node.imports}
541567
try:
@@ -551,6 +577,38 @@ def _set_class_metadata(self, cls: type) -> None:
551577
except StopIteration:
552578
pass
553579

580+
def _register_class_attr_refs(
581+
self, cls: type, extra_names: set[str]
582+
) -> list[str]:
583+
"""Auto-register user classes referenced from the class body.
584+
585+
Returns the qualified names of one method per referenced class so
586+
callers can wire dependency edges that keep them in the subgraph.
587+
"""
588+
if not extra_names:
589+
return []
590+
module_obj = sys.modules.get(cls.__module__)
591+
if module_obj is None:
592+
return []
593+
594+
ref_method_qnames: list[str] = []
595+
for name in extra_names:
596+
if name in _BUILTIN_NAMES:
597+
continue
598+
obj = getattr(module_obj, name, None)
599+
if obj is None or obj is cls:
600+
continue
601+
if not (inspect.isclass(obj) and _is_user_class(obj)):
602+
continue
603+
self._auto_register_class(obj)
604+
for ref_node in self._nodes.values():
605+
if (
606+
ref_node.owner_class == obj.__name__
607+
and ref_node.module == obj.__module__
608+
):
609+
ref_method_qnames.append(ref_node.qualified_name)
610+
return ref_method_qnames
611+
554612
def _resolve_class_bases(self, cls: type) -> None:
555613
"""Detect class bases and store them on method nodes.
556614
@@ -592,6 +650,36 @@ def _resolve_class_bases(self, cls: type) -> None:
592650
if _is_user_class(base_cls):
593651
self._auto_register_class(base_cls)
594652

653+
# Add ordering edges from child methods to every parent method per
654+
# direct user base, so the topological reconstruction emits parents
655+
# first when the subclass is included via the registry pattern
656+
# (without relying on ``super()`` being present in the subclass body).
657+
for base_cls in cls.__bases__:
658+
if base_cls is object or not _is_user_class(base_cls):
659+
continue
660+
parent_method_qnames = [
661+
parent_node.qualified_name
662+
for parent_node in self._nodes.values()
663+
if (
664+
parent_node.owner_class == base_cls.__name__
665+
and parent_node.module == base_cls.__module__
666+
)
667+
]
668+
if not parent_method_qnames:
669+
continue
670+
for child_node in self._nodes.values():
671+
if (
672+
child_node.owner_class != class_name
673+
or child_node.module != module_name
674+
):
675+
continue
676+
bucket = self._inclusion_deps.setdefault(
677+
child_node.qualified_name, set()
678+
)
679+
for parent_qname in parent_method_qnames:
680+
if parent_qname != child_node.qualified_name:
681+
bucket.add(parent_qname)
682+
595683
def _discover_untraced_deps(
596684
self, module_name: str, node: FunctionNode
597685
) -> None:
@@ -609,6 +697,7 @@ def _discover_untraced_deps(
609697
self._discover_bare_call_deps(node, module_obj)
610698
if node.owner_class:
611699
self._discover_self_call_deps(node, module_obj, module_name)
700+
self._discover_init_subclass_deps(node, module_obj)
612701

613702
def _discover_bare_call_deps(
614703
self,
@@ -675,6 +764,51 @@ def _discover_self_call_deps(
675764
self._set_class_metadata(cls_obj)
676765
self._resolve_class_bases(cls_obj)
677766

767+
def _discover_init_subclass_deps(
768+
self,
769+
node: FunctionNode,
770+
module_obj: object,
771+
) -> None:
772+
"""Pull subclasses of registry-style parents into the caller's deps.
773+
774+
When the traced source references a class with a user-defined
775+
``__init_subclass__``, its subclasses participate by being defined --
776+
not by being named in the source. Add inclusion edges from this
777+
node to one method of each user subclass so the subgraph keeps them.
778+
"""
779+
for name in find_bare_calls(node.source) | get_used_names(node.source):
780+
if name in _BUILTIN_NAMES:
781+
continue
782+
obj = getattr(module_obj, name, None)
783+
if obj is None or not inspect.isclass(obj):
784+
continue
785+
if "__init_subclass__" not in obj.__dict__:
786+
continue
787+
# Skip when this node is itself a method of obj or an ancestor;
788+
# adding child-class edges from a parent method causes cycles
789+
# via the parent ordering edges added in ``_resolve_class_bases``.
790+
if node.owner_class is not None:
791+
node_cls = getattr(module_obj, node.owner_class.rsplit(".", 1)[-1], None)
792+
if (
793+
inspect.isclass(node_cls)
794+
and node_cls is not None
795+
and (node_cls is obj or issubclass(obj, node_cls))
796+
):
797+
continue
798+
for sub in obj.__subclasses__():
799+
if not _is_user_class(sub):
800+
continue
801+
self._auto_register_class(sub)
802+
for sub_node in self._nodes.values():
803+
if (
804+
sub_node.owner_class == sub.__name__
805+
and sub_node.module == sub.__module__
806+
and sub_node.qualified_name != node.qualified_name
807+
):
808+
self._inclusion_deps.setdefault(
809+
node.qualified_name, set()
810+
).add(sub_node.qualified_name)
811+
678812
# -- Refresh & dependency merging ------------------------------------------
679813

680814
def refresh(self) -> None:
@@ -693,6 +827,9 @@ def refresh(self) -> None:
693827
for ref_qname in node.closure_func_refs.values():
694828
if ref_qname != qname and ref_qname not in deps:
695829
deps.append(ref_qname)
830+
for incl_qname in self._inclusion_deps.get(qname, set()):
831+
if incl_qname in self._nodes and incl_qname not in deps:
832+
deps.append(incl_qname)
696833
node.dependencies = deps
697834

698835
def _add_super_deps(self) -> None:

tests/test_auto_discovery.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -939,3 +939,113 @@ def test_class_docstring_captured(tmp_path: Path) -> None:
939939
)
940940
source = reconstruct(serialize(), "method")
941941
assert "A documented class" in source
942+
943+
# ---------------------------------------------------------------------------
944+
# Feature 5: Bare class call without user-defined ``__init__``
945+
# ---------------------------------------------------------------------------
946+
947+
948+
def test_bare_class_no_init_chained(tmp_path: Path) -> None:
949+
"""``Tag().label()`` works when ``Tag`` has no user-defined ``__init__``."""
950+
create_module(
951+
tmp_path,
952+
"adbarenoinit",
953+
(
954+
"from pyfuse import trace\n\n"
955+
"class Tag:\n"
956+
" def label(self):\n"
957+
" return 'tag'\n\n"
958+
"@trace\n"
959+
"def construct_tag():\n"
960+
" return Tag().label()\n"
961+
),
962+
)
963+
source = reconstruct(serialize(), "construct_tag")
964+
assert "class Tag" in source
965+
assert "def label" in source
966+
967+
ns: dict[str, object] = {}
968+
exec(source, ns) # noqa: S102
969+
assert ns["construct_tag"]() == "tag" # type: ignore[operator]
970+
971+
972+
# ---------------------------------------------------------------------------
973+
# Feature 6: ``__init_subclass__`` registry pattern
974+
# ---------------------------------------------------------------------------
975+
976+
977+
def test_init_subclass_registry_pulls_subclasses(tmp_path: Path) -> None:
978+
"""Subclasses of registry-style parents survive serialization."""
979+
create_module(
980+
tmp_path,
981+
"adsubregistry",
982+
(
983+
"from pyfuse import trace\n\n"
984+
"class Plugin:\n"
985+
" _registry = {}\n"
986+
" def __init__(self):\n"
987+
" pass\n"
988+
" def __init_subclass__(cls, name=None, **kwargs):\n"
989+
" super().__init_subclass__(**kwargs)\n"
990+
" if name is not None:\n"
991+
" Plugin._registry[name] = cls\n"
992+
" def run(self, payload):\n"
993+
" raise NotImplementedError\n\n"
994+
"class GreetPlugin(Plugin, name='greet'):\n"
995+
" def run(self, payload):\n"
996+
" return f'hello, {payload}!'\n\n"
997+
"@trace\n"
998+
"def dispatch(plugin_name, payload):\n"
999+
" Plugin()\n"
1000+
" cls = Plugin._registry[plugin_name]\n"
1001+
" return cls().run(payload)\n"
1002+
),
1003+
)
1004+
source = reconstruct(serialize(), "dispatch")
1005+
assert "class Plugin" in source
1006+
assert "class GreetPlugin(Plugin, name='greet'):" in source
1007+
1008+
ns: dict[str, object] = {}
1009+
exec(source, ns) # noqa: S102
1010+
assert ns["dispatch"]("greet", "world") == "hello, world!" # type: ignore[operator]
1011+
1012+
1013+
# ---------------------------------------------------------------------------
1014+
# Feature 7: Class-level descriptor (user class referenced in class body)
1015+
# ---------------------------------------------------------------------------
1016+
1017+
1018+
def test_class_level_descriptor_registered(tmp_path: Path) -> None:
1019+
"""A user class referenced from a class body (``field = Doubler()``) is captured."""
1020+
create_module(
1021+
tmp_path,
1022+
"addescriptor",
1023+
(
1024+
"from pyfuse import trace\n\n"
1025+
"class Doubler:\n"
1026+
" def __set_name__(self, owner, name):\n"
1027+
" self._attr = f'_{name}'\n"
1028+
" def __get__(self, obj, objtype=None):\n"
1029+
" if obj is None:\n"
1030+
" return self\n"
1031+
" return getattr(obj, self._attr) * 2\n"
1032+
" def __set__(self, obj, value):\n"
1033+
" setattr(obj, self._attr, value)\n\n"
1034+
"class Counter:\n"
1035+
" value = Doubler()\n"
1036+
" def __init__(self, start):\n"
1037+
" self.value = start\n\n"
1038+
"@trace\n"
1039+
"def doubled(start):\n"
1040+
" c = Counter(start)\n"
1041+
" return c.value\n"
1042+
),
1043+
)
1044+
source = reconstruct(serialize(), "doubled")
1045+
assert "class Doubler" in source
1046+
assert "class Counter" in source
1047+
assert "value = Doubler()" in source
1048+
1049+
ns: dict[str, object] = {}
1050+
exec(source, ns) # noqa: S102
1051+
assert ns["doubled"](5) == 10 # type: ignore[operator]

0 commit comments

Comments
 (0)