From c11afd33d6467c85ad887e9b603c72421b3f0d35 Mon Sep 17 00:00:00 2001 From: HeshamHM28 Date: Fri, 13 Feb 2026 00:08:31 +0000 Subject: [PATCH 1/2] refactor(java): replace regex/brace-counting with tree-sitter in instrumentation Replaces fragile line-by-line regex scanning and manual brace counting with tree-sitter AST analysis for Java test instrumentation. This eliminates several classes of bugs: - `public final class` not matching class detection patterns - `@TestOnly` being incorrectly matched as `@Test` - Nested parentheses breaking method call extraction - Braces inside strings/comments corrupting method boundary detection - Class renaming hitting matches inside comments and string literals - Lambda detection missing edge cases Changes: - parser.py: Add find_test_methods(), find_method_invocations(), find_identifier_references(), find_import_insertion_point() helpers - instrumentation.py: Refactor _add_timing_instrumentation and _add_behavior_instrumentation to use tree-sitter for @Test method detection and body boundary extraction via body node ranges - instrumentation.py: Refactor class renaming to use tree-sitter identifier references (excludes strings/comments) - instrumentation.py: Refactor class extraction in generated test instrumentation to use find_classes() - instrumentation.py: Remove dead helpers (_is_test_annotation, _find_balanced_end, _find_method_calls_balanced, _add_import) - line_profiler.py: Replace regex class detection with find_classes() All 33 instrumentation unit tests pass with byte-for-byte identical output. Co-Authored-By: Claude Opus 4.6 --- codeflash/languages/java/instrumentation.py | 764 +++++++------------- codeflash/languages/java/line_profiler.py | 23 +- codeflash/languages/java/parser.py | 182 +++++ 3 files changed, 469 insertions(+), 500 deletions(-) diff --git a/codeflash/languages/java/instrumentation.py b/codeflash/languages/java/instrumentation.py index 63a97b17b..14303e0f8 100644 --- a/codeflash/languages/java/instrumentation.py +++ b/codeflash/languages/java/instrumentation.py @@ -42,102 +42,6 @@ def _get_function_name(func: Any) -> str: # Pattern to detect primitive array types in assertions _PRIMITIVE_ARRAY_PATTERN = re.compile(r"new\s+(int|long|double|float|short|byte|char|boolean)\s*\[\s*\]") -# Pattern to match @Test annotation exactly (not @TestOnly, @TestFactory, etc.) -_TEST_ANNOTATION_RE = re.compile(r"^@Test(?:\s*\(.*\))?(?:\s.*)?$") - - -def _is_test_annotation(stripped_line: str) -> bool: - """Check if a stripped line is an @Test annotation (not @TestOnly, @TestFactory, etc.). - - Matches: - @Test - @Test(expected = ...) - @Test(timeout = 5000) - Does NOT match: - @TestOnly - @TestFactory - @TestTemplate - """ - return bool(_TEST_ANNOTATION_RE.match(stripped_line)) - - -def _find_balanced_end(text: str, start: int) -> int: - """Find the position after the closing paren that balances the opening paren at start. - - Args: - text: The source text. - start: Index of the opening parenthesis '('. - - Returns: - Index one past the matching closing ')', or -1 if not found. - - """ - if start >= len(text) or text[start] != "(": - return -1 - depth = 1 - pos = start + 1 - in_string = False - string_char = None - in_char = False - while pos < len(text) and depth > 0: - ch = text[pos] - prev = text[pos - 1] if pos > 0 else "" - if ch == "'" and not in_string and prev != "\\": - in_char = not in_char - elif ch == '"' and not in_char and prev != "\\": - if not in_string: - in_string = True - string_char = ch - elif ch == string_char: - in_string = False - string_char = None - elif not in_string and not in_char: - if ch == "(": - depth += 1 - elif ch == ")": - depth -= 1 - pos += 1 - return pos if depth == 0 else -1 - - -def _find_method_calls_balanced(line: str, func_name: str): - """Find method calls to func_name with properly balanced parentheses. - - Handles nested parentheses in arguments correctly, unlike a pure regex approach. - Returns a list of (start, end, full_call) tuples where start/end are positions - in the line and full_call is the matched text (receiver.funcName(args)). - - Args: - line: A single line of Java source code. - func_name: The method name to look for. - - Returns: - List of (start_pos, end_pos, full_call_text) tuples. - - """ - # First find all occurrences of .funcName( in the line using regex - # to locate the method name, then use balanced paren finding for args - prefix_pattern = re.compile( - rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*{re.escape(func_name)}\s*\(" - ) - results = [] - search_start = 0 - while search_start < len(line): - m = prefix_pattern.search(line, search_start) - if not m: - break - # m.end() - 1 is the position of the opening paren - open_paren_pos = m.end() - 1 - close_pos = _find_balanced_end(line, open_paren_pos) - if close_pos == -1: - # Unbalanced parens - skip this match - search_start = m.end() - continue - full_call = line[m.start():close_pos] - results.append((m.start(), close_pos, full_call)) - search_start = close_pos - return results - def _infer_array_cast_type(line: str) -> str | None: """Infer the array cast type needed for assertion methods. @@ -278,12 +182,21 @@ def instrument_existing_test( new_class_name = f"{original_class_name}__perfonlyinstrumented" # Rename all references to the original class name in the source. - # This includes the class declaration, return types, constructor calls, - # variable declarations, etc. We use word-boundary matching to avoid - # replacing substrings of other identifiers. - modified_source = re.sub( - rf"\b{re.escape(original_class_name)}\b", new_class_name, source - ) + # Uses tree-sitter to find identifier/type_identifier AST nodes, + # which correctly excludes matches inside string literals and comments. + if analyzer is None: + from codeflash.languages.java.parser import get_java_analyzer + analyzer = get_java_analyzer() + + refs = analyzer.find_identifier_references(source, original_class_name) + if refs: + source_bytes = source.encode("utf8") + new_name_bytes = new_class_name.encode("utf8") + for start, end in reversed(refs): + source_bytes = source_bytes[:start] + new_name_bytes + source_bytes[end:] + modified_source = source_bytes.decode("utf8") + else: + modified_source = source # Add timing instrumentation to test methods # Use original class name (without suffix) in timing markers for consistency with Python @@ -305,8 +218,11 @@ def instrument_existing_test( def _add_behavior_instrumentation(source: str, class_name: str, func_name: str) -> str: """Add behavior instrumentation to test methods. + Uses tree-sitter to find @Test methods, their body boundaries, and + method invocations of func_name (with lambda-awareness via parent-chain walk). + For behavior mode, this adds: - 1. Gson import for JSON serialization + 1. SQL imports for SQLite database writes 2. SQLite database connection setup 3. Function call wrapping to capture return values 4. SQLite insert with serialized return values @@ -320,282 +236,210 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str) Instrumented source code. """ - # Add necessary imports at the top of the file - # Note: We don't import java.sql.Statement because it can conflict with - # other Statement classes (e.g., com.aerospike.client.query.Statement). - # Instead, we use the fully qualified name java.sql.Statement in the code. - # Note: We don't use Gson because it may not be available as a dependency. - # Instead, we use String.valueOf() for serialization. + from codeflash.languages.java.parser import get_java_analyzer + + analyzer = get_java_analyzer() + + # ── Step 1: Add imports ────────────────────────────────────────────── import_statements = [ "import java.sql.Connection;", "import java.sql.DriverManager;", "import java.sql.PreparedStatement;", ] - # Find position to insert imports (after package, before class) lines = source.split("\n") - result = [] + result_lines: list[str] = [] imports_added = False - i = 0 + idx = 0 - while i < len(lines): - line = lines[i] + while idx < len(lines): + line = lines[idx] stripped = line.strip() - # Add imports after the last existing import or before the class declaration if not imports_added: if stripped.startswith("import "): - result.append(line) - i += 1 - # Find end of imports - while i < len(lines) and lines[i].strip().startswith("import "): - result.append(lines[i]) - i += 1 - # Add our imports + result_lines.append(line) + idx += 1 + while idx < len(lines) and lines[idx].strip().startswith("import "): + result_lines.append(lines[idx]) + idx += 1 for imp in import_statements: if imp not in source: - result.append(imp) + result_lines.append(imp) imports_added = True continue - if stripped.startswith(("public class", "class")): - # No imports found, add before class - result.extend(import_statements) - result.append("") + # Use tree-sitter class detection: any class/interface/enum keyword + if stripped.startswith(("public class", "class", "public final class", + "final class", "abstract class", "public abstract class")): + result_lines.extend(import_statements) + result_lines.append("") imports_added = True - result.append(line) - i += 1 + result_lines.append(line) + idx += 1 + + source = "\n".join(result_lines) + + # ── Step 2: Find @Test methods and wrap their bodies ───────────────── + source_bytes = source.encode("utf8") + test_methods = analyzer.find_test_methods(source) + + if not test_methods: + return source - # Now add timing and SQLite instrumentation to test methods - source = "\n".join(result) lines = source.split("\n") - result = [] - i = 0 - iteration_counter = 0 - helper_added = False + replacements: list[tuple[int, int, str]] = [] - while i < len(lines): - line = lines[i] - stripped = line.strip() + for iter_id, method_info in enumerate(test_methods, start=1): + body_node = method_info.body_node - # Look for @Test annotation (not @TestOnly, @TestFactory, etc.) - if _is_test_annotation(stripped): - if not helper_added: - helper_added = True - result.append(line) - i += 1 - - # Collect any additional annotations - while i < len(lines) and lines[i].strip().startswith("@"): - result.append(lines[i]) - i += 1 - - # Now find the method signature and opening brace - method_lines = [] - while i < len(lines): - method_lines.append(lines[i]) - if "{" in lines[i]: - break - i += 1 - - # Add the method signature lines - for ml in method_lines: - result.append(ml) - i += 1 - - # We're now inside the method body - iteration_counter += 1 - iter_id = iteration_counter - - # Detect indentation - method_sig_line = method_lines[-1] if method_lines else "" - base_indent = len(method_sig_line) - len(method_sig_line.lstrip()) - indent = " " * (base_indent + 4) - - # Collect method body until we find matching closing brace - brace_depth = 1 - body_lines = [] - - while i < len(lines) and brace_depth > 0: - body_line = lines[i] - # Count braces more efficiently using string methods - open_count = body_line.count("{") - close_count = body_line.count("}") - brace_depth += open_count - close_count - - if brace_depth > 0: - body_lines.append(body_line) - i += 1 - else: - # We've hit the closing brace - i += 1 - break - - # Wrap function calls to capture return values - # Look for patterns like: obj.funcName(args) or new Class().funcName(args) - call_counter = 0 - wrapped_body_lines = [] - - # Track lambda block nesting depth to avoid wrapping calls inside lambda bodies. - # assertThrows/assertDoesNotThrow expect an Executable (void functional interface), - # and wrapping the call in a variable assignment would turn the void-compatible - # lambda into a value-returning lambda, causing a compilation error. - # Also, variables declared outside lambdas cannot be reassigned inside them - # (Java requires effectively final variables in lambda captures). - # Handles both no-arg lambdas: () -> { func(); } - # and parameterized lambdas: (a, b, c) -> { func(); } - lambda_brace_depth = 0 - - for body_line in body_lines: - # Detect block lambda openings: (...) -> { or () -> { - # Matches both () -> { and (a, b, c) -> { - is_lambda_open = bool(re.search(r"->\s*\{", body_line)) - - # Update lambda brace depth tracking for block lambdas - if is_lambda_open or lambda_brace_depth > 0: - open_braces = body_line.count("{") - close_braces = body_line.count("}") - if is_lambda_open and lambda_brace_depth == 0: - # Starting a new lambda block - only count braces from this lambda - lambda_brace_depth = open_braces - close_braces - else: - lambda_brace_depth += open_braces - close_braces - # Ensure depth doesn't go below 0 - lambda_brace_depth = max(0, lambda_brace_depth) - - inside_lambda = lambda_brace_depth > 0 or bool(re.search(r"->\s+\S", body_line)) - - # Check if this line contains a call to the target function - if func_name in body_line and "(" in body_line: - # Skip wrapping if the function call is inside a lambda expression - if inside_lambda: - wrapped_body_lines.append(body_line) - continue - - line_indent = len(body_line) - len(body_line.lstrip()) - line_indent_str = " " * line_indent - - # Find all matches using balanced parenthesis matching - # This correctly handles nested parens like: - # obj.func(a, Rows.toRowID(frame.getIndex(), row)) - matches = _find_method_calls_balanced(body_line, func_name) - if matches: - # Process matches in reverse order to maintain correct positions - new_line = body_line - for start_pos, end_pos, full_call in reversed(matches): - call_counter += 1 - var_name = f"_cf_result{iter_id}_{call_counter}" - - # Check if we need to cast the result for assertions with primitive arrays - # This handles assertArrayEquals(int[], int[]) etc. - cast_type = _infer_array_cast_type(body_line) - var_with_cast = f"({cast_type}){var_name}" if cast_type else var_name - - # Replace this occurrence with the variable (with cast if needed) - new_line = new_line[:start_pos] + var_with_cast + new_line[end_pos:] - - # Use 'var' instead of 'Object' to preserve the exact return type. - # This avoids boxing mismatches (e.g., assertEquals(int, Object) where - # Object is boxed Long but expected is boxed Integer). Requires Java 10+. - capture_line = f"{line_indent_str}var {var_name} = {full_call};" - wrapped_body_lines.append(capture_line) - - # Immediately serialize the captured result while the variable - # is still in scope. This is necessary because the variable may - # be declared inside a nested block (while/for/if/try) and would - # be out of scope at the end of the method body. - serialize_line = ( - f"{line_indent_str}_cf_serializedResult{iter_id} = " - f"com.codeflash.Serializer.serialize((Object) {var_name});" - ) - wrapped_body_lines.append(serialize_line) - - # Check if the line is now just a variable reference (invalid statement) - # This happens when the original line was just a void method call - # e.g., "BubbleSort.bubbleSort(original);" becomes "_cf_result1_1;" - stripped_new = new_line.strip().rstrip(";").strip() - if stripped_new and stripped_new not in (var_name, var_with_cast): - wrapped_body_lines.append(new_line) - else: - wrapped_body_lines.append(body_line) - else: - wrapped_body_lines.append(body_line) - - # Add behavior instrumentation code - behavior_start_code = [ - f"{indent}// Codeflash behavior instrumentation", - f'{indent}int _cf_loop{iter_id} = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));', - f"{indent}int _cf_iter{iter_id} = {iter_id};", - f'{indent}String _cf_mod{iter_id} = "{class_name}";', - f'{indent}String _cf_cls{iter_id} = "{class_name}";', - f'{indent}String _cf_fn{iter_id} = "{func_name}";', - f'{indent}String _cf_outputFile{iter_id} = System.getenv("CODEFLASH_OUTPUT_FILE");', - f'{indent}String _cf_testIteration{iter_id} = System.getenv("CODEFLASH_TEST_ITERATION");', - f'{indent}if (_cf_testIteration{iter_id} == null) _cf_testIteration{iter_id} = "0";', - f'{indent}System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":" + _cf_iter{iter_id} + "######$!");', - f"{indent}long _cf_start{iter_id} = System.nanoTime();", - f"{indent}byte[] _cf_serializedResult{iter_id} = null;", - f"{indent}try {{", - ] - result.extend(behavior_start_code) - - # Add the wrapped body lines with extra indentation. - # Serialization of captured results is already done inline (immediately - # after each capture) so the _cf_serializedResult variable is always - # assigned while the captured variable is still in scope. - for bl in wrapped_body_lines: - result.append(" " + bl) - - # Add finally block with SQLite write - method_close_indent = " " * base_indent - behavior_end_code = [ - f"{indent}}} finally {{", - f"{indent} long _cf_end{iter_id} = System.nanoTime();", - f"{indent} long _cf_dur{iter_id} = _cf_end{iter_id} - _cf_start{iter_id};", - f'{indent} System.out.println("!######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":" + _cf_iter{iter_id} + ":" + _cf_dur{iter_id} + "######!");', - f"{indent} // Write to SQLite if output file is set", - f"{indent} if (_cf_outputFile{iter_id} != null && !_cf_outputFile{iter_id}.isEmpty()) {{", - f"{indent} try {{", - f'{indent} Class.forName("org.sqlite.JDBC");', - f'{indent} try (Connection _cf_conn{iter_id} = DriverManager.getConnection("jdbc:sqlite:" + _cf_outputFile{iter_id})) {{', - f"{indent} try (java.sql.Statement _cf_stmt{iter_id} = _cf_conn{iter_id}.createStatement()) {{", - f'{indent} _cf_stmt{iter_id}.execute("CREATE TABLE IF NOT EXISTS test_results (" +', - f'{indent} "test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " +', - f'{indent} "function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " +', - f'{indent} "runtime INTEGER, return_value BLOB, verification_type TEXT)");', - f"{indent} }}", - f'{indent} String _cf_sql{iter_id} = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)";', - f"{indent} try (PreparedStatement _cf_pstmt{iter_id} = _cf_conn{iter_id}.prepareStatement(_cf_sql{iter_id})) {{", - f"{indent} _cf_pstmt{iter_id}.setString(1, _cf_mod{iter_id});", - f"{indent} _cf_pstmt{iter_id}.setString(2, _cf_cls{iter_id});", - f'{indent} _cf_pstmt{iter_id}.setString(3, "{class_name}Test");', - f"{indent} _cf_pstmt{iter_id}.setString(4, _cf_fn{iter_id});", - f"{indent} _cf_pstmt{iter_id}.setInt(5, _cf_loop{iter_id});", - f'{indent} _cf_pstmt{iter_id}.setString(6, _cf_iter{iter_id} + "_" + _cf_testIteration{iter_id});', - f"{indent} _cf_pstmt{iter_id}.setLong(7, _cf_dur{iter_id});", - f"{indent} _cf_pstmt{iter_id}.setBytes(8, _cf_serializedResult{iter_id});", # Kryo-serialized return value - f'{indent} _cf_pstmt{iter_id}.setString(9, "function_call");', - f"{indent} _cf_pstmt{iter_id}.executeUpdate();", - f"{indent} }}", - f"{indent} }}", - f"{indent} }} catch (Exception _cf_e{iter_id}) {{", - f'{indent} System.err.println("CodeflashHelper: SQLite error: " + _cf_e{iter_id}.getMessage());', - f"{indent} }}", - f"{indent} }}", - f"{indent}}}", - f"{method_close_indent}}}", # Method closing brace + # Extract body lines (between { and }) + body_start_line = body_node.start_point[0] + body_end_line = body_node.end_point[0] + body_lines = lines[body_start_line + 1 : body_end_line] + + brace_line = lines[body_start_line] + base_indent = len(brace_line) - len(brace_line.lstrip()) + indent = " " * (base_indent + 4) + + # ── Find method invocations via tree-sitter ────────────────────── + invocations = analyzer.find_method_invocations(body_node, source_bytes, func_name) + + # Group invocations by their 0-indexed source line + invocations_by_line: dict[int, list] = {} + for inv in invocations: + inv_line = inv.node.start_point[0] + invocations_by_line.setdefault(inv_line, []).append(inv) + + # ── Wrap function calls per body line ──────────────────────────── + call_counter = 0 + wrapped_body_lines: list[str] = [] + + for local_idx, body_line in enumerate(body_lines): + source_line_idx = body_start_line + 1 + local_idx + line_invocations = invocations_by_line.get(source_line_idx, []) + + # Filter to non-lambda invocations on a single line + actionable = [ + inv for inv in line_invocations + if not inv.in_lambda and inv.node.start_point[0] == inv.node.end_point[0] ] - result.extend(behavior_end_code) - else: - result.append(line) - i += 1 - return "\n".join(result) + if actionable: + line_indent = len(body_line) - len(body_line.lstrip()) + line_indent_str = " " * line_indent + + # Process matches in reverse column order to preserve positions + actionable.sort(key=lambda inv: inv.node.start_point[1], reverse=True) + new_line = body_line + last_var_name = None + last_var_with_cast = None + + for inv in actionable: + call_counter += 1 + var_name = f"_cf_result{iter_id}_{call_counter}" + last_var_name = var_name + + cast_type = _infer_array_cast_type(body_line) + var_with_cast = f"({cast_type}){var_name}" if cast_type else var_name + last_var_with_cast = var_with_cast + + start_col = inv.node.start_point[1] + end_col = inv.node.end_point[1] + new_line = new_line[:start_col] + var_with_cast + new_line[end_col:] + + capture_line = f"{line_indent_str}var {var_name} = {inv.full_text};" + wrapped_body_lines.append(capture_line) + + serialize_line = ( + f"{line_indent_str}_cf_serializedResult{iter_id} = " + f"com.codeflash.Serializer.serialize((Object) {var_name});" + ) + wrapped_body_lines.append(serialize_line) + + # Skip line if it collapsed to just a bare variable reference + stripped_new = new_line.strip().rstrip(";").strip() + if stripped_new and stripped_new not in (last_var_name, last_var_with_cast): + wrapped_body_lines.append(new_line) + else: + wrapped_body_lines.append(body_line) + + # ── Build replacement body ─────────────────────────────────────── + method_close_indent = " " * base_indent + behavior_lines = [ + "{", + f"{indent}// Codeflash behavior instrumentation", + f'{indent}int _cf_loop{iter_id} = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));', + f"{indent}int _cf_iter{iter_id} = {iter_id};", + f'{indent}String _cf_mod{iter_id} = "{class_name}";', + f'{indent}String _cf_cls{iter_id} = "{class_name}";', + f'{indent}String _cf_fn{iter_id} = "{func_name}";', + f'{indent}String _cf_outputFile{iter_id} = System.getenv("CODEFLASH_OUTPUT_FILE");', + f'{indent}String _cf_testIteration{iter_id} = System.getenv("CODEFLASH_TEST_ITERATION");', + f'{indent}if (_cf_testIteration{iter_id} == null) _cf_testIteration{iter_id} = "0";', + f'{indent}System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":" + _cf_iter{iter_id} + "######$!");', + f"{indent}long _cf_start{iter_id} = System.nanoTime();", + f"{indent}byte[] _cf_serializedResult{iter_id} = null;", + f"{indent}try {{", + ] + + for bl in wrapped_body_lines: + behavior_lines.append(" " + bl) + + behavior_lines.extend([ + f"{indent}}} finally {{", + f"{indent} long _cf_end{iter_id} = System.nanoTime();", + f"{indent} long _cf_dur{iter_id} = _cf_end{iter_id} - _cf_start{iter_id};", + f'{indent} System.out.println("!######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":" + _cf_iter{iter_id} + ":" + _cf_dur{iter_id} + "######!");', + f"{indent} // Write to SQLite if output file is set", + f"{indent} if (_cf_outputFile{iter_id} != null && !_cf_outputFile{iter_id}.isEmpty()) {{", + f"{indent} try {{", + f'{indent} Class.forName("org.sqlite.JDBC");', + f'{indent} try (Connection _cf_conn{iter_id} = DriverManager.getConnection("jdbc:sqlite:" + _cf_outputFile{iter_id})) {{', + f"{indent} try (java.sql.Statement _cf_stmt{iter_id} = _cf_conn{iter_id}.createStatement()) {{", + f'{indent} _cf_stmt{iter_id}.execute("CREATE TABLE IF NOT EXISTS test_results (" +', + f'{indent} "test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " +', + f'{indent} "function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " +', + f'{indent} "runtime INTEGER, return_value BLOB, verification_type TEXT)");', + f"{indent} }}", + f'{indent} String _cf_sql{iter_id} = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)";', + f"{indent} try (PreparedStatement _cf_pstmt{iter_id} = _cf_conn{iter_id}.prepareStatement(_cf_sql{iter_id})) {{", + f"{indent} _cf_pstmt{iter_id}.setString(1, _cf_mod{iter_id});", + f"{indent} _cf_pstmt{iter_id}.setString(2, _cf_cls{iter_id});", + f'{indent} _cf_pstmt{iter_id}.setString(3, "{class_name}Test");', + f"{indent} _cf_pstmt{iter_id}.setString(4, _cf_fn{iter_id});", + f"{indent} _cf_pstmt{iter_id}.setInt(5, _cf_loop{iter_id});", + f'{indent} _cf_pstmt{iter_id}.setString(6, _cf_iter{iter_id} + "_" + _cf_testIteration{iter_id});', + f"{indent} _cf_pstmt{iter_id}.setLong(7, _cf_dur{iter_id});", + f"{indent} _cf_pstmt{iter_id}.setBytes(8, _cf_serializedResult{iter_id});", + f'{indent} _cf_pstmt{iter_id}.setString(9, "function_call");', + f"{indent} _cf_pstmt{iter_id}.executeUpdate();", + f"{indent} }}", + f"{indent} }}", + f"{indent} }} catch (Exception _cf_e{iter_id}) {{", + f'{indent} System.err.println("CodeflashHelper: SQLite error: " + _cf_e{iter_id}.getMessage());', + f"{indent} }}", + f"{indent} }}", + f"{indent}}}", + f"{method_close_indent}}}", + ]) + + replacement = "\n".join(behavior_lines) + replacements.append((body_node.start_byte, body_node.end_byte, replacement)) + + # Apply replacements in reverse byte order + for start, end, replacement in sorted(replacements, key=lambda x: x[0], reverse=True): + source_bytes = source_bytes[:start] + replacement.encode("utf8") + source_bytes[end:] + + return source_bytes.decode("utf8") def _add_timing_instrumentation(source: str, class_name: str, func_name: str) -> str: """Add timing instrumentation to test methods with inner loop for JIT warmup. + Uses tree-sitter to find @Test methods and their body boundaries, + then replaces each method body with a timing-wrapped version. + For each @Test method, this adds: 1. Inner loop that runs N iterations (controlled by CODEFLASH_INNER_ITERATIONS env var) 2. Start timing marker printed at the beginning of each iteration @@ -619,110 +463,74 @@ def _add_timing_instrumentation(source: str, class_name: str, func_name: str) -> Instrumented source code. """ - # Find all @Test methods and add timing around their bodies - # Pattern matches: @Test (with optional parameters) followed by method declaration - # We process line by line for cleaner handling + from codeflash.languages.java.parser import get_java_analyzer - lines = source.split("\n") - result = [] - i = 0 - iteration_counter = 0 + analyzer = get_java_analyzer() + source_bytes = source.encode("utf8") + test_methods = analyzer.find_test_methods(source) - while i < len(lines): - line = lines[i] - stripped = line.strip() + if not test_methods: + return source - # Look for @Test annotation (not @TestOnly, @TestFactory, etc.) - if _is_test_annotation(stripped): - result.append(line) - i += 1 - - # Collect any additional annotations - while i < len(lines) and lines[i].strip().startswith("@"): - result.append(lines[i]) - i += 1 - - # Now find the method signature and opening brace - method_lines = [] - while i < len(lines): - method_lines.append(lines[i]) - if "{" in lines[i]: - break - i += 1 - - # Add the method signature lines - result.extend(method_lines) - i += 1 - - # We're now inside the method body - iteration_counter += 1 - iter_id = iteration_counter - - # Detect indentation from method signature line (line with opening brace) - method_sig_line = method_lines[-1] if method_lines else "" - base_indent = len(method_sig_line) - len(method_sig_line.lstrip()) - indent = " " * (base_indent + 4) # Add one level of indentation - inner_indent = " " * (base_indent + 8) # Two levels for inside inner loop - inner_body_indent = " " * (base_indent + 12) # Three levels for try block body - - # Add timing instrumentation with inner loop - # Note: CODEFLASH_LOOP_INDEX must always be set - no null check, crash if missing - # CODEFLASH_INNER_ITERATIONS controls inner loop count (default: 100) - timing_start_code = [ - f"{indent}// Codeflash timing instrumentation with inner loop for JIT warmup", - f'{indent}int _cf_loop{iter_id} = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));', - f'{indent}int _cf_innerIterations{iter_id} = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "100"));', - f'{indent}String _cf_mod{iter_id} = "{class_name}";', - f'{indent}String _cf_cls{iter_id} = "{class_name}";', - f'{indent}String _cf_fn{iter_id} = "{func_name}";', - "", - f"{indent}for (int _cf_i{iter_id} = 0; _cf_i{iter_id} < _cf_innerIterations{iter_id}; _cf_i{iter_id}++) {{", - f'{inner_indent}System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":" + _cf_i{iter_id} + "######$!");', - f"{inner_indent}long _cf_start{iter_id} = System.nanoTime();", - f"{inner_indent}try {{", - ] - result.extend(timing_start_code) - - # Collect method body until we find matching closing brace - brace_depth = 1 - body_lines = [] - - while i < len(lines) and brace_depth > 0: - body_line = lines[i] - # Count braces (simple approach - doesn't handle strings/comments perfectly) - for ch in body_line: - if ch == "{": - brace_depth += 1 - elif ch == "}": - brace_depth -= 1 - - if brace_depth > 0: - body_lines.append(body_line) - i += 1 - else: - # This line contains the closing brace, but we've hit depth 0 - # Add indented body lines (inside try block, inside for loop) - for bl in body_lines: - result.append(" " + bl) # 8 extra spaces for inner loop + try - - # Add finally block and close inner loop - method_close_indent = " " * base_indent # Same level as method signature - timing_end_code = [ - f"{inner_indent}}} finally {{", - f"{inner_indent} long _cf_end{iter_id} = System.nanoTime();", - f"{inner_indent} long _cf_dur{iter_id} = _cf_end{iter_id} - _cf_start{iter_id};", - f'{inner_indent} System.out.println("!######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":" + _cf_i{iter_id} + ":" + _cf_dur{iter_id} + "######!");', - f"{inner_indent}}}", - f"{indent}}}", # Close for loop - f"{method_close_indent}}}", # Method closing brace - ] - result.extend(timing_end_code) - i += 1 - else: - result.append(line) - i += 1 - - return "\n".join(result) + lines = source.split("\n") + + # Build a replacement for each @Test method body. + # iter_id is assigned in forward (source) order; replacements are applied in reverse. + replacements: list[tuple[int, int, str]] = [] + + for iter_id, method_info in enumerate(test_methods, start=1): + body_node = method_info.body_node + + # Extract body lines (between { and }, exclusive of both brace lines) + body_start_line = body_node.start_point[0] # 0-indexed line of { + body_end_line = body_node.end_point[0] # 0-indexed line of } + body_lines = lines[body_start_line + 1 : body_end_line] + + # Indentation from the line containing the opening brace + brace_line = lines[body_start_line] + base_indent = len(brace_line) - len(brace_line.lstrip()) + indent = " " * (base_indent + 4) + inner_indent = " " * (base_indent + 8) + method_close_indent = " " * base_indent + + # Build the replacement body (opening { through closing }) + timing_lines = [ + "{", + f"{indent}// Codeflash timing instrumentation with inner loop for JIT warmup", + f'{indent}int _cf_loop{iter_id} = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));', + f'{indent}int _cf_innerIterations{iter_id} = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "100"));', + f'{indent}String _cf_mod{iter_id} = "{class_name}";', + f'{indent}String _cf_cls{iter_id} = "{class_name}";', + f'{indent}String _cf_fn{iter_id} = "{func_name}";', + "", + f"{indent}for (int _cf_i{iter_id} = 0; _cf_i{iter_id} < _cf_innerIterations{iter_id}; _cf_i{iter_id}++) {{", + f'{inner_indent}System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":" + _cf_i{iter_id} + "######$!");', + f"{inner_indent}long _cf_start{iter_id} = System.nanoTime();", + f"{inner_indent}try {{", + ] + + # Original body lines, indented by 8 extra spaces (for loop + try) + for bl in body_lines: + timing_lines.append(" " + bl) + + timing_lines.extend([ + f"{inner_indent}}} finally {{", + f"{inner_indent} long _cf_end{iter_id} = System.nanoTime();", + f"{inner_indent} long _cf_dur{iter_id} = _cf_end{iter_id} - _cf_start{iter_id};", + f'{inner_indent} System.out.println("!######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":" + _cf_i{iter_id} + ":" + _cf_dur{iter_id} + "######!");', + f"{inner_indent}}}", + f"{indent}}}", # Close for loop + f"{method_close_indent}}}", # Method closing brace + ]) + + replacement = "\n".join(timing_lines) + replacements.append((body_node.start_byte, body_node.end_byte, replacement)) + + # Apply replacements in reverse byte order to preserve earlier offsets + for start, end, replacement in sorted(replacements, key=lambda x: x[0], reverse=True): + source_bytes = source_bytes[:start] + replacement.encode("utf8") + source_bytes[end:] + + return source_bytes.decode("utf8") def create_benchmark_test( @@ -825,14 +633,16 @@ def instrument_generated_java_test( if mode == "behavior": test_code = transform_java_assertions(test_code, function_name, qualified_name) - # Extract class name from the test code - # Use pattern that starts at beginning of line to avoid matching words in comments - class_match = re.search(r"^(?:public\s+)?class\s+(\w+)", test_code, re.MULTILINE) - if not class_match: + # Extract class name from the test code using tree-sitter + from codeflash.languages.java.parser import get_java_analyzer + + analyzer = get_java_analyzer() + classes = analyzer.find_classes(test_code) + if not classes: logger.warning("Could not find class name in generated test") return test_code - original_class_name = class_match.group(1) + original_class_name = classes[0].name # Rename class based on mode if mode == "behavior": @@ -840,11 +650,17 @@ def instrument_generated_java_test( else: new_class_name = f"{original_class_name}__perfonlyinstrumented" - # Rename all references to the original class name in the source. - # This includes the class declaration, return types, constructor calls, etc. - modified_code = re.sub( - rf"\b{re.escape(original_class_name)}\b", new_class_name, test_code - ) + # Rename all identifier references to the class name using tree-sitter + # (excludes matches inside string literals and comments) + refs = analyzer.find_identifier_references(test_code, original_class_name) + if refs: + code_bytes = test_code.encode("utf8") + new_name_bytes = new_class_name.encode("utf8") + for start, end in reversed(refs): + code_bytes = code_bytes[:start] + new_name_bytes + code_bytes[end:] + modified_code = code_bytes.decode("utf8") + else: + modified_code = test_code # For performance mode, add timing instrumentation # Use original class name (without suffix) in timing markers for consistency with Python @@ -859,32 +675,4 @@ def instrument_generated_java_test( return modified_code -def _add_import(source: str, import_statement: str) -> str: - """Add an import statement to the source. - - Args: - source: The source code. - import_statement: The import to add. - - Returns: - Source with import added. - - """ - lines = source.splitlines(keepends=True) - insert_idx = 0 - - # Find the last import or package statement - for i, line in enumerate(lines): - stripped = line.strip() - if stripped.startswith(("import ", "package ")): - insert_idx = i + 1 - elif stripped and not stripped.startswith("//") and not stripped.startswith("/*"): - # First non-import, non-comment line - if insert_idx == 0: - insert_idx = i - break - - lines.insert(insert_idx, import_statement + "\n") - return "".join(lines) - diff --git a/codeflash/languages/java/line_profiler.py b/codeflash/languages/java/line_profiler.py index 8a59ed6e6..67895149c 100644 --- a/codeflash/languages/java/line_profiler.py +++ b/codeflash/languages/java/line_profiler.py @@ -9,7 +9,6 @@ import json import logging -import re from pathlib import Path from typing import TYPE_CHECKING @@ -93,18 +92,18 @@ def instrument_source( # Add profiler class and initialization profiler_class_code = self._generate_profiler_class() - # Insert profiler class before the package's first class - # Find the first class/interface/enum/record declaration - # Must handle any combination of modifiers: public final class, abstract class, etc. - class_pattern = re.compile( - r"^(?:(?:public|private|protected|final|abstract|static|sealed|non-sealed)\s+)*" - r"(?:class|interface|enum|record)\s+" - ) + # Insert profiler class before the package's first class declaration. + # Uses tree-sitter find_classes() to handle any modifier combination + # (public final class, abstract class, etc.). + from codeflash.languages.java.parser import get_java_analyzer + + analyzer = get_java_analyzer() + source_text = "".join(lines) + classes = analyzer.find_classes(source_text) import_end_idx = 0 - for i, line in enumerate(lines): - if class_pattern.match(line.strip()): - import_end_idx = i - break + if classes: + # find_classes returns 1-indexed start_line; convert to 0-indexed + import_end_idx = classes[0].start_line - 1 lines_with_profiler = ( lines[:import_end_idx] + [profiler_class_code + "\n"] + lines[import_end_idx:] diff --git a/codeflash/languages/java/parser.py b/codeflash/languages/java/parser.py index 72a530179..52d10726a 100644 --- a/codeflash/languages/java/parser.py +++ b/codeflash/languages/java/parser.py @@ -53,6 +53,24 @@ class JavaMethodNode: javadoc_start_line: int | None = None # Line where Javadoc comment starts +@dataclass +class TestMethodInfo: + """A @Test-annotated method found by tree-sitter analysis.""" + + name: str + method_node: Node + body_node: Node # The block (body) of the method + + +@dataclass +class MethodCallInfo: + """A method invocation found within a given AST subtree.""" + + node: Node + full_text: str + in_lambda: bool + + @dataclass class JavaClassNode: """Represents a class found by tree-sitter analysis.""" @@ -678,6 +696,170 @@ def get_package_name(self, source: str) -> str | None: return None + def find_test_methods(self, source: str) -> list[TestMethodInfo]: + """Find all @Test-annotated methods in source code. + + Identifies method_declaration nodes whose modifiers contain a + marker_annotation or annotation with name ``Test`` (not ``TestOnly``, + ``TestFactory``, etc.). + + Args: + source: The source code to analyze. + + Returns: + List of TestMethodInfo in source order. + + """ + source_bytes = source.encode("utf8") + tree = self.parse(source_bytes) + results: list[TestMethodInfo] = [] + self._walk_for_test_methods(tree.root_node, source_bytes, results) + return results + + def _walk_for_test_methods( + self, node: Node, source_bytes: bytes, results: list[TestMethodInfo] + ) -> None: + """Recursively walk the tree to find @Test methods.""" + if node.type == "method_declaration": + if self._has_test_annotation(node, source_bytes): + body = node.child_by_field_name("body") + name_node = node.child_by_field_name("name") + if body and name_node: + results.append( + TestMethodInfo( + name=self.get_node_text(name_node, source_bytes), + method_node=node, + body_node=body, + ) + ) + + for child in node.children: + self._walk_for_test_methods(child, source_bytes, results) + + def _has_test_annotation(self, method_node: Node, source_bytes: bytes) -> bool: + """Check if a method_declaration has an @Test annotation. + + Only matches exact ``Test`` name – not ``TestOnly``, ``TestFactory``, etc. + Handles both ``@Test`` (marker_annotation) and ``@Test(...)`` (annotation). + """ + for child in method_node.children: + if child.type == "modifiers": + for mod_child in child.children: + if mod_child.type in ("marker_annotation", "annotation"): + name_node = mod_child.child_by_field_name("name") + if name_node is None: + # Fallback: search direct children for identifier + for ann_child in mod_child.children: + if ann_child.type == "identifier": + name_node = ann_child + break + if name_node and self.get_node_text(name_node, source_bytes) == "Test": + return True + return False + + def find_method_invocations( + self, body_node: Node, source_bytes: bytes, func_name: str + ) -> list[MethodCallInfo]: + """Find all invocations of *func_name* within a given AST subtree. + + Checks the parent chain for ``lambda_expression`` nodes to set the + ``in_lambda`` flag. + + Args: + body_node: The subtree to search within (typically a method body block). + source_bytes: The full source code as bytes. + func_name: The method name to match. + + Returns: + List of MethodCallInfo in source order. + + """ + results: list[MethodCallInfo] = [] + self._walk_for_invocations(body_node, source_bytes, func_name, results, in_lambda=False) + return results + + def _walk_for_invocations( + self, + node: Node, + source_bytes: bytes, + func_name: str, + results: list[MethodCallInfo], + in_lambda: bool, + ) -> None: + """Recursively find method invocations matching func_name.""" + if node.type == "lambda_expression": + in_lambda = True + + if node.type == "method_invocation": + name_node = node.child_by_field_name("name") + if name_node and self.get_node_text(name_node, source_bytes) == func_name: + results.append( + MethodCallInfo( + node=node, + full_text=self.get_node_text(node, source_bytes), + in_lambda=in_lambda, + ) + ) + + for child in node.children: + self._walk_for_invocations(child, source_bytes, func_name, results, in_lambda) + + def find_identifier_references(self, source: str, name: str) -> list[tuple[int, int]]: + """Find all ``identifier`` / ``type_identifier`` AST nodes matching *name*. + + Tree-sitter naturally excludes matches inside string literals and comments, + making this safer than ``re.sub(r"\\b...\\b", ...)``. + + Args: + source: The source code to analyze. + name: The identifier name to match. + + Returns: + List of ``(start_byte, end_byte)`` pairs in source order. + + """ + source_bytes = source.encode("utf8") + tree = self.parse(source_bytes) + refs: list[tuple[int, int]] = [] + self._walk_for_identifiers(tree.root_node, source_bytes, name, refs) + return refs + + def _walk_for_identifiers( + self, node: Node, source_bytes: bytes, name: str, refs: list[tuple[int, int]] + ) -> None: + """Recursively find identifier/type_identifier nodes matching name.""" + if node.type in ("identifier", "type_identifier"): + if self.get_node_text(node, source_bytes) == name: + refs.append((node.start_byte, node.end_byte)) + for child in node.children: + self._walk_for_identifiers(child, source_bytes, name, refs) + + def find_import_insertion_point(self, source: str) -> int: + """Find the 0-indexed line after the last import or package declaration. + + Returns the line index where new import statements should be inserted. + If no imports or package found, returns 0. + + Args: + source: The source code to analyze. + + Returns: + 0-indexed line number for insertion. + + """ + source_bytes = source.encode("utf8") + tree = self.parse(source_bytes) + last_line = 0 + + for child in tree.root_node.children: + if child.type in ("import_declaration", "package_declaration"): + # end_point[0] is 0-indexed last line of the node + candidate = child.end_point[0] + 1 + if candidate > last_line: + last_line = candidate + + return last_line + def get_java_analyzer() -> JavaAnalyzer: """Get a JavaAnalyzer instance. From 501efb5da875ea2620495df7690a0eaf758f19af Mon Sep 17 00:00:00 2001 From: HeshamHM28 Date: Fri, 13 Feb 2026 00:43:35 +0000 Subject: [PATCH 2/2] fix(java): add progressive fallback for token limit in code context extraction Previously, get_code_optimization_context_for_language() would raise a hard ValueError when the extracted code context exceeded the 16,000 token limit, causing 93% of Java functions in large projects to fail optimization. This was because Java's helper traversal (max_depth=2) pulls in transitive dependencies, and type skeleton wrapping adds all class fields and constructors. This commit adds a 4-stage progressive fallback strategy: 1. Full context (all helpers, Javadoc intact) 2. Remove cross-file helpers (keep same-file helpers only) 3. Strip Javadoc comments from all code 4. Remove all helpers (target code only) Each stage is tried in order until the token limit is satisfied, with debug logging when a fallback is used. The same fallback applies independently to both optim and testgen token limits. Also extracts the code string building logic into a reusable _build_code_strings_for_language() helper and adds a _strip_javadoc_comments() utility for removing /** ... */ blocks while preserving other comments. Co-Authored-By: Claude Opus 4.6 --- codeflash/context/code_context_extractor.py | 249 ++++++++++++++------ 1 file changed, 180 insertions(+), 69 deletions(-) diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index 61de73c32..72c6d51c3 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -208,39 +208,38 @@ def get_code_optimization_context( ) -def get_code_optimization_context_for_language( - function_to_optimize: FunctionToOptimize, - project_root_path: Path, - optim_token_limit: int = OPTIMIZATION_CONTEXT_TOKEN_LIMIT, - testgen_token_limit: int = TESTGEN_CONTEXT_TOKEN_LIMIT, -) -> CodeOptimizationContext: - """Extract code optimization context for non-Python languages. +def _strip_javadoc_comments(source: str) -> str: + """Strip Javadoc (/** ... */) comments from Java source code. - Uses the language support abstraction to extract code context and converts - it to the CodeOptimizationContext format expected by the pipeline. + Preserves single-line comments (//) and regular block comments (/* ... */). + """ + import re - This function supports multi-file context extraction, grouping helpers by file - and creating proper CodeStringsMarkdown with file paths for multi-file replacement. + return re.sub(r"/\*\*.*?\*/\s*", "", source, flags=re.DOTALL) + + +def _build_code_strings_for_language( + code_context, + function_to_optimize: FunctionToOptimize, + project_root_path: Path, + include_cross_file_helpers: bool = True, + strip_javadoc: bool = False, + include_same_file_helpers: bool = True, +) -> tuple[list[CodeString], list[FunctionSource], str]: + """Build CodeString list from a CodeContext with configurable reduction. Args: - function_to_optimize: The function to extract context for. + code_context: CodeContext from language support. + function_to_optimize: The target function. project_root_path: Root of the project. - optim_token_limit: Token limit for optimization context. - testgen_token_limit: Token limit for testgen context. + include_cross_file_helpers: Whether to include helpers from other files. + strip_javadoc: Whether to strip Javadoc comments from all code. + include_same_file_helpers: Whether to include same-file helper methods. Returns: - CodeOptimizationContext with target code and dependencies. + Tuple of (code_strings, helper_function_sources, read_only_context). """ - from codeflash.languages import get_language_support - - # Get language support for this function - language = Language(function_to_optimize.language) - lang_support = get_language_support(language) - - # Extract code context using language support - code_context = lang_support.extract_code_context(function_to_optimize, project_root_path, project_root_path) - # Build imports string if available imports_code = "\n".join(code_context.imports) if code_context.imports else "" @@ -251,82 +250,194 @@ def get_code_optimization_context_for_language( target_relative_path = function_to_optimize.file_path # Group helpers by file path - helpers_by_file: dict[Path, list[HelperFunction]] = defaultdict(list) + helpers_by_file: dict[Path, list] = defaultdict(list) helper_function_sources = [] for helper in code_context.helper_functions: helpers_by_file[helper.file_path].append(helper) # Convert to FunctionSource for pipeline compatibility - helper_function_sources.append( - FunctionSource( - file_path=helper.file_path, - qualified_name=helper.qualified_name, - fully_qualified_name=helper.qualified_name, - only_function_name=helper.name, - source_code=helper.source_code, - jedi_definition=None, - ) + should_include = ( + (helper.file_path == function_to_optimize.file_path and include_same_file_helpers) + or (helper.file_path != function_to_optimize.file_path and include_cross_file_helpers) ) + if should_include: + helper_function_sources.append( + FunctionSource( + file_path=helper.file_path, + qualified_name=helper.qualified_name, + fully_qualified_name=helper.qualified_name, + only_function_name=helper.name, + source_code=helper.source_code, + jedi_definition=None, + ) + ) - # Build read-writable code (target file + same-file helpers + global variables) - read_writable_code_strings = [] + # Build read-writable code (target file + same-file helpers) + code_strings = [] # Combine target code with same-file helpers target_file_code = code_context.target_code - same_file_helpers = helpers_by_file.get(function_to_optimize.file_path, []) - if same_file_helpers: - helper_code = "\n\n".join(h.source_code for h in same_file_helpers) - target_file_code = target_file_code + "\n\n" + helper_code - - # Note: code_context.read_only_context contains type definitions and global variables - # These should be passed as read-only context to the AI, not prepended to the target code - # If prepended to target code, the AI treats them as code to optimize and includes them in output + if include_same_file_helpers: + same_file_helpers = helpers_by_file.get(function_to_optimize.file_path, []) + if same_file_helpers: + helper_code = "\n\n".join(h.source_code for h in same_file_helpers) + target_file_code = target_file_code + "\n\n" + helper_code # Add imports to target file code if imports_code: target_file_code = imports_code + "\n\n" + target_file_code - read_writable_code_strings.append( + if strip_javadoc: + target_file_code = _strip_javadoc_comments(target_file_code) + + code_strings.append( CodeString(code=target_file_code, file_path=target_relative_path, language=function_to_optimize.language) ) # Add helper files (cross-file helpers) - for file_path, file_helpers in helpers_by_file.items(): - if file_path == function_to_optimize.file_path: - continue # Already included in target file + if include_cross_file_helpers: + for file_path, file_helpers in helpers_by_file.items(): + if file_path == function_to_optimize.file_path: + continue # Already included in target file - try: - helper_relative_path = file_path.resolve().relative_to(project_root_path.resolve()) - except ValueError: - helper_relative_path = file_path + try: + helper_relative_path = file_path.resolve().relative_to(project_root_path.resolve()) + except ValueError: + helper_relative_path = file_path + + combined_helper_code = "\n\n".join(h.source_code for h in file_helpers) + if strip_javadoc: + combined_helper_code = _strip_javadoc_comments(combined_helper_code) + + code_strings.append( + CodeString( + code=combined_helper_code, + file_path=helper_relative_path, + language=function_to_optimize.language, + ) + ) - # Combine all helpers from this file - combined_helper_code = "\n\n".join(h.source_code for h in file_helpers) + read_only_context = code_context.read_only_context + if strip_javadoc and read_only_context: + read_only_context = _strip_javadoc_comments(read_only_context) - read_writable_code_strings.append( - CodeString( - code=combined_helper_code, file_path=helper_relative_path, language=function_to_optimize.language - ) + return code_strings, helper_function_sources, read_only_context + + +def get_code_optimization_context_for_language( + function_to_optimize: FunctionToOptimize, + project_root_path: Path, + optim_token_limit: int = OPTIMIZATION_CONTEXT_TOKEN_LIMIT, + testgen_token_limit: int = TESTGEN_CONTEXT_TOKEN_LIMIT, +) -> CodeOptimizationContext: + """Extract code optimization context for non-Python languages. + + Uses the language support abstraction to extract code context and converts + it to the CodeOptimizationContext format expected by the pipeline. + + This function supports multi-file context extraction, grouping helpers by file + and creating proper CodeStringsMarkdown with file paths for multi-file replacement. + + Applies progressive fallback when token limits are exceeded: + 1. Full context (all helpers, Javadoc intact) + 2. Remove cross-file helpers + 3. Strip Javadoc comments + 4. Remove all helpers (target code only) + + Args: + function_to_optimize: The function to extract context for. + project_root_path: Root of the project. + optim_token_limit: Token limit for optimization context. + testgen_token_limit: Token limit for testgen context. + + Returns: + CodeOptimizationContext with target code and dependencies. + + """ + from codeflash.languages import get_language_support + + # Get language support for this function + language = Language(function_to_optimize.language) + lang_support = get_language_support(language) + + # Extract code context using language support + code_context = lang_support.extract_code_context(function_to_optimize, project_root_path, project_root_path) + + # Progressive fallback strategies, ordered from most to least context + fallback_strategies = [ + {"include_cross_file_helpers": True, "strip_javadoc": False, "include_same_file_helpers": True}, + {"include_cross_file_helpers": False, "strip_javadoc": False, "include_same_file_helpers": True}, + {"include_cross_file_helpers": False, "strip_javadoc": True, "include_same_file_helpers": True}, + {"include_cross_file_helpers": False, "strip_javadoc": True, "include_same_file_helpers": False}, + ] + + fallback_descriptions = [ + "full context", + "without cross-file helpers", + "without cross-file helpers and Javadoc", + "target code only (no helpers, no Javadoc)", + ] + + code_strings = None + helper_function_sources = None + read_only_context = None + + for i, strategy in enumerate(fallback_strategies): + code_strings, helper_function_sources, read_only_context = _build_code_strings_for_language( + code_context, function_to_optimize, project_root_path, **strategy ) + read_writable_code = CodeStringsMarkdown( + code_strings=code_strings, language=function_to_optimize.language + ) + read_writable_tokens = encoded_tokens_len(read_writable_code.markdown) + + if read_writable_tokens <= optim_token_limit: + if i > 0: + logger.debug( + "Code context exceeded token limit, using fallback: %s (%d tokens)", + fallback_descriptions[i], + read_writable_tokens, + ) + break + else: + raise ValueError("Read-writable code has exceeded token limit even after removing all helpers and Javadoc") + read_writable_code = CodeStringsMarkdown( - code_strings=read_writable_code_strings, language=function_to_optimize.language + code_strings=code_strings, language=function_to_optimize.language ) - # Build testgen context (same as read_writable for non-Python) + # Build testgen context with its own progressive fallback + # Start from the same strategy level that worked for optim + testgen_code_strings = code_strings + testgen_helpers = helper_function_sources + testgen_context = CodeStringsMarkdown( - code_strings=read_writable_code_strings.copy(), language=function_to_optimize.language + code_strings=testgen_code_strings.copy(), language=function_to_optimize.language ) - - # Check token limits - read_writable_tokens = encoded_tokens_len(read_writable_code.markdown) - if read_writable_tokens > optim_token_limit: - raise ValueError("Read-writable code has exceeded token limit, cannot proceed") - testgen_tokens = encoded_tokens_len(testgen_context.markdown) + if testgen_tokens > testgen_token_limit: - raise ValueError("Testgen code context has exceeded token limit, cannot proceed") + # Try remaining fallback strategies for testgen + for j in range(i + 1, len(fallback_strategies)): + testgen_code_strings, testgen_helpers, read_only_context = _build_code_strings_for_language( + code_context, function_to_optimize, project_root_path, **fallback_strategies[j] + ) + testgen_context = CodeStringsMarkdown( + code_strings=testgen_code_strings.copy(), language=function_to_optimize.language + ) + testgen_tokens = encoded_tokens_len(testgen_context.markdown) + + if testgen_tokens <= testgen_token_limit: + logger.debug( + "Testgen context exceeded token limit, using fallback: %s (%d tokens)", + fallback_descriptions[j], + testgen_tokens, + ) + break + else: + raise ValueError("Testgen code context has exceeded token limit even after removing all helpers and Javadoc") # Generate code hash from all read-writable code code_hash = hashlib.sha256(read_writable_code.flat.encode("utf-8")).hexdigest() @@ -336,7 +447,7 @@ def get_code_optimization_context_for_language( read_writable_code=read_writable_code, # Pass type definitions and globals as read-only context for the AI # This way the AI sees them as context but doesn't include them in optimized output - read_only_context_code=code_context.read_only_context, + read_only_context_code=read_only_context, hashing_code_context=read_writable_code.flat, hashing_code_context_hash=code_hash, helper_functions=helper_function_sources,