Skip to content

refactor(java): replace regex/brace-counting with tree-sitter in instrumentation#1473

Closed
HeshamHM28 wants to merge 2 commits into
omni-javafrom
refactor/tree-sitter-instrumentation
Closed

refactor(java): replace regex/brace-counting with tree-sitter in instrumentation#1473
HeshamHM28 wants to merge 2 commits into
omni-javafrom
refactor/tree-sitter-instrumentation

Conversation

@HeshamHM28
Copy link
Copy Markdown
Contributor

Summary

  • Replace fragile line-by-line regex scanning and manual brace counting in Java instrumentation with tree-sitter AST analysis
  • Add find_test_methods(), find_method_invocations(), find_identifier_references(), find_import_insertion_point() helpers to JavaAnalyzer
  • Refactor _add_timing_instrumentation and _add_behavior_instrumentation to use tree-sitter for @test method detection and body boundary extraction
  • Refactor class renaming to use tree-sitter identifier references (excludes matches inside strings/comments)
  • Remove dead helpers: _is_test_annotation, _find_balanced_end, _find_method_calls_balanced, _add_import
  • Fix line_profiler.py class detection to use find_classes() instead of regex
  • Net reduction of ~30 lines while eliminating multiple bug classes (modifier combinations, nested parens, brace-in-string corruption, comment/string false matches, lambda edge cases)

Test plan

  • All 33 instrumentation unit tests pass with byte-for-byte identical output
  • All 31 parser tests pass unchanged
  • Verified tree-sitter helpers correctly handle: @Test vs @TestOnly discrimination, lambda-aware method invocation detection, identifier references excluding strings/comments, import insertion point detection
  • Verified edge cases: public final class, @Disabled @Test, inner/nested classes, empty test methods, nested braces

🤖 Generated with Claude Code

…rumentation

Replaces fragile line-by-line regex scanning and manual brace counting with
tree-sitter AST analysis for Java test instrumentation. This eliminates several
classes of bugs:

- `public final class` not matching class detection patterns
- `@TestOnly` being incorrectly matched as `@Test`
- Nested parentheses breaking method call extraction
- Braces inside strings/comments corrupting method boundary detection
- Class renaming hitting matches inside comments and string literals
- Lambda detection missing edge cases

Changes:
- parser.py: Add find_test_methods(), find_method_invocations(),
  find_identifier_references(), find_import_insertion_point() helpers
- instrumentation.py: Refactor _add_timing_instrumentation and
  _add_behavior_instrumentation to use tree-sitter for @test method
  detection and body boundary extraction via body node ranges
- instrumentation.py: Refactor class renaming to use tree-sitter
  identifier references (excludes strings/comments)
- instrumentation.py: Refactor class extraction in generated test
  instrumentation to use find_classes()
- instrumentation.py: Remove dead helpers (_is_test_annotation,
  _find_balanced_end, _find_method_calls_balanced, _add_import)
- line_profiler.py: Replace regex class detection with find_classes()

All 33 instrumentation unit tests pass with byte-for-byte identical output.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@codeflash-ai
Copy link
Copy Markdown
Contributor

codeflash-ai Bot commented Feb 13, 2026

⚡️ Codeflash found optimizations for this PR

📄 3,150% (31.50x) speedup for JavaAnalyzer._has_test_annotation in codeflash/languages/java/parser.py

⏱️ Runtime : 124 milliseconds 3.80 milliseconds (best of 5 runs)

A dependent PR with the suggested changes has been created. Please review:

If you approve, it will be merged into this PR (branch refactor/tree-sitter-instrumentation).

Static Badge

Comment on lines +790 to +806
if node.type == "lambda_expression":
in_lambda = True

if node.type == "method_invocation":
name_node = node.child_by_field_name("name")
if name_node and self.get_node_text(name_node, source_bytes) == func_name:
results.append(
MethodCallInfo(
node=node,
full_text=self.get_node_text(node, source_bytes),
in_lambda=in_lambda,
)
)

for child in node.children:
self._walk_for_invocations(child, source_bytes, func_name, results, in_lambda)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️Codeflash found 12% (0.12x) speedup for JavaAnalyzer.find_method_invocations in codeflash/languages/java/parser.py

⏱️ Runtime : 3.39 milliseconds 3.03 milliseconds (best of 14 runs)

📝 Explanation and details

This optimization achieves an 11% runtime improvement by applying two complementary performance enhancements to the AST traversal logic:

Key Optimizations

1. Iterative Stack-Based Traversal (Eliminates Recursion Overhead)

The original code uses recursive calls to _walk_for_invocations, incurring Python function call overhead (~28.7% of time spent on recursion in the line profiler). The optimized version replaces this with an explicit stack-based iteration:

stack: list[tuple[Node, bool]] = [(node, in_lambda)]
while stack:
    current, current_in_lambda = stack.pop()
    # ... process node ...
    for child in reversed(children):
        stack.append((child, current_in_lambda))

This eliminates the cost of recursive function calls while maintaining identical traversal order (children are pushed in reverse to preserve source order when popping).

2. Byte-Level Comparison (Avoids Redundant UTF-8 Decoding)

The original code decodes every candidate name node to UTF-8 via get_node_text() for comparison (~24% of time in line profiler). The optimized version pre-encodes func_name once to bytes and performs direct byte slice comparison:

