diff --git a/codeflash/verification/instrument_codeflash_capture.py b/codeflash/verification/instrument_codeflash_capture.py index aed5a3e1b..f2e04e890 100644 --- a/codeflash/verification/instrument_codeflash_capture.py +++ b/codeflash/verification/instrument_codeflash_capture.py @@ -125,15 +125,7 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: return node has_init = False - # Build decorator node ONCE for each class, not per loop iteration - decorator = ast.Call( - func=self._base_decorator_func, - args=[], - keywords=[ - ast.keyword(arg="function_name", value=ast.Constant(value=f"{node.name}.__init__")), - *self._base_decorator_keywords, - ], - ) + decorator = None # Only scan node.body once for both __init__ and decorator check for item in node.body: @@ -151,7 +143,16 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: if isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "codeflash_capture": break else: - # No decorator found + # No decorator found - create it lazily on first use + if decorator is None: + decorator = ast.Call( + func=self._base_decorator_func, + args=[], + keywords=[ + ast.keyword(arg="function_name", value=ast.Constant(value=f"{node.name}.__init__")), + *self._base_decorator_keywords, + ], + ) item.decorator_list.insert(0, decorator) self.inserted_decorator = True @@ -187,6 +188,16 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: defaults=[], ) + # Build decorator for the synthetic __init__ + decorator = ast.Call( + func=self._base_decorator_func, + args=[], + keywords=[ + ast.keyword(arg="function_name", value=ast.Constant(value=f"{node.name}.__init__")), + *self._base_decorator_keywords, + ], + ) + # Create the complete function init_func = ast.FunctionDef( name="__init__", args=arguments, body=[super_call], decorator_list=[decorator], returns=None