Skip to content

Commit 28d95d2

Browse files
committed
Limitations
1 parent 651c9ee commit 28d95d2

9 files changed

Lines changed: 704 additions & 43 deletions

File tree

pyfuse/core/models.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ class FunctionNode:
4444
closure_func_refs: dict[str, str] = field(default_factory=dict)
4545
module_vars: dict[str, str] = field(default_factory=dict)
4646
class_bases: list[str] = field(default_factory=list)
47+
class_keywords: dict[str, str] = field(default_factory=dict)
48+
class_attrs: list[str] = field(default_factory=list)
49+
class_decorators: list[str] = field(default_factory=list)
4750

4851
def to_dict(self) -> dict[str, Any]:
4952
d: dict[str, Any] = {
@@ -63,6 +66,12 @@ def to_dict(self) -> dict[str, Any]:
6366
d["module_vars"] = dict(self.module_vars)
6467
if self.class_bases:
6568
d["class_bases"] = list(self.class_bases)
69+
if self.class_keywords:
70+
d["class_keywords"] = dict(self.class_keywords)
71+
if self.class_attrs:
72+
d["class_attrs"] = list(self.class_attrs)
73+
if self.class_decorators:
74+
d["class_decorators"] = list(self.class_decorators)
6675
return d
6776

6877
def to_content_blob(self) -> dict[str, Any]:
@@ -86,6 +95,12 @@ def to_content_blob(self) -> dict[str, Any]:
8695
blob["module_vars"] = dict(self.module_vars)
8796
if self.class_bases:
8897
blob["class_bases"] = list(self.class_bases)
98+
if self.class_keywords:
99+
blob["class_keywords"] = dict(self.class_keywords)
100+
if self.class_attrs:
101+
blob["class_attrs"] = list(self.class_attrs)
102+
if self.class_decorators:
103+
blob["class_decorators"] = list(self.class_decorators)
89104
return blob
90105

91106
def content_hash(self) -> str:
@@ -107,6 +122,9 @@ def content_hash(self) -> str:
107122
"closure_func_refs": dict(sorted(self.closure_func_refs.items())),
108123
"module_vars": dict(sorted(self.module_vars.items())),
109124
"class_bases": list(self.class_bases),
125+
"class_keywords": dict(sorted(self.class_keywords.items())),
126+
"class_attrs": list(self.class_attrs),
127+
"class_decorators": list(self.class_decorators),
110128
}
111129
raw = json.dumps(canonical, sort_keys=True, separators=(",", ":"))
112130
return hashlib.sha256(raw.encode("utf-8")).hexdigest()[:16]
@@ -125,4 +143,7 @@ def from_dict(cls, data: dict[str, Any]) -> Self:
125143
closure_func_refs=data.get("closure_func_refs", {}),
126144
module_vars=data.get("module_vars", {}),
127145
class_bases=data.get("class_bases", []),
146+
class_keywords=data.get("class_keywords", {}),
147+
class_attrs=data.get("class_attrs", []),
148+
class_decorators=data.get("class_decorators", []),
128149
)

pyfuse/core/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
_VERSION = "0.3.0"
1+
_VERSION = "0.4.0"

pyfuse/graph/analyzer.py

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -212,27 +212,65 @@ def has_super_call(func_source: str) -> bool:
212212
return False
213213

214214

215-
def get_class_bases_from_source(cls: type) -> list[str]:
216-
"""Extract base class names from the class definition source AST.
217-
218-
Returns simple base class names (not fully qualified). Skips ``object``
219-
since it's implicit.
215+
def get_class_bases_from_source(
216+
cls: type,
217+
) -> tuple[list[str], dict[str, str]]:
218+
"""Extract base class names and keyword arguments from the class definition.
219+
220+
Returns ``(bases, keywords)`` where *bases* is a list of base class names
221+
(excluding ``object``) and *keywords* maps keyword names to their unparsed
222+
AST values (e.g. ``{"metaclass": "ABCMeta"}``).
220223
"""
221224
try:
222225
source = textwrap.dedent(inspect.getsource(cls))
223226
except (OSError, TypeError):
224-
return []
227+
return [], {}
225228
try:
226229
tree = ast.parse(source)
227230
except SyntaxError:
228-
return []
231+
return [], {}
229232
for node in tree.body:
230233
if isinstance(node, ast.ClassDef) and node.name == cls.__name__:
231-
bases: list[str] = []
232-
for base in node.bases:
233-
bases.append(ast.unparse(base))
234-
return [b for b in bases if b != "object"]
235-
return []
234+
bases = [ast.unparse(b) for b in node.bases if ast.unparse(b) != "object"]
235+
keywords: dict[str, str] = {}
236+
for kw in node.keywords:
237+
if kw.arg is not None:
238+
keywords[kw.arg] = ast.unparse(kw.value)
239+
return bases, keywords
240+
return [], {}
241+
242+
243+
def get_class_attrs(cls: type) -> tuple[list[str], list[str]]:
244+
"""Extract class-level attributes and decorators from the class source AST.
245+
246+
Returns ``(attrs, decorators)`` where *attrs* is a list of source code
247+
strings for class body statements (assignments, annotated assignments,
248+
docstrings) and *decorators* is a list of decorator source strings
249+
(without the ``@`` prefix).
250+
"""
251+
try:
252+
source = textwrap.dedent(inspect.getsource(cls))
253+
except (OSError, TypeError):
254+
return [], []
255+
try:
256+
tree = ast.parse(source)
257+
except SyntaxError:
258+
return [], []
259+
for node in tree.body:
260+
if not (isinstance(node, ast.ClassDef) and node.name == cls.__name__):
261+
continue
262+
attrs: list[str] = []
263+
for child in node.body:
264+
if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)):
265+
continue
266+
segment = ast.get_source_segment(source, child)
267+
if segment is not None:
268+
attrs.append(textwrap.dedent(segment))
269+
else:
270+
attrs.append(ast.unparse(child))
271+
decorators = [ast.unparse(d) for d in node.decorator_list]
272+
return attrs, decorators
273+
return [], []
236274

237275

238276
def find_bare_calls(func_source: str) -> set[str]:

0 commit comments

Comments
 (0)