name_bytes = func_name.encode("utf8")  # Once per invocation
# Later:
if src[name_node.start_byte : name_node.end_byte] == name_bytes:

This avoids thousands of redundant UTF-8 decode operations for non-matching nodes, which is particularly effective given that most nodes won't match the target function name.

Performance Profile

The annotated tests show the optimization excels at scale:

  • Small inputs (single invocations): 13-44% slower due to stack/encoding setup overhead
  • Large inputs (1000+ invocations): 4-14% faster as the savings from reduced decoding and recursion overhead compound

The acceptance based on runtime reflects that real-world usage likely involves analyzing larger AST subtrees where these optimizations provide significant net benefit. The method would typically be called on method bodies containing many statements and expressions, making the large-scale performance characteristics more relevant than toy examples.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 24 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Click to see Generated Regression Tests
import importlib  # to import the module under test so we can inject MethodCallInfo
import types  # to create simple namespace objects to act like tree-sitter nodes
from dataclasses import \
    dataclass  # to create a lightweight container for MethodCallInfo

import pytest  # used for our unit tests
from codeflash.languages.java.parser import JavaAnalyzer

# import the module and class under test
parser_module = importlib.import_module("codeflash.languages.java.parser")
JavaAnalyzer = parser_module.JavaAnalyzer  # class we will instantiate and test

# Helper to build lightweight "Node-like" objects without defining a new class.
# We use types.SimpleNamespace instances and attach the attributes and callable expected by the analyzer.
def make_node(node_type: str, start: int = 0, end: int = 0, name_node: object | None = None, children: list | None = None):
    """
    Create a lightweight node-like object with the attributes used by JavaAnalyzer:
      - .type : string
      - .start_byte and .end_byte : ints used by get_node_text
      - .children : list of child nodes
      - .child_by_field_name(name) : callable returning the named child or None
    """
    if children is None:
        children = []

    # Map the 'name' field to the provided name_node if present.
    field_map = {"name": name_node} if name_node is not None else {}

    # Define a callable that returns the mapped field or None.
    def child_by_field_name(name):
        return field_map.get(name)

    # Build the simple namespace node with the required attributes.
    node = types.SimpleNamespace(
        type=node_type,
        start_byte=start,
        end_byte=end,
        children=children,
        child_by_field_name=child_by_field_name,
    )
    return node

def test_single_method_invocation_matches_name():
    # Create source bytes for slicing by start/end bytes.
    # "doThing()" is the method invocation; name "doThing" sits within it.
    source = b"doThing();"
    # Create a name node that covers the substring "doThing" (bytes 0..7)
    name_node = types.SimpleNamespace(type="identifier", start_byte=0, end_byte=7, children=[], child_by_field_name=lambda n: None)
    # Create a method_invocation node that covers the whole call "doThing()"
    invocation_node = make_node("method_invocation", start=0, end=9, name_node=name_node, children=[])
    # Top-level body node that contains the invocation
    body = make_node("block", start=0, end=9, children=[invocation_node])

    analyzer = JavaAnalyzer()  # create a real analyzer instance
    # Call find_method_invocations looking for "doThing"
    codeflash_output = analyzer.find_method_invocations(body, source, "doThing"); results = codeflash_output # 7.35μs -> 8.43μs (12.7% slower)
    info = results[0]

def test_multiple_invocations_preserve_source_order():
    # Build a source with two invocations in order.
    source = b"first();second();"
    # Create name nodes for "first" and "second"
    name1 = types.SimpleNamespace(type="identifier", start_byte=0, end_byte=5, children=[], child_by_field_name=lambda n: None)
    inv1 = make_node("method_invocation", start=0, end=7, name_node=name1, children=[])
    name2 = types.SimpleNamespace(type="identifier", start_byte=7, end_byte=13, children=[], child_by_field_name=lambda n: None)
    inv2 = make_node("method_invocation", start=7, end=15, name_node=name2, children=[])
    # Body contains both invocations in order
    body = make_node("block", start=0, end=15, children=[inv1, inv2])

    analyzer = JavaAnalyzer()
    codeflash_output = analyzer.find_method_invocations(body, source, "second"); results = codeflash_output # 3.52μs -> 4.80μs (26.7% slower)

def test_non_matching_names_are_ignored():
    # Create a method_invocation whose name does not match the search term.
    source = b"other();"
    name_node = types.SimpleNamespace(type="identifier", start_byte=0, end_byte=5, children=[], child_by_field_name=lambda n: None)
    invocation = make_node("method_invocation", start=0, end=7, name_node=name_node, children=[])
    body = make_node("block", start=0, end=7, children=[invocation])

    analyzer = JavaAnalyzer()
    codeflash_output = analyzer.find_method_invocations(body, source, "targetName"); results = codeflash_output # 2.46μs -> 3.49μs (29.3% slower)

def test_method_invocation_without_name_node_is_ignored():
    # If a method_invocation exists but child_by_field_name("name") returns None, it should be ignored.
    source = b"unknown();"
    # Create invocation without a name child
    invocation = make_node("method_invocation", start=0, end=9, name_node=None, children=[])
    body = make_node("block", start=0, end=9, children=[invocation])

    analyzer = JavaAnalyzer()
    codeflash_output = analyzer.find_method_invocations(body, source, "unknown"); results = codeflash_output # 1.67μs -> 2.77μs (39.7% slower)

