Skip to content

Commit 085708a

Browse files
committed
fix?
1 parent de7b997 commit 085708a

13 files changed

Lines changed: 650 additions & 112 deletions

pyfuse/graph/analyzer.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -346,9 +346,15 @@ def _resolve_owner_class(qualname: str) -> str | None:
346346
if len(parts) == 1:
347347
return None
348348
prefix = parts[0]
349-
if "<locals>" in prefix:
350-
return None
351-
return prefix
349+
if "<locals>" not in prefix:
350+
return prefix
351+
# For nested classes like "outer.<locals>.MyClass.__init__",
352+
# extract the class name after the last "<locals>." segment.
353+
after_locals = prefix.rsplit("<locals>.", 1)[-1]
354+
# If there's still a class name (not empty, not another scope marker)
355+
if after_locals and "<" not in after_locals:
356+
return after_locals
357+
return None
352358

353359

354360
def _extract_annotation_type_names(annotation: ast.expr) -> set[str]:

pyfuse/graph/graph.py

Lines changed: 83 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -128,28 +128,39 @@ def _try_get_lambda_source(func: Callable[..., object]) -> str | None:
128128

129129
def _capture_closure(
130130
func: Callable[..., object],
131-
) -> tuple[dict[str, str], dict[str, str], dict[str, Callable[..., object]]]:
131+
) -> tuple[dict[str, str], dict[str, str], dict[str, Callable[..., object]], list[ImportInfo], dict[str, type]]:
132132
"""Extract closure variables and traced function references from *func*.
133133
134-
Returns ``(closure_vars, closure_func_refs, closure_func_objects)``
134+
Returns ``(closure_vars, closure_func_refs, closure_func_objects, closure_module_imports, closure_classes)``
135135
where *closure_vars* maps variable names to repr strings,
136-
*closure_func_refs* maps variable names to qualified names, and
136+
*closure_func_refs* maps variable names to qualified names,
137137
*closure_func_objects* maps qualified names to the actual callable
138-
objects (for auto-registration of non-traced functions).
138+
objects (for auto-registration of non-traced functions),
139+
*closure_module_imports* is a list of :class:`ImportInfo` for module
140+
objects found in the closure (from inline imports), and
141+
*closure_classes* maps variable names to user-defined class objects
142+
that need auto-registration.
139143
"""
140144
closure_vars: dict[str, str] = {}
141145
closure_func_refs: dict[str, str] = {}
142146
closure_func_objects: dict[str, Callable[..., object]] = {}
147+
closure_module_imports: list[ImportInfo] = []
148+
closure_classes: dict[str, type] = {}
143149

144150
if not func.__code__.co_freevars:
145-
return closure_vars, closure_func_refs, closure_func_objects
151+
return closure_vars, closure_func_refs, closure_func_objects, closure_module_imports, closure_classes
146152

147153
try:
148154
closure_info = inspect.getclosurevars(func)
149155
except ValueError:
150-
return closure_vars, closure_func_refs, closure_func_objects
156+
return closure_vars, closure_func_refs, closure_func_objects, closure_module_imports, closure_classes
151157

