Skip to content

Commit f3ecd22

Browse files
authored
Merge pull request #1496 from codeflash-ai/fix/java/e2e/test
[Fix] Java falling tests
2 parents 7a2a48b + c4da93c commit f3ecd22

28 files changed

Lines changed: 368 additions & 328 deletions

.github/workflows/unit-tests.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,19 @@ jobs:
2424
fetch-depth: 0
2525
token: ${{ secrets.GITHUB_TOKEN }}
2626

27+
- name: Set up JDK 11
28+
uses: actions/setup-java@v4
29+
with:
30+
java-version: '11'
31+
distribution: 'temurin'
32+
cache: maven
33+
34+
- name: Build and install codeflash-runtime JAR
35+
run: |
36+
cd codeflash-java-runtime
37+
mvn clean package -q -DskipTests
38+
mvn install -q -DskipTests
39+
2740
- name: Install uv
2841
uses: astral-sh/setup-uv@v6
2942
with:

.github/workflows/windows-unit-tests.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,19 @@ jobs:
2222
fetch-depth: 0
2323
token: ${{ secrets.GITHUB_TOKEN }}
2424

25+
- name: Set up JDK 11
26+
uses: actions/setup-java@v4
27+
with:
28+
java-version: '11'
29+
distribution: 'temurin'
30+
cache: maven
31+
32+
- name: Build and install codeflash-runtime JAR
33+
run: |
34+
cd codeflash-java-runtime
35+
mvn clean package -q -DskipTests
36+
mvn install -q -DskipTests
37+
2538
- name: Install uv
2639
uses: astral-sh/setup-uv@v6
2740
with:

codeflash/benchmarking/replay_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,6 @@
66
from pathlib import Path
77
from typing import TYPE_CHECKING, Any
88

9-
from codeflash.cli_cmds.console import logger
10-
from codeflash.code_utils.formatter import sort_imports
11-
from codeflash.discovery.functions_to_optimize import inspect_top_level_functions_or_methods
12-
from codeflash.verification.verification_utils import get_test_file_path
13-
149
if TYPE_CHECKING:
1510
from collections.abc import Generator
1611

