Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions codeflash/verification/instrument_codeflash_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,7 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
return node

has_init = False
# Build decorator node ONCE for each class, not per loop iteration
decorator = ast.Call(
func=self._base_decorator_func,
args=[],
keywords=[
ast.keyword(arg="function_name", value=ast.Constant(value=f"{node.name}.__init__")),
*self._base_decorator_keywords,
],
)
decorator = None

# Only scan node.body once for both __init__ and decorator check
for item in node.body:
Expand All @@ -151,7 +143,16 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
if isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "codeflash_capture":
break
else:
# No decorator found
# No decorator found - create it lazily on first use
if decorator is None:
decorator = ast.Call(
func=self._base_decorator_func,
args=[],
keywords=[
ast.keyword(arg="function_name", value=ast.Constant(value=f"{node.name}.__init__")),
*self._base_decorator_keywords,
],
)
item.decorator_list.insert(0, decorator)
self.inserted_decorator = True

Expand Down Expand Up @@ -187,6 +188,16 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
defaults=[],
)

# Build decorator for the synthetic __init__
decorator = ast.Call(
func=self._base_decorator_func,
args=[],
keywords=[
ast.keyword(arg="function_name", value=ast.Constant(value=f"{node.name}.__init__")),
*self._base_decorator_keywords,
],
)

# Create the complete function
init_func = ast.FunctionDef(
name="__init__", args=arguments, body=[super_call], decorator_list=[decorator], returns=None
Expand Down
Loading