def test_invocation_inside_lambda_sets_in_lambda_flag():
    # Build a structure where a lambda_expression node contains a method_invocation.
    source = b"(() -> target());"
    # name node covering "target"
    name_node = types.SimpleNamespace(type="identifier", start_byte=6, end_byte=12, children=[], child_by_field_name=lambda n: None)
    invocation = make_node("method_invocation", start=6, end=14, name_node=name_node, children=[])
    # The lambda node contains the invocation as a child
    lambda_node = make_node("lambda_expression", start=0, end=14, children=[invocation])
    # Body is the lambda node itself (starting search here should set in_lambda True for its subtree)
    body = lambda_node

    analyzer = JavaAnalyzer()
    codeflash_output = analyzer.find_method_invocations(body, source, "target"); results = codeflash_output # 2.63μs -> 3.47μs (24.0% slower)

def test_empty_body_returns_empty_list():
    # An empty block (no children) should return an empty list.
    source = b""
    body = make_node("block", start=0, end=0, children=[])

    analyzer = JavaAnalyzer()
    codeflash_output = analyzer.find_method_invocations(body, source, "anything"); results = codeflash_output # 972ns -> 1.73μs (43.9% slower)

def test_empty_function_name_matches_empty_name_node():
    # If searching for an empty func_name, and a name node has an empty span, it should match.
    source = b""  # empty source
    # Create a name node with start==end==0 so its text decodes to empty string
    name_node = types.SimpleNamespace(type="identifier", start_byte=0, end_byte=0, children=[], child_by_field_name=lambda n: None)
    invocation = make_node("method_invocation", start=0, end=0, name_node=name_node, children=[])
    body = make_node("block", start=0, end=0, children=[invocation])

    analyzer = JavaAnalyzer()
    codeflash_output = analyzer.find_method_invocations(body, source, ""); results = codeflash_output # 3.58μs -> 4.81μs (25.6% slower)

def test_passing_none_for_body_raises_attribute_error():
    # Passing None for body_node should raise an AttributeError when the implementation tries to access .type
    analyzer = JavaAnalyzer()
    with pytest.raises(AttributeError):
        analyzer.find_method_invocations(None, b"", "anything") # 3.68μs -> 4.63μs (20.6% slower)

def test_special_characters_in_source_and_name():
    # Ensure UTF-8 content is handled correctly by get_node_text
    # Include a multibyte UTF-8 character (e.g., 'λ') in the source.
    # We still use bytes for source_bytes.
    source = "callλ();".encode("utf8")
    # find the byte range for "callλ" - approximate by encoding prefix lengths
    # "call" (4 bytes) + "λ" (2 bytes in UTF-8) => total 6 bytes
    name_node = types.SimpleNamespace(type="identifier", start_byte=0, end_byte=6, children=[], child_by_field_name=lambda n: None)
    invocation = make_node("method_invocation", start=0, end=8, name_node=name_node, children=[])
    body = make_node("block", start=0, end=8, children=[invocation])

    analyzer = JavaAnalyzer()
    # Search using the exact decoded string (must match get_node_text decoding)
    codeflash_output = analyzer.find_method_invocations(body, source, "callλ"); results = codeflash_output # 4.71μs -> 5.83μs (19.2% slower)

def test_large_scale_many_invocations_performance_and_correctness():
    # Construct a body node with 1000 children, alternating matching and non-matching invocations.
    # This validates both performance at a modest scale and correctness of counting/order.
    count = 1000
    source_parts = []
    children = []
    # We'll build source bytes and corresponding nodes; each invocation occupies a distinct byte range.
    byte_offset = 0
    for i in range(count):
        if i % 2 == 0:
            # matching invocation "matchX();"
            name_text = f"match{i}"
            full_text = f"{name_text}();"
            name_bytes = name_text.encode("utf8")
            full_bytes = full_text.encode("utf8")
            # create nodes with appropriate byte offsets
            name_node = types.SimpleNamespace(type="identifier", start_byte=byte_offset, end_byte=byte_offset + len(name_bytes), children=[], child_by_field_name=lambda n: None)
            inv_node = make_node("method_invocation", start=byte_offset, end=byte_offset + len(full_bytes), name_node=name_node, children=[])
            children.append(inv_node)
            source_parts.append(full_bytes)
            byte_offset += len(full_bytes)
        else:
            # non-matching invocation "otherX();"
            name_text = f"other{i}"
            full_text = f"{name_text}();"
            name_bytes = name_text.encode("utf8")
            full_bytes = full_text.encode("utf8")
            name_node = types.SimpleNamespace(type="identifier", start_byte=byte_offset, end_byte=byte_offset + len(name_bytes), children=[], child_by_field_name=lambda n: None)
            inv_node = make_node("method_invocation", start=byte_offset, end=byte_offset + len(full_bytes), name_node=name_node, children=[])
            children.append(inv_node)
            source_parts.append(full_bytes)
            byte_offset += len(full_bytes)

    # Join all bytes to form the source
    source = b"".join(source_parts)
    body = make_node("block", start=0, end=len(source), children=children)

    analyzer = JavaAnalyzer()
    # Search for "match0" (only even indices contain matches, but we search for the generic "match" prefix
    # To validate counting, we'll search for "match0" and separately "match" to demonstrate both behaviors.
    codeflash_output = analyzer.find_method_invocations(body, source, "match0"); results_match0 = codeflash_output # 479μs -> 463μs (3.61% faster)

    # Now search for "match" which should match every even invocation name starting with "match"
    # Our implementation matches full identifier equality; since we built unique names like "match0","match2",...
    # searching for "match" should yield zero results because there is no exact "match" identifier.
    codeflash_output = analyzer.find_method_invocations(body, source, "match"); results_match = codeflash_output # 481μs -> 425μs (13.1% faster)

    # Finally, confirm that searching for any of the even names returns the correct count of 1 each.
    for idx in range(0, 10, 2):  # test first 5 matches to keep test small and deterministic
        codeflash_output = analyzer.find_method_invocations(body, source, f"match{idx}"); results = codeflash_output # 2.40ms -> 2.11ms (13.8% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To test or edit this optimization locally git merge codeflash/optimize-pr1473-2026-02-13T00.37.31

