1616
1717import logging
1818import re
19- from functools import lru_cache
2019from typing import TYPE_CHECKING
2120
2221if TYPE_CHECKING :
@@ -43,6 +42,102 @@ def _get_function_name(func: Any) -> str:
4342# Pattern to detect primitive array types in assertions
4443_PRIMITIVE_ARRAY_PATTERN = re .compile (r"new\s+(int|long|double|float|short|byte|char|boolean)\s*\[\s*\]" )
4544
45+ # Pattern to match @Test annotation exactly (not @TestOnly, @TestFactory, etc.)
46+ _TEST_ANNOTATION_RE = re .compile (r"^@Test(?:\s*\(.*\))?(?:\s.*)?$" )
47+
48+
49+ def _is_test_annotation (stripped_line : str ) -> bool :
50+ """Check if a stripped line is an @Test annotation (not @TestOnly, @TestFactory, etc.).
51+
52+ Matches:
53+ @Test
54+ @Test(expected = ...)
55+ @Test(timeout = 5000)
56+ Does NOT match:
57+ @TestOnly
58+ @TestFactory
59+ @TestTemplate
60+ """
61+ return bool (_TEST_ANNOTATION_RE .match (stripped_line ))
62+
63+
64+ def _find_balanced_end (text : str , start : int ) -> int :
65+ """Find the position after the closing paren that balances the opening paren at start.
66+
67+ Args:
68+ text: The source text.
69+ start: Index of the opening parenthesis '('.
70+
71+ Returns:
72+ Index one past the matching closing ')', or -1 if not found.
73+
74+ """
75+ if start >= len (text ) or text [start ] != "(" :
76+ return - 1
77+ depth = 1
78+ pos = start + 1
79+ in_string = False
80+ string_char = None
81+ in_char = False
82+ while pos < len (text ) and depth > 0 :
83+ ch = text [pos ]
84+ prev = text [pos - 1 ] if pos > 0 else ""
85+ if ch == "'" and not in_string and prev != "\\ " :
86+ in_char = not in_char
87+ elif ch == '"' and not in_char and prev != "\\ " :
88+ if not in_string :
89+ in_string = True
90+ string_char = ch
91+ elif ch == string_char :
92+ in_string = False
93+ string_char = None
94+ elif not in_string and not in_char :
95+ if ch == "(" :
96+ depth += 1
97+ elif ch == ")" :
98+ depth -= 1
99+ pos += 1
100+ return pos if depth == 0 else - 1
101+
102+
103+ def _find_method_calls_balanced (line : str , func_name : str ):
104+ """Find method calls to func_name with properly balanced parentheses.
105+
106+ Handles nested parentheses in arguments correctly, unlike a pure regex approach.
107+ Returns a list of (start, end, full_call) tuples where start/end are positions
108+ in the line and full_call is the matched text (receiver.funcName(args)).
109+
110+ Args:
111+ line: A single line of Java source code.
112+ func_name: The method name to look for.
113+
114+ Returns:
115+ List of (start_pos, end_pos, full_call_text) tuples.
116+
117+ """
118+ # First find all occurrences of .funcName( in the line using regex
119+ # to locate the method name, then use balanced paren finding for args
120+ prefix_pattern = re .compile (
121+ rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*{ re .escape (func_name )} \s*\("
122+ )
123+ results = []
124+ search_start = 0
125+ while search_start < len (line ):
126+ m = prefix_pattern .search (line , search_start )
127+ if not m :
128+ break
129+ # m.end() - 1 is the position of the opening paren
130+ open_paren_pos = m .end () - 1
131+ close_pos = _find_balanced_end (line , open_paren_pos )
132+ if close_pos == - 1 :
133+ # Unbalanced parens - skip this match
134+ search_start = m .end ()
135+ continue
136+ full_call = line [m .start ():close_pos ]
137+ results .append ((m .start (), close_pos , full_call ))
138+ search_start = close_pos
139+ return results
140+
46141
47142def _infer_array_cast_type (line : str ) -> str | None :
48143 """Infer the array cast type needed for assertion methods.
@@ -182,11 +277,13 @@ def instrument_existing_test(
182277 else :
183278 new_class_name = f"{ original_class_name } __perfonlyinstrumented"
184279
185- # Rename the class declaration in the source
186- # Pattern: "public class ClassName" or "class ClassName"
187- pattern = rf"\b(public\s+)?class\s+{ re .escape (original_class_name )} \b"
188- replacement = rf"\1class { new_class_name } "
189- modified_source = re .sub (pattern , replacement , source )
280+ # Rename all references to the original class name in the source.
281+ # This includes the class declaration, return types, constructor calls,
282+ # variable declarations, etc. We use word-boundary matching to avoid
283+ # replacing substrings of other identifiers.
284+ modified_source = re .sub (
285+ rf"\b{ re .escape (original_class_name )} \b" , new_class_name , source
286+ )
190287
191288 # Add timing instrumentation to test methods
192289 # Use original class name (without suffix) in timing markers for consistency with Python
@@ -277,15 +374,12 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
277374 iteration_counter = 0
278375 helper_added = False
279376
280- # Pre-compile the regex pattern once
281- method_call_pattern = _get_method_call_pattern (func_name )
282-
283377 while i < len (lines ):
284378 line = lines [i ]
285379 stripped = line .strip ()
286380
287- # Look for @Test annotation
288- if stripped . startswith ( "@Test" ):
381+ # Look for @Test annotation (not @TestOnly, @TestFactory, etc.)
382+ if _is_test_annotation ( stripped ):
289383 if not helper_added :
290384 helper_added = True
291385 result .append (line )
@@ -342,27 +436,20 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
342436 call_counter = 0
343437 wrapped_body_lines = []
344438
345- # Use regex to find method calls with the target function
346- # Pattern matches: receiver.funcName(args) where receiver can be:
347- # - identifier (counter, calc, etc.)
348- # - new ClassName()
349- # - new ClassName(args)
350- # - this
351- method_call_pattern = re .compile (
352- rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({ re .escape (func_name )} )\s*\(([^)]*)\)" , re .MULTILINE
353- )
354-
355439 # Track lambda block nesting depth to avoid wrapping calls inside lambda bodies.
356440 # assertThrows/assertDoesNotThrow expect an Executable (void functional interface),
357441 # and wrapping the call in a variable assignment would turn the void-compatible
358442 # lambda into a value-returning lambda, causing a compilation error.
359- # Handles both expression lambdas: () -> func()
360- # and block lambdas: () -> { func(); }
443+ # Also, variables declared outside lambdas cannot be reassigned inside them
444+ # (Java requires effectively final variables in lambda captures).
445+ # Handles both no-arg lambdas: () -> { func(); }
446+ # and parameterized lambdas: (a, b, c) -> { func(); }
361447 lambda_brace_depth = 0
362448
363449 for body_line in body_lines :
364- # Detect new block lambda openings: () -> {
365- is_lambda_open = bool (re .search (r"\(\s*\)\s*->\s*\{" , body_line ))
450+ # Detect block lambda openings: (...) -> { or () -> {
451+ # Matches both () -> { and (a, b, c) -> {
452+ is_lambda_open = bool (re .search (r"->\s*\{" , body_line ))
366453
367454 # Update lambda brace depth tracking for block lambdas
368455 if is_lambda_open or lambda_brace_depth > 0 :
@@ -376,7 +463,7 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
376463 # Ensure depth doesn't go below 0
377464 lambda_brace_depth = max (0 , lambda_brace_depth )
378465
379- inside_lambda = lambda_brace_depth > 0 or bool (re .search (r"\(\s*\)\s*-> " , body_line ))
466+ inside_lambda = lambda_brace_depth > 0 or bool (re .search (r"->\s+\S " , body_line ))
380467
381468 # Check if this line contains a call to the target function
382469 if func_name in body_line and "(" in body_line :
@@ -388,30 +475,41 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
388475 line_indent = len (body_line ) - len (body_line .lstrip ())
389476 line_indent_str = " " * line_indent
390477
391- # Find all matches in the line
392- matches = list (method_call_pattern .finditer (body_line ))
478+ # Find all matches using balanced parenthesis matching
479+ # This correctly handles nested parens like:
480+ # obj.func(a, Rows.toRowID(frame.getIndex(), row))
481+ matches = _find_method_calls_balanced (body_line , func_name )
393482 if matches :
394483 # Process matches in reverse order to maintain correct positions
395484 new_line = body_line
396- for match in reversed (matches ):
485+ for start_pos , end_pos , full_call in reversed (matches ):
397486 call_counter += 1
398487 var_name = f"_cf_result{ iter_id } _{ call_counter } "
399- full_call = match .group (0 ) # e.g., "new StringUtils().reverse(\"hello\")"
400488
401489 # Check if we need to cast the result for assertions with primitive arrays
402490 # This handles assertArrayEquals(int[], int[]) etc.
403491 cast_type = _infer_array_cast_type (body_line )
404492 var_with_cast = f"({ cast_type } ){ var_name } " if cast_type else var_name
405493
406494 # Replace this occurrence with the variable (with cast if needed)
407- new_line = new_line [: match . start () ] + var_with_cast + new_line [match . end () :]
495+ new_line = new_line [:start_pos ] + var_with_cast + new_line [end_pos :]
408496
409497 # Use 'var' instead of 'Object' to preserve the exact return type.
410498 # This avoids boxing mismatches (e.g., assertEquals(int, Object) where
411499 # Object is boxed Long but expected is boxed Integer). Requires Java 10+.
412500 capture_line = f"{ line_indent_str } var { var_name } = { full_call } ;"
413501 wrapped_body_lines .append (capture_line )
414502
503+ # Immediately serialize the captured result while the variable
504+ # is still in scope. This is necessary because the variable may
505+ # be declared inside a nested block (while/for/if/try) and would
506+ # be out of scope at the end of the method body.
507+ serialize_line = (
508+ f"{ line_indent_str } _cf_serializedResult{ iter_id } = "
509+ f"com.codeflash.Serializer.serialize((Object) { var_name } );"
510+ )
511+ wrapped_body_lines .append (serialize_line )
512+
415513 # Check if the line is now just a variable reference (invalid statement)
416514 # This happens when the original line was just a void method call
417515 # e.g., "BubbleSort.bubbleSort(original);" becomes "_cf_result1_1;"
@@ -423,15 +521,6 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
423521 else :
424522 wrapped_body_lines .append (body_line )
425523
426- # Build the serialized return value expression
427- # If we captured any calls, serialize the last one via Kryo; otherwise null bytes
428- # The (Object) cast ensures primitives get autoboxed before being passed to the method.
429- if call_counter > 0 :
430- result_var = f"_cf_result{ iter_id } _{ call_counter } "
431- serialize_expr = f"com.codeflash.Serializer.serialize((Object) { result_var } )"
432- else :
433- serialize_expr = "null"
434-
435524 # Add behavior instrumentation code
436525 behavior_start_code = [
437526 f"{ indent } // Codeflash behavior instrumentation" ,
@@ -450,13 +539,13 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
450539 ]
451540 result .extend (behavior_start_code )
452541
453- # Add the wrapped body lines with extra indentation
542+ # Add the wrapped body lines with extra indentation.
543+ # Serialization of captured results is already done inline (immediately
544+ # after each capture) so the _cf_serializedResult variable is always
545+ # assigned while the captured variable is still in scope.
454546 for bl in wrapped_body_lines :
455547 result .append (" " + bl )
456548
457- # Add serialization after the body (before finally)
458- result .append (f"{ indent } _cf_serializedResult{ iter_id } = { serialize_expr } ;" )
459-
460549 # Add finally block with SQLite write
461550 method_close_indent = " " * base_indent
462551 behavior_end_code = [
@@ -543,8 +632,8 @@ def _add_timing_instrumentation(source: str, class_name: str, func_name: str) ->
543632 line = lines [i ]
544633 stripped = line .strip ()
545634
546- # Look for @Test annotation
547- if stripped . startswith ( "@Test" ):
635+ # Look for @Test annotation (not @TestOnly, @TestFactory, etc.)
636+ if _is_test_annotation ( stripped ):
548637 result .append (line )
549638 i += 1
550639
@@ -751,9 +840,10 @@ def instrument_generated_java_test(
751840 else :
752841 new_class_name = f"{ original_class_name } __perfonlyinstrumented"
753842
754- # Rename the class in the source
843+ # Rename all references to the original class name in the source.
844+ # This includes the class declaration, return types, constructor calls, etc.
755845 modified_code = re .sub (
756- rf"\b(public\s+)?class\s+ { re .escape (original_class_name )} \b" , rf"\1class { new_class_name } " , test_code
846+ rf"\b{ re .escape (original_class_name )} \b" , new_class_name , test_code
757847 )
758848
759849 # For performance mode, add timing instrumentation
@@ -798,9 +888,3 @@ def _add_import(source: str, import_statement: str) -> str:
798888 return "" .join (lines )
799889
800890
801- @lru_cache (maxsize = 128 )
802- def _get_method_call_pattern (func_name : str ):
803- """Cache compiled regex patterns for method call matching."""
804- return re .compile (
805- rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({ re .escape (func_name )} )\s*\(([^)]*)\)" , re .MULTILINE
806- )
0 commit comments