Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 138 additions & 54 deletions codeflash/languages/java/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import logging
import re
from functools import lru_cache
from typing import TYPE_CHECKING

if TYPE_CHECKING:
Expand All @@ -43,6 +42,102 @@ def _get_function_name(func: Any) -> str:
# Pattern to detect primitive array types in assertions
_PRIMITIVE_ARRAY_PATTERN = re.compile(r"new\s+(int|long|double|float|short|byte|char|boolean)\s*\[\s*\]")

# Pattern to match @Test annotation exactly (not @TestOnly, @TestFactory, etc.)
_TEST_ANNOTATION_RE = re.compile(r"^@Test(?:\s*\(.*\))?(?:\s.*)?$")


def _is_test_annotation(stripped_line: str) -> bool:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't treesitter help with this?

"""Check if a stripped line is an @Test annotation (not @TestOnly, @TestFactory, etc.).

Matches:
@Test
@Test(expected = ...)
@Test(timeout = 5000)
Does NOT match:
@TestOnly
@TestFactory
@TestTemplate
"""
return bool(_TEST_ANNOTATION_RE.match(stripped_line))


Comment on lines +61 to +63
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️Codeflash found 28% (0.28x) speedup for _is_test_annotation in codeflash/languages/java/instrumentation.py

⏱️ Runtime : 544 microseconds 424 microseconds (best of 220 runs)

📝 Explanation and details

The optimized code achieves a 28% runtime improvement by replacing a regex match operation with a hand-crafted string parsing algorithm. Here's why it's faster:

Key Optimization Strategy:

  1. Fast-path prefix check: Instead of immediately invoking the regex engine, the code first checks if the string starts with "@Test" using Python's highly-optimized str.startswith() method. This eliminates ~27% of inputs (2,245 out of 8,364 valid inputs) in just a simple string comparison, avoiding the overhead of regex compilation lookup and matching.

  2. Direct string operations replace regex: For strings that pass the prefix check, the code uses explicit string operations (len(), indexing, isspace(), find()) rather than the regex engine's state machine. Python's native string operations are implemented in C and avoid the overhead of:

    • Regex pattern interpretation
    • Backtracking logic
    • Capture group management
    • Generic pattern matching machinery
  3. Algorithmic structure mirrors the pattern: The hand-rolled parser directly implements the logic of `^@test(?:\s*(.))?(?:\s.)?

    Details

📝 Explanation and details : - Exactly matches `"@test"` (5 characters) - Handles the optional `\s*\(.*\)` group by explicitly looking for `(` after optional whitespace, then finding the matching `)` - Validates trailing content must start with whitespace if present

Performance by Test Case Type:

  • Simple @Test: Up to 119-141% faster (1.77μs → 811ns) - benefits maximally from exact length check
  • With parameters like @Test(timeout=5000): 20-32% faster - still faster despite parenthesis parsing logic
  • Invalid prefixes like @TestOnly: 40-50% faster - rejected immediately by startswith() check
  • Trailing whitespace/content: 100-125% faster - simplified whitespace detection via isspace()

Type Safety Preservation:

The code maintains behavioral compatibility by explicitly checking for non-string inputs and delegating to the regex to raise the appropriate TypeError, ensuring existing error handling remains intact.

Impact Assessment:

This function appears to be used for parsing Java test annotations, likely called during code analysis or test discovery phases. The 28% speedup would be most beneficial when:

  • Processing large Java codebases with thousands of annotations
  • Running in tight loops during file scanning
  • Performance-sensitive contexts like IDE integrations or CI/CD pipelines

The optimization trades code complexity for runtime performance - the manual parsing logic is more verbose but demonstrates superior performance across all test categories.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 1076 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Click to see Generated Regression Tests
import pytest  # used for our unit tests
from codeflash.languages.java.instrumentation import _is_test_annotation

def test_matches_plain_test_annotation():
    # Plain @Test should be recognized as a test annotation
    codeflash_output = _is_test_annotation("@Test") # 1.77μs -> 811ns (119% faster)

def test_matches_test_with_parentheses_contents():
    # Parentheses with content immediately after @Test should match
    codeflash_output = _is_test_annotation("@Test(expected = Exception.class)") # 2.05μs -> 1.61μs (27.3% faster)
    # Parentheses with no space before '(' (common style) should also match
    codeflash_output = _is_test_annotation("@Test(timeout=5000)") # 1.00μs -> 781ns (28.2% faster)
    # Parentheses with a space before '(' should match as \s* allows that
    codeflash_output = _is_test_annotation("@Test (timeout = 5)") # 691ns -> 390ns (77.2% faster)
    # Empty parentheses are allowed (.* can match empty)
    codeflash_output = _is_test_annotation("@Test()") # 621ns -> 571ns (8.76% faster)

def test_matches_test_with_trailing_elements():
    # The regex allows optional whitespace then any trailing text, so annotations
    # followed by comments or modifiers on the same line should match
    codeflash_output = _is_test_annotation("@Test // some comment") # 2.11μs -> 982ns (115% faster)
    codeflash_output = _is_test_annotation("@Test public void foo() {}") # 971ns -> 441ns (120% faster)
    codeflash_output = _is_test_annotation("@Test (timeout=1) // trailing") # 841ns -> 321ns (162% faster)

def test_does_not_match_variants_starting_with_test_but_not_exact():
    # Should NOT match variants like @TestOnly, @TestFactory, @TestTemplate
    codeflash_output = _is_test_annotation("@TestOnly") # 1.78μs -> 1.18μs (50.8% faster)
    codeflash_output = _is_test_annotation("@TestFactory") # 771ns -> 521ns (48.0% faster)
    codeflash_output = _is_test_annotation("@TestTemplate") # 511ns -> 400ns (27.8% faster)
    # Also should not match other words that immediately follow without a space
    codeflash_output = _is_test_annotation("@Tester") # 471ns -> 380ns (23.9% faster)
    codeflash_output = _is_test_annotation("@Tested") # 450ns -> 371ns (21.3% faster)

def test_leading_or_trailing_whitespace_behavior():
    # The function expects a "stripped" line; leading whitespace prevents a match
    # because the regex anchors at the beginning of the string (^@Test)
    codeflash_output = _is_test_annotation("  @Test") # 1.11μs -> 691ns (60.9% faster)
    # Trailing whitespace should be fine because $ anchors the end and whitespace is part of the line
    codeflash_output = _is_test_annotation("@Test ") # 1.39μs -> 632ns (120% faster)
    # Newline characters inside the string will prevent match because '.' doesn't match newline
    codeflash_output = _is_test_annotation("@Test\n") # 741ns -> 350ns (112% faster)

def test_empty_string_and_missing_parenthesis_cases():
    # Empty string is not a test annotation
    codeflash_output = _is_test_annotation("") # 1.12μs -> 711ns (57.8% faster)
    # Missing closing parenthesis should not match because pattern requires a closing ')'
    codeflash_output = _is_test_annotation("@Test(timeout=5") # 1.25μs -> 1.23μs (1.70% faster)
    # Missing opening parenthesis but with a closing one is also invalid
    codeflash_output = _is_test_annotation("@Test) extra") # 631ns -> 531ns (18.8% faster)
    # Parentheses containing comments/special characters should still match (dot matches them)
    codeflash_output = _is_test_annotation("@Test(/* comment */)") # 972ns -> 622ns (56.3% faster)

def test_non_string_types_raise_type_error():
    # Passing None should raise a TypeError because re.match expects a string
    with pytest.raises(TypeError):
        _is_test_annotation(None) # 2.96μs -> 3.04μs (2.63% slower)
    # Passing an integer should raise a TypeError
    with pytest.raises(TypeError):
        _is_test_annotation(123) # 1.66μs -> 1.66μs (0.060% faster)
    # Passing bytes should raise a TypeError since the compiled pattern is str-based
    with pytest.raises(TypeError):
        _is_test_annotation(b"@Test") # 1.30μs -> 1.39μs (6.53% slower)

def test_large_scale_mixed_annotations_1000_entries():
    # Create 1000 annotation-like strings, alternating between valid and invalid,
    # and assert that the function correctly identifies them.
    total = 1000
    matches = 0
    for i in range(total):
        # Even indices are valid test annotations with slight variations
        if i % 2 == 0:
            if i % 4 == 0:
                s = "@Test"
            else:
                # include parentheses and a trailing comment for variation
                s = "@Test(timeout=100) // iteration {}".format(i)
            expected = True
        else:
            # Odd indices are intentionally invalid variants
            if i % 3 == 0:
                s = "@TestOnly{}".format(i)
            else:
                s = "  @Test"  # leading spaces make it invalid for this function
            expected = False

        # Each call should be deterministic and equal to expected
        codeflash_output = _is_test_annotation(s); result = codeflash_output # 439μs -> 352μs (24.8% faster)
        if result:
            matches += 1

def test_repeated_calls_idempotent_and_fast():
    # Call the function repeatedly on the same inputs to ensure consistent results and no state retention.
    inputs = [
        "@Test",
        "@Test(timeout=1)",
        "@TestOnly",
        "@Test extra",
        "@Test()",
        "@Test(timeout=something) trailing text",
        "",
        "  @Test",
    ]
    # Repeat 500 times to exercise repeated use (total calls 8 * 500 = 4000)
    for _ in range(500):
        # Use a comprehension to ensure the same order and deterministic results
        results = [_is_test_annotation(s) for s in inputs]
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import re

# imports
import pytest
from codeflash.languages.java.instrumentation import _is_test_annotation

def test_basic_test_annotation():
    """Test that a simple @Test annotation is recognized."""
    codeflash_output = _is_test_annotation("@Test") # 1.98μs -> 822ns (141% faster)

def test_test_annotation_with_parentheses():
    """Test that @Test with empty parentheses is recognized."""
    codeflash_output = _is_test_annotation("@Test()") # 2.17μs -> 1.72μs (26.2% faster)

def test_test_annotation_with_expected_parameter():
    """Test that @Test with expected parameter is recognized."""
    codeflash_output = _is_test_annotation("@Test(expected = Exception.class)") # 2.06μs -> 1.67μs (23.2% faster)

def test_test_annotation_with_timeout_parameter():
    """Test that @Test with timeout parameter is recognized."""
    codeflash_output = _is_test_annotation("@Test(timeout = 5000)") # 2.02μs -> 1.62μs (24.7% faster)

def test_test_annotation_with_multiple_parameters():
    """Test that @Test with multiple parameters is recognized."""
    codeflash_output = _is_test_annotation("@Test(timeout = 5000, expected = Exception.class)") # 2.00μs -> 1.67μs (19.7% faster)

def test_test_only_annotation_not_recognized():
    """Test that @TestOnly annotation is NOT recognized as @Test."""
    codeflash_output = _is_test_annotation("@TestOnly") # 1.75μs -> 1.21μs (44.6% faster)

def test_test_factory_annotation_not_recognized():
    """Test that @TestFactory annotation is NOT recognized as @Test."""
    codeflash_output = _is_test_annotation("@TestFactory") # 1.69μs -> 1.21μs (39.8% faster)

def test_test_template_annotation_not_recognized():
    """Test that @TestTemplate annotation is NOT recognized as @Test."""
    codeflash_output = _is_test_annotation("@TestTemplate") # 1.68μs -> 1.16μs (44.8% faster)

def test_test_case_annotation_not_recognized():
    """Test that @TestCase annotation is NOT recognized as @Test."""
    codeflash_output = _is_test_annotation("@TestCase") # 1.68μs -> 1.14μs (47.4% faster)

def test_override_annotation_not_recognized():
    """Test that @Override annotation is NOT recognized as @Test."""
    codeflash_output = _is_test_annotation("@Override") # 1.14μs -> 681ns (67.8% faster)

def test_test_annotation_with_trailing_whitespace():
    """Test that @Test with trailing whitespace and content is recognized."""
    codeflash_output = _is_test_annotation("@Test some comment") # 2.15μs -> 1.03μs (109% faster)

def test_test_annotation_with_trailing_spaces_only():
    """Test that @Test followed by spaces is recognized."""
    codeflash_output = _is_test_annotation("@Test   ") # 2.03μs -> 952ns (114% faster)

def test_test_with_complex_parameter():
    """Test @Test with complex parameter containing special characters."""
    codeflash_output = _is_test_annotation("@Test(expected = java.io.IOException.class)") # 2.05μs -> 1.75μs (17.2% faster)

def test_test_with_nested_parentheses():
    """Test @Test with nested parentheses in parameters."""
    codeflash_output = _is_test_annotation("@Test(expected = (String) value)") # 1.96μs -> 1.74μs (12.6% faster)

def test_empty_string():
    """Test that empty string returns False."""
    codeflash_output = _is_test_annotation("") # 1.09μs -> 681ns (60.4% faster)

def test_only_whitespace():
    """Test that whitespace-only string returns False."""
    codeflash_output = _is_test_annotation("   ") # 1.05μs -> 682ns (54.3% faster)

def test_test_with_leading_whitespace():
    """Test that @Test with leading whitespace (already stripped) is False."""
    # Note: This tests a stripped line, so leading whitespace shouldn't appear
    # But if it does, it should fail to match
    codeflash_output = _is_test_annotation(" @Test") # 1.20μs -> 682ns (76.2% faster)

def test_test_lowercase():
    """Test that lowercase @test is NOT recognized (case-sensitive)."""
    codeflash_output = _is_test_annotation("@test") # 1.16μs -> 721ns (61.2% faster)

def test_test_mixed_case():
    """Test that mixed case @TEST is NOT recognized."""
    codeflash_output = _is_test_annotation("@TEST") # 1.21μs -> 651ns (86.2% faster)

def test_test_annotation_with_newline():
    """Test that @Test with embedded newline character is NOT recognized."""
    codeflash_output = _is_test_annotation("@Test\n") # 2.08μs -> 972ns (114% faster)

def test_test_annotation_with_tab():
    """Test that @Test with tab character between annotation and content."""
    codeflash_output = _is_test_annotation("@Test\tsome content") # 2.07μs -> 922ns (125% faster)

def test_partial_test_annotation():
    """Test that partial match like @Tes is NOT recognized."""
    codeflash_output = _is_test_annotation("@Tes") # 1.09μs -> 701ns (55.8% faster)

def test_test_annotation_without_at_symbol():
    """Test that Test without @ symbol is NOT recognized."""
    codeflash_output = _is_test_annotation("Test") # 1.06μs -> 671ns (58.3% faster)

def test_test_annotation_as_substring():
    """Test that @Test as substring is NOT recognized."""
    codeflash_output = _is_test_annotation("some@Test") # 1.15μs -> 681ns (69.2% faster)

def test_test_annotation_with_dots():
    """Test that fully qualified @Test annotation is NOT recognized."""
    codeflash_output = _is_test_annotation("@org.junit.Test") # 1.19μs -> 692ns (72.3% faster)

def test_test_annotation_with_trailing_parentheses_no_content():
    """Test @Test() with no parameters is recognized."""
    codeflash_output = _is_test_annotation("@Test()") # 2.10μs -> 1.64μs (28.1% faster)

def test_test_annotation_with_spaces_in_parameters():
    """Test @Test with spaces inside parameter list."""
    codeflash_output = _is_test_annotation("@Test( expected = Exception.class )") # 2.03μs -> 1.69μs (20.1% faster)

def test_test_annotation_with_multiword_trailing_content():
    """Test @Test with multiple words trailing."""
    codeflash_output = _is_test_annotation("@Test this is a comment") # 2.06μs -> 1.02μs (102% faster)

def test_test_annotation_with_special_chars_in_params():
    """Test @Test with special characters in parameters."""
    codeflash_output = _is_test_annotation("@Test(timeout = 5_000)") # 2.03μs -> 1.68μs (20.9% faster)

def test_only_at_symbol():
    """Test that only @ symbol returns False."""
    codeflash_output = _is_test_annotation("@") # 1.08μs -> 661ns (63.7% faster)

def test_test_with_underscore():
    """Test that @Test_something is NOT recognized."""
    codeflash_output = _is_test_annotation("@Test_something") # 1.72μs -> 1.21μs (42.2% faster)

def test_testable_annotation():
    """Test that @Testable is NOT recognized as @Test."""
    codeflash_output = _is_test_annotation("@Testable") # 1.67μs -> 1.17μs (42.7% faster)

def test_test_annotation_with_equals_in_param():
    """Test @Test with equation-like parameter."""
    codeflash_output = _is_test_annotation("@Test(a = b = c)") # 2.07μs -> 1.62μs (27.8% faster)

def test_test_annotation_with_empty_param():
    """Test @Test with empty parameter value."""
    codeflash_output = _is_test_annotation("@Test(value = )") # 2.02μs -> 1.64μs (23.2% faster)

def test_many_test_annotations_in_list():
    """Test processing 100 valid @Test annotations."""
    # Create a list of 100 valid @Test annotations and verify all are recognized
    annotations = ["@Test"] * 100
    results = [_is_test_annotation(ann) for ann in annotations]

def test_many_false_annotations_in_list():
    """Test processing 100 non-@Test annotations."""
    # Create a list of 100 invalid annotations and verify none are recognized
    annotations = ["@TestOnly"] * 100
    results = [_is_test_annotation(ann) for ann in annotations]

def test_mixed_annotations_large_batch():
    """Test processing 500 mixed valid and invalid annotations."""
    # Create alternating pattern of valid and invalid annotations
    annotations = ["@Test", "@TestFactory"] * 250
    results = [_is_test_annotation(ann) for ann in annotations]
    # Count True results (should be 250)
    true_count = sum(1 for r in results if r)

def test_test_annotation_with_very_long_parameters():
    """Test @Test with extremely long parameter string."""
    long_param = "@Test(" + "a = b, " * 200 + "c = d)"
    codeflash_output = _is_test_annotation(long_param); result = codeflash_output # 2.51μs -> 2.01μs (24.8% faster)

def test_test_annotation_with_very_long_trailing_content():
    """Test @Test with very long trailing content."""
    long_content = "@Test " + "x " * 500
    codeflash_output = _is_test_annotation(long_content); result = codeflash_output # 2.35μs -> 1.06μs (122% faster)

def test_batch_of_test_annotations_with_variations():
    """Test 1000 @Test annotations with various parameter combinations."""
    # Create diverse @Test annotations
    variations = [
        "@Test",
        "@Test()",
        "@Test(timeout = 5000)",
        "@Test(expected = Exception.class)",
        "@Test(timeout = 5000, expected = Exception.class)",
    ]
    # Repeat pattern 200 times to get 1000 annotations
    annotations = variations * 200
    results = [_is_test_annotation(ann) for ann in annotations]

def test_batch_of_non_test_annotations():
    """Test 1000 non-@Test annotations of various types."""
    # Create diverse non-@Test annotations
    variations = [
        "@TestOnly",
        "@TestFactory",
        "@TestTemplate",
        "@Override",
        "@Deprecated",
    ]
    # Repeat pattern 200 times to get 1000 annotations
    annotations = variations * 200
    results = [_is_test_annotation(ann) for ann in annotations]

def test_random_single_char_strings_performance():
    """Test performance with 500 single-character strings."""
    chars = ["a", "b", "@", "(", ")", " "] * 83 + ["c"]  # 500 items
    results = [_is_test_annotation(c) for c in chars]

def test_return_type_is_boolean():
    """Test that return type is always a boolean (not truthy value)."""
    # Test True case
    codeflash_output = _is_test_annotation("@Test"); result_true = codeflash_output # 1.88μs -> 902ns (109% faster)
    
    # Test False case
    codeflash_output = _is_test_annotation("@TestOnly"); result_false = codeflash_output # 1.02μs -> 781ns (30.9% faster)

def test_consistency_across_repeated_calls():
    """Test that repeated calls with same input return consistent results."""
    test_input = "@Test(timeout = 5000)"
    results = [_is_test_annotation(test_input) for _ in range(100)]

def test_test_annotation_with_ampersand():
    """Test @Test with ampersand in parameters."""
    codeflash_output = _is_test_annotation("@Test(condition = a & b)") # 2.04μs -> 1.71μs (19.3% faster)

def test_test_annotation_with_pipe():
    """Test @Test with pipe character in parameters."""
    codeflash_output = _is_test_annotation("@Test(condition = a | b)") # 2.02μs -> 1.60μs (26.3% faster)

def test_test_annotation_with_percent():
    """Test @Test with percent character."""
    codeflash_output = _is_test_annotation("@Test(timeout = 100%)") # 1.97μs -> 1.49μs (32.3% faster)

def test_test_annotation_with_caret():
    """Test @Test with caret character."""
    codeflash_output = _is_test_annotation("@Test(value = 5^2)") # 1.96μs -> 1.59μs (23.3% faster)

def test_all_whitespace_types_after_test():
    """Test @Test followed by various whitespace characters."""
    # Space
    codeflash_output = _is_test_annotation("@Test ") # 1.98μs -> 991ns (100% faster)
    # Multiple spaces
    codeflash_output = _is_test_annotation("@Test    ") # 1.01μs -> 441ns (129% faster)
    # Tab and space
    codeflash_output = _is_test_annotation("@Test \t ") # 711ns -> 320ns (122% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To test or edit this optimization locally git merge codeflash/optimize-pr1472-2026-02-12T23.43.14

Click to see suggested changes
Suggested change
return bool(_TEST_ANNOTATION_RE.match(stripped_line))
# Validate input type to match regex behavior (raises TypeError for non-strings)
if not isinstance(stripped_line, str):
# Trigger the same TypeError that re.match would raise
_TEST_ANNOTATION_RE.match(stripped_line)
return False # unreachable, but for clarity
# Fast path: must start with the literal prefix
s = stripped_line
if not s.startswith("@Test"):
return False
# If exactly "@Test"
n = 5 # len("@Test")
L = len(s)
if n == L:
return True
# If the very next character is whitespace, it's a match
ch = s[n]
if ch.isspace():
return True
# Otherwise, allow optional spaces then a '(' which must have a matching ')'.
# After a matching ')', the remainder must be either empty or start with whitespace.
i = n
while i < L and s[i].isspace():
i += 1
if i < L and s[i] == "(":
# Find a closing parenthesis; if none, it cannot match (regex also would fail
# the \(.*\) requirement, and the alternative \s.* requires an initial space
# which we don't have here).
j = s.find(")", i + 1)
if j == -1:
return False
if j + 1 == L:
return True
return s[j + 1].isspace()
return False

Static Badge

def _find_balanced_end(text: str, start: int) -> int:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't treesitter help with this?

"""Find the position after the closing paren that balances the opening paren at start.

Args:
text: The source text.
start: Index of the opening parenthesis '('.

Returns:
Index one past the matching closing ')', or -1 if not found.

"""
if start >= len(text) or text[start] != "(":
return -1
depth = 1
pos = start + 1
in_string = False
string_char = None
in_char = False
while pos < len(text) and depth > 0:
ch = text[pos]
prev = text[pos - 1] if pos > 0 else ""
if ch == "'" and not in_string and prev != "\\":
in_char = not in_char
elif ch == '"' and not in_char and prev != "\\":
if not in_string:
in_string = True
string_char = ch
elif ch == string_char:
in_string = False
string_char = None
elif not in_string and not in_char:
if ch == "(":
depth += 1
elif ch == ")":
depth -= 1
pos += 1
return pos if depth == 0 else -1


def _find_method_calls_balanced(line: str, func_name: str):
"""Find method calls to func_name with properly balanced parentheses.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't treesitter help with this?


Handles nested parentheses in arguments correctly, unlike a pure regex approach.
Returns a list of (start, end, full_call) tuples where start/end are positions
in the line and full_call is the matched text (receiver.funcName(args)).

Args:
line: A single line of Java source code.
func_name: The method name to look for.

Returns:
List of (start_pos, end_pos, full_call_text) tuples.

"""
# First find all occurrences of .funcName( in the line using regex
# to locate the method name, then use balanced paren finding for args
prefix_pattern = re.compile(
rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*{re.escape(func_name)}\s*\("
)
results = []
search_start = 0
while search_start < len(line):
m = prefix_pattern.search(line, search_start)
if not m:
break
# m.end() - 1 is the position of the opening paren
open_paren_pos = m.end() - 1
close_pos = _find_balanced_end(line, open_paren_pos)
if close_pos == -1:
# Unbalanced parens - skip this match
search_start = m.end()
continue
full_call = line[m.start():close_pos]
results.append((m.start(), close_pos, full_call))
search_start = close_pos
return results


def _infer_array_cast_type(line: str) -> str | None:
"""Infer the array cast type needed for assertion methods.
Expand Down Expand Up @@ -182,11 +277,13 @@ def instrument_existing_test(
else:
new_class_name = f"{original_class_name}__perfonlyinstrumented"

# Rename the class declaration in the source
# Pattern: "public class ClassName" or "class ClassName"
pattern = rf"\b(public\s+)?class\s+{re.escape(original_class_name)}\b"
replacement = rf"\1class {new_class_name}"
modified_source = re.sub(pattern, replacement, source)
# Rename all references to the original class name in the source.
# This includes the class declaration, return types, constructor calls,
# variable declarations, etc. We use word-boundary matching to avoid
# replacing substrings of other identifiers.
modified_source = re.sub(
rf"\b{re.escape(original_class_name)}\b", new_class_name, source
)

# Add timing instrumentation to test methods
# Use original class name (without suffix) in timing markers for consistency with Python
Expand Down Expand Up @@ -277,15 +374,12 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
iteration_counter = 0
helper_added = False

# Pre-compile the regex pattern once
method_call_pattern = _get_method_call_pattern(func_name)

while i < len(lines):
line = lines[i]
stripped = line.strip()

# Look for @Test annotation
if stripped.startswith("@Test"):
# Look for @Test annotation (not @TestOnly, @TestFactory, etc.)
if _is_test_annotation(stripped):
if not helper_added:
helper_added = True
result.append(line)
Expand Down Expand Up @@ -342,27 +436,20 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
call_counter = 0
wrapped_body_lines = []

# Use regex to find method calls with the target function
# Pattern matches: receiver.funcName(args) where receiver can be:
# - identifier (counter, calc, etc.)
# - new ClassName()
# - new ClassName(args)
# - this
method_call_pattern = re.compile(
rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({re.escape(func_name)})\s*\(([^)]*)\)", re.MULTILINE
)

# Track lambda block nesting depth to avoid wrapping calls inside lambda bodies.
# assertThrows/assertDoesNotThrow expect an Executable (void functional interface),
# and wrapping the call in a variable assignment would turn the void-compatible
# lambda into a value-returning lambda, causing a compilation error.
# Handles both expression lambdas: () -> func()
# and block lambdas: () -> { func(); }
# Also, variables declared outside lambdas cannot be reassigned inside them
# (Java requires effectively final variables in lambda captures).
# Handles both no-arg lambdas: () -> { func(); }
# and parameterized lambdas: (a, b, c) -> { func(); }
lambda_brace_depth = 0

for body_line in body_lines:
# Detect new block lambda openings: () -> {
is_lambda_open = bool(re.search(r"\(\s*\)\s*->\s*\{", body_line))
# Detect block lambda openings: (...) -> { or () -> {
# Matches both () -> { and (a, b, c) -> {
is_lambda_open = bool(re.search(r"->\s*\{", body_line))

# Update lambda brace depth tracking for block lambdas
if is_lambda_open or lambda_brace_depth > 0:
Expand All @@ -376,7 +463,7 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
# Ensure depth doesn't go below 0
lambda_brace_depth = max(0, lambda_brace_depth)

inside_lambda = lambda_brace_depth > 0 or bool(re.search(r"\(\s*\)\s*->", body_line))
inside_lambda = lambda_brace_depth > 0 or bool(re.search(r"->\s+\S", body_line))

# Check if this line contains a call to the target function
if func_name in body_line and "(" in body_line:
Expand All @@ -388,30 +475,41 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
line_indent = len(body_line) - len(body_line.lstrip())
line_indent_str = " " * line_indent

# Find all matches in the line
matches = list(method_call_pattern.finditer(body_line))
# Find all matches using balanced parenthesis matching
# This correctly handles nested parens like:
# obj.func(a, Rows.toRowID(frame.getIndex(), row))
matches = _find_method_calls_balanced(body_line, func_name)
if matches:
# Process matches in reverse order to maintain correct positions
new_line = body_line
for match in reversed(matches):
for start_pos, end_pos, full_call in reversed(matches):
call_counter += 1
var_name = f"_cf_result{iter_id}_{call_counter}"
full_call = match.group(0) # e.g., "new StringUtils().reverse(\"hello\")"

# Check if we need to cast the result for assertions with primitive arrays
# This handles assertArrayEquals(int[], int[]) etc.
cast_type = _infer_array_cast_type(body_line)
var_with_cast = f"({cast_type}){var_name}" if cast_type else var_name

# Replace this occurrence with the variable (with cast if needed)
new_line = new_line[: match.start()] + var_with_cast + new_line[match.end() :]
new_line = new_line[:start_pos] + var_with_cast + new_line[end_pos:]

# Use 'var' instead of 'Object' to preserve the exact return type.
# This avoids boxing mismatches (e.g., assertEquals(int, Object) where
# Object is boxed Long but expected is boxed Integer). Requires Java 10+.
capture_line = f"{line_indent_str}var {var_name} = {full_call};"
wrapped_body_lines.append(capture_line)

# Immediately serialize the captured result while the variable
# is still in scope. This is necessary because the variable may
# be declared inside a nested block (while/for/if/try) and would
# be out of scope at the end of the method body.
serialize_line = (
f"{line_indent_str}_cf_serializedResult{iter_id} = "
f"com.codeflash.Serializer.serialize((Object) {var_name});"
)
wrapped_body_lines.append(serialize_line)

# Check if the line is now just a variable reference (invalid statement)
# This happens when the original line was just a void method call
# e.g., "BubbleSort.bubbleSort(original);" becomes "_cf_result1_1;"
Expand All @@ -423,15 +521,6 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
else:
wrapped_body_lines.append(body_line)

# Build the serialized return value expression
# If we captured any calls, serialize the last one via Kryo; otherwise null bytes
# The (Object) cast ensures primitives get autoboxed before being passed to the method.
if call_counter > 0:
result_var = f"_cf_result{iter_id}_{call_counter}"
serialize_expr = f"com.codeflash.Serializer.serialize((Object) {result_var})"
else:
serialize_expr = "null"

# Add behavior instrumentation code
behavior_start_code = [
f"{indent}// Codeflash behavior instrumentation",
Expand All @@ -450,13 +539,13 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
]
result.extend(behavior_start_code)

# Add the wrapped body lines with extra indentation
# Add the wrapped body lines with extra indentation.
# Serialization of captured results is already done inline (immediately
# after each capture) so the _cf_serializedResult variable is always
# assigned while the captured variable is still in scope.
for bl in wrapped_body_lines:
result.append(" " + bl)

# Add serialization after the body (before finally)
result.append(f"{indent} _cf_serializedResult{iter_id} = {serialize_expr};")

# Add finally block with SQLite write
method_close_indent = " " * base_indent
behavior_end_code = [
Expand Down Expand Up @@ -543,8 +632,8 @@ def _add_timing_instrumentation(source: str, class_name: str, func_name: str) ->
line = lines[i]
stripped = line.strip()

# Look for @Test annotation
if stripped.startswith("@Test"):
# Look for @Test annotation (not @TestOnly, @TestFactory, etc.)
if _is_test_annotation(stripped):
result.append(line)
i += 1

Expand Down Expand Up @@ -751,9 +840,10 @@ def instrument_generated_java_test(
else:
new_class_name = f"{original_class_name}__perfonlyinstrumented"

# Rename the class in the source
# Rename all references to the original class name in the source.
# This includes the class declaration, return types, constructor calls, etc.
modified_code = re.sub(
rf"\b(public\s+)?class\s+{re.escape(original_class_name)}\b", rf"\1class {new_class_name}", test_code
rf"\b{re.escape(original_class_name)}\b", new_class_name, test_code
)

# For performance mode, add timing instrumentation
Expand Down Expand Up @@ -798,9 +888,3 @@ def _add_import(source: str, import_statement: str) -> str:
return "".join(lines)


@lru_cache(maxsize=128)
def _get_method_call_pattern(func_name: str):
"""Cache compiled regex patterns for method call matching."""
return re.compile(
rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({re.escape(func_name)})\s*\(([^)]*)\)", re.MULTILINE
)
13 changes: 8 additions & 5 deletions codeflash/languages/java/line_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import json
import logging
import re
from pathlib import Path
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -89,17 +90,19 @@ def instrument_source(
end_idx = func.ending_line
lines = lines[:start_idx] + func_lines + lines[end_idx:]

instrumented_source = "".join(lines)

# Add profiler class and initialization
profiler_class_code = self._generate_profiler_class()

# Insert profiler class before the package's first class
# Find the first class declaration
# Find the first class/interface/enum/record declaration
# Must handle any combination of modifiers: public final class, abstract class, etc.
class_pattern = re.compile(
r"^(?:(?:public|private|protected|final|abstract|static|sealed|non-sealed)\s+)*"
r"(?:class|interface|enum|record)\s+"
)
import_end_idx = 0
for i, line in enumerate(lines):
stripped = line.strip()
if stripped.startswith("public class ") or stripped.startswith("class "):
if class_pattern.match(line.strip()):
import_end_idx = i
break

Expand Down
2 changes: 1 addition & 1 deletion codeflash/languages/java/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ def _run_benchmarking_tests_maven(
loop_count = 0
last_result = None

per_loop_timeout = timeout or max(120, 60 + inner_iterations)
per_loop_timeout = max(timeout or 0, 120, 60 + inner_iterations)

logger.debug("Using Maven-based benchmarking (fallback mode)")

Expand Down
Loading
Loading