Click to see suggested changes
Suggested change
if node.type == "lambda_expression":
in_lambda = True
if node.type == "method_invocation":
name_node = node.child_by_field_name("name")
if name_node and self.get_node_text(name_node, source_bytes) == func_name:
results.append(
MethodCallInfo(
node=node,
full_text=self.get_node_text(node, source_bytes),
in_lambda=in_lambda,
)
)
for child in node.children:
self._walk_for_invocations(child, source_bytes, func_name, results, in_lambda)
# Pre-encode func_name to bytes to avoid decoding many short name nodes.
name_bytes = func_name.encode("utf8")
src = source_bytes
# Use an explicit stack of (node, in_lambda_flag) to avoid recursion overhead.
stack: list[tuple[Node, bool]] = [(node, in_lambda)]
while stack:
current, current_in_lambda = stack.pop()
# If this node is a lambda, mark the flag for its subtree.
if current.type == "lambda_expression":
current_in_lambda = True
# Check for method_invocation nodes and compare the raw bytes of the name node.
if current.type == "method_invocation":
name_node = current.child_by_field_name("name")
if name_node:
if src[name_node.start_byte : name_node.end_byte] == name_bytes:
results.append(
MethodCallInfo(
node=current,
full_text=self.get_node_text(current, source_bytes),
in_lambda=current_in_lambda,
)
)
# Push children onto the stack in reverse order to preserve source order.
children = current.children
if children:
for child in reversed(children):
stack.append((child, current_in_lambda))

Static Badge

…xtraction

Previously, get_code_optimization_context_for_language() would raise a
hard ValueError when the extracted code context exceeded the 16,000
token limit, causing 93% of Java functions in large projects to fail
optimization. This was because Java's helper traversal (max_depth=2)
pulls in transitive dependencies, and type skeleton wrapping adds all
class fields and constructors.

This commit adds a 4-stage progressive fallback strategy:
1. Full context (all helpers, Javadoc intact)
2. Remove cross-file helpers (keep same-file helpers only)
3. Strip Javadoc comments from all code
4. Remove all helpers (target code only)

Each stage is tried in order until the token limit is satisfied, with
debug logging when a fallback is used. The same fallback applies
independently to both optim and testgen token limits.

Also extracts the code string building logic into a reusable
_build_code_strings_for_language() helper and adds a _strip_javadoc_comments()
utility for removing /** ... */ blocks while preserving other comments.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Comment on lines +216 to +220
import re

This function supports multi-file context extraction, grouping helpers by file
and creating proper CodeStringsMarkdown with file paths for multi-file replacement.
return re.sub(r"/\*\*.*?\*/\s*", "", source, flags=re.DOTALL)


Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️Codeflash found 24% (0.24x) speedup for _strip_javadoc_comments in codeflash/context/code_context_extractor.py

⏱️ Runtime : 1.78 milliseconds 1.43 milliseconds (best of 241 runs)

📝 Explanation and details

The optimized code achieves a 24% speedup by replacing regex-based pattern matching with a manual string scanning approach using Python's built-in str.find() method.

Key Optimization:
The original implementation uses re.sub(r"/\*\*.*?\*/\s*", "", source, flags=re.DOTALL) which incurs significant overhead from:

  1. Regex compilation and pattern matching engine
  2. The re.DOTALL flag enabling . to match newlines
  3. Non-greedy matching (.*?) which requires backtracking

The optimized version eliminates this overhead by:

  1. Using str.find("/**") to locate Javadoc starts - a simple C-level string search
  2. Using str.find("*/", idx + 3) to find the closing delimiter
  3. Manually scanning whitespace with s[j].isspace() instead of regex \s*
  4. Building the result with string slicing and "".join(parts)

Performance Characteristics:

  • Small inputs (single/few Javadocs): 50-100% faster due to avoiding regex overhead entirely
  • Medium inputs (dozens of Javadocs): 30-70% faster, benefiting from simpler string operations
  • Large Javadocs (thousands of characters): Up to 796% faster on very large single comments, as str.find() is more efficient than regex backtracking
  • Many Javadocs (hundreds/thousands): Shows some regression (30-50% slower) because the manual loop has more overhead per iteration than regex's optimized matching, but the overall 24% improvement indicates this case is less common in real workloads