152158
for name, value in closure_info.nonlocals.items():
159+
# Skip the implicit __class__ cell injected by Python for super() calls.
160+
# Reconstructed code uses explicit super(ClassName, self) instead.
161+
if name == "__class__" and inspect.isclass(value):
162+
continue
163+
153164
try:
154165
repr_value = repr(value)
155166
except Exception:
@@ -167,6 +178,17 @@ def _capture_closure(
167178
except SyntaxError:
168179
pass
169180

181+
# Module objects from inline imports (e.g. `import time as _time`)
182+
if inspect.ismodule(value):
183+
mod_name = value.__name__
184+
if name == mod_name or name == mod_name.split(".")[0]:
185+
stmt = f"import {mod_name}"
186+
else:
187+
stmt = f"import {mod_name} as {name}"
188+
closure_module_imports.append(ImportInfo(statement=stmt, bound_name=name))
189+
logger.debug("Closure var '%s' is module %s", name, mod_name)
190+
continue
191+
170192
# Callables: prefer source-level capture over serialization
171193
if getattr(value, "__pyfuse_traced__", False):
172194
unwrapped = value
@@ -189,6 +211,12 @@ def _capture_closure(
189211
logger.debug("Closure var '%s' is untraced user function %s", name, ref_qname)
190212
continue
191213

214+
# User-defined classes: auto-register all their methods
215+
if inspect.isclass(value) and _is_user_class(value):
216+
closure_classes[name] = value
217+
logger.debug("Closure var '%s' is user class %s", name, value.__qualname__)
218+
continue
219+
192220
# Non-callable fallbacks
193221
ctor_expr = _try_constructor_expr(value)
194222
if ctor_expr is not None:
@@ -210,7 +238,7 @@ def _capture_closure(
210238
stacklevel=3,
211239
)
212240

213-
return closure_vars, closure_func_refs, closure_func_objects
241+
return closure_vars, closure_func_refs, closure_func_objects, closure_module_imports, closure_classes
214242

215243

216244
def _mermaid_node_id(qname: str) -> str:
@@ -325,18 +353,27 @@ def register(self, func: Callable[..., object]) -> None:
325353
"unavailable. Functions must be defined in .py source files."
326354
) from exc
327355

328-
closure_vars, closure_func_refs, closure_func_objects = _capture_closure(original)
356+
closure_vars, closure_func_refs, closure_func_objects, closure_module_imports, closure_classes = _capture_closure(original)
329357

330358
for ref_qname, func_obj in closure_func_objects.items():
331359
if ref_qname not in self._nodes:
332360
self._auto_register(func_obj)
333361

362+
for cls_obj in closure_classes.values():
363+
self._auto_register_class(cls_obj)
364+
365+
# Add module imports from closures (inline imports like `import time as _time`)
366+
existing = {imp.bound_name for imp in analysis.imports}
367+
for imp in closure_module_imports:
368+
if imp.bound_name not in existing:
369+
analysis.imports.append(imp)
370+
existing.add(imp.bound_name)
371+
334372
if closure_vars:
335373
closure_names: set[str] = set()
336374
for cv in closure_vars.values():
337375
closure_names |= get_used_names(cv)
338376
if closure_names:
339-
existing = {imp.bound_name for imp in analysis.imports}
340377
all_imports = get_module_imports(original)
341378
for imp in all_imports:
342379
if imp.bound_name in closure_names and imp.bound_name not in existing:
@@ -396,13 +433,40 @@ def _auto_register(self, func: Callable[..., object]) -> bool:
396433
)
397434
return False
398435

436+
closure_vars, closure_func_refs, closure_func_objects, closure_module_imports, closure_classes = _capture_closure(func)
437+
438+
# Add module imports from closures (inline imports)
439+
existing_names = {imp.bound_name for imp in analysis.imports}
440+
for imp in closure_module_imports:
441+
if imp.bound_name not in existing_names:
442+
analysis.imports.append(imp)
443+
existing_names.add(imp.bound_name)
444+
445+
# Add imports needed by closure var expressions
446+
if closure_vars:
447+
closure_names: set[str] = set()
448+
for cv in closure_vars.values():
449+
closure_names |= get_used_names(cv)
450+
if closure_names:
451+
try:
452+
all_imports = get_module_imports(func)
453+
for imp in all_imports:
454+
if imp.bound_name in closure_names and imp.bound_name not in existing_names:
455+
analysis.imports.append(imp)
456+
existing_names.add(imp.bound_name)
457+
except (OSError, TypeError):
458+
pass
459+
399460
dependencies = [
400461
dep for dep in detect_traced_dependencies(
401462
analysis.source, func.__module__, self._nodes,
402463
owner_class=analysis.owner_class,
403464
)
404465
if dep != qualified_name
405466
]
467+
for ref_qname in closure_func_refs.values():
468+
if ref_qname != qualified_name and ref_qname not in dependencies:
469+
dependencies.append(ref_qname)
406470

