Skip to content
54 changes: 42 additions & 12 deletions codeflash/languages/java/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,20 @@
from typing import TYPE_CHECKING

from codeflash.code_utils.code_utils import encoded_tokens_len
from codeflash.languages.base import CodeContext, HelperFunction, Language
from codeflash.languages.base import CodeContext, HelperFunction
from codeflash.languages.java.discovery import discover_functions_from_source
from codeflash.languages.java.import_resolver import JavaImportResolver, find_helper_files
from codeflash.languages.java.parser import get_java_analyzer
from codeflash.languages.language_enum import Language

if TYPE_CHECKING:
from pathlib import Path

from tree_sitter import Node

from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.java.parser import JavaAnalyzer, JavaMethodNode
from codeflash.languages.java.import_resolver import ResolvedImport
from codeflash.languages.java.parser import JavaAnalyzer, JavaImportInfo, JavaMethodNode

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -360,7 +362,7 @@ def _extract_type_declaration(type_node: Node, source_bytes: bytes, type_kind: s


# Keep old function name for backwards compatibility
def _extract_class_declaration(node, source_bytes):
def _extract_class_declaration(node: Node, source_bytes: bytes) -> str:
return _extract_type_declaration(node, source_bytes, "class")


Expand Down Expand Up @@ -629,6 +631,8 @@ def _extract_function_source_by_lines(source: str, function: FunctionToOptimize)

start_line = function.doc_start_line or function.starting_line
end_line = function.ending_line
if start_line is None or end_line is None:
return ""

# Convert from 1-indexed to 0-indexed
start_idx = start_line - 1
Expand Down Expand Up @@ -672,6 +676,8 @@ def find_helper_functions(
func_id = f"{file_path}:{func.qualified_name}"
if func_id not in visited_functions:
visited_functions.add(func_id)
if func.starting_line is None or func.ending_line is None:
continue

# Extract the function source using tree-sitter for resilient lookup
func_source = extract_function_source(source, func, analyzer=analyzer)
Expand Down Expand Up @@ -795,7 +801,7 @@ def extract_read_only_context(source: str, function: FunctionToOptimize, analyze
return "\n".join(context_parts)


def _import_to_statement(import_info) -> str:
def _import_to_statement(import_info: JavaImportInfo) -> str:
"""Convert a JavaImportInfo to an import statement string.

Args:
Expand Down Expand Up @@ -863,6 +869,10 @@ def extract_class_context(file_path: Path, class_name: str, analyzer: JavaAnalyz

# Maximum token budget for imported type skeletons to avoid bloating testgen context
IMPORTED_SKELETON_TOKEN_BUDGET = 4000
# Maximum types to expand from a single wildcard import before filtering to referenced types only.
# Packages with more types than this (e.g. org.jooq with 870+) would waste minutes of disk I/O
# and almost always exceed the token budget.
MAX_WILDCARD_TYPES_UNFILTERED = 50


def _extract_type_names_from_code(code: str, analyzer: JavaAnalyzer) -> set[str]:
Expand Down Expand Up @@ -894,7 +904,11 @@ def _extract_type_names_from_code(code: str, analyzer: JavaAnalyzer) -> set[str]


def get_java_imported_type_skeletons(
imports: list, project_root: Path, module_root: Path | None, analyzer: JavaAnalyzer, target_code: str = ""
imports: list[JavaImportInfo],
project_root: Path,
module_root: Path | None,
analyzer: JavaAnalyzer,
target_code: str = "",
) -> str:
"""Extract type skeletons for project-internal imported types.

Expand Down Expand Up @@ -929,14 +943,32 @@ def get_java_imported_type_skeletons(
priority_types = _extract_type_names_from_code(target_code, analyzer)

# Pre-resolve all imports, expanding wildcards into individual types
resolved_imports: list = []
resolved_imports: list[ResolvedImport] = []
for imp in imports:
if imp.is_wildcard:
# Expand wildcard imports (e.g., com.aerospike.client.policy.*) into individual types
expanded = resolver.expand_wildcard_import(imp.import_path)
# First try unfiltered expansion with a cap. If the package is small enough, take all types.
# If it's huge (e.g. org.jooq.* with 870+ types), filter to only types referenced in the target code.
expanded = resolver.expand_wildcard_import(imp.import_path, max_types=MAX_WILDCARD_TYPES_UNFILTERED + 1)
if len(expanded) > MAX_WILDCARD_TYPES_UNFILTERED:
if priority_types:
expanded = resolver.expand_wildcard_import(imp.import_path, filter_names=priority_types)
logger.debug(
"Wildcard %s.* exceeds %d types, filtered to %d referenced types",
imp.import_path,
MAX_WILDCARD_TYPES_UNFILTERED,
len(expanded),
)
else:
expanded = expanded[:MAX_WILDCARD_TYPES_UNFILTERED]
logger.debug(
"Wildcard %s.* exceeds %d types, capped (no target types to filter by)",
imp.import_path,
MAX_WILDCARD_TYPES_UNFILTERED,
)
elif expanded:
logger.debug("Expanded wildcard import %s.* into %d types", imp.import_path, len(expanded))
if expanded:
resolved_imports.extend(expanded)
logger.debug("Expanded wildcard import %s.* into %d types", imp.import_path, len(expanded))
continue

resolved = resolver.resolve_import(imp)
Expand All @@ -956,7 +988,7 @@ def get_java_imported_type_skeletons(

for resolved in resolved_imports:
class_name = resolved.class_name
if not class_name:
if not class_name or resolved.file_path is None:
continue

dedup_key = (str(resolved.file_path), class_name)
Expand Down Expand Up @@ -1078,8 +1110,6 @@ def _extract_public_method_signatures(source: str, class_name: str, analyzer: Ja
continue

node = method.node
if not node:
continue

# Check if the method is public
is_public = False
Expand Down
36 changes: 23 additions & 13 deletions codeflash/languages/java/import_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,14 +220,20 @@ def _extract_class_name(self, import_path: str) -> str | None:
return last_part
return None

def expand_wildcard_import(self, import_path: str) -> list[ResolvedImport]:
def expand_wildcard_import(
self, import_path: str, max_types: int = 0, filter_names: set[str] | None = None
) -> list[ResolvedImport]:
"""Expand a wildcard import (e.g., com.example.utils.*) to individual class imports.

Resolves the package path to a directory and returns a ResolvedImport for each
.java file found in that directory.

Args:
import_path: The package path (without the trailing .*).
max_types: Maximum number of types to return. 0 means no limit.
filter_names: If provided, only include types whose class name is in this set.

"""
# Convert package path to directory path
# e.g., "com.example.utils" -> "com/example/utils"
relative_dir = import_path.replace(".", "/")

resolved: list[ResolvedImport] = []
Expand All @@ -237,17 +243,21 @@ def expand_wildcard_import(self, import_path: str) -> list[ResolvedImport]:
if candidate_dir.is_dir():
for java_file in candidate_dir.glob("*.java"):
class_name = java_file.stem
# Only include files that look like class names (start with uppercase)
if class_name and class_name[0].isupper():
resolved.append(
ResolvedImport(
import_path=f"{import_path}.{class_name}",
file_path=java_file,
is_external=False,
is_wildcard=False,
class_name=class_name,
)
if not class_name or not class_name[0].isupper():
continue
if filter_names is not None and class_name not in filter_names:
continue
resolved.append(
ResolvedImport(
import_path=f"{import_path}.{class_name}",
file_path=java_file,
is_external=False,
is_wildcard=False,
class_name=class_name,
)
)
if max_types and len(resolved) >= max_types:
return resolved

return resolved

Expand Down
Loading