Trade-offs:
The optimization performs exceptionally well when Javadoc comments contain large amounts of text, or when there are relatively few comments to process. The slight regression on inputs with hundreds of consecutive small Javadocs is offset by dramatic gains on large comments and typical real-world Java source files with moderate documentation density.

The line profiler shows the optimized version spends most time in str.find() calls (32.1% combined) and whitespace scanning (28.7%), which are still faster than the regex engine's 97.3% time consumption in the original.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 54 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Click to see Generated Regression Tests
import pytest  # used for our unit tests
from codeflash.context.code_context_extractor import _strip_javadoc_comments

def test_remove_simple_javadoc():
    # Simple Javadoc comment before a class should be removed entirely along with following whitespace/newline.
    source = "/** Simple comment */\npublic class A {}"
    # After stripping, the Javadoc and its trailing newline/space should be gone.
    codeflash_output = _strip_javadoc_comments(source); result = codeflash_output # 4.00μs -> 2.47μs (61.8% faster)

def test_preserve_single_line_and_regular_block_comments():
    # Ensure that single-line (//) comments and regular block comments (/* ... */) are preserved,
    # while Javadoc (/** ... */) is removed.
    source = (
        "int x = 0; // single-line comment\n"
        "/* regular block comment */\n"
        "/** Javadoc to remove */\n"
        "int y = 1;"
    )
    codeflash_output = _strip_javadoc_comments(source); result = codeflash_output # 3.86μs -> 2.39μs (61.1% faster)
    # Regular comments remain; Javadoc is removed along with the newline trailing it.
    expected = "int x = 0; // single-line comment\n/* regular block comment */\nint y = 1;"

def test_remove_multiline_javadoc_with_leading_stars():
    # Javadoc often has stars on each line; these should be removed entirely.
    source = "/**\n * This is a Javadoc\n * spanning multiple lines\n */\nvoid m(){}"
    codeflash_output = _strip_javadoc_comments(source); result = codeflash_output # 3.85μs -> 2.24μs (72.0% faster)

def test_javadoc_with_comment_like_sequences_inside():
    # Javadoc might contain sequences that look like other comments; they must be swallowed as part of the Javadoc.
    source = "/** Contains /* and // inside the Javadoc */\nint z;"
    codeflash_output = _strip_javadoc_comments(source); result = codeflash_output # 3.71μs -> 2.30μs (61.3% faster)

def test_no_javadoc_returns_same_string():
    # If there is no Javadoc comment, the source should remain exactly the same.
    source = "/* not a javadoc */\n// just a line\nclass C {}"
    codeflash_output = _strip_javadoc_comments(source); result = codeflash_output # 2.66μs -> 1.28μs (107% faster)

def test_empty_string_returns_empty():
    # Empty input should return empty output.
    codeflash_output = _strip_javadoc_comments("") # 2.56μs -> 1.25μs (104% faster)

def test_none_raises_type_error():
    # Passing a non-string (None) should raise a TypeError due to re.sub expecting a string.
    with pytest.raises(TypeError):
        _strip_javadoc_comments(None) # 5.35μs -> 2.34μs (129% faster)

def test_javadoc_at_end_of_file_removes_trailing_whitespace():
    # Javadoc at EOF should be removed and trailing whitespace after it trimmed.
    source = "class C{}\n/** End comment */   "
    codeflash_output = _strip_javadoc_comments(source); result = codeflash_output # 3.54μs -> 2.58μs (37.2% faster)

def test_javadoc_adjacent_to_code_no_unwanted_space_removal():
    # When Javadoc is directly adjacent to code (no intervening whitespace), removal should not delete code tokens.
    source = "int a = 0;/**doc*/int b = 1;"
    codeflash_output = _strip_javadoc_comments(source); result = codeflash_output # 3.58μs -> 2.29μs (56.3% faster)

def test_large_scale_many_javadocs():
    # Build a large source string containing 1000 Javadoc comments interleaved with code lines.
    n = 1000  # scale up to 1000 as required
    # Each block: code line, then a Javadoc comment, then a newline.
    source = "".join(f"code{i};\n/** Comment {i} */\n" for i in range(n))
    # Expected result: each code line remains followed by its newline; Javadocs and the newline immediately after them are removed.
    expected = "".join(f"code{i};\n" for i in range(n))
    codeflash_output = _strip_javadoc_comments(source); result = codeflash_output # 235μs -> 367μs (36.0% slower)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import re

# imports
import pytest
from codeflash.context.code_context_extractor import _strip_javadoc_comments

def test_single_javadoc_comment_simple():
    """Test removal of a simple single-line Javadoc comment."""
    source = "/** This is a Javadoc comment */ public class Foo {}"
    expected = "public class Foo {}"
    codeflash_output = _strip_javadoc_comments(source) # 4.21μs -> 2.53μs (66.1% faster)

def test_multiline_javadoc_comment():
    """Test removal of a multi-line Javadoc comment."""
    source = """/**
 * This is a multi-line Javadoc comment
 * with multiple lines
 */ public class Foo {}"""
    expected = "public class Foo {}"
    codeflash_output = _strip_javadoc_comments(source) # 4.20μs -> 2.52μs (66.8% faster)

