diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index 61de73c32..72c6d51c3 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -208,39 +208,38 @@ def get_code_optimization_context( ) -def get_code_optimization_context_for_language( - function_to_optimize: FunctionToOptimize, - project_root_path: Path, - optim_token_limit: int = OPTIMIZATION_CONTEXT_TOKEN_LIMIT, - testgen_token_limit: int = TESTGEN_CONTEXT_TOKEN_LIMIT, -) -> CodeOptimizationContext: - """Extract code optimization context for non-Python languages. +def _strip_javadoc_comments(source: str) -> str: + """Strip Javadoc (/** ... */) comments from Java source code. - Uses the language support abstraction to extract code context and converts - it to the CodeOptimizationContext format expected by the pipeline. + Preserves single-line comments (//) and regular block comments (/* ... */). + """ + import re - This function supports multi-file context extraction, grouping helpers by file - and creating proper CodeStringsMarkdown with file paths for multi-file replacement. + return re.sub(r"/\*\*.*?\*/\s*", "", source, flags=re.DOTALL) + + +def _build_code_strings_for_language( + code_context, + function_to_optimize: FunctionToOptimize, + project_root_path: Path, + include_cross_file_helpers: bool = True, + strip_javadoc: bool = False, + include_same_file_helpers: bool = True, +) -> tuple[list[CodeString], list[FunctionSource], str]: + """Build CodeString list from a CodeContext with configurable reduction. Args: - function_to_optimize: The function to extract context for. + code_context: CodeContext from language support. + function_to_optimize: The target function. project_root_path: Root of the project. - optim_token_limit: Token limit for optimization context. - testgen_token_limit: Token limit for testgen context. + include_cross_file_helpers: Whether to include helpers from other files. + strip_javadoc: Whether to strip Javadoc comments from all code. + include_same_file_helpers: Whether to include same-file helper methods. Returns: - CodeOptimizationContext with target code and dependencies. + Tuple of (code_strings, helper_function_sources, read_only_context). """ - from codeflash.languages import get_language_support - - # Get language support for this function - language = Language(function_to_optimize.language) - lang_support = get_language_support(language) - - # Extract code context using language support - code_context = lang_support.extract_code_context(function_to_optimize, project_root_path, project_root_path) - # Build imports string if available imports_code = "\n".join(code_context.imports) if code_context.imports else "" @@ -251,82 +250,194 @@ def get_code_optimization_context_for_language( target_relative_path = function_to_optimize.file_path # Group helpers by file path - helpers_by_file: dict[Path, list[HelperFunction]] = defaultdict(list) + helpers_by_file: dict[Path, list] = defaultdict(list) helper_function_sources = [] for helper in code_context.helper_functions: helpers_by_file[helper.file_path].append(helper) # Convert to FunctionSource for pipeline compatibility - helper_function_sources.append( - FunctionSource( - file_path=helper.file_path, - qualified_name=helper.qualified_name, - fully_qualified_name=helper.qualified_name, - only_function_name=helper.name, - source_code=helper.source_code, - jedi_definition=None, - ) + should_include = ( + (helper.file_path == function_to_optimize.file_path and include_same_file_helpers) + or (helper.file_path != function_to_optimize.file_path and include_cross_file_helpers) ) + if should_include: + helper_function_sources.append( + FunctionSource( + file_path=helper.file_path, + qualified_name=helper.qualified_name, + fully_qualified_name=helper.qualified_name, + only_function_name=helper.name, + source_code=helper.source_code, + jedi_definition=None, + ) + ) - # Build read-writable code (target file + same-file helpers + global variables) - read_writable_code_strings = [] + # Build read-writable code (target file + same-file helpers) + code_strings = [] # Combine target code with same-file helpers target_file_code = code_context.target_code - same_file_helpers = helpers_by_file.get(function_to_optimize.file_path, []) - if same_file_helpers: - helper_code = "\n\n".join(h.source_code for h in same_file_helpers) - target_file_code = target_file_code + "\n\n" + helper_code - - # Note: code_context.read_only_context contains type definitions and global variables - # These should be passed as read-only context to the AI, not prepended to the target code - # If prepended to target code, the AI treats them as code to optimize and includes them in output + if include_same_file_helpers: + same_file_helpers = helpers_by_file.get(function_to_optimize.file_path, []) + if same_file_helpers: + helper_code = "\n\n".join(h.source_code for h in same_file_helpers) + target_file_code = target_file_code + "\n\n" + helper_code # Add imports to target file code if imports_code: target_file_code = imports_code + "\n\n" + target_file_code - read_writable_code_strings.append( + if strip_javadoc: + target_file_code = _strip_javadoc_comments(target_file_code) + + code_strings.append( CodeString(code=target_file_code, file_path=target_relative_path, language=function_to_optimize.language) ) # Add helper files (cross-file helpers) - for file_path, file_helpers in helpers_by_file.items(): - if file_path == function_to_optimize.file_path: - continue # Already included in target file + if include_cross_file_helpers: + for file_path, file_helpers in helpers_by_file.items(): + if file_path == function_to_optimize.file_path: + continue # Already included in target file - try: - helper_relative_path = file_path.resolve().relative_to(project_root_path.resolve()) - except ValueError: - helper_relative_path = file_path + try: + helper_relative_path = file_path.resolve().relative_to(project_root_path.resolve()) + except ValueError: + helper_relative_path = file_path + + combined_helper_code = "\n\n".join(h.source_code for h in file_helpers) + if strip_javadoc: + combined_helper_code = _strip_javadoc_comments(combined_helper_code) + + code_strings.append( + CodeString( + code=combined_helper_code, + file_path=helper_relative_path, + language=function_to_optimize.language, + ) + ) - # Combine all helpers from this file - combined_helper_code = "\n\n".join(h.source_code for h in file_helpers) + read_only_context = code_context.read_only_context + if strip_javadoc and read_only_context: + read_only_context = _strip_javadoc_comments(read_only_context) - read_writable_code_strings.append( - CodeString( - code=combined_helper_code, file_path=helper_relative_path, language=function_to_optimize.language - ) + return code_strings, helper_function_sources, read_only_context + + +def get_code_optimization_context_for_language( + function_to_optimize: FunctionToOptimize, + project_root_path: Path, + optim_token_limit: int = OPTIMIZATION_CONTEXT_TOKEN_LIMIT, + testgen_token_limit: int = TESTGEN_CONTEXT_TOKEN_LIMIT, +) -> CodeOptimizationContext: + """Extract code optimization context for non-Python languages. + + Uses the language support abstraction to extract code context and converts + it to the CodeOptimizationContext format expected by the pipeline. + + This function supports multi-file context extraction, grouping helpers by file + and creating proper CodeStringsMarkdown with file paths for multi-file replacement. + + Applies progressive fallback when token limits are exceeded: + 1. Full context (all helpers, Javadoc intact) + 2. Remove cross-file helpers + 3. Strip Javadoc comments + 4. Remove all helpers (target code only) + + Args: + function_to_optimize: The function to extract context for. + project_root_path: Root of the project. + optim_token_limit: Token limit for optimization context. + testgen_token_limit: Token limit for testgen context. + + Returns: + CodeOptimizationContext with target code and dependencies. + + """ + from codeflash.languages import get_language_support + + # Get language support for this function + language = Language(function_to_optimize.language) + lang_support = get_language_support(language) + + # Extract code context using language support + code_context = lang_support.extract_code_context(function_to_optimize, project_root_path, project_root_path) + + # Progressive fallback strategies, ordered from most to least context + fallback_strategies = [ + {"include_cross_file_helpers": True, "strip_javadoc": False, "include_same_file_helpers": True}, + {"include_cross_file_helpers": False, "strip_javadoc": False, "include_same_file_helpers": True}, + {"include_cross_file_helpers": False, "strip_javadoc": True, "include_same_file_helpers": True}, + {"include_cross_file_helpers": False, "strip_javadoc": True, "include_same_file_helpers": False}, + ] + + fallback_descriptions = [ + "full context", + "without cross-file helpers", + "without cross-file helpers and Javadoc", + "target code only (no helpers, no Javadoc)", + ] + + code_strings = None + helper_function_sources = None + read_only_context = None + + for i, strategy in enumerate(fallback_strategies): + code_strings, helper_function_sources, read_only_context = _build_code_strings_for_language( + code_context, function_to_optimize, project_root_path, **strategy ) + read_writable_code = CodeStringsMarkdown( + code_strings=code_strings, language=function_to_optimize.language + ) + read_writable_tokens = encoded_tokens_len(read_writable_code.markdown) + + if read_writable_tokens <= optim_token_limit: + if i > 0: + logger.debug( + "Code context exceeded token limit, using fallback: %s (%d tokens)", + fallback_descriptions[i], + read_writable_tokens, + ) + break + else: + raise ValueError("Read-writable code has exceeded token limit even after removing all helpers and Javadoc") + read_writable_code = CodeStringsMarkdown( - code_strings=read_writable_code_strings, language=function_to_optimize.language + code_strings=code_strings, language=function_to_optimize.language ) - # Build testgen context (same as read_writable for non-Python) + # Build testgen context with its own progressive fallback + # Start from the same strategy level that worked for optim + testgen_code_strings = code_strings + testgen_helpers = helper_function_sources + testgen_context = CodeStringsMarkdown( - code_strings=read_writable_code_strings.copy(), language=function_to_optimize.language + code_strings=testgen_code_strings.copy(), language=function_to_optimize.language ) - - # Check token limits - read_writable_tokens = encoded_tokens_len(read_writable_code.markdown) - if read_writable_tokens > optim_token_limit: - raise ValueError("Read-writable code has exceeded token limit, cannot proceed") - testgen_tokens = encoded_tokens_len(testgen_context.markdown) + if testgen_tokens > testgen_token_limit: - raise ValueError("Testgen code context has exceeded token limit, cannot proceed") + # Try remaining fallback strategies for testgen + for j in range(i + 1, len(fallback_strategies)): + testgen_code_strings, testgen_helpers, read_only_context = _build_code_strings_for_language( + code_context, function_to_optimize, project_root_path, **fallback_strategies[j] + ) + testgen_context = CodeStringsMarkdown( + code_strings=testgen_code_strings.copy(), language=function_to_optimize.language + ) + testgen_tokens = encoded_tokens_len(testgen_context.markdown) + + if testgen_tokens <= testgen_token_limit: + logger.debug( + "Testgen context exceeded token limit, using fallback: %s (%d tokens)", + fallback_descriptions[j], + testgen_tokens, + ) + break + else: + raise ValueError("Testgen code context has exceeded token limit even after removing all helpers and Javadoc") # Generate code hash from all read-writable code code_hash = hashlib.sha256(read_writable_code.flat.encode("utf-8")).hexdigest() @@ -336,7 +447,7 @@ def get_code_optimization_context_for_language( read_writable_code=read_writable_code, # Pass type definitions and globals as read-only context for the AI # This way the AI sees them as context but doesn't include them in optimized output - read_only_context_code=code_context.read_only_context, + read_only_context_code=read_only_context, hashing_code_context=read_writable_code.flat, hashing_code_context_hash=code_hash, helper_functions=helper_function_sources, diff --git a/codeflash/languages/java/instrumentation.py b/codeflash/languages/java/instrumentation.py index 63a97b17b..14303e0f8 100644 --- a/codeflash/languages/java/instrumentation.py +++ b/codeflash/languages/java/instrumentation.py @@ -42,102 +42,6 @@ def _get_function_name(func: Any) -> str: # Pattern to detect primitive array types in assertions _PRIMITIVE_ARRAY_PATTERN = re.compile(r"new\s+(int|long|double|float|short|byte|char|boolean)\s*\[\s*\]") -# Pattern to match @Test annotation exactly (not @TestOnly, @TestFactory, etc.) -_TEST_ANNOTATION_RE = re.compile(r"^@Test(?:\s*\(.*\))?(?:\s.*)?$") - - -def _is_test_annotation(stripped_line: str) -> bool: - """Check if a stripped line is an @Test annotation (not @TestOnly, @TestFactory, etc.). - - Matches: - @Test - @Test(expected = ...) - @Test(timeout = 5000) - Does NOT match: - @TestOnly - @TestFactory - @TestTemplate - """ - return bool(_TEST_ANNOTATION_RE.match(stripped_line)) - - -def _find_balanced_end(text: str, start: int) -> int: - """Find the position after the closing paren that balances the opening paren at start. - - Args: - text: The source text. - start: Index of the opening parenthesis '('. - - Returns: - Index one past the matching closing ')', or -1 if not found. - - """ - if start >= len(text) or text[start] != "(": - return -1 - depth = 1 - pos = start + 1 - in_string = False - string_char = None - in_char = False - while pos < len(text) and depth > 0: - ch = text[pos] - prev = text[pos - 1] if pos > 0 else "" - if ch == "'" and not in_string and prev != "\\": - in_char = not in_char - elif ch == '"' and not in_char and prev != "\\": - if not in_string: - in_string = True - string_char = ch - elif ch == string_char: - in_string = False - string_char = None - elif not in_string and not in_char: - if ch == "(": - depth += 1 - elif ch == ")": - depth -= 1 - pos += 1 - return pos if depth == 0 else -1 - - -def _find_method_calls_balanced(line: str, func_name: str): - """Find method calls to func_name with properly balanced parentheses. - - Handles nested parentheses in arguments correctly, unlike a pure regex approach. - Returns a list of (start, end, full_call) tuples where start/end are positions - in the line and full_call is the matched text (receiver.funcName(args)). - - Args: - line: A single line of Java source code. - func_name: The method name to look for. - - Returns: - List of (start_pos, end_pos, full_call_text) tuples. - - """ - # First find all occurrences of .funcName( in the line using regex - # to locate the method name, then use balanced paren finding for args - prefix_pattern = re.compile( - rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*{re.escape(func_name)}\s*\(" - ) - results = [] - search_start = 0 - while search_start < len(line): - m = prefix_pattern.search(line, search_start) - if not m: - break - # m.end() - 1 is the position of the opening paren - open_paren_pos = m.end() - 1 - close_pos = _find_balanced_end(line, open_paren_pos) - if close_pos == -1: - # Unbalanced parens - skip this match - search_start = m.end() - continue - full_call = line[m.start():close_pos] - results.append((m.start(), close_pos, full_call)) - search_start = close_pos - return results - def _infer_array_cast_type(line: str) -> str | None: """Infer the array cast type needed for assertion methods. @@ -278,12 +182,21 @@ def instrument_existing_test( new_class_name = f"{original_class_name}__perfonlyinstrumented" # Rename all references to the original class name in the source. - # This includes the class declaration, return types, constructor calls, - # variable declarations, etc. We use word-boundary matching to avoid - # replacing substrings of other identifiers. - modified_source = re.sub( - rf"\b{re.escape(original_class_name)}\b", new_class_name, source - ) + # Uses tree-sitter to find identifier/type_identifier AST nodes, + # which correctly excludes matches inside string literals and comments. + if analyzer is None: + from codeflash.languages.java.parser import get_java_analyzer + analyzer = get_java_analyzer() + + refs = analyzer.find_identifier_references(source, original_class_name) + if refs: + source_bytes = source.encode("utf8") + new_name_bytes = new_class_name.encode("utf8") + for start, end in reversed(refs): + source_bytes = source_bytes[:start] + new_name_bytes + source_bytes[end:] + modified_source = source_bytes.decode("utf8") + else: + modified_source = source # Add timing instrumentation to test methods # Use original class name (without suffix) in timing markers for consistency with Python @@ -305,8 +218,11 @@ def instrument_existing_test( def _add_behavior_instrumentation(source: str, class_name: str, func_name: str) -> str: """Add behavior instrumentation to test methods. + Uses tree-sitter to find @Test methods, their body boundaries, and + method invocations of func_name (with lambda-awareness via parent-chain walk). + For behavior mode, this adds: - 1. Gson import for JSON serialization + 1. SQL imports for SQLite database writes 2. SQLite database connection setup 3. Function call wrapping to capture return values 4. SQLite insert with serialized return values @@ -320,282 +236,210 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str) Instrumented source code. """ - # Add necessary imports at the top of the file - # Note: We don't import java.sql.Statement because it can conflict with - # other Statement classes (e.g., com.aerospike.client.query.Statement). - # Instead, we use the fully qualified name java.sql.Statement in the code. - # Note: We don't use Gson because it may not be available as a dependency. - # Instead, we use String.valueOf() for serialization. + from codeflash.languages.java.parser import get_java_analyzer + + analyzer = get_java_analyzer() + + # ── Step 1: Add imports ────────────────────────────────────────────── import_statements = [ "import java.sql.Connection;", "import java.sql.DriverManager;", "import java.sql.PreparedStatement;", ] - # Find position to insert imports (after package, before class) lines = source.split("\n") - result = [] + result_lines: list[str] = [] imports_added = False - i = 0 + idx = 0 - while i < len(lines): - line = lines[i] + while idx < len(lines): + line = lines[idx] stripped = line.strip() - # Add imports after the last existing import or before the class declaration if not imports_added: if stripped.startswith("import "): - result.append(line) - i += 1 - # Find end of imports - while i < len(lines) and lines[i].strip().startswith("import "): - result.append(lines[i]) - i += 1 - # Add our imports + result_lines.append(line) + idx += 1 + while idx < len(lines) and lines[idx].strip().startswith("import "): + result_lines.append(lines[idx]) + idx += 1 for imp in import_statements: if imp not in source: - result.append(imp) + result_lines.append(imp) imports_added = True continue - if stripped.startswith(("public class", "class")): - # No imports found, add before class - result.extend(import_statements) - result.append("") + # Use tree-sitter class detection: any class/interface/enum keyword + if stripped.startswith(("public class", "class", "public final class", + "final class", "abstract class", "public abstract class")): + result_lines.extend(import_statements) + result_lines.append("") imports_added = True - result.append(line) - i += 1 + result_lines.append(line) + idx += 1 + + source = "\n".join(result_lines) + + # ── Step 2: Find @Test methods and wrap their bodies ───────────────── + source_bytes = source.encode("utf8") + test_methods = analyzer.find_test_methods(source) + + if not test_methods: + return source - # Now add timing and SQLite instrumentation to test methods - source = "\n".join(result) lines = source.split("\n") - result = [] - i = 0 - iteration_counter = 0 - helper_added = False + replacements: list[tuple[int, int, str]] = [] - while i < len(lines): - line = lines[i] - stripped = line.strip() + for iter_id, method_info in enumerate(test_methods, start=1): + body_node = method_info.body_node - # Look for @Test annotation (not @TestOnly, @TestFactory, etc.) - if _is_test_annotation(stripped): - if not helper_added: - helper_added = True - result.append(line) - i += 1 - - # Collect any additional annotations - while i < len(lines) and lines[i].strip().startswith("@"): - result.append(lines[i]) - i += 1 - - # Now find the method signature and opening brace - method_lines = [] - while i < len(lines): - method_lines.append(lines[i]) - if "{" in lines[i]: - break - i += 1 - - # Add the method signature lines - for ml in method_lines: - result.append(ml) - i += 1 - - # We're now inside the method body - iteration_counter += 1 - iter_id = iteration_counter - - # Detect indentation - method_sig_line = method_lines[-1] if method_lines else "" - base_indent = len(method_sig_line) - len(method_sig_line.lstrip()) - indent = " " * (base_indent + 4) - - # Collect method body until we find matching closing brace - brace_depth = 1 - body_lines = [] - - while i < len(lines) and brace_depth > 0: - body_line = lines[i] - # Count braces more efficiently using string methods - open_count = body_line.count("{") - close_count = body_line.count("}") - brace_depth += open_count - close_count - - if brace_depth > 0: - body_lines.append(body_line) - i += 1 - else: - # We've hit the closing brace - i += 1 - break - - # Wrap function calls to capture return values - # Look for patterns like: obj.funcName(args) or new Class().funcName(args) - call_counter = 0 - wrapped_body_lines = [] - - # Track lambda block nesting depth to avoid wrapping calls inside lambda bodies. - # assertThrows/assertDoesNotThrow expect an Executable (void functional interface), - # and wrapping the call in a variable assignment would turn the void-compatible - # lambda into a value-returning lambda, causing a compilation error. - # Also, variables declared outside lambdas cannot be reassigned inside them - # (Java requires effectively final variables in lambda captures). - # Handles both no-arg lambdas: () -> { func(); } - # and parameterized lambdas: (a, b, c) -> { func(); } - lambda_brace_depth = 0 - - for body_line in body_lines: - # Detect block lambda openings: (...) -> { or () -> { - # Matches both () -> { and (a, b, c) -> { - is_lambda_open = bool(re.search(r"->\s*\{", body_line)) - - # Update lambda brace depth tracking for block lambdas - if is_lambda_open or lambda_brace_depth > 0: - open_braces = body_line.count("{") - close_braces = body_line.count("}") - if is_lambda_open and lambda_brace_depth == 0: - # Starting a new lambda block - only count braces from this lambda - lambda_brace_depth = open_braces - close_braces - else: - lambda_brace_depth += open_braces - close_braces - # Ensure depth doesn't go below 0 - lambda_brace_depth = max(0, lambda_brace_depth) - - inside_lambda = lambda_brace_depth > 0 or bool(re.search(r"->\s+\S", body_line)) - - # Check if this line contains a call to the target function - if func_name in body_line and "(" in body_line: - # Skip wrapping if the function call is inside a lambda expression - if inside_lambda: - wrapped_body_lines.append(body_line) - continue - - line_indent = len(body_line) - len(body_line.lstrip()) - line_indent_str = " " * line_indent - - # Find all matches using balanced parenthesis matching - # This correctly handles nested parens like: - # obj.func(a, Rows.toRowID(frame.getIndex(), row)) - matches = _find_method_calls_balanced(body_line, func_name) - if matches: - # Process matches in reverse order to maintain correct positions - new_line = body_line - for start_pos, end_pos, full_call in reversed(matches): - call_counter += 1 - var_name = f"_cf_result{iter_id}_{call_counter}" - - # Check if we need to cast the result for assertions with primitive arrays - # This handles assertArrayEquals(int[], int[]) etc. - cast_type = _infer_array_cast_type(body_line) - var_with_cast = f"({cast_type}){var_name}" if cast_type else var_name - - # Replace this occurrence with the variable (with cast if needed) - new_line = new_line[:start_pos] + var_with_cast + new_line[end_pos:] - - # Use 'var' instead of 'Object' to preserve the exact return type. - # This avoids boxing mismatches (e.g., assertEquals(int, Object) where - # Object is boxed Long but expected is boxed Integer). Requires Java 10+. - capture_line = f"{line_indent_str}var {var_name} = {full_call};" - wrapped_body_lines.append(capture_line) - - # Immediately serialize the captured result while the variable - # is still in scope. This is necessary because the variable may - # be declared inside a nested block (while/for/if/try) and would - # be out of scope at the end of the method body. - serialize_line = ( - f"{line_indent_str}_cf_serializedResult{iter_id} = " - f"com.codeflash.Serializer.serialize((Object) {var_name});" - ) - wrapped_body_lines.append(serialize_line) - - # Check if the line is now just a variable reference (invalid statement) - # This happens when the original line was just a void method call - # e.g., "BubbleSort.bubbleSort(original);" becomes "_cf_result1_1;" - stripped_new = new_line.strip().rstrip(";").strip() - if stripped_new and stripped_new not in (var_name, var_with_cast): - wrapped_body_lines.append(new_line) - else: - wrapped_body_lines.append(body_line) - else: - wrapped_body_lines.append(body_line) - - # Add behavior instrumentation code - behavior_start_code = [ - f"{indent}// Codeflash behavior instrumentation", - f'{indent}int _cf_loop{iter_id} = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));', - f"{indent}int _cf_iter{iter_id} = {iter_id};", - f'{indent}String _cf_mod{iter_id} = "{class_name}";', - f'{indent}String _cf_cls{iter_id} = "{class_name}";', - f'{indent}String _cf_fn{iter_id} = "{func_name}";', - f'{indent}String _cf_outputFile{iter_id} = System.getenv("CODEFLASH_OUTPUT_FILE");', - f'{indent}String _cf_testIteration{iter_id} = System.getenv("CODEFLASH_TEST_ITERATION");', - f'{indent}if (_cf_testIteration{iter_id} == null) _cf_testIteration{iter_id} = "0";', - f'{indent}System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":" + _cf_iter{iter_id} + "######$!");', - f"{indent}long _cf_start{iter_id} = System.nanoTime();", - f"{indent}byte[] _cf_serializedResult{iter_id} = null;", - f"{indent}try {{", - ] - result.extend(behavior_start_code) - - # Add the wrapped body lines with extra indentation. - # Serialization of captured results is already done inline (immediately - # after each capture) so the _cf_serializedResult variable is always - # assigned while the captured variable is still in scope. - for bl in wrapped_body_lines: - result.append(" " + bl) - - # Add finally block with SQLite write - method_close_indent = " " * base_indent - behavior_end_code = [ - f"{indent}}} finally {{", - f"{indent} long _cf_end{iter_id} = System.nanoTime();", - f"{indent} long _cf_dur{iter_id} = _cf_end{iter_id} - _cf_start{iter_id};", - f'{indent} System.out.println("!######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":" + _cf_iter{iter_id} + ":" + _cf_dur{iter_id} + "######!");', - f"{indent} // Write to SQLite if output file is set", - f"{indent} if (_cf_outputFile{iter_id} != null && !_cf_outputFile{iter_id}.isEmpty()) {{", - f"{indent} try {{", - f'{indent} Class.forName("org.sqlite.JDBC");', - f'{indent} try (Connection _cf_conn{iter_id} = DriverManager.getConnection("jdbc:sqlite:" + _cf_outputFile{iter_id})) {{', - f"{indent} try (java.sql.Statement _cf_stmt{iter_id} = _cf_conn{iter_id}.createStatement()) {{", - f'{indent} _cf_stmt{iter_id}.execute("CREATE TABLE IF NOT EXISTS test_results (" +', - f'{indent} "test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " +', - f'{indent} "function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " +', - f'{indent} "runtime INTEGER, return_value BLOB, verification_type TEXT)");', - f"{indent} }}", - f'{indent} String _cf_sql{iter_id} = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)";', - f"{indent} try (PreparedStatement _cf_pstmt{iter_id} = _cf_conn{iter_id}.prepareStatement(_cf_sql{iter_id})) {{", - f"{indent} _cf_pstmt{iter_id}.setString(1, _cf_mod{iter_id});", - f"{indent} _cf_pstmt{iter_id}.setString(2, _cf_cls{iter_id});", - f'{indent} _cf_pstmt{iter_id}.setString(3, "{class_name}Test");', - f"{indent} _cf_pstmt{iter_id}.setString(4, _cf_fn{iter_id});", - f"{indent} _cf_pstmt{iter_id}.setInt(5, _cf_loop{iter_id});", - f'{indent} _cf_pstmt{iter_id}.setString(6, _cf_iter{iter_id} + "_" + _cf_testIteration{iter_id});', - f"{indent} _cf_pstmt{iter_id}.setLong(7, _cf_dur{iter_id});", - f"{indent} _cf_pstmt{iter_id}.setBytes(8, _cf_serializedResult{iter_id});", # Kryo-serialized return value - f'{indent} _cf_pstmt{iter_id}.setString(9, "function_call");', - f"{indent} _cf_pstmt{iter_id}.executeUpdate();", - f"{indent} }}", - f"{indent} }}", - f"{indent} }} catch (Exception _cf_e{iter_id}) {{", - f'{indent} System.err.println("CodeflashHelper: SQLite error: " + _cf_e{iter_id}.getMessage());', - f"{indent} }}", - f"{indent} }}", - f"{indent}}}", - f"{method_close_indent}}}", # Method closing brace + # Extract body lines (between { and }) + body_start_line = body_node.start_point[0] + body_end_line = body_node.end_point[0] + body_lines = lines[body_start_line + 1 : body_end_line] + + brace_line = lines[body_start_line] + base_indent = len(brace_line) - len(brace_line.lstrip()) + indent = " " * (base_indent + 4) + + # ── Find method invocations via tree-sitter ────────────────────── + invocations = analyzer.find_method_invocations(body_node, source_bytes, func_name) + + # Group invocations by their 0-indexed source line + invocations_by_line: dict[int, list] = {} + for inv in invocations: + inv_line = inv.node.start_point[0] + invocations_by_line.setdefault(inv_line, []).append(inv) + + # ── Wrap function calls per body line ──────────────────────────── + call_counter = 0 + wrapped_body_lines: list[str] = [] + + for local_idx, body_line in enumerate(body_lines): + source_line_idx = body_start_line + 1 + local_idx + line_invocations = invocations_by_line.get(source_line_idx, []) + + # Filter to non-lambda invocations on a single line + actionable = [ + inv for inv in line_invocations + if not inv.in_lambda and inv.node.start_point[0] == inv.node.end_point[0] ] - result.extend(behavior_end_code) - else: - result.append(line) - i += 1 - return "\n".join(result) + if actionable: + line_indent = len(body_line) - len(body_line.lstrip()) + line_indent_str = " " * line_indent + + # Process matches in reverse column order to preserve positions + actionable.sort(key=lambda inv: inv.node.start_point[1], reverse=True) + new_line = body_line + last_var_name = None + last_var_with_cast = None + + for inv in actionable: + call_counter += 1 + var_name = f"_cf_result{iter_id}_{call_counter}" + last_var_name = var_name + + cast_type = _infer_array_cast_type(body_line) + var_with_cast = f"({cast_type}){var_name}" if cast_type else var_name + last_var_with_cast = var_with_cast + + start_col = inv.node.start_point[1] + end_col = inv.node.end_point[1] + new_line = new_line[:start_col] + var_with_cast + new_line[end_col:] + + capture_line = f"{line_indent_str}var {var_name} = {inv.full_text};" + wrapped_body_lines.append(capture_line) + + serialize_line = ( + f"{line_indent_str}_cf_serializedResult{iter_id} = " + f"com.codeflash.Serializer.serialize((Object) {var_name});" + ) + wrapped_body_lines.append(serialize_line) + + # Skip line if it collapsed to just a bare variable reference + stripped_new = new_line.strip().rstrip(";").strip() + if stripped_new and stripped_new not in (last_var_name, last_var_with_cast): + wrapped_body_lines.append(new_line) + else: + wrapped_body_lines.append(body_line) + + # ── Build replacement body ─────────────────────────────────────── + method_close_indent = " " * base_indent + behavior_lines = [ + "{", + f"{indent}// Codeflash behavior instrumentation", + f'{indent}int _cf_loop{iter_id} = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));', + f"{indent}int _cf_iter{iter_id} = {iter_id};", + f'{indent}String _cf_mod{iter_id} = "{class_name}";', + f'{indent}String _cf_cls{iter_id} = "{class_name}";', + f'{indent}String _cf_fn{iter_id} = "{func_name}";', + f'{indent}String _cf_outputFile{iter_id} = System.getenv("CODEFLASH_OUTPUT_FILE");', + f'{indent}String _cf_testIteration{iter_id} = System.getenv("CODEFLASH_TEST_ITERATION");', + f'{indent}if (_cf_testIteration{iter_id} == null) _cf_testIteration{iter_id} = "0";', + f'{indent}System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":" + _cf_iter{iter_id} + "######$!");', + f"{indent}long _cf_start{iter_id} = System.nanoTime();", + f"{indent}byte[] _cf_serializedResult{iter_id} = null;", + f"{indent}try {{", + ] + + for bl in wrapped_body_lines: + behavior_lines.append(" " + bl) + + behavior_lines.extend([ + f"{indent}}} finally {{", + f"{indent} long _cf_end{iter_id} = System.nanoTime();", + f"{indent} long _cf_dur{iter_id} = _cf_end{iter_id} - _cf_start{iter_id};", + f'{indent} System.out.println("!######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":" + _cf_iter{iter_id} + ":" + _cf_dur{iter_id} + "######!");', + f"{indent} // Write to SQLite if output file is set", + f"{indent} if (_cf_outputFile{iter_id} != null && !_cf_outputFile{iter_id}.isEmpty()) {{", + f"{indent} try {{", + f'{indent} Class.forName("org.sqlite.JDBC");', + f'{indent} try (Connection _cf_conn{iter_id} = DriverManager.getConnection("jdbc:sqlite:" + _cf_outputFile{iter_id})) {{', + f"{indent} try (java.sql.Statement _cf_stmt{iter_id} = _cf_conn{iter_id}.createStatement()) {{", + f'{indent} _cf_stmt{iter_id}.execute("CREATE TABLE IF NOT EXISTS test_results (" +', + f'{indent} "test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " +', + f'{indent} "function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " +', + f'{indent} "runtime INTEGER, return_value BLOB, verification_type TEXT)");', + f"{indent} }}", + f'{indent} String _cf_sql{iter_id} = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)";', + f"{indent} try (PreparedStatement _cf_pstmt{iter_id} = _cf_conn{iter_id}.prepareStatement(_cf_sql{iter_id})) {{", + f"{indent} _cf_pstmt{iter_id}.setString(1, _cf_mod{iter_id});", + f"{indent} _cf_pstmt{iter_id}.setString(2, _cf_cls{iter_id});", + f'{indent} _cf_pstmt{iter_id}.setString(3, "{class_name}Test");', + f"{indent} _cf_pstmt{iter_id}.setString(4, _cf_fn{iter_id});", + f"{indent} _cf_pstmt{iter_id}.setInt(5, _cf_loop{iter_id});", + f'{indent} _cf_pstmt{iter_id}.setString(6, _cf_iter{iter_id} + "_" + _cf_testIteration{iter_id});', + f"{indent} _cf_pstmt{iter_id}.setLong(7, _cf_dur{iter_id});", + f"{indent} _cf_pstmt{iter_id}.setBytes(8, _cf_serializedResult{iter_id});", + f'{indent} _cf_pstmt{iter_id}.setString(9, "function_call");', + f"{indent} _cf_pstmt{iter_id}.executeUpdate();", + f"{indent} }}", + f"{indent} }}", + f"{indent} }} catch (Exception _cf_e{iter_id}) {{", + f'{indent} System.err.println("CodeflashHelper: SQLite error: " + _cf_e{iter_id}.getMessage());', + f"{indent} }}", + f"{indent} }}", + f"{indent}}}", + f"{method_close_indent}}}", + ]) + + replacement = "\n".join(behavior_lines) + replacements.append((body_node.start_byte, body_node.end_byte, replacement)) + + # Apply replacements in reverse byte order + for start, end, replacement in sorted(replacements, key=lambda x: x[0], reverse=True): + source_bytes = source_bytes[:start] + replacement.encode("utf8") + source_bytes[end:] + + return source_bytes.decode("utf8") def _add_timing_instrumentation(source: str, class_name: str, func_name: str) -> str: """Add timing instrumentation to test methods with inner loop for JIT warmup. + Uses tree-sitter to find @Test methods and their body boundaries, + then replaces each method body with a timing-wrapped version. + For each @Test method, this adds: 1. Inner loop that runs N iterations (controlled by CODEFLASH_INNER_ITERATIONS env var) 2. Start timing marker printed at the beginning of each iteration @@ -619,110 +463,74 @@ def _add_timing_instrumentation(source: str, class_name: str, func_name: str) -> Instrumented source code. """ - # Find all @Test methods and add timing around their bodies - # Pattern matches: @Test (with optional parameters) followed by method declaration - # We process line by line for cleaner handling + from codeflash.languages.java.parser import get_java_analyzer - lines = source.split("\n") - result = [] - i = 0 - iteration_counter = 0 + analyzer = get_java_analyzer() + source_bytes = source.encode("utf8") + test_methods = analyzer.find_test_methods(source) - while i < len(lines): - line = lines[i] - stripped = line.strip() + if not test_methods: + return source - # Look for @Test annotation (not @TestOnly, @TestFactory, etc.) - if _is_test_annotation(stripped): - result.append(line) - i += 1 - - # Collect any additional annotations - while i < len(lines) and lines[i].strip().startswith("@"): - result.append(lines[i]) - i += 1 - - # Now find the method signature and opening brace - method_lines = [] - while i < len(lines): - method_lines.append(lines[i]) - if "{" in lines[i]: - break - i += 1 - - # Add the method signature lines - result.extend(method_lines) - i += 1 - - # We're now inside the method body - iteration_counter += 1 - iter_id = iteration_counter - - # Detect indentation from method signature line (line with opening brace) - method_sig_line = method_lines[-1] if method_lines else "" - base_indent = len(method_sig_line) - len(method_sig_line.lstrip()) - indent = " " * (base_indent + 4) # Add one level of indentation - inner_indent = " " * (base_indent + 8) # Two levels for inside inner loop - inner_body_indent = " " * (base_indent + 12) # Three levels for try block body - - # Add timing instrumentation with inner loop - # Note: CODEFLASH_LOOP_INDEX must always be set - no null check, crash if missing - # CODEFLASH_INNER_ITERATIONS controls inner loop count (default: 100) - timing_start_code = [ - f"{indent}// Codeflash timing instrumentation with inner loop for JIT warmup", - f'{indent}int _cf_loop{iter_id} = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));', - f'{indent}int _cf_innerIterations{iter_id} = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "100"));', - f'{indent}String _cf_mod{iter_id} = "{class_name}";', - f'{indent}String _cf_cls{iter_id} = "{class_name}";', - f'{indent}String _cf_fn{iter_id} = "{func_name}";', - "", - f"{indent}for (int _cf_i{iter_id} = 0; _cf_i{iter_id} < _cf_innerIterations{iter_id}; _cf_i{iter_id}++) {{", - f'{inner_indent}System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":" + _cf_i{iter_id} + "######$!");', - f"{inner_indent}long _cf_start{iter_id} = System.nanoTime();", - f"{inner_indent}try {{", - ] - result.extend(timing_start_code) - - # Collect method body until we find matching closing brace - brace_depth = 1 - body_lines = [] - - while i < len(lines) and brace_depth > 0: - body_line = lines[i] - # Count braces (simple approach - doesn't handle strings/comments perfectly) - for ch in body_line: - if ch == "{": - brace_depth += 1 - elif ch == "}": - brace_depth -= 1 - - if brace_depth > 0: - body_lines.append(body_line) - i += 1 - else: - # This line contains the closing brace, but we've hit depth 0 - # Add indented body lines (inside try block, inside for loop) - for bl in body_lines: - result.append(" " + bl) # 8 extra spaces for inner loop + try - - # Add finally block and close inner loop - method_close_indent = " " * base_indent # Same level as method signature - timing_end_code = [ - f"{inner_indent}}} finally {{", - f"{inner_indent} long _cf_end{iter_id} = System.nanoTime();", - f"{inner_indent} long _cf_dur{iter_id} = _cf_end{iter_id} - _cf_start{iter_id};", - f'{inner_indent} System.out.println("!######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":" + _cf_i{iter_id} + ":" + _cf_dur{iter_id} + "######!");', - f"{inner_indent}}}", - f"{indent}}}", # Close for loop - f"{method_close_indent}}}", # Method closing brace - ] - result.extend(timing_end_code) - i += 1 - else: - result.append(line) - i += 1 - - return "\n".join(result) + lines = source.split("\n") + + # Build a replacement for each @Test method body. + # iter_id is assigned in forward (source) order; replacements are applied in reverse. + replacements: list[tuple[int, int, str]] = [] + + for iter_id, method_info in enumerate(test_methods, start=1): + body_node = method_info.body_node + + # Extract body lines (between { and }, exclusive of both brace lines) + body_start_line = body_node.start_point[0] # 0-indexed line of { + body_end_line = body_node.end_point[0] # 0-indexed line of } + body_lines = lines[body_start_line + 1 : body_end_line] + + # Indentation from the line containing the opening brace + brace_line = lines[body_start_line] + base_indent = len(brace_line) - len(brace_line.lstrip()) + indent = " " * (base_indent + 4) + inner_indent = " " * (base_indent + 8) + method_close_indent = " " * base_indent + + # Build the replacement body (opening { through closing }) + timing_lines = [ + "{", + f"{indent}// Codeflash timing instrumentation with inner loop for JIT warmup", + f'{indent}int _cf_loop{iter_id} = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));', + f'{indent}int _cf_innerIterations{iter_id} = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "100"));', + f'{indent}String _cf_mod{iter_id} = "{class_name}";', + f'{indent}String _cf_cls{iter_id} = "{class_name}";', + f'{indent}String _cf_fn{iter_id} = "{func_name}";', + "", + f"{indent}for (int _cf_i{iter_id} = 0; _cf_i{iter_id} < _cf_innerIterations{iter_id}; _cf_i{iter_id}++) {{", + f'{inner_indent}System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":" + _cf_i{iter_id} + "######$!");', + f"{inner_indent}long _cf_start{iter_id} = System.nanoTime();", + f"{inner_indent}try {{", + ] + + # Original body lines, indented by 8 extra spaces (for loop + try) + for bl in body_lines: + timing_lines.append(" " + bl) + + timing_lines.extend([ + f"{inner_indent}}} finally {{", + f"{inner_indent} long _cf_end{iter_id} = System.nanoTime();", + f"{inner_indent} long _cf_dur{iter_id} = _cf_end{iter_id} - _cf_start{iter_id};", + f'{inner_indent} System.out.println("!######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":" + _cf_i{iter_id} + ":" + _cf_dur{iter_id} + "######!");', + f"{inner_indent}}}", + f"{indent}}}", # Close for loop + f"{method_close_indent}}}", # Method closing brace + ]) + + replacement = "\n".join(timing_lines) + replacements.append((body_node.start_byte, body_node.end_byte, replacement)) + + # Apply replacements in reverse byte order to preserve earlier offsets + for start, end, replacement in sorted(replacements, key=lambda x: x[0], reverse=True): + source_bytes = source_bytes[:start] + replacement.encode("utf8") + source_bytes[end:] + + return source_bytes.decode("utf8") def create_benchmark_test( @@ -825,14 +633,16 @@ def instrument_generated_java_test( if mode == "behavior": test_code = transform_java_assertions(test_code, function_name, qualified_name) - # Extract class name from the test code - # Use pattern that starts at beginning of line to avoid matching words in comments - class_match = re.search(r"^(?:public\s+)?class\s+(\w+)", test_code, re.MULTILINE) - if not class_match: + # Extract class name from the test code using tree-sitter + from codeflash.languages.java.parser import get_java_analyzer + + analyzer = get_java_analyzer() + classes = analyzer.find_classes(test_code) + if not classes: logger.warning("Could not find class name in generated test") return test_code - original_class_name = class_match.group(1) + original_class_name = classes[0].name # Rename class based on mode if mode == "behavior": @@ -840,11 +650,17 @@ def instrument_generated_java_test( else: new_class_name = f"{original_class_name}__perfonlyinstrumented" - # Rename all references to the original class name in the source. - # This includes the class declaration, return types, constructor calls, etc. - modified_code = re.sub( - rf"\b{re.escape(original_class_name)}\b", new_class_name, test_code - ) + # Rename all identifier references to the class name using tree-sitter + # (excludes matches inside string literals and comments) + refs = analyzer.find_identifier_references(test_code, original_class_name) + if refs: + code_bytes = test_code.encode("utf8") + new_name_bytes = new_class_name.encode("utf8") + for start, end in reversed(refs): + code_bytes = code_bytes[:start] + new_name_bytes + code_bytes[end:] + modified_code = code_bytes.decode("utf8") + else: + modified_code = test_code # For performance mode, add timing instrumentation # Use original class name (without suffix) in timing markers for consistency with Python @@ -859,32 +675,4 @@ def instrument_generated_java_test( return modified_code -def _add_import(source: str, import_statement: str) -> str: - """Add an import statement to the source. - - Args: - source: The source code. - import_statement: The import to add. - - Returns: - Source with import added. - - """ - lines = source.splitlines(keepends=True) - insert_idx = 0 - - # Find the last import or package statement - for i, line in enumerate(lines): - stripped = line.strip() - if stripped.startswith(("import ", "package ")): - insert_idx = i + 1 - elif stripped and not stripped.startswith("//") and not stripped.startswith("/*"): - # First non-import, non-comment line - if insert_idx == 0: - insert_idx = i - break - - lines.insert(insert_idx, import_statement + "\n") - return "".join(lines) - diff --git a/codeflash/languages/java/line_profiler.py b/codeflash/languages/java/line_profiler.py index 8a59ed6e6..67895149c 100644 --- a/codeflash/languages/java/line_profiler.py +++ b/codeflash/languages/java/line_profiler.py @@ -9,7 +9,6 @@ import json import logging -import re from pathlib import Path from typing import TYPE_CHECKING @@ -93,18 +92,18 @@ def instrument_source( # Add profiler class and initialization profiler_class_code = self._generate_profiler_class() - # Insert profiler class before the package's first class - # Find the first class/interface/enum/record declaration - # Must handle any combination of modifiers: public final class, abstract class, etc. - class_pattern = re.compile( - r"^(?:(?:public|private|protected|final|abstract|static|sealed|non-sealed)\s+)*" - r"(?:class|interface|enum|record)\s+" - ) + # Insert profiler class before the package's first class declaration. + # Uses tree-sitter find_classes() to handle any modifier combination + # (public final class, abstract class, etc.). + from codeflash.languages.java.parser import get_java_analyzer + + analyzer = get_java_analyzer() + source_text = "".join(lines) + classes = analyzer.find_classes(source_text) import_end_idx = 0 - for i, line in enumerate(lines): - if class_pattern.match(line.strip()): - import_end_idx = i - break + if classes: + # find_classes returns 1-indexed start_line; convert to 0-indexed + import_end_idx = classes[0].start_line - 1 lines_with_profiler = ( lines[:import_end_idx] + [profiler_class_code + "\n"] + lines[import_end_idx:] diff --git a/codeflash/languages/java/parser.py b/codeflash/languages/java/parser.py index 72a530179..52d10726a 100644 --- a/codeflash/languages/java/parser.py +++ b/codeflash/languages/java/parser.py @@ -53,6 +53,24 @@ class JavaMethodNode: javadoc_start_line: int | None = None # Line where Javadoc comment starts +@dataclass +class TestMethodInfo: + """A @Test-annotated method found by tree-sitter analysis.""" + + name: str + method_node: Node + body_node: Node # The block (body) of the method + + +@dataclass +class MethodCallInfo: + """A method invocation found within a given AST subtree.""" + + node: Node + full_text: str + in_lambda: bool + + @dataclass class JavaClassNode: """Represents a class found by tree-sitter analysis.""" @@ -678,6 +696,170 @@ def get_package_name(self, source: str) -> str | None: return None + def find_test_methods(self, source: str) -> list[TestMethodInfo]: + """Find all @Test-annotated methods in source code. + + Identifies method_declaration nodes whose modifiers contain a + marker_annotation or annotation with name ``Test`` (not ``TestOnly``, + ``TestFactory``, etc.). + + Args: + source: The source code to analyze. + + Returns: + List of TestMethodInfo in source order. + + """ + source_bytes = source.encode("utf8") + tree = self.parse(source_bytes) + results: list[TestMethodInfo] = [] + self._walk_for_test_methods(tree.root_node, source_bytes, results) + return results + + def _walk_for_test_methods( + self, node: Node, source_bytes: bytes, results: list[TestMethodInfo] + ) -> None: + """Recursively walk the tree to find @Test methods.""" + if node.type == "method_declaration": + if self._has_test_annotation(node, source_bytes): + body = node.child_by_field_name("body") + name_node = node.child_by_field_name("name") + if body and name_node: + results.append( + TestMethodInfo( + name=self.get_node_text(name_node, source_bytes), + method_node=node, + body_node=body, + ) + ) + + for child in node.children: + self._walk_for_test_methods(child, source_bytes, results) + + def _has_test_annotation(self, method_node: Node, source_bytes: bytes) -> bool: + """Check if a method_declaration has an @Test annotation. + + Only matches exact ``Test`` name – not ``TestOnly``, ``TestFactory``, etc. + Handles both ``@Test`` (marker_annotation) and ``@Test(...)`` (annotation). + """ + for child in method_node.children: + if child.type == "modifiers": + for mod_child in child.children: + if mod_child.type in ("marker_annotation", "annotation"): + name_node = mod_child.child_by_field_name("name") + if name_node is None: + # Fallback: search direct children for identifier + for ann_child in mod_child.children: + if ann_child.type == "identifier": + name_node = ann_child + break + if name_node and self.get_node_text(name_node, source_bytes) == "Test": + return True + return False + + def find_method_invocations( + self, body_node: Node, source_bytes: bytes, func_name: str + ) -> list[MethodCallInfo]: + """Find all invocations of *func_name* within a given AST subtree. + + Checks the parent chain for ``lambda_expression`` nodes to set the + ``in_lambda`` flag. + + Args: + body_node: The subtree to search within (typically a method body block). + source_bytes: The full source code as bytes. + func_name: The method name to match. + + Returns: + List of MethodCallInfo in source order. + + """ + results: list[MethodCallInfo] = [] + self._walk_for_invocations(body_node, source_bytes, func_name, results, in_lambda=False) + return results + + def _walk_for_invocations( + self, + node: Node, + source_bytes: bytes, + func_name: str, + results: list[MethodCallInfo], + in_lambda: bool, + ) -> None: + """Recursively find method invocations matching func_name.""" + if node.type == "lambda_expression": + in_lambda = True + + if node.type == "method_invocation": + name_node = node.child_by_field_name("name") + if name_node and self.get_node_text(name_node, source_bytes) == func_name: + results.append( + MethodCallInfo( + node=node, + full_text=self.get_node_text(node, source_bytes), + in_lambda=in_lambda, + ) + ) + + for child in node.children: + self._walk_for_invocations(child, source_bytes, func_name, results, in_lambda) + + def find_identifier_references(self, source: str, name: str) -> list[tuple[int, int]]: + """Find all ``identifier`` / ``type_identifier`` AST nodes matching *name*. + + Tree-sitter naturally excludes matches inside string literals and comments, + making this safer than ``re.sub(r"\\b...\\b", ...)``. + + Args: + source: The source code to analyze. + name: The identifier name to match. + + Returns: + List of ``(start_byte, end_byte)`` pairs in source order. + + """ + source_bytes = source.encode("utf8") + tree = self.parse(source_bytes) + refs: list[tuple[int, int]] = [] + self._walk_for_identifiers(tree.root_node, source_bytes, name, refs) + return refs + + def _walk_for_identifiers( + self, node: Node, source_bytes: bytes, name: str, refs: list[tuple[int, int]] + ) -> None: + """Recursively find identifier/type_identifier nodes matching name.""" + if node.type in ("identifier", "type_identifier"): + if self.get_node_text(node, source_bytes) == name: + refs.append((node.start_byte, node.end_byte)) + for child in node.children: + self._walk_for_identifiers(child, source_bytes, name, refs) + + def find_import_insertion_point(self, source: str) -> int: + """Find the 0-indexed line after the last import or package declaration. + + Returns the line index where new import statements should be inserted. + If no imports or package found, returns 0. + + Args: + source: The source code to analyze. + + Returns: + 0-indexed line number for insertion. + + """ + source_bytes = source.encode("utf8") + tree = self.parse(source_bytes) + last_line = 0 + + for child in tree.root_node.children: + if child.type in ("import_declaration", "package_declaration"): + # end_point[0] is 0-indexed last line of the node + candidate = child.end_point[0] + 1 + if candidate > last_line: + last_line = candidate + + return last_line + def get_java_analyzer() -> JavaAnalyzer: """Get a JavaAnalyzer instance.