Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .github/workflows/unit-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@ jobs:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}

- name: Set up JDK 11
uses: actions/setup-java@v4
with:
java-version: '11'
distribution: 'temurin'
cache: maven

- name: Build and install codeflash-runtime JAR
run: |
cd codeflash-java-runtime
mvn clean package -q -DskipTests
mvn install -q -DskipTests

- name: Install uv
uses: astral-sh/setup-uv@v6
with:
Expand Down
13 changes: 13 additions & 0 deletions .github/workflows/windows-unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,19 @@ jobs:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}

- name: Set up JDK 11
uses: actions/setup-java@v4
with:
java-version: '11'
distribution: 'temurin'
cache: maven

- name: Build and install codeflash-runtime JAR
run: |
cd codeflash-java-runtime
mvn clean package -q -DskipTests
mvn install -q -DskipTests

- name: Install uv
uses: astral-sh/setup-uv@v6
with:
Expand Down
10 changes: 5 additions & 5 deletions codeflash/benchmarking/replay_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.formatter import sort_imports
from codeflash.discovery.functions_to_optimize import inspect_top_level_functions_or_methods
from codeflash.verification.verification_utils import get_test_file_path

if TYPE_CHECKING:
from collections.abc import Generator

Expand Down Expand Up @@ -232,6 +227,11 @@ def generate_replay_test(trace_file_path: Path, output_dir: Path, max_run_count:
The number of replay tests generated

"""
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.formatter import sort_imports
from codeflash.discovery.functions_to_optimize import inspect_top_level_functions_or_methods
from codeflash.verification.verification_utils import get_test_file_path

count = 0
try:
# Connect to the database
Expand Down
12 changes: 6 additions & 6 deletions codeflash/github/PrComment.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ class PrComment:

def to_json(self) -> dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]]:
report_table: dict[str, dict[str, int]] = {}
for test_type, result in self.winning_behavior_test_results.get_test_pass_fail_report_by_type().items():
for test_type, report in self.winning_behavior_test_results.get_test_pass_fail_report_by_type().items():
name = test_type.to_name()
if name:
report_table[name] = result
report_table[name] = report

result: dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]] = {
json_result: dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]] = {
"optimization_explanation": self.optimization_explanation,
"best_runtime": humanize_runtime(self.best_runtime),
"original_runtime": humanize_runtime(self.original_runtime),
Expand All @@ -45,10 +45,10 @@ def to_json(self) -> dict[str, Union[str, int, dict[str, dict[str, int]], list[B
}

if self.original_async_throughput is not None and self.best_async_throughput is not None:
result["original_async_throughput"] = str(self.original_async_throughput)
result["best_async_throughput"] = str(self.best_async_throughput)
json_result["original_async_throughput"] = str(self.original_async_throughput)
json_result["best_async_throughput"] = str(self.best_async_throughput)

return result
return json_result


class FileDiffContent(BaseModel):
Expand Down
13 changes: 5 additions & 8 deletions codeflash/languages/java/build_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@
import xml.etree.ElementTree as ET
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from pathlib import Path
from pathlib import Path

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -311,9 +308,9 @@ def find_maven_executable(project_root: Path | None = None) -> str | None:
return str(mvnw_cmd_path)

# Check for Maven wrapper in current directory
if os.path.exists("mvnw"):
if Path("mvnw").exists():
return "./mvnw"
if os.path.exists("mvnw.cmd"):
if Path("mvnw.cmd").exists():
return "mvnw.cmd"

# Check system Maven
Expand Down Expand Up @@ -347,9 +344,9 @@ def find_gradle_executable(project_root: Path | None = None) -> str | None:
return str(gradlew_bat_path)

# Check for Gradle wrapper in current directory
if os.path.exists("gradlew"):
if Path("gradlew").exists():
return "./gradlew"
if os.path.exists("gradlew.bat"):
if Path("gradlew.bat").exists():
return "gradlew.bat"

# Check system Gradle
Expand Down
4 changes: 2 additions & 2 deletions codeflash/languages/java/comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _find_java_executable() -> str | None:
if platform.system() == "Darwin":
# Try to extract Java home from Maven (which always finds it)
try:
result = subprocess.run(["mvn", "--version"], capture_output=True, text=True, timeout=10)
result = subprocess.run(["mvn", "--version"], capture_output=True, text=True, timeout=10, check=False)
Comment thread
HeshamHM28 marked this conversation as resolved.
for line in result.stdout.split("\n"):
if "runtime:" in line:
runtime_path = line.split("runtime:")[-1].strip()
Expand All @@ -116,7 +116,7 @@ def _find_java_executable() -> str | None:
if java_path:
# Verify it's a real Java, not a macOS stub
try:
result = subprocess.run([java_path, "--version"], capture_output=True, text=True, timeout=5)
result = subprocess.run([java_path, "--version"], capture_output=True, text=True, timeout=5, check=False)
if result.returncode == 0:
return java_path
except (subprocess.TimeoutExpired, FileNotFoundError):
Expand Down
25 changes: 13 additions & 12 deletions codeflash/languages/java/concurrency_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@

import logging
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, ClassVar

if TYPE_CHECKING:
from pathlib import Path

from codeflash.languages.base import FunctionInfo

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -57,7 +58,7 @@ class ConcurrencyInfo:
async_method_calls: list[str] = None
"""List of async/concurrent method calls."""

def __post_init__(self):
def __post_init__(self) -> None:
if self.async_method_calls is None:
self.async_method_calls = []

Expand All @@ -66,7 +67,7 @@ class JavaConcurrencyAnalyzer:
"""Analyzes Java code for concurrency patterns."""

# Concurrent patterns to detect
COMPLETABLE_FUTURE_PATTERNS = {
COMPLETABLE_FUTURE_PATTERNS: ClassVar[set[str]] = {
"CompletableFuture",
"supplyAsync",
"runAsync",
Expand All @@ -78,7 +79,7 @@ class JavaConcurrencyAnalyzer:
"anyOf",
}

EXECUTOR_PATTERNS = {
EXECUTOR_PATTERNS: ClassVar[set[str]] = {
"ExecutorService",
"Executors",
"ThreadPoolExecutor",
Expand All @@ -91,14 +92,14 @@ class JavaConcurrencyAnalyzer:
"newWorkStealingPool",
}

VIRTUAL_THREAD_PATTERNS = {
VIRTUAL_THREAD_PATTERNS: ClassVar[set[str]] = {
"newVirtualThreadPerTaskExecutor",
"Thread.startVirtualThread",
"Thread.ofVirtual",
"VirtualThreads",
}

CONCURRENT_COLLECTION_PATTERNS = {
CONCURRENT_COLLECTION_PATTERNS: ClassVar[set[str]] = {
"ConcurrentHashMap",
"ConcurrentLinkedQueue",
"ConcurrentLinkedDeque",
Expand All @@ -111,7 +112,7 @@ class JavaConcurrencyAnalyzer:
"ArrayBlockingQueue",
}

ATOMIC_PATTERNS = {
ATOMIC_PATTERNS: ClassVar[set[str]] = {
"AtomicInteger",
"AtomicLong",
"AtomicBoolean",
Expand All @@ -121,7 +122,7 @@ class JavaConcurrencyAnalyzer:
"AtomicReferenceArray",
}

def __init__(self, analyzer=None):
def __init__(self, analyzer=None) -> None:
"""Initialize concurrency analyzer.

Args:
Expand All @@ -145,13 +146,13 @@ def analyze_function(self, func: FunctionInfo, source: str | None = None) -> Con
try:
source = func.file_path.read_text(encoding="utf-8")
except Exception as e:
logger.warning("Failed to read source for %s: %s", func.name, e)
logger.warning("Failed to read source for %s: %s", func.function_name, e)
return ConcurrencyInfo(is_concurrent=False, patterns=[])

# Extract function source
lines = source.splitlines()
func_start = func.start_line - 1 # Convert to 0-indexed
func_end = func.end_line
func_start = func.starting_line - 1 # Convert to 0-indexed
func_end = func.ending_line
func_source = "\n".join(lines[func_start:func_end])

# Detect patterns
Expand Down
12 changes: 6 additions & 6 deletions codeflash/languages/java/line_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
import json
import logging
import re
from pathlib import Path
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from pathlib import Path

from tree_sitter import Node

from codeflash.languages.base import FunctionInfo
Expand Down Expand Up @@ -101,7 +102,7 @@ def instrument_source(self, source: str, file_path: Path, functions: list[Functi
import_end_idx = i
break

lines_with_profiler = lines[:import_end_idx] + [profiler_class_code + "\n"] + lines[import_end_idx:]
lines_with_profiler = [*lines[:import_end_idx], profiler_class_code + "\n", *lines[import_end_idx:]]

result = "".join(lines_with_profiler)
if not analyzer.validate_syntax(result):
Expand Down Expand Up @@ -298,8 +299,7 @@ def _instrument_function(self, func: FunctionInfo, lines: list[str], file_path:
and not stripped.startswith("//")
and not stripped.startswith("/*")
and not stripped.startswith("*")
and stripped != "}"
and stripped != "};"
and stripped not in ("}", "};")
):
# Get indentation
indent = len(line) - len(line.lstrip())
Expand Down Expand Up @@ -434,8 +434,8 @@ def parse_results(profile_file: Path) -> dict:
result["str_out"] = format_line_profile_results(result)
return result

except Exception as e:
logger.error("Failed to parse line profile results: %s", e)
except Exception:
logger.exception("Failed to parse line profile results")
return {"timings": {}, "unit": 1e-9, "raw_data": {}, "str_out": ""}


Expand Down
7 changes: 4 additions & 3 deletions codeflash/languages/java/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import CodeContext, FunctionFilterCriteria, HelperFunction, TestInfo, TestResult
from codeflash.languages.java.concurrency_analyzer import ConcurrencyInfo

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -111,7 +112,7 @@ def find_helper_functions(self, function: FunctionToOptimize, project_root: Path
"""Find helper functions called by the target function."""
return find_helper_functions(function, project_root, analyzer=self._analyzer)

def analyze_concurrency(self, function: FunctionInfo, source: str | None = None):
def analyze_concurrency(self, function: FunctionToOptimize, source: str | None = None) -> ConcurrencyInfo:
"""Analyze a function for concurrency patterns.

Args:
Expand Down Expand Up @@ -319,8 +320,8 @@ def instrument_source_for_line_profiler(
func_info.file_path.write_text(instrumented, encoding="utf-8")

return True
except Exception as e:
logger.error("Failed to instrument %s for line profiling: %s", func_info.function_name, e)
except Exception:
logger.exception("Failed to instrument %s for line profiling", func_info.function_name)
return False

def parse_line_profile_results(self, line_profiler_output_file: Path) -> dict:
Expand Down
Loading
Loading