@@ -102,6 +102,22 @@ def __init__(
102102 self ._init_kwarg = ast .arg (arg = "kwargs" )
103103 self ._init_self_arg = ast .arg (arg = "self" , annotation = None )
104104
105+ # Precreate commonly reused AST fragments for classes that lack __init__
106+ # Create the super().__init__(*args, **kwargs) Expr (reuse prebuilt pieces)
107+ self ._super_call_expr = ast .Expr (
108+ value = ast .Call (func = self ._super_func , args = [self ._super_starred ], keywords = [self ._super_kwarg ])
109+ )
110+ # Create function arguments: self, *args, **kwargs (reuse arg nodes)
111+ self ._init_arguments = ast .arguments (
112+ posonlyargs = [],
113+ args = [self ._init_self_arg ],
114+ vararg = self ._init_vararg ,
115+ kwonlyargs = [],
116+ kw_defaults = [],
117+ kwarg = self ._init_kwarg ,
118+ defaults = [],
119+ )
120+
105121 def visit_ImportFrom (self , node : ast .ImportFrom ) -> ast .ImportFrom :
106122 # Check if our import already exists
107123 if node .module == "codeflash.verification.codeflash_capture" and any (
@@ -162,47 +178,35 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
162178 # TODO: support by saving a reference to the generated __init__ before overriding, e.g.
163179 # _orig_init = ClassName.__init__; then calling _orig_init(self, *args, **kwargs) in the wrapper
164180 for dec in node .decorator_list :
165- dec_name = None
166- if isinstance (dec , ast .Name ):
167- dec_name = dec .id
168- elif isinstance (dec , ast .Call ) and isinstance (dec .func , ast .Name ):
169- dec_name = dec .func .id
170- elif isinstance (dec , ast .Attribute ):
171- dec_name = dec .attr
181+ dec_name = self ._expr_name (dec )
172182 if dec_name == "dataclass" :
173183 return node
174184
175185 # Skip NamedTuples — their __init__ is synthesized and cannot be overwritten.
176186 for base in node .bases :
177- base_name = None
178- if isinstance (base , ast .Name ):
179- base_name = base .id
180- elif isinstance (base , ast .Attribute ):
181- base_name = base .attr
187+ base_name = self ._expr_name (base )
182188 if base_name == "NamedTuple" :
183189 return node
184190
185191 # Create super().__init__(*args, **kwargs) call (use prebuilt AST fragments)
186- super_call = ast .Expr (
187- value = ast .Call (func = self ._super_func , args = [self ._super_starred ], keywords = [self ._super_kwarg ])
188- )
189- # Create function arguments: self, *args, **kwargs (reuse arg nodes)
190- arguments = ast .arguments (
191- posonlyargs = [],
192- args = [self ._init_self_arg ],
193- vararg = self ._init_vararg ,
194- kwonlyargs = [],
195- kw_defaults = [],
196- kwarg = self ._init_kwarg ,
197- defaults = [],
198- )
192+ super_call = self ._super_call_expr
193+ # Create the complete function using prebuilt arguments/body but attach the class-specific decorator
199194
200195 # Create the complete function
201196 init_func = ast .FunctionDef (
202- name = "__init__" , args = arguments , body = [super_call ], decorator_list = [decorator ], returns = None
197+ name = "__init__" , args = self . _init_arguments , body = [super_call ], decorator_list = [decorator ], returns = None
203198 )
204199
205200 node .body .insert (0 , init_func )
206201 self .inserted_decorator = True
207202
208203 return node
204+
205+ def _expr_name (self , node : ast .AST ) -> str | None :
206+ if isinstance (node , ast .Name ):
207+ return node .id
208+ if isinstance (node , ast .Call ) and isinstance (node .func , ast .Name ):
209+ return node .func .id
210+ if isinstance (node , ast .Attribute ):
211+ return node .attr
212+ return None
0 commit comments