|
11 | 11 | from typing import TYPE_CHECKING |
12 | 12 |
|
13 | 13 | from codeflash.code_utils.code_utils import encoded_tokens_len |
14 | | -from codeflash.languages.base import CodeContext, HelperFunction, Language |
| 14 | +from codeflash.languages.base import CodeContext, HelperFunction |
15 | 15 | from codeflash.languages.java.discovery import discover_functions_from_source |
16 | 16 | from codeflash.languages.java.import_resolver import JavaImportResolver, find_helper_files |
17 | 17 | from codeflash.languages.java.parser import get_java_analyzer |
| 18 | +from codeflash.languages.language_enum import Language |
18 | 19 |
|
19 | 20 | if TYPE_CHECKING: |
20 | 21 | from pathlib import Path |
21 | 22 |
|
22 | 23 | from tree_sitter import Node |
23 | 24 |
|
24 | 25 | from codeflash.discovery.functions_to_optimize import FunctionToOptimize |
25 | | - from codeflash.languages.java.parser import JavaAnalyzer, JavaMethodNode |
| 26 | + from codeflash.languages.java.import_resolver import ResolvedImport |
| 27 | + from codeflash.languages.java.parser import JavaAnalyzer, JavaImportInfo, JavaMethodNode |
26 | 28 |
|
27 | 29 | logger = logging.getLogger(__name__) |
28 | 30 |
|
@@ -360,7 +362,7 @@ def _extract_type_declaration(type_node: Node, source_bytes: bytes, type_kind: s |
360 | 362 |
|
361 | 363 |
|
362 | 364 | # Keep old function name for backwards compatibility |
363 | | -def _extract_class_declaration(node, source_bytes): |
| 365 | +def _extract_class_declaration(node: Node, source_bytes: bytes) -> str: |
364 | 366 | return _extract_type_declaration(node, source_bytes, "class") |
365 | 367 |
|
366 | 368 |
|
@@ -629,6 +631,8 @@ def _extract_function_source_by_lines(source: str, function: FunctionToOptimize) |
629 | 631 |
|
630 | 632 | start_line = function.doc_start_line or function.starting_line |
631 | 633 | end_line = function.ending_line |
| 634 | + if start_line is None or end_line is None: |
| 635 | + return "" |
632 | 636 |
|
633 | 637 | # Convert from 1-indexed to 0-indexed |
634 | 638 | start_idx = start_line - 1 |
@@ -672,6 +676,8 @@ def find_helper_functions( |
672 | 676 | func_id = f"{file_path}:{func.qualified_name}" |
673 | 677 | if func_id not in visited_functions: |
674 | 678 | visited_functions.add(func_id) |
| 679 | + if func.starting_line is None or func.ending_line is None: |
| 680 | + continue |
675 | 681 |
|
676 | 682 | # Extract the function source using tree-sitter for resilient lookup |
677 | 683 | func_source = extract_function_source(source, func, analyzer=analyzer) |
@@ -795,7 +801,7 @@ def extract_read_only_context(source: str, function: FunctionToOptimize, analyze |
795 | 801 | return "\n".join(context_parts) |
796 | 802 |
|
797 | 803 |
|
798 | | -def _import_to_statement(import_info) -> str: |
| 804 | +def _import_to_statement(import_info: JavaImportInfo) -> str: |
799 | 805 | """Convert a JavaImportInfo to an import statement string. |
800 | 806 |
|
801 | 807 | Args: |
@@ -898,7 +904,11 @@ def _extract_type_names_from_code(code: str, analyzer: JavaAnalyzer) -> set[str] |
898 | 904 |
|
899 | 905 |
|
900 | 906 | def get_java_imported_type_skeletons( |
901 | | - imports: list, project_root: Path, module_root: Path | None, analyzer: JavaAnalyzer, target_code: str = "" |
| 907 | + imports: list[JavaImportInfo], |
| 908 | + project_root: Path, |
| 909 | + module_root: Path | None, |
| 910 | + analyzer: JavaAnalyzer, |
| 911 | + target_code: str = "", |
902 | 912 | ) -> str: |
903 | 913 | """Extract type skeletons for project-internal imported types. |
904 | 914 |
|
@@ -933,7 +943,7 @@ def get_java_imported_type_skeletons( |
933 | 943 | priority_types = _extract_type_names_from_code(target_code, analyzer) |
934 | 944 |
|
935 | 945 | # Pre-resolve all imports, expanding wildcards into individual types |
936 | | - resolved_imports: list = [] |
| 946 | + resolved_imports: list[ResolvedImport] = [] |
937 | 947 | for imp in imports: |
938 | 948 | if imp.is_wildcard: |
939 | 949 | # First try unfiltered expansion with a cap. If the package is small enough, take all types. |
@@ -978,7 +988,7 @@ def get_java_imported_type_skeletons( |
978 | 988 |
|
979 | 989 | for resolved in resolved_imports: |
980 | 990 | class_name = resolved.class_name |
981 | | - if not class_name: |
| 991 | + if not class_name or resolved.file_path is None: |
982 | 992 | continue |
983 | 993 |
|
984 | 994 | dedup_key = (str(resolved.file_path), class_name) |
@@ -1100,8 +1110,6 @@ def _extract_public_method_signatures(source: str, class_name: str, analyzer: Ja |
1100 | 1110 | continue |
1101 | 1111 |
|
1102 | 1112 | node = method.node |
1103 | | - if not node: |
1104 | | - continue |
1105 | 1113 |
|
1106 | 1114 | # Check if the method is public |
1107 | 1115 | is_public = False |
|
0 commit comments