Skip to content

Commit 0a2ec48

Browse files
Merge pull request #1951 from codeflash-ai/cf-1085-cap-wildcard-import-expansion
fix: cap wildcard import expansion to avoid token explosion
2 parents 8aff48f + 33b4eb8 commit 0a2ec48

3 files changed

Lines changed: 193 additions & 120 deletions

File tree

codeflash/languages/java/context.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,20 @@
1111
from typing import TYPE_CHECKING
1212

1313
from codeflash.code_utils.code_utils import encoded_tokens_len
14-
from codeflash.languages.base import CodeContext, HelperFunction, Language
14+
from codeflash.languages.base import CodeContext, HelperFunction
1515
from codeflash.languages.java.discovery import discover_functions_from_source
1616
from codeflash.languages.java.import_resolver import JavaImportResolver, find_helper_files
1717
from codeflash.languages.java.parser import get_java_analyzer
18+
from codeflash.languages.language_enum import Language
1819

1920
if TYPE_CHECKING:
2021
from pathlib import Path
2122

2223
from tree_sitter import Node
2324

2425
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
25-
from codeflash.languages.java.parser import JavaAnalyzer, JavaMethodNode
26+
from codeflash.languages.java.import_resolver import ResolvedImport
27+
from codeflash.languages.java.parser import JavaAnalyzer, JavaImportInfo, JavaMethodNode
2628

2729
logger = logging.getLogger(__name__)
2830

@@ -360,7 +362,7 @@ def _extract_type_declaration(type_node: Node, source_bytes: bytes, type_kind: s
360362

361363

362364
# Keep old function name for backwards compatibility
363-
def _extract_class_declaration(node, source_bytes):
365+
def _extract_class_declaration(node: Node, source_bytes: bytes) -> str:
364366
return _extract_type_declaration(node, source_bytes, "class")
365367

366368

@@ -629,6 +631,8 @@ def _extract_function_source_by_lines(source: str, function: FunctionToOptimize)
629631

630632
start_line = function.doc_start_line or function.starting_line
631633
end_line = function.ending_line
634+
if start_line is None or end_line is None:
635+
return ""
632636

633637
# Convert from 1-indexed to 0-indexed
634638
start_idx = start_line - 1
@@ -672,6 +676,8 @@ def find_helper_functions(
672676
func_id = f"{file_path}:{func.qualified_name}"
673677
if func_id not in visited_functions:
674678
visited_functions.add(func_id)
679+
if func.starting_line is None or func.ending_line is None:
680+
continue
675681

676682
# Extract the function source using tree-sitter for resilient lookup
677683
func_source = extract_function_source(source, func, analyzer=analyzer)
@@ -795,7 +801,7 @@ def extract_read_only_context(source: str, function: FunctionToOptimize, analyze
795801
return "\n".join(context_parts)
796802

797803

798-
def _import_to_statement(import_info) -> str:
804+
def _import_to_statement(import_info: JavaImportInfo) -> str:
799805
"""Convert a JavaImportInfo to an import statement string.
800806
801807
Args:
@@ -863,6 +869,10 @@ def extract_class_context(file_path: Path, class_name: str, analyzer: JavaAnalyz
863869

864870
# Maximum token budget for imported type skeletons to avoid bloating testgen context
865871
IMPORTED_SKELETON_TOKEN_BUDGET = 4000
872+
# Maximum types to expand from a single wildcard import before filtering to referenced types only.
873+
# Packages with more types than this (e.g. org.jooq with 870+) would waste minutes of disk I/O
874+
# and almost always exceed the token budget.
875+
MAX_WILDCARD_TYPES_UNFILTERED = 50
866876

867877

868878
def _extract_type_names_from_code(code: str, analyzer: JavaAnalyzer) -> set[str]:
@@ -894,7 +904,11 @@ def _extract_type_names_from_code(code: str, analyzer: JavaAnalyzer) -> set[str]
894904

895905

896906
def get_java_imported_type_skeletons(
897-
imports: list, project_root: Path, module_root: Path | None, analyzer: JavaAnalyzer, target_code: str = ""
907+
imports: list[JavaImportInfo],
908+
project_root: Path,
909+
module_root: Path | None,
910+
analyzer: JavaAnalyzer,
911+
target_code: str = "",
898912
) -> str:
899913
"""Extract type skeletons for project-internal imported types.
900914
@@ -929,14 +943,32 @@ def get_java_imported_type_skeletons(
929943
priority_types = _extract_type_names_from_code(target_code, analyzer)
930944

931945
# Pre-resolve all imports, expanding wildcards into individual types
932-
resolved_imports: list = []
946+
resolved_imports: list[ResolvedImport] = []
933947
for imp in imports:
934948
if imp.is_wildcard:
935-
# Expand wildcard imports (e.g., com.aerospike.client.policy.*) into individual types
936-
expanded = resolver.expand_wildcard_import(imp.import_path)
949+
# First try unfiltered expansion with a cap. If the package is small enough, take all types.
950+
# If it's huge (e.g. org.jooq.* with 870+ types), filter to only types referenced in the target code.
951+
expanded = resolver.expand_wildcard_import(imp.import_path, max_types=MAX_WILDCARD_TYPES_UNFILTERED + 1)
952+
if len(expanded) > MAX_WILDCARD_TYPES_UNFILTERED:
953+
if priority_types:
954+
expanded = resolver.expand_wildcard_import(imp.import_path, filter_names=priority_types)
955+
logger.debug(
956+
"Wildcard %s.* exceeds %d types, filtered to %d referenced types",
957+
imp.import_path,
958+
MAX_WILDCARD_TYPES_UNFILTERED,
959+
len(expanded),
960+
)
961+
else:
962+
expanded = expanded[:MAX_WILDCARD_TYPES_UNFILTERED]
963+
logger.debug(
964+
"Wildcard %s.* exceeds %d types, capped (no target types to filter by)",
965+
imp.import_path,
966+
MAX_WILDCARD_TYPES_UNFILTERED,
967+
)
968+
elif expanded:
969+
logger.debug("Expanded wildcard import %s.* into %d types", imp.import_path, len(expanded))
937970
if expanded:
938971
resolved_imports.extend(expanded)
939-
logger.debug("Expanded wildcard import %s.* into %d types", imp.import_path, len(expanded))
940972
continue
941973

942974
resolved = resolver.resolve_import(imp)
@@ -956,7 +988,7 @@ def get_java_imported_type_skeletons(
956988

957989
for resolved in resolved_imports:
958990
class_name = resolved.class_name
959-
if not class_name:
991+
if not class_name or resolved.file_path is None:
960992
continue
961993

962994
dedup_key = (str(resolved.file_path), class_name)
@@ -1078,8 +1110,6 @@ def _extract_public_method_signatures(source: str, class_name: str, analyzer: Ja
10781110
continue
10791111

10801112
node = method.node
1081-
if not node:
1082-
continue
10831113

10841114
# Check if the method is public
10851115
is_public = False

codeflash/languages/java/import_resolver.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -220,14 +220,20 @@ def _extract_class_name(self, import_path: str) -> str | None:
220220
return last_part
221221
return None
222222

223-
def expand_wildcard_import(self, import_path: str) -> list[ResolvedImport]:
223+
def expand_wildcard_import(
224+
self, import_path: str, max_types: int = 0, filter_names: set[str] | None = None
225+
) -> list[ResolvedImport]:
224226
"""Expand a wildcard import (e.g., com.example.utils.*) to individual class imports.
225227
226228
Resolves the package path to a directory and returns a ResolvedImport for each
227229
.java file found in that directory.
230+
231+
Args:
232+
import_path: The package path (without the trailing .*).
233+
max_types: Maximum number of types to return. 0 means no limit.
234+
filter_names: If provided, only include types whose class name is in this set.
235+
228236
"""
229-
# Convert package path to directory path
230-
# e.g., "com.example.utils" -> "com/example/utils"
231237
relative_dir = import_path.replace(".", "/")
232238

233239
resolved: list[ResolvedImport] = []
@@ -237,17 +243,21 @@ def expand_wildcard_import(self, import_path: str) -> list[ResolvedImport]:
237243
if candidate_dir.is_dir():
238244
for java_file in candidate_dir.glob("*.java"):
239245
class_name = java_file.stem
240-
# Only include files that look like class names (start with uppercase)
241-
if class_name and class_name[0].isupper():
242-
resolved.append(
243-
ResolvedImport(
244-
import_path=f"{import_path}.{class_name}",
245-
file_path=java_file,
246-
is_external=False,
247-
is_wildcard=False,
248-
class_name=class_name,
249-
)
246+
if not class_name or not class_name[0].isupper():
247+
continue
248+
if filter_names is not None and class_name not in filter_names:
249+
continue
250+
resolved.append(
251+
ResolvedImport(
252+
import_path=f"{import_path}.{class_name}",
253+
file_path=java_file,
254+
is_external=False,
255+
is_wildcard=False,
256+
class_name=class_name,
250257
)
258+
)
259+
if max_types and len(resolved) >= max_types:
260+
return resolved
251261

252262
return resolved
253263

0 commit comments

Comments
 (0)