def test_javadoc_with_trailing_whitespace():
    """Test that trailing whitespace after Javadoc is removed."""
    source = "/** Comment */   public void method() {}"
    expected = "public void method() {}"
    codeflash_output = _strip_javadoc_comments(source) # 3.68μs -> 2.47μs (49.1% faster)

def test_preserves_single_line_comments():
    """Test that single-line comments (//) are preserved."""
    source = "// This is a comment\npublic class Foo {}"
    expected = source
    codeflash_output = _strip_javadoc_comments(source) # 2.78μs -> 1.30μs (114% faster)

def test_preserves_block_comments():
    """Test that regular block comments (/* ... */) are preserved."""
    source = "/* This is a block comment */ public class Foo {}"
    expected = source
    codeflash_output = _strip_javadoc_comments(source) # 2.72μs -> 1.30μs (110% faster)

def test_multiple_javadoc_comments():
    """Test removal of multiple Javadoc comments in sequence."""
    source = "/** Comment 1 */ public class Foo { /** Comment 2 */ public void bar() {} }"
    expected = "public class Foo { public void bar() {} }"
    codeflash_output = _strip_javadoc_comments(source) # 4.02μs -> 3.10μs (29.7% faster)

def test_javadoc_before_class():
    """Test typical Javadoc comment before class declaration."""
    source = """/**
 * Main application class
 */
public class Application {}"""
    expected = "\npublic class Application {}"
    codeflash_output = _strip_javadoc_comments(source) # 3.67μs -> 2.37μs (54.7% faster)

def test_javadoc_before_method():
    """Test typical Javadoc comment before method declaration."""
    source = """/**
 * Processes data
 * @param data the input data
 * @return processed result
 */
public String process(String data) {}"""
    expected = "\npublic String process(String data) {}"
    codeflash_output = _strip_javadoc_comments(source) # 4.25μs -> 2.39μs (77.7% faster)

def test_code_between_javadoc_comments():
    """Test that code between multiple Javadoc comments is preserved."""
    source = "/** Doc 1 */ int x = 5; /** Doc 2 */ int y = 10;"
    expected = " int x = 5;  int y = 10;"
    codeflash_output = _strip_javadoc_comments(source) # 3.80μs -> 2.88μs (32.2% faster)

def test_javadoc_with_special_characters():
    """Test Javadoc comments containing special characters."""
    source = "/** Comment with @param, @return, and <html> tags */ class Foo {}"
    expected = " class Foo {}"
    codeflash_output = _strip_javadoc_comments(source) # 3.82μs -> 2.27μs (67.9% faster)

def test_empty_string():
    """Test with empty string input."""
    source = ""
    expected = ""
    codeflash_output = _strip_javadoc_comments(source) # 2.60μs -> 1.25μs (108% faster)

def test_only_javadoc_comment():
    """Test with only a Javadoc comment and nothing else."""
    source = "/** Just a comment */"
    expected = ""
    codeflash_output = _strip_javadoc_comments(source) # 3.31μs -> 1.93μs (71.6% faster)

def test_javadoc_with_newlines():
    """Test Javadoc comment spanning many lines with various newline styles."""
    source = "/**\n\n\n* Multiple newlines\n\n*/"
    expected = ""
    codeflash_output = _strip_javadoc_comments(source) # 3.37μs -> 1.89μs (78.0% faster)

def test_nested_asterisks_in_javadoc():
    """Test Javadoc containing nested asterisks."""
    source = "/** Comment with ** asterisks ** inside */ code"
    expected = " code"
    codeflash_output = _strip_javadoc_comments(source) # 3.80μs -> 2.31μs (64.6% faster)

def test_javadoc_with_url():
    """Test Javadoc containing URL with slashes and asterisks."""
    source = "/** See https://example.com for details */ class Foo {}"
    expected = " class Foo {}"
    codeflash_output = _strip_javadoc_comments(source) # 3.68μs -> 2.33μs (58.1% faster)

def test_incomplete_javadoc_start():
    """Test that incomplete Javadoc start (/** without closing */) is not matched."""
    source = "/** Incomplete comment\nclass Foo {}"
    expected = source
    codeflash_output = _strip_javadoc_comments(source) # 3.48μs -> 1.74μs (99.6% faster)

def test_incomplete_javadoc_end():
    """Test that incomplete Javadoc end (*/ without /** start) is not matched."""
    source = "class Foo {} Incomplete end */ of comment"
    expected = source
    codeflash_output = _strip_javadoc_comments(source) # 2.67μs -> 1.32μs (102% faster)

def test_javadoc_only_opening():
    """Test with only /** opening with no closing."""
    source = "/** opening only"
    expected = source
    codeflash_output = _strip_javadoc_comments(source) # 3.23μs -> 1.67μs (93.7% faster)

def test_single_slash_star_not_javadoc():
    """Test that single /* is not treated as Javadoc."""
    source = "/* This is not javadoc */ code"
    expected = source
    codeflash_output = _strip_javadoc_comments(source) # 2.71μs -> 1.23μs (120% faster)

