Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 30 additions & 26 deletions codeflash/verification/instrument_codeflash_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,22 @@ def __init__(
self._init_kwarg = ast.arg(arg="kwargs")
self._init_self_arg = ast.arg(arg="self", annotation=None)

# Precreate commonly reused AST fragments for classes that lack __init__
# Create the super().__init__(*args, **kwargs) Expr (reuse prebuilt pieces)
self._super_call_expr = ast.Expr(
value=ast.Call(func=self._super_func, args=[self._super_starred], keywords=[self._super_kwarg])
)
# Create function arguments: self, *args, **kwargs (reuse arg nodes)
self._init_arguments = ast.arguments(
posonlyargs=[],
args=[self._init_self_arg],
vararg=self._init_vararg,
kwonlyargs=[],
kw_defaults=[],
kwarg=self._init_kwarg,
defaults=[],
)

def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom:
# Check if our import already exists
if node.module == "codeflash.verification.codeflash_capture" and any(
Expand Down Expand Up @@ -162,47 +178,35 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
# TODO: support by saving a reference to the generated __init__ before overriding, e.g.
# _orig_init = ClassName.__init__; then calling _orig_init(self, *args, **kwargs) in the wrapper
for dec in node.decorator_list:
dec_name = None
if isinstance(dec, ast.Name):
dec_name = dec.id
elif isinstance(dec, ast.Call) and isinstance(dec.func, ast.Name):
dec_name = dec.func.id
elif isinstance(dec, ast.Attribute):
dec_name = dec.attr
dec_name = self._expr_name(dec)
if dec_name == "dataclass":
return node

# Skip NamedTuples — their __init__ is synthesized and cannot be overwritten.
for base in node.bases:
base_name = None
if isinstance(base, ast.Name):
base_name = base.id
elif isinstance(base, ast.Attribute):
base_name = base.attr
base_name = self._expr_name(base)
if base_name == "NamedTuple":
return node

# Create super().__init__(*args, **kwargs) call (use prebuilt AST fragments)
super_call = ast.Expr(
value=ast.Call(func=self._super_func, args=[self._super_starred], keywords=[self._super_kwarg])
)
# Create function arguments: self, *args, **kwargs (reuse arg nodes)
arguments = ast.arguments(
posonlyargs=[],
args=[self._init_self_arg],
vararg=self._init_vararg,
kwonlyargs=[],
kw_defaults=[],
kwarg=self._init_kwarg,
defaults=[],
)
super_call = self._super_call_expr
# Create the complete function using prebuilt arguments/body but attach the class-specific decorator

# Create the complete function
init_func = ast.FunctionDef(
name="__init__", args=arguments, body=[super_call], decorator_list=[decorator], returns=None
name="__init__", args=self._init_arguments, body=[super_call], decorator_list=[decorator], returns=None
)

node.body.insert(0, init_func)
self.inserted_decorator = True

return node

def _expr_name(self, node: ast.AST) -> str | None:
if isinstance(node, ast.Name):
return node.id
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
return node.func.id
if isinstance(node, ast.Attribute):
return node.attr
return None
Loading