@@ -232,6 +227,11 @@ def generate_replay_test(trace_file_path: Path, output_dir: Path, max_run_count:
232227
The number of replay tests generated
233228
234229
"""
230+
from codeflash.cli_cmds.console import logger
231+
from codeflash.code_utils.formatter import sort_imports
232+
from codeflash.discovery.functions_to_optimize import inspect_top_level_functions_or_methods
233+
from codeflash.verification.verification_utils import get_test_file_path
234+
235235
count = 0
236236
try:
237237
# Connect to the database

codeflash/github/PrComment.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ class PrComment:
2626

2727
def to_json(self) -> dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]]:
2828
report_table: dict[str, dict[str, int]] = {}
29-
for test_type, result in self.winning_behavior_test_results.get_test_pass_fail_report_by_type().items():
29+
for test_type, report in self.winning_behavior_test_results.get_test_pass_fail_report_by_type().items():
3030
name = test_type.to_name()
3131
if name:
32-
report_table[name] = result
32+
report_table[name] = report
3333

34-
result: dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]] = {
34+
json_result: dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]] = {
3535
"optimization_explanation": self.optimization_explanation,
3636
"best_runtime": humanize_runtime(self.best_runtime),
3737
"original_runtime": humanize_runtime(self.original_runtime),
@@ -45,10 +45,10 @@ def to_json(self) -> dict[str, Union[str, int, dict[str, dict[str, int]], list[B
4545
}
4646

4747
if self.original_async_throughput is not None and self.best_async_throughput is not None:
48-
result["original_async_throughput"] = str(self.original_async_throughput)
49-
result["best_async_throughput"] = str(self.best_async_throughput)
48+
json_result["original_async_throughput"] = str(self.original_async_throughput)
49+
json_result["best_async_throughput"] = str(self.best_async_throughput)
5050

51-
return result
51+
return json_result
5252

5353

5454
class FileDiffContent(BaseModel):

codeflash/languages/java/build_tools.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@
1313
import xml.etree.ElementTree as ET
1414
from dataclasses import dataclass
1515
from enum import Enum
16-
from typing import TYPE_CHECKING
17-
18-
if TYPE_CHECKING:
19-
from pathlib import Path
16+
from pathlib import Path
2017

2118
logger = logging.getLogger(__name__)
2219

@@ -311,9 +308,9 @@ def find_maven_executable(project_root: Path | None = None) -> str | None:
311308
return str(mvnw_cmd_path)
312309

313310
# Check for Maven wrapper in current directory
314-
if os.path.exists("mvnw"):
311+
if Path("mvnw").exists():
315312
return "./mvnw"
316-
if os.path.exists("mvnw.cmd"):
313+
if Path("mvnw.cmd").exists():
317314
return "mvnw.cmd"
318315

319316
# Check system Maven
@@ -347,9 +344,9 @@ def find_gradle_executable(project_root: Path | None = None) -> str | None:
347344
return str(gradlew_bat_path)
348345

349346
# Check for Gradle wrapper in current directory
350-
if os.path.exists("gradlew"):
347+
if Path("gradlew").exists():
351348
return "./gradlew"
352-
if os.path.exists("gradlew.bat"):
349+
if Path("gradlew.bat").exists():
353350
return "gradlew.bat"
354351

355352
# Check system Gradle

codeflash/languages/java/comparator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def _find_java_executable() -> str | None:
9090
if platform.system() == "Darwin":
9191
# Try to extract Java home from Maven (which always finds it)
9292
try:
93-
result = subprocess.run(["mvn", "--version"], capture_output=True, text=True, timeout=10)
93+
result = subprocess.run(["mvn", "--version"], capture_output=True, text=True, timeout=10, check=False)
9494
for line in result.stdout.split("\n"):
9595
if "runtime:" in line:
9696
runtime_path = line.split("runtime:")[-1].strip()
@@ -116,7 +116,7 @@ def _find_java_executable() -> str | None:
116116
if java_path:
117117
# Verify it's a real Java, not a macOS stub
118118
try:
119-
result = subprocess.run([java_path, "--version"], capture_output=True, text=True, timeout=5)
119+
result = subprocess.run([java_path, "--version"], capture_output=True, text=True, timeout=5, check=False)
120120
if result.returncode == 0:
121121
return java_path
122122
except (subprocess.TimeoutExpired, FileNotFoundError):

codeflash/languages/java/concurrency_analyzer.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414

1515
import logging
1616
from dataclasses import dataclass
17-
from pathlib import Path
18-
from typing import TYPE_CHECKING
17+
from typing import TYPE_CHECKING, ClassVar
1918

2019
if TYPE_CHECKING:
20+
from pathlib import Path
21+
2122
from codeflash.languages.base import FunctionInfo
2223

2324
logger = logging.getLogger(__name__)
@@ -57,7 +58,7 @@ class ConcurrencyInfo:
5758
async_method_calls: list[str] = None
5859
"""List of async/concurrent method calls."""
5960

60-
def __post_init__(self):
61+
def __post_init__(self) -> None:
6162
if self.async_method_calls is None:
6263
self.async_method_calls = []
6364

@@ -66,7 +67,7 @@ class JavaConcurrencyAnalyzer:
6667
"""Analyzes Java code for concurrency patterns."""
6768

6869
# Concurrent patterns to detect
69-
COMPLETABLE_FUTURE_PATTERNS = {
70+
COMPLETABLE_FUTURE_PATTERNS: ClassVar[set[str]] = {
7071
"CompletableFuture",
7172
"supplyAsync",
7273
"runAsync",
@@ -78,7 +79,7 @@ class JavaConcurrencyAnalyzer:
7879
"anyOf",
7980
}
8081

81-
EXECUTOR_PATTERNS = {
82+
EXECUTOR_PATTERNS: ClassVar[set[str]] = {
8283
"ExecutorService",
8384
"Executors",
8485
"ThreadPoolExecutor",
@@ -91,14 +92,14 @@ class JavaConcurrencyAnalyzer:
9192
"newWorkStealingPool",
9293
}
9394

94-
VIRTUAL_THREAD_PATTERNS = {
95+
VIRTUAL_THREAD_PATTERNS: ClassVar[set[str]] = {
9596
"newVirtualThreadPerTaskExecutor",
9697
"Thread.startVirtualThread",
9798
"Thread.ofVirtual",
9899
"VirtualThreads",
99100
}
100101

101-
CONCURRENT_COLLECTION_PATTERNS = {
102+
CONCURRENT_COLLECTION_PATTERNS: ClassVar[set[str]] = {
102103
"ConcurrentHashMap",
103104
"ConcurrentLinkedQueue",
104105
"ConcurrentLinkedDeque",
@@ -111,7 +112,7 @@ class JavaConcurrencyAnalyzer:
111112
"ArrayBlockingQueue",
112113
}
113114

114-
ATOMIC_PATTERNS = {
115+
ATOMIC_PATTERNS: ClassVar[set[str]] = {
115116
"AtomicInteger",
116117
"AtomicLong",
117118
"AtomicBoolean",
@@ -121,7 +122,7 @@ class JavaConcurrencyAnalyzer:
121122
"AtomicReferenceArray",
122123
}
123124

124-
def __init__(self, analyzer=None):
125+
def __init__(self, analyzer=None) -> None:
125126
"""Initialize concurrency analyzer.
126127
127128
Args:
@@ -145,13 +146,13 @@ def analyze_function(self, func: FunctionInfo, source: str | None = None) -> Con
145146
try:
146147
source = func.file_path.read_text(encoding="utf-8")
147148
except Exception as e:
148-
logger.warning("Failed to read source for %s: %s", func.name, e)
149+
logger.warning("Failed to read source for %s: %s", func.function_name, e)
149150
return ConcurrencyInfo(is_concurrent=False, patterns=[])
150151

151152
# Extract function source
152153
lines = source.splitlines()
153-
func_start = func.start_line - 1 # Convert to 0-indexed
154-
func_end = func.end_line
154+
func_start = func.starting_line - 1 # Convert to 0-indexed
155+
func_end = func.ending_line
155156
func_source = "\n".join(lines[func_start:func_end])
156157

157158
# Detect patterns

codeflash/languages/java/line_profiler.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
import json
1111
import logging
1212
import re
13-
from pathlib import Path
1413
from typing import TYPE_CHECKING
1514

1615
if TYPE_CHECKING:
16+
from pathlib import Path
17+
1718
from tree_sitter import Node
1819

1920
from codeflash.languages.base import FunctionInfo
@@ -101,7 +102,7 @@ def instrument_source(self, source: str, file_path: Path, functions: list[Functi
101102
import_end_idx = i
102103
break
103104

104-
lines_with_profiler = lines[:import_end_idx] + [profiler_class_code + "\n"] + lines[import_end_idx:]
105+
lines_with_profiler = [*lines[:import_end_idx], profiler_class_code + "\n", *lines[import_end_idx:]]
105106

106107
result = "".join(lines_with_profiler)
107108
if not analyzer.validate_syntax(result):
@@ -298,8 +299,7 @@ def _instrument_function(self, func: FunctionInfo, lines: list[str], file_path:
298299
and not stripped.startswith("//")
299300
and not stripped.startswith("/*")
300301
and not stripped.startswith("*")
301-
and stripped != "}"
302-
and stripped != "};"
302+
and stripped not in ("}", "};")
303303
):
304304
# Get indentation
305305
indent = len(line) - len(line.lstrip())
@@ -434,8 +434,8 @@ def parse_results(profile_file: Path) -> dict:
434434
result["str_out"] = format_line_profile_results(result)
435435
return result
436436

437-
except Exception as e:
438-
logger.error("Failed to parse line profile results: %s", e)
437+
except Exception:
438+
logger.exception("Failed to parse line profile results")
439439
return {"timings": {}, "unit": 1e-9, "raw_data": {}, "str_out": ""}
440440

441441

codeflash/languages/java/support.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
4141
from codeflash.languages.base import CodeContext, FunctionFilterCriteria, HelperFunction, TestInfo, TestResult
42+
from codeflash.languages.java.concurrency_analyzer import ConcurrencyInfo
4243

4344
logger = logging.getLogger(__name__)
4445

@@ -111,7 +112,7 @@ def find_helper_functions(self, function: FunctionToOptimize, project_root: Path
111112
"""Find helper functions called by the target function."""
112113
return find_helper_functions(function, project_root, analyzer=self._analyzer)
113114

114-
def analyze_concurrency(self, function: FunctionInfo, source: str | None = None):
115+
def analyze_concurrency(self, function: FunctionToOptimize, source: str | None = None) -> ConcurrencyInfo:
115116
"""Analyze a function for concurrency patterns.
116117
117118
Args:
@@ -319,8 +320,8 @@ def instrument_source_for_line_profiler(
319320
func_info.file_path.write_text(instrumented, encoding="utf-8")
320321

321322
return True
322-
except Exception as e:
323-
logger.error("Failed to instrument %s for line profiling: %s", func_info.function_name, e)
323+
except Exception:
324+
logger.exception("Failed to instrument %s for line profiling", func_info.function_name)
324325
return False
325326

326327
def parse_line_profile_results(self, line_profiler_output_file: Path) -> dict:

0 commit comments

Comments
 (0)