def test_javadoc_with_tabs():
    """Test Javadoc comment with tab characters."""
    source = "/**\t\t* Comment with tabs\t*/ code"
    expected = " code"
    codeflash_output = _strip_javadoc_comments(source) # 3.61μs -> 2.34μs (54.0% faster)

def test_javadoc_adjacent_to_code_no_space():
    """Test Javadoc immediately adjacent to code without spaces."""
    source = "/**Doc*/code"
    expected = "code"
    codeflash_output = _strip_javadoc_comments(source) # 3.27μs -> 2.16μs (51.4% faster)

def test_multiple_consecutive_javadoc():
    """Test multiple consecutive Javadoc comments."""
    source = "/** Doc 1 */ /** Doc 2 */ /** Doc 3 */ code"
    expected = "  code"
    codeflash_output = _strip_javadoc_comments(source) # 3.88μs -> 3.17μs (22.3% faster)

def test_javadoc_with_closing_brace_inside():
    """Test Javadoc comment containing closing braces."""
    source = "/** Comment with } and { braces */ class Foo {}"
    expected = " class Foo {}"
    codeflash_output = _strip_javadoc_comments(source) # 3.55μs -> 2.23μs (59.4% faster)

def test_javadoc_with_regex_characters():
    """Test Javadoc containing regex special characters."""
    source = "/** Comment with [abc], (x|y), and . characters */ code"
    expected = " code"
    codeflash_output = _strip_javadoc_comments(source) # 3.69μs -> 2.24μs (64.8% faster)

def test_single_character_javadoc():
    """Test Javadoc comment with only a single character."""
    source = "/** a */ code"
    expected = " code"
    codeflash_output = _strip_javadoc_comments(source) # 3.27μs -> 2.15μs (52.0% faster)

def test_javadoc_with_only_whitespace():
    """Test Javadoc comment containing only whitespace."""
    source = "/**     */ code"
    expected = " code"
    codeflash_output = _strip_javadoc_comments(source) # 3.34μs -> 2.11μs (58.8% faster)

def test_javadoc_with_newline_and_spaces():
    """Test Javadoc followed by newline and spaces."""
    source = "/** Doc */\n   \npublic class Foo {}"
    expected = "\n   \npublic class Foo {}"
    codeflash_output = _strip_javadoc_comments(source) # 3.43μs -> 2.44μs (40.4% faster)

def test_line_comment_looks_like_javadoc_start():
    """Test line comment containing /** pattern (should be preserved)."""
    source = "// This comment has /** in it\ncode"
    expected = source
    codeflash_output = _strip_javadoc_comments(source) # 3.42μs -> 1.85μs (84.4% faster)

def test_block_comment_with_javadoc_inside():
    """Test regular block comment that contains /** (should be preserved)."""
    source = "/* This looks like /** but is not */ code"
    expected = source
    codeflash_output = _strip_javadoc_comments(source) # 3.56μs -> 2.36μs (50.6% faster)

def test_string_with_javadoc_pattern():
    """Test that Javadoc pattern inside a string is still removed (regex operates on raw text)."""
    source = 'String s = "/** not in string */"; code'
    expected = 'String s = "; code'
    codeflash_output = _strip_javadoc_comments(source) # 3.67μs -> 2.30μs (59.8% faster)

def test_very_long_javadoc():
    """Test removal of a very long Javadoc comment."""
    long_comment = "/** " + "x" * 1000 + " */"
    source = long_comment + " code"
    expected = " code"
    codeflash_output = _strip_javadoc_comments(source) # 11.5μs -> 3.41μs (237% faster)

def test_javadoc_with_asterisk_at_start_of_lines():
    """Test typical Javadoc with asterisks at the start of each line."""
    source = """/**
 * Line 1
 * Line 2
 * Line 3
 */
code"""
    expected = "\ncode"
    codeflash_output = _strip_javadoc_comments(source) # 3.75μs -> 2.30μs (63.0% faster)

def test_javadoc_immediately_followed_by_newline():
    """Test Javadoc comment immediately followed by newline (whitespace after */)."""
    source = "/** Comment */\ncode"
    expected = "\ncode"
    codeflash_output = _strip_javadoc_comments(source) # 3.37μs -> 2.13μs (58.3% faster)

def test_unicode_in_javadoc():
    """Test Javadoc containing Unicode characters."""
    source = "/** Comment with Unicode: \u00e9\u00e0\u00fc */ code"
    expected = " code"
    codeflash_output = _strip_javadoc_comments(source) # 3.70μs -> 2.45μs (51.1% faster)

def test_empty_javadoc():
    """Test completely empty Javadoc comment."""
    source = "/**/ code"
    expected = " code"
    codeflash_output = _strip_javadoc_comments(source) # 3.25μs -> 1.64μs (99.0% faster)

def test_whitespace_between_opening_and_closing():
    """Test Javadoc with only whitespace between opening and closing."""
    source = "/**   \n  \t  */ code"
    expected = " code"
    codeflash_output = _strip_javadoc_comments(source) # 3.39μs -> 2.24μs (51.4% faster)

def test_many_javadoc_comments():
    """Test removal of 100 Javadoc comments in a single source string."""
    # Build a source string with 100 Javadoc comments interspersed with code
    parts = []
    for i in range(100):
        parts.append(f"/** Comment {i} */ code_{i};")
    source = "\n".join(parts)
    
    # Verify that all Javadoc comments are removed
    codeflash_output = _strip_javadoc_comments(source); result = codeflash_output # 27.0μs -> 40.8μs (33.9% slower)

