@@ -178,14 +178,18 @@ def wrap_target_calls_with_treesitter(
178178 # Build line byte-start offsets for mapping calls to body_lines indices
179179 line_byte_starts = []
180180 offset = 0
181- for line in body_lines :
181+ # Precompute encoded lines to avoid repeated encoding
182+ encoded_lines = [line .encode ("utf8" ) for line in body_lines ]
183+ for encoded_line in encoded_lines :
182184 line_byte_starts .append (offset )
183- offset += len (line .encode ("utf8" )) + 1 # +1 for \n from join
185+ offset += len (encoded_line ) + 1 # +1 for \n from join
186+
187+ # Group non-lambda and non-complex-expression calls by their line index
184188
185189 # Group non-lambda and non-complex-expression calls by their line index
186190 calls_by_line : dict [int , list [dict [str , Any ]]] = {}
187191 for call in calls :
188- if call ["in_lambda" ] or call . get ( "in_complex" , False ) :
192+ if call ["skip_instrumentation" ] :
189193 logger .debug ("Skipping behavior instrumentation for call in lambda or complex expression" )
190194 continue
191195 line_idx = _byte_to_line_index (call ["start_byte" ], line_byte_starts )
@@ -202,7 +206,8 @@ def wrap_target_calls_with_treesitter(
202206 line_calls = sorted (calls_by_line [line_idx ], key = lambda c : c ["start_byte" ], reverse = True )
203207 line_indent_str = " " * (len (body_line ) - len (body_line .lstrip ()))
204208 line_byte_start = line_byte_starts [line_idx ]
205- line_bytes = body_line .encode ("utf8" )
209+ line_bytes = encoded_lines [line_idx ]
210+
206211
207212 new_line = body_line
208213 # Track cumulative char shift from earlier edits on this line
@@ -291,14 +296,17 @@ def _collect_calls(
291296 if parent_type == "expression_statement" :
292297 es_start = parent .start_byte - prefix_len
293298 es_end = parent .end_byte - prefix_len
299+
300+ # Compute skip flags once during collection
301+ skip_instrumentation = _should_skip_instrumentation (node )
302+
294303 out .append (
295304 {
296305 "start_byte" : start ,
297306 "end_byte" : end ,
298307 "full_call" : analyzer .get_node_text (node , wrapper_bytes ),
299308 "parent_type" : parent_type ,
300- "in_lambda" : _is_inside_lambda (node ),
301- "in_complex" : _is_inside_complex_expression (node ),
309+ "skip_instrumentation" : skip_instrumentation ,
302310 "es_start_byte" : es_start ,
303311 "es_end_byte" : es_end ,
304312 }
@@ -328,8 +336,7 @@ def _infer_array_cast_type(line: str) -> str | None:
328336
329337 """
330338 # Only apply to assertion methods that take arrays
331- assertion_methods = ("assertArrayEquals" , "assertArrayNotEquals" )
332- if not any (method in line for method in assertion_methods ):
339+ if "assertArrayEquals" not in line and "assertArrayNotEquals" not in line :
333340 return None
334341
335342 # Look for primitive array type in the line (usually the first/expected argument)
@@ -1191,3 +1198,86 @@ def _add_import(source: str, import_statement: str) -> str:
11911198
11921199 lines .insert (insert_idx , import_statement + "\n " )
11931200 return "" .join (lines )
1201+
1202+
1203+
1204+
1205+
1206+ def _should_skip_instrumentation (node : Any ) -> bool :
1207+ """Check if a node should skip instrumentation (in lambda or complex expression)."""
1208+ current = node .parent
1209+ while current is not None :
1210+ node_type = current .type
1211+
1212+ # Stop at statement boundaries
1213+ if node_type in {
1214+ "method_declaration" ,
1215+ "block" ,
1216+ "if_statement" ,
1217+ "for_statement" ,
1218+ "while_statement" ,
1219+ "try_statement" ,
1220+ "expression_statement" ,
1221+ }:
1222+ return False
1223+
1224+ # Lambda check
1225+ if node_type == "lambda_expression" :
1226+ return True
1227+
1228+ # Complex expression check
1229+ if node_type in {
1230+ "cast_expression" ,
1231+ "ternary_expression" ,
1232+ "array_access" ,
1233+ "binary_expression" ,
1234+ "unary_expression" ,
1235+ "parenthesized_expression" ,
1236+ "instanceof_expression" ,
1237+ }:
1238+ logger .debug ("Found complex expression parent: %s" , node_type )
1239+ return True
1240+
1241+ current = current .parent
1242+ return False
1243+
1244+
1245+
1246+
1247+ def _should_skip_instrumentation (node : Any ) -> bool :
1248+ """Check if a node should skip instrumentation (in lambda or complex expression)."""
1249+ current = node .parent
1250+ while current is not None :
1251+ node_type = current .type
1252+
1253+ # Stop at statement boundaries
1254+ if node_type in {
1255+ "method_declaration" ,
1256+ "block" ,
1257+ "if_statement" ,
1258+ "for_statement" ,
1259+ "while_statement" ,
1260+ "try_statement" ,
1261+ "expression_statement" ,
1262+ }:
1263+ return False
1264+
1265+ # Lambda check
1266+ if node_type == "lambda_expression" :
1267+ return True
1268+
1269+ # Complex expression check
1270+ if node_type in {
1271+ "cast_expression" ,
1272+ "ternary_expression" ,
1273+ "array_access" ,
1274+ "binary_expression" ,
1275+ "unary_expression" ,
1276+ "parenthesized_expression" ,
1277+ "instanceof_expression" ,
1278+ }:
1279+ logger .debug ("Found complex expression parent: %s" , node_type )
1280+ return True
1281+
1282+ current = current .parent
1283+ return False
0 commit comments