407471
node = FunctionNode(
408472
qualified_name=qualified_name,
@@ -412,12 +476,22 @@ def _auto_register(self, func: Callable[..., object]) -> bool:
412476
imports=analysis.imports,
413477
dependencies=dependencies,
414478
owner_class=analysis.owner_class,
479+
closure_vars=closure_vars,
480+
closure_func_refs=closure_func_refs,
415481
module_vars=analysis.module_vars,
416482
)
417483
self._nodes[qualified_name] = node
418484
self._funcs[qualified_name] = func
419485
logger.info("Auto-registered untraced dependency %s", qualified_name)
420486

487+
# Auto-register closure function deps (after node is in self._nodes to prevent re-entry)
488+
for ref_qname, func_obj in closure_func_objects.items():
489+
if ref_qname not in self._nodes:
490+
self._auto_register(func_obj)
491+
492+
for cls_obj in closure_classes.values():
493+
self._auto_register_class(cls_obj)
494+
421495
self._discover_untraced_deps(func.__module__, node)
422496
return True
423497

pyfuse/graph/store.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Content-addressable store for serializing and reconstructing functions."""
22

3+
import ast
34
import json
45
import logging
56
from typing import Any, Self
@@ -72,6 +73,46 @@ def _apply_closure_transforms(
7273
return source
7374

7475

76+
class _SuperRewriter(ast.NodeTransformer):
77+
"""Replace ``super()`` with ``super(ClassName, self)`` or ``super(ClassName, cls)``."""
78+
79+
def __init__(self, class_name: str) -> None:
80+
self._class_name = class_name
81+
self._first_param: str | None = None
82+
83+
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
84+
self._first_param = node.args.args[0].arg if node.args.args else "self"
85+
self.generic_visit(node)
86+
self._first_param = None
87+
return node
88+
89+
visit_AsyncFunctionDef = visit_FunctionDef # type: ignore[assignment]
90+
91+
def visit_Call(self, node: ast.Call) -> ast.Call:
92+
self.generic_visit(node)
93+
if (
94+
isinstance(node.func, ast.Name)
95+
and node.func.id == "super"
96+
and not node.args
97+
and not node.keywords
98+
):
99+
node.args = [
100+
ast.Name(id=self._class_name, ctx=ast.Load()),
101+
ast.Name(id=self._first_param or "self", ctx=ast.Load()),
102+
]
103+
return node
104+
105+
106+
def _rewrite_bare_super(source: str, class_name: str) -> str:
107+
"""Replace zero-arg ``super()`` with ``super(ClassName, self/cls)``."""
108+
if "super()" not in source:
109+
return source
110+
tree = ast.parse(source)
111+
tree = _SuperRewriter(class_name).visit(tree)
112+
ast.fix_missing_locations(tree)
113+
return ast.unparse(tree)
114+
115+
75116
def _indent_method(source: str) -> str:
76117
"""Indent a method source for embedding inside a class block."""
77118
return "\n".join(
@@ -89,7 +130,9 @@ def _build_class_block(
89130
class_name = owner_class.rsplit(".", 1)[-1]
90131

91132
method_sources = [
92-
_indent_method(_apply_closure_transforms(nodes[qname], nodes))
133+
_indent_method(_rewrite_bare_super(
134+
_apply_closure_transforms(nodes[qname], nodes), class_name,
135+
))
93136
for qname in member_qnames
94137
]
95138

@@ -119,7 +162,8 @@ def _build_class_block(
119162

120163
decorator_lines = "".join(f"@{d}\n" for d in class_decorators)
121164
attr_block = "".join(
122-
_indent_method(attr) + "\n\n" for attr in class_attrs
165+
_indent_method(_rewrite_bare_super(attr, class_name)) + "\n\n"
166+
for attr in class_attrs
123167
)
124168

125169
return decorator_lines + header + attr_block + "\n\n".join(method_sources)

pyfuse/typing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ async def run_in(self, delay: timedelta | float, *args: P.args, **kwargs: P.kwar
3636
async def run_every(
3737
self,
3838
frequency: timedelta | float,
39-
*args: P.args,
39+
*args: Any,
4040
_start_at: datetime | None = ...,
41-
**kwargs: P.kwargs,
41+
**kwargs: Any,
4242
) -> ScheduleHandle: ...
4343

4444

0 commit comments

Comments
 (0)