def test_large_javadoc_comment():
    """Test removal of a Javadoc comment with a very large amount of content."""
    # Create a large Javadoc comment with 10000 lines of text
    large_content = "\n".join([f" * Line {i}" for i in range(1000)])
    source = f"/**\n{large_content}\n */ code"
    
    expected = " code"
    codeflash_output = _strip_javadoc_comments(source); result = codeflash_output # 102μs -> 13.0μs (685% faster)

def test_alternating_comments_and_code():
    """Test with 500 alternating Javadoc comments and code blocks."""
    parts = []
    for i in range(500):
        parts.append(f"/** Doc {i} */")
        parts.append(f"int x{i} = {i};")
    source = "\n".join(parts)
    
    codeflash_output = _strip_javadoc_comments(source); result = codeflash_output # 113μs -> 197μs (42.6% slower)

def test_nested_structure_with_many_javadocs():
    """Test removal of Javadoc comments from a complex nested structure."""
    source = ""
    for class_idx in range(10):
        source += f"/** Class {class_idx} */\nclass Class{class_idx} {{\n"
        for method_idx in range(20):
            source += f"  /** Method {method_idx} */\n"
            source += f"  public void method{method_idx}() {{}}\n"
        source += "}\n"
    
    codeflash_output = _strip_javadoc_comments(source); result = codeflash_output # 52.5μs -> 103μs (49.2% slower)

def test_performance_with_many_false_positives():
    """Test performance when there are many /* */ comments (not javadoc) to skip over."""
    # Create source with many regular block comments that shouldn't be removed
    parts = []
    for i in range(200):
        parts.append(f"/* Regular comment {i} */")
        parts.append(f"code_{i};")
    source = "\n".join(parts)
    
    codeflash_output = _strip_javadoc_comments(source); result = codeflash_output # 6.14μs -> 6.43μs (4.50% slower)

def test_single_long_line_with_many_javadocs():
    """Test processing of a very long single line with many Javadoc comments."""
    # Create a single very long line with multiple Javadoc comments
    parts = []
    for i in range(300):
        parts.append(f"/** Doc{i} */ x{i};")
    source = " ".join(parts)
    
    codeflash_output = _strip_javadoc_comments(source); result = codeflash_output # 60.2μs -> 112μs (46.7% slower)

def test_large_string_mostly_javadoc():
    """Test processing a large string that is mostly Javadoc comments."""
    # Create a source where most content is Javadoc
    large_javadoc = "/** " + ("x" * 50000) + " */"
    source = large_javadoc + " actual_code;"
    
    codeflash_output = _strip_javadoc_comments(source); result = codeflash_output # 412μs -> 46.0μs (796% faster)

def test_1000_javadoc_removals():
    """Test removal efficiency with 1000 Javadoc comments."""
    # Generate 1000 javadoc comments with varying content lengths
    parts = []
    for i in range(1000):
        content_length = (i % 100) + 1  # Vary content length from 1 to 100 chars
        javadoc = f"/** {'a' * content_length} */"
        parts.append(javadoc)
        parts.append(f"code{i};")
    source = "\n".join(parts)
    
    codeflash_output = _strip_javadoc_comments(source); result = codeflash_output # 604μs -> 446μs (35.5% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To test or edit this optimization locally git merge codeflash/optimize-pr1473-2026-02-13T00.49.56

Click to see suggested changes
Suggested change
import re
This function supports multi-file context extraction, grouping helpers by file
and creating proper CodeStringsMarkdown with file paths for multi-file replacement.
return re.sub(r"/\*\*.*?\*/\s*", "", source, flags=re.DOTALL)
# Manual scan to remove "/** ... */" occurrences and any following whitespace,
# avoiding regex overhead while preserving the original behavior.
s = source
n = len(s)
i = 0
parts: list[str] = []
while True:
idx = s.find("/**", i)
if idx == -1:
parts.append(s[i:])
break
parts.append(s[i:idx])
# Find the closing '*/' that comes after the initial '/**'.
# (Start searching at idx + 3 so we don't reuse the second '*' of '/**'.)
end = s.find("*/", idx + 3)
if end == -1:
# No closing delimiter found; preserve the rest unchanged.
parts.append(s[idx:])
break
j = end + 2
# Skip any whitespace characters following the closing '*/' (equivalent to \s*).
while j < n and s[j].isspace():
j += 1
i = j
return "".join(parts)

Static Badge

@codeflash-ai
Copy link
Copy Markdown
Contributor

codeflash-ai Bot commented Feb 13, 2026

⚡️ Codeflash found optimizations for this PR

📄 28% (0.28x) speedup for _build_code_strings_for_language in codeflash/context/code_context_extractor.py

⏱️ Runtime : 51.7 milliseconds 40.4 milliseconds (best of 74 runs)

A dependent PR with the suggested changes has been created. Please review:

If you approve, it will be merged into this PR (branch refactor/tree-sitter-instrumentation).

Static Badge

@HeshamHM28 HeshamHM28 closed this Feb 16, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant