Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions codeflash/code_utils/config_consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
JAVA_TESTCASE_TIMEOUT = 120 # Java Maven tests need more time due to startup overhead
MAX_FUNCTION_TEST_SECONDS = 60
MIN_IMPROVEMENT_THRESHOLD = 0.05
MIN_IMPROVEMENT_THRESHOLD_JAVA = 0.02
MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD = 0.10 # 10% minimum improvement for async throughput
MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD = 0.20 # 20% concurrency ratio improvement required
CONCURRENCY_FACTOR = 10 # Number of concurrent executions for concurrency benchmark
Expand Down
22 changes: 17 additions & 5 deletions codeflash/languages/java/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,11 +508,23 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
)
wrapped_body_lines.append(serialize_line)

# Check if the line is now just a variable reference (invalid statement)
# This happens when the original line was just a void method call
# e.g., "BubbleSort.bubbleSort(original);" becomes "_cf_result1_1;"
# Check if the line is now just a variable reference (invalid statement).
# This happens when the original line was just a void method call:
# "BubbleSort.bubbleSort(original);" becomes "_cf_result1_1;"
# It also happens when assertThrows was transformed to try-catch:
# "try { func(args); } catch (...) {}" becomes
# "try { _cf_result1_1; } catch (...) {}"
# A bare variable is not a valid Java statement.
stripped_new = new_line.strip().rstrip(";").strip()
if stripped_new and stripped_new not in (var_name, var_with_cast):
is_bare_var = stripped_new in (var_name, var_with_cast)
is_try_with_bare_var = bool(re.match(
r"try\s*\{\s*(?:"
+ re.escape(var_name)
+ (r"|" + re.escape(var_with_cast) if var_with_cast != var_name else "")
+ r")\s*;\s*\}\s*catch\s*\(",
stripped_new,
))
if stripped_new and not is_bare_var and not is_try_with_bare_var:
wrapped_body_lines.append(new_line)
else:
wrapped_body_lines.append(body_line)
Expand Down Expand Up @@ -834,7 +846,7 @@ def instrument_generated_java_test(
original_class_name = class_match.group(1)


# For performance mode, add timing instrumentation
# Add mode-specific instrumentation
# Use original class name (without suffix) in timing markers for consistency with Python
if mode == "performance":

Expand Down
65 changes: 62 additions & 3 deletions codeflash/languages/java/replacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ class ParsedOptimization:
target_method_source: str
new_fields: list[str] # Source text of new fields to add
new_helper_methods: list[str] # Source text of new helper methods to add
new_imports: list[str] # Import statements to add (e.g., "import java.nio.file.Files;")


def _parse_optimization_source(new_source: str, target_method_name: str, analyzer: JavaAnalyzer) -> ParsedOptimization:
"""Parse optimization source to extract method and additional class members.
"""Parse optimization source to extract method, imports, and additional class members.

The new_source may contain:
- Just a method definition
Expand All @@ -48,13 +49,20 @@ def _parse_optimization_source(new_source: str, target_method_name: str, analyze
analyzer: JavaAnalyzer instance.

Returns:
ParsedOptimization with the method and any additional members.
ParsedOptimization with the method, imports, and any additional members.

"""
new_fields: list[str] = []
new_helper_methods: list[str] = []
target_method_source = new_source # Default to the whole source

# Extract import statements from the candidate code
new_imports: list[str] = []
for imp in analyzer.find_imports(new_source):
prefix = "import static " if imp.is_static else "import "
suffix = ".*" if imp.is_wildcard else ""
new_imports.append(f"{prefix}{imp.import_path}{suffix};")

# Check if this is a full class or just a method
classes = analyzer.find_classes(new_source)

Expand Down Expand Up @@ -92,10 +100,57 @@ def _parse_optimization_source(new_source: str, target_method_name: str, analyze
new_fields.append(field.source_text)

return ParsedOptimization(
target_method_source=target_method_source, new_fields=new_fields, new_helper_methods=new_helper_methods
target_method_source=target_method_source,
new_fields=new_fields,
new_helper_methods=new_helper_methods,
new_imports=new_imports,
)


def _add_missing_imports(source: str, candidate_imports: list[str], analyzer: JavaAnalyzer) -> str:
"""Add import statements from the optimization candidate that are missing in the original source.

Args:
source: The original source code.
candidate_imports: Import statements from the candidate (e.g., ["import java.nio.file.Files;"]).
analyzer: JavaAnalyzer instance.

Returns:
Source code with missing imports added.

"""
existing_imports = analyzer.find_imports(source)
existing_import_strs = set()
for imp in existing_imports:
prefix = "import static " if imp.is_static else "import "
suffix = ".*" if imp.is_wildcard else ""
existing_import_strs.add(f"{prefix}{imp.import_path}{suffix};")

missing_imports = [imp for imp in candidate_imports if imp not in existing_import_strs]
if not missing_imports:
return source

logger.debug("Adding %d missing imports: %s", len(missing_imports), missing_imports)

# Insert after the last existing import, or after the package declaration
lines = source.splitlines(keepends=True)
insert_line = 0

if existing_imports:
insert_line = max(imp.end_line for imp in existing_imports)
else:
# No existing imports — insert after package declaration
for i, line in enumerate(lines):
if line.strip().startswith("package "):
insert_line = i + 1
break

import_block = "".join(imp + "\n" for imp in missing_imports)
before = lines[:insert_line]
after = lines[insert_line:]
return "".join(before) + import_block + "".join(after)


def _insert_class_members(
source: str, class_name: str, fields: list[str], methods: list[str], analyzer: JavaAnalyzer
) -> str:
Expand Down Expand Up @@ -237,6 +292,10 @@ def replace_function(
# Parse the optimization to extract components
parsed = _parse_optimization_source(new_source, func_name, analyzer)

# Add any new imports from the optimization candidate
if parsed.new_imports:
source = _add_missing_imports(source, parsed.new_imports, analyzer)

# Find the method in the original source
methods = analyzer.find_methods(source)
target_method = None
Expand Down
6 changes: 6 additions & 0 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
COVERAGE_THRESHOLD,
INDIVIDUAL_TESTCASE_TIMEOUT,
MIN_CORRECT_CANDIDATES,
MIN_IMPROVEMENT_THRESHOLD_JAVA,
OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
REFINED_CANDIDATE_RANKING_WEIGHTS,
REPEAT_OPTIMIZATION_PROBABILITY,
Expand Down Expand Up @@ -1364,6 +1365,8 @@ def process_single_candidate(
eval_ctx.record_successful_candidate(candidate.optimization_id, candidate_result.best_test_runtime, perf_gain)

# Check if this is a successful optimization
# Use a lower threshold for Java where I/O-bound functions have smaller optimization margins
java_override = MIN_IMPROVEMENT_THRESHOLD_JAVA if self.language_support.language == "java" else None
is_successful_opt = speedup_critic(
candidate_result,
original_code_baseline.runtime,
Expand All @@ -1372,6 +1375,7 @@ def process_single_candidate(
best_throughput_until_now=None,
original_concurrency_metrics=original_code_baseline.concurrency_metrics,
best_concurrency_ratio_until_now=None,
min_improvement_override=java_override,
) and quantity_of_tests_critic(candidate_result)

tree = self.build_runtime_info_tree(
Expand Down Expand Up @@ -2272,13 +2276,15 @@ def find_and_process_best_optimization(
fto_benchmark_timings=self.function_benchmark_timings,
total_benchmark_timings=self.total_benchmark_timings,
)
java_override = MIN_IMPROVEMENT_THRESHOLD_JAVA if self.language_support.language == "java" else None
acceptance_reason = get_acceptance_reason(
original_runtime_ns=original_code_baseline.runtime,
optimized_runtime_ns=best_optimization.runtime,
original_async_throughput=original_code_baseline.async_throughput,
optimized_async_throughput=best_optimization.async_throughput,
original_concurrency_metrics=original_code_baseline.concurrency_metrics,
optimized_concurrency_metrics=best_optimization.concurrency_metrics,
min_improvement_override=java_override,
)
explanation = Explanation(
raw_explanation_message=best_optimization.candidate.explanation,
Expand Down
34 changes: 34 additions & 0 deletions codeflash/optimization/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,38 @@ def _verify_js_requirements(self) -> None:
except Exception as e:
logger.debug(f"Failed to verify JS requirements: {e}")

def _setup_java_runtime(self) -> None:
from pathlib import Path as _Path

from codeflash.languages.java.build_tools import add_codeflash_dependency_to_pom, install_codeflash_runtime

runtime_jar = _Path(__file__).parent.parent / "languages" / "java" / "resources" / "codeflash-runtime-1.0.0.jar"
project_root = self.args.project_root

if not runtime_jar.exists():
logger.warning("codeflash-runtime JAR not found at %s, behavior capture may not work", runtime_jar)
return

if not install_codeflash_runtime(project_root, runtime_jar):
logger.warning("Failed to install codeflash-runtime to local Maven repo")
return

# Add dependency to the test module's pom.xml (or root pom.xml)
test_root = _Path(self.args.tests_root) if hasattr(self.args, "tests_root") and self.args.tests_root else None
pom_path = project_root / "pom.xml"

# For multi-module projects, find the test module's pom.xml
if test_root and test_root != project_root:
candidate = test_root
while candidate != project_root and candidate != candidate.parent:
if (candidate / "pom.xml").exists():
pom_path = candidate / "pom.xml"
break
candidate = candidate.parent

if pom_path.exists():
add_codeflash_dependency_to_pom(pom_path)

def run_benchmarks(
self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], num_optimizable_functions: int
) -> tuple[dict[str, dict[BenchmarkKey, float]], dict[BenchmarkKey, float]]:
Expand Down Expand Up @@ -495,6 +527,8 @@ def run(self) -> None:
self.test_cfg.js_project_root = self._find_js_project_root(file_path)
# Verify JS requirements before proceeding
self._verify_js_requirements()
elif is_java():
self._setup_java_runtime()
break

if self.args.all:
Expand Down
8 changes: 6 additions & 2 deletions codeflash/result/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def speedup_critic(
best_throughput_until_now: int | None = None,
original_concurrency_metrics: ConcurrencyMetrics | None = None,
best_concurrency_ratio_until_now: float | None = None,
min_improvement_override: float | None = None,
) -> bool:
"""Take in a correct optimized Test Result and decide if the optimization should actually be surfaced to the user.

Expand All @@ -92,7 +93,8 @@ def speedup_critic(
- Concurrency improvements detect when blocking calls are replaced with non-blocking equivalents
"""
# Runtime performance evaluation
noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_code_runtime < 10000 else MIN_IMPROVEMENT_THRESHOLD
threshold = min_improvement_override if min_improvement_override is not None else MIN_IMPROVEMENT_THRESHOLD
noise_floor = 3 * threshold if original_code_runtime < 10000 else threshold
if not disable_gh_action_noise and env_utils.is_ci():
noise_floor = noise_floor * 2 # Increase the noise floor in GitHub Actions mode

Expand Down Expand Up @@ -146,13 +148,15 @@ def get_acceptance_reason(
optimized_async_throughput: int | None = None,
original_concurrency_metrics: ConcurrencyMetrics | None = None,
optimized_concurrency_metrics: ConcurrencyMetrics | None = None,
min_improvement_override: float | None = None,
) -> AcceptanceReason:
"""Determine why an optimization was accepted.

Returns the primary reason for acceptance, with priority:
concurrency > throughput > runtime (for async code).
"""
noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_runtime_ns < 10000 else MIN_IMPROVEMENT_THRESHOLD
threshold = min_improvement_override if min_improvement_override is not None else MIN_IMPROVEMENT_THRESHOLD
noise_floor = 3 * threshold if original_runtime_ns < 10000 else threshold
if env_utils.is_ci():
noise_floor = noise_floor * 2

Expand Down
15 changes: 15 additions & 0 deletions codeflash/verification/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,21 @@ def generate_tests(
)

logger.debug(f"Instrumented Java tests locally for {func_name}")
logger.debug(
f"=== Java Generated Tests (raw) for {func_name} ===\n"
f"{generated_test_source}\n"
f"=== End Java Generated Tests ==="
)
logger.debug(
f"=== Java Instrumented Behavior Tests for {func_name} ===\n"
f"{instrumented_behavior_test_source}\n"
f"=== End Java Instrumented Behavior Tests ==="
)
logger.debug(
f"=== Java Instrumented Perf Tests for {func_name} ===\n"
f"{instrumented_perf_test_source}\n"
f"=== End Java Instrumented Perf Tests ==="
)
else:
# Python: instrumentation is done by aiservice, just replace temp dir placeholders
instrumented_behavior_test_source = instrumented_behavior_test_source.replace(
Expand Down
22 changes: 11 additions & 11 deletions tests/test_languages/test_java/test_instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,7 @@ def test_instrument_generated_test_behavior_mode(self):
1. Remove assertions containing the target function call
2. Capture the function return value instead
3. Rename the class with __perfinstrumented suffix
4. Add SQLite behavior instrumentation to capture return values
"""
test_code = """import org.junit.jupiter.api.Test;

Expand All @@ -932,17 +933,16 @@ def test_instrument_generated_test_behavior_mode(self):
mode="behavior",
)

# Behavior mode transforms assertions to capture return values
expected = """import org.junit.jupiter.api.Test;

public class CalculatorTest__perfinstrumented {
@Test
public void testAdd() {
Object _cf_result1 = new Calculator().add(2, 2);
}
}
"""
assert result == expected
# Behavior mode transforms assertions, renames class, and adds SQLite instrumentation
assert "class CalculatorTest__perfinstrumented" in result
assert "import java.sql.Connection;" in result
assert "import java.sql.DriverManager;" in result
assert "import java.sql.PreparedStatement;" in result
assert "CODEFLASH_OUTPUT_FILE" in result
assert "CREATE TABLE IF NOT EXISTS test_results" in result
assert "INSERT INTO test_results VALUES" in result
assert "_cf_serializedResult1" in result
assert "com.codeflash.Serializer.serialize" in result

def test_instrument_generated_test_performance_mode(self):
"""Test instrumenting generated test in performance mode with inner loop."""
Expand Down
3 changes: 2 additions & 1 deletion tests/test_languages/test_java/test_replacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,8 @@ def test_optimize_null_checks(self, tmp_path: Path):

assert result is True
new_code = java_file.read_text(encoding="utf-8")
expected = """public class NullChecker {
expected = """import java.util.Objects;
public class NullChecker {
public boolean isEqual(String s1, String s2) {
return Objects.equals(s1, s2);
}
Expand Down
Loading