Skip to content

Commit b4e233a

Browse files
committed
fix omni-java
1 parent ae4eb7c commit b4e233a

9 files changed

Lines changed: 154 additions & 22 deletions

File tree

codeflash/code_utils/config_consts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
JAVA_TESTCASE_TIMEOUT = 120 # Java Maven tests need more time due to startup overhead
1111
MAX_FUNCTION_TEST_SECONDS = 60
1212
MIN_IMPROVEMENT_THRESHOLD = 0.05
13+
MIN_IMPROVEMENT_THRESHOLD_JAVA = 0.02
1314
MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD = 0.10 # 10% minimum improvement for async throughput
1415
MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD = 0.20 # 20% concurrency ratio improvement required
1516
CONCURRENCY_FACTOR = 10 # Number of concurrent executions for concurrency benchmark

codeflash/languages/java/instrumentation.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -508,11 +508,23 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
508508
)
509509
wrapped_body_lines.append(serialize_line)
510510

511-
# Check if the line is now just a variable reference (invalid statement)
512-
# This happens when the original line was just a void method call
513-
# e.g., "BubbleSort.bubbleSort(original);" becomes "_cf_result1_1;"
511+
# Check if the line is now just a variable reference (invalid statement).
512+
# This happens when the original line was just a void method call:
513+
# "BubbleSort.bubbleSort(original);" becomes "_cf_result1_1;"
514+
# It also happens when assertThrows was transformed to try-catch:
515+
# "try { func(args); } catch (...) {}" becomes
516+
# "try { _cf_result1_1; } catch (...) {}"
517+
# A bare variable is not a valid Java statement.
514518
stripped_new = new_line.strip().rstrip(";").strip()
515-
if stripped_new and stripped_new not in (var_name, var_with_cast):
519+
is_bare_var = stripped_new in (var_name, var_with_cast)
520+
is_try_with_bare_var = bool(re.match(
521+
r"try\s*\{\s*(?:"
522+
+ re.escape(var_name)
523+
+ (r"|" + re.escape(var_with_cast) if var_with_cast != var_name else "")
524+
+ r")\s*;\s*\}\s*catch\s*\(",
525+
stripped_new,
526+
))
527+
if stripped_new and not is_bare_var and not is_try_with_bare_var:
516528
wrapped_body_lines.append(new_line)
517529
else:
518530
wrapped_body_lines.append(body_line)
@@ -834,7 +846,7 @@ def instrument_generated_java_test(
834846
original_class_name = class_match.group(1)
835847

836848

837-
# For performance mode, add timing instrumentation
849+
# Add mode-specific instrumentation
838850
# Use original class name (without suffix) in timing markers for consistency with Python
839851
if mode == "performance":
840852

codeflash/languages/java/replacement.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,11 @@ class ParsedOptimization:
3333
target_method_source: str
3434
new_fields: list[str] # Source text of new fields to add
3535
new_helper_methods: list[str] # Source text of new helper methods to add
36+
new_imports: list[str] # Import statements to add (e.g., "import java.nio.file.Files;")
3637

3738

3839
def _parse_optimization_source(new_source: str, target_method_name: str, analyzer: JavaAnalyzer) -> ParsedOptimization:
39-
"""Parse optimization source to extract method and additional class members.
40+
"""Parse optimization source to extract method, imports, and additional class members.
4041
4142
The new_source may contain:
4243
- Just a method definition
@@ -48,13 +49,20 @@ def _parse_optimization_source(new_source: str, target_method_name: str, analyze
4849
analyzer: JavaAnalyzer instance.
4950
5051
Returns:
51-
ParsedOptimization with the method and any additional members.
52+
ParsedOptimization with the method, imports, and any additional members.
5253
5354
"""
5455
new_fields: list[str] = []
5556
new_helper_methods: list[str] = []
5657
target_method_source = new_source # Default to the whole source
5758

59+
# Extract import statements from the candidate code
60+
new_imports: list[str] = []
61+
for imp in analyzer.find_imports(new_source):
62+
prefix = "import static " if imp.is_static else "import "
63+
suffix = ".*" if imp.is_wildcard else ""
64+
new_imports.append(f"{prefix}{imp.import_path}{suffix};")
65+
5866
# Check if this is a full class or just a method
5967
classes = analyzer.find_classes(new_source)
6068

@@ -92,10 +100,57 @@ def _parse_optimization_source(new_source: str, target_method_name: str, analyze
92100
new_fields.append(field.source_text)
93101

94102
return ParsedOptimization(
95-
target_method_source=target_method_source, new_fields=new_fields, new_helper_methods=new_helper_methods
103+
target_method_source=target_method_source,
104+
new_fields=new_fields,
105+
new_helper_methods=new_helper_methods,
106+
new_imports=new_imports,
96107
)
97108

98109

110+
def _add_missing_imports(source: str, candidate_imports: list[str], analyzer: JavaAnalyzer) -> str:
111+
"""Add import statements from the optimization candidate that are missing in the original source.
112+
113+
Args:
114+
source: The original source code.
115+
candidate_imports: Import statements from the candidate (e.g., ["import java.nio.file.Files;"]).
116+
analyzer: JavaAnalyzer instance.
117+
118+
Returns:
119+
Source code with missing imports added.
120+
121+
"""
122+
existing_imports = analyzer.find_imports(source)
123+
existing_import_strs = set()
124+
for imp in existing_imports:
125+
prefix = "import static " if imp.is_static else "import "
126+
suffix = ".*" if imp.is_wildcard else ""
127+
existing_import_strs.add(f"{prefix}{imp.import_path}{suffix};")
128+
129+
missing_imports = [imp for imp in candidate_imports if imp not in existing_import_strs]
130+
if not missing_imports:
131+
return source
132+
133+
logger.debug("Adding %d missing imports: %s", len(missing_imports), missing_imports)
134+
135+
# Insert after the last existing import, or after the package declaration
136+
lines = source.splitlines(keepends=True)
137+
insert_line = 0
138+
139+
if existing_imports:
140+
insert_line = max(imp.end_line for imp in existing_imports)
141+
else:
142+
# No existing imports — insert after package declaration
143+
for i, line in enumerate(lines):
144+
if line.strip().startswith("package "):
145+
insert_line = i + 1
146+
break
147+
148+
import_block = "".join(imp + "\n" for imp in missing_imports)
149+
before = lines[:insert_line]
150+
after = lines[insert_line:]
151+
return "".join(before) + import_block + "".join(after)
152+
153+
99154
def _insert_class_members(
100155
source: str, class_name: str, fields: list[str], methods: list[str], analyzer: JavaAnalyzer
101156
) -> str:
@@ -237,6 +292,10 @@ def replace_function(
237292
# Parse the optimization to extract components
238293
parsed = _parse_optimization_source(new_source, func_name, analyzer)
239294

295+
# Add any new imports from the optimization candidate
296+
if parsed.new_imports:
297+
source = _add_missing_imports(source, parsed.new_imports, analyzer)
298+
240299
# Find the method in the original source
241300
methods = analyzer.find_methods(source)
242301
target_method = None

codeflash/optimization/function_optimizer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
COVERAGE_THRESHOLD,
4949
INDIVIDUAL_TESTCASE_TIMEOUT,
5050
MIN_CORRECT_CANDIDATES,
51+
MIN_IMPROVEMENT_THRESHOLD_JAVA,
5152
OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
5253
REFINED_CANDIDATE_RANKING_WEIGHTS,
5354
REPEAT_OPTIMIZATION_PROBABILITY,
@@ -1364,6 +1365,8 @@ def process_single_candidate(
13641365
eval_ctx.record_successful_candidate(candidate.optimization_id, candidate_result.best_test_runtime, perf_gain)
13651366

13661367
# Check if this is a successful optimization
1368+
# Use a lower threshold for Java where I/O-bound functions have smaller optimization margins
1369+
java_override = MIN_IMPROVEMENT_THRESHOLD_JAVA if self.language_support.language == "java" else None
13671370
is_successful_opt = speedup_critic(
13681371
candidate_result,
13691372
original_code_baseline.runtime,
@@ -1372,6 +1375,7 @@ def process_single_candidate(
13721375
best_throughput_until_now=None,
13731376
original_concurrency_metrics=original_code_baseline.concurrency_metrics,
13741377
best_concurrency_ratio_until_now=None,
1378+
min_improvement_override=java_override,
13751379
) and quantity_of_tests_critic(candidate_result)
13761380

13771381
tree = self.build_runtime_info_tree(
@@ -2272,13 +2276,15 @@ def find_and_process_best_optimization(
22722276
fto_benchmark_timings=self.function_benchmark_timings,
22732277
total_benchmark_timings=self.total_benchmark_timings,
22742278
)
2279+
java_override = MIN_IMPROVEMENT_THRESHOLD_JAVA if self.language_support.language == "java" else None
22752280
acceptance_reason = get_acceptance_reason(
22762281
original_runtime_ns=original_code_baseline.runtime,
22772282
optimized_runtime_ns=best_optimization.runtime,
22782283
original_async_throughput=original_code_baseline.async_throughput,
22792284
optimized_async_throughput=best_optimization.async_throughput,
22802285
original_concurrency_metrics=original_code_baseline.concurrency_metrics,
22812286
optimized_concurrency_metrics=best_optimization.concurrency_metrics,
2287+
min_improvement_override=java_override,
22822288
)
22832289
explanation = Explanation(
22842290
raw_explanation_message=best_optimization.candidate.explanation,

codeflash/optimization/optimizer.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,38 @@ def _verify_js_requirements(self) -> None:
117117
except Exception as e:
118118
logger.debug(f"Failed to verify JS requirements: {e}")
119119

120+
def _setup_java_runtime(self) -> None:
121+
from pathlib import Path as _Path
122+
123+
from codeflash.languages.java.build_tools import add_codeflash_dependency_to_pom, install_codeflash_runtime
124+
125+
runtime_jar = _Path(__file__).parent.parent / "languages" / "java" / "resources" / "codeflash-runtime-1.0.0.jar"
126+
project_root = self.args.project_root
127+
128+
if not runtime_jar.exists():
129+
logger.warning("codeflash-runtime JAR not found at %s, behavior capture may not work", runtime_jar)
130+
return
131+
132+
if not install_codeflash_runtime(project_root, runtime_jar):
133+
logger.warning("Failed to install codeflash-runtime to local Maven repo")
134+
return
135+
136+
# Add dependency to the test module's pom.xml (or root pom.xml)
137+
test_root = _Path(self.args.tests_root) if hasattr(self.args, "tests_root") and self.args.tests_root else None
138+
pom_path = project_root / "pom.xml"
139+
140+
# For multi-module projects, find the test module's pom.xml
141+
if test_root and test_root != project_root:
142+
candidate = test_root
143+
while candidate != project_root and candidate != candidate.parent:
144+
if (candidate / "pom.xml").exists():
145+
pom_path = candidate / "pom.xml"
146+
break
147+
candidate = candidate.parent
148+
149+
if pom_path.exists():
150+
add_codeflash_dependency_to_pom(pom_path)
151+
120152
def run_benchmarks(
121153
self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], num_optimizable_functions: int
122154
) -> tuple[dict[str, dict[BenchmarkKey, float]], dict[BenchmarkKey, float]]:
@@ -495,6 +527,8 @@ def run(self) -> None:
495527
self.test_cfg.js_project_root = self._find_js_project_root(file_path)
496528
# Verify JS requirements before proceeding
497529
self._verify_js_requirements()
530+
elif is_java():
531+
self._setup_java_runtime()
498532
break
499533

500534
if self.args.all:

codeflash/result/critic.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def speedup_critic(
7272
best_throughput_until_now: int | None = None,
7373
original_concurrency_metrics: ConcurrencyMetrics | None = None,
7474
best_concurrency_ratio_until_now: float | None = None,
75+
min_improvement_override: float | None = None,
7576
) -> bool:
7677
"""Take in a correct optimized Test Result and decide if the optimization should actually be surfaced to the user.
7778
@@ -92,7 +93,8 @@ def speedup_critic(
9293
- Concurrency improvements detect when blocking calls are replaced with non-blocking equivalents
9394
"""
9495
# Runtime performance evaluation
95-
noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_code_runtime < 10000 else MIN_IMPROVEMENT_THRESHOLD
96+
threshold = min_improvement_override if min_improvement_override is not None else MIN_IMPROVEMENT_THRESHOLD
97+
noise_floor = 3 * threshold if original_code_runtime < 10000 else threshold
9698
if not disable_gh_action_noise and env_utils.is_ci():
9799
noise_floor = noise_floor * 2 # Increase the noise floor in GitHub Actions mode
98100

@@ -146,13 +148,15 @@ def get_acceptance_reason(
146148
optimized_async_throughput: int | None = None,
147149
original_concurrency_metrics: ConcurrencyMetrics | None = None,
148150
optimized_concurrency_metrics: ConcurrencyMetrics | None = None,
151+
min_improvement_override: float | None = None,
149152
) -> AcceptanceReason:
150153
"""Determine why an optimization was accepted.
151154
152155
Returns the primary reason for acceptance, with priority:
153156
concurrency > throughput > runtime (for async code).
154157
"""
155-
noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_runtime_ns < 10000 else MIN_IMPROVEMENT_THRESHOLD
158+
threshold = min_improvement_override if min_improvement_override is not None else MIN_IMPROVEMENT_THRESHOLD
159+
noise_floor = 3 * threshold if original_runtime_ns < 10000 else threshold
156160
if env_utils.is_ci():
157161
noise_floor = noise_floor * 2
158162

codeflash/verification/verifier.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,21 @@ def generate_tests(
122122
)
123123

124124
logger.debug(f"Instrumented Java tests locally for {func_name}")
125+
logger.debug(
126+
f"=== Java Generated Tests (raw) for {func_name} ===\n"
127+
f"{generated_test_source}\n"
128+
f"=== End Java Generated Tests ==="
129+
)
130+
logger.debug(
131+
f"=== Java Instrumented Behavior Tests for {func_name} ===\n"
132+
f"{instrumented_behavior_test_source}\n"
133+
f"=== End Java Instrumented Behavior Tests ==="
134+
)
135+
logger.debug(
136+
f"=== Java Instrumented Perf Tests for {func_name} ===\n"
137+
f"{instrumented_perf_test_source}\n"
138+
f"=== End Java Instrumented Perf Tests ==="
139+
)
125140
else:
126141
# Python: instrumentation is done by aiservice, just replace temp dir placeholders
127142
instrumented_behavior_test_source = instrumented_behavior_test_source.replace(

tests/test_languages/test_java/test_instrumentation.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -915,6 +915,7 @@ def test_instrument_generated_test_behavior_mode(self):
915915
1. Remove assertions containing the target function call
916916
2. Capture the function return value instead
917917
3. Rename the class with __perfinstrumented suffix
918+
4. Add SQLite behavior instrumentation to capture return values
918919
"""
919920
test_code = """import org.junit.jupiter.api.Test;
920921
@@ -932,17 +933,16 @@ def test_instrument_generated_test_behavior_mode(self):
932933
mode="behavior",
933934
)
934935

935-
# Behavior mode transforms assertions to capture return values
936-
expected = """import org.junit.jupiter.api.Test;
937-
938-
public class CalculatorTest__perfinstrumented {
939-
@Test
940-
public void testAdd() {
941-
Object _cf_result1 = new Calculator().add(2, 2);
942-
}
943-
}
944-
"""
945-
assert result == expected
936+
# Behavior mode transforms assertions, renames class, and adds SQLite instrumentation
937+
assert "class CalculatorTest__perfinstrumented" in result
938+
assert "import java.sql.Connection;" in result
939+
assert "import java.sql.DriverManager;" in result
940+
assert "import java.sql.PreparedStatement;" in result
941+
assert "CODEFLASH_OUTPUT_FILE" in result
942+
assert "CREATE TABLE IF NOT EXISTS test_results" in result
943+
assert "INSERT INTO test_results VALUES" in result
944+
assert "_cf_serializedResult1" in result
945+
assert "com.codeflash.Serializer.serialize" in result
946946

947947
def test_instrument_generated_test_performance_mode(self):
948948
"""Test instrumenting generated test in performance mode with inner loop."""

tests/test_languages/test_java/test_replacement.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,8 @@ def test_optimize_null_checks(self, tmp_path: Path):
647647

648648
assert result is True
649649
new_code = java_file.read_text(encoding="utf-8")
650-
expected = """public class NullChecker {
650+
expected = """import java.util.Objects;
651+
public class NullChecker {
651652
public boolean isEqual(String s1, String s2) {
652653
return Objects.equals(s1, s2);
653654
}

0 commit comments

Comments
 (0)