Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -429,3 +429,4 @@ code_to_optimize/**/package-lock.json

# Other tools
.codeflash/
.codeflash_eval_worktrees/
2 changes: 1 addition & 1 deletion codeflash/api/aiservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
FunctionRepairInfo,
OptimizationReviewResult,
OptimizedCandidate,
OptimizedCandidateSource,
TestFileReview,
)
from codeflash.models.shared_types import OptimizedCandidateSource
from codeflash.telemetry.posthog_cf import ph
from codeflash.version import __version__ as codeflash_version

Expand Down
7 changes: 7 additions & 0 deletions codeflash/cli_cmds/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,13 @@ def _build_parser() -> ArgumentParser:
)
parser.add_argument("--no-draft", default=False, action="store_true", help="Skip optimization for draft PRs")
parser.add_argument("--worktree", default=False, action="store_true", help="Use worktree for optimization")
parser.add_argument(
"--parallel-candidates",
type=int,
default=0,
metavar="N",
help="Evaluate up to N optimization candidates in parallel using git worktrees (0 = sequential)",
)
parser.add_argument(
"--testgen-review", default=False, action="store_true", help="Enable AI review and repair of generated tests"
)
Expand Down
127 changes: 127 additions & 0 deletions codeflash/code_utils/worktree_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from __future__ import annotations

import contextlib
import functools
import os
import shutil
import stat
from pathlib import Path
from typing import TYPE_CHECKING, Any

import anyio

if TYPE_CHECKING:
from collections.abc import Callable
from typing import Self

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.git_utils import git_root_dir, mirror_path


class WorktreeSlot:
__slots__ = ("_git_root", "index", "path")

def __init__(self, path: Path, index: int, git_root: Path) -> None:
self.path = path
self.index = index
self._git_root = git_root

def mirror(self, original_path: Path) -> Path:
return mirror_path(original_path, self._git_root, self.path)

async def write_candidate(self, file_path: Path, code: str) -> None:
mirrored = anyio.Path(self.mirror(file_path))
await mirrored.parent.mkdir(parents=True, exist_ok=True)
await mirrored.write_text(code, encoding="utf-8")

async def restore_file(self, file_path: Path, original_code: str) -> None:
mirrored = anyio.Path(self.mirror(file_path))
await mirrored.write_text(original_code, encoding="utf-8")


class WorktreePool:
def __init__(self, pool_size: int = 4, base_dir: Path | None = None) -> None:
self._pool_size = pool_size
self._git_root = git_root_dir()
self._base_dir = base_dir or (self._git_root / ".codeflash_eval_worktrees")
self._slots: list[WorktreeSlot] = []
self._send: anyio.abc.ObjectSendStream[WorktreeSlot] | None = None
self._receive: anyio.abc.ObjectReceiveStream[WorktreeSlot] | None = None
self._initialized = False

async def initialize(self) -> None:
if self._initialized:
return
await anyio.Path(self._base_dir).mkdir(parents=True, exist_ok=True)

async with anyio.create_task_group() as tg:
results: list[WorktreeSlot | None] = [None] * self._pool_size
for i in range(self._pool_size):
tg.start_soon(self._create_slot_task, i, results)

self._slots = [s for s in results if s is not None]
self._send, self._receive = anyio.create_memory_object_stream[WorktreeSlot](self._pool_size)
for slot in self._slots:
await self._send.send(slot)
self._initialized = True
logger.debug(f"WorktreePool initialized with {len(self._slots)} slots at {self._base_dir}")

async def _create_slot_task(self, index: int, results: list[WorktreeSlot | None]) -> None:
results[index] = await self._create_slot(index)

async def _create_slot(self, index: int) -> WorktreeSlot:
slot_dir = self._base_dir / f"slot-{index}"
if slot_dir.exists():
await anyio.to_thread.run_sync(functools.partial(shutil.rmtree, slot_dir, onerror=_handle_remove_readonly))

result = await anyio.run_process(
["git", "-C", str(self._git_root), "worktree", "add", "--detach", str(slot_dir), "HEAD"], check=False
)
if result.returncode != 0:
msg = f"Failed to create worktree slot {index}: {result.stderr.decode()}"
raise RuntimeError(msg)

pid_file = anyio.Path(slot_dir / ".codeflash_pool.pid")
await pid_file.write_text(str(os.getpid()), encoding="utf-8")

return WorktreeSlot(slot_dir, index, self._git_root)

async def acquire(self) -> WorktreeSlot:
assert self._receive is not None
return await self._receive.receive()

async def release(self, slot: WorktreeSlot) -> None:
assert self._send is not None
await self._send.send(slot)

async def cleanup(self) -> None:
async with anyio.create_task_group() as tg:
for slot in self._slots:
tg.start_soon(self._remove_slot_async, slot)
self._slots.clear()
self._initialized = False

if self._base_dir.exists():
with contextlib.suppress(Exception):
await anyio.run_process(["git", "-C", str(self._git_root), "worktree", "prune"], check=False)
with contextlib.suppress(OSError):
self._base_dir.rmdir()

async def _remove_slot_async(self, slot: WorktreeSlot) -> None:
if slot.path.exists():
await anyio.to_thread.run_sync(functools.partial(shutil.rmtree, slot.path, onerror=_handle_remove_readonly))

async def __aenter__(self) -> Self:
await self.initialize()
return self

async def __aexit__(self, *exc: object) -> None:
await self.cleanup()


def _handle_remove_readonly(func: Callable[..., Any], path: str, exc_info: tuple[Any, ...]) -> None:
if isinstance(exc_info[1], PermissionError):
Path(path).chmod(stat.S_IWUSR | stat.S_IRUSR | stat.S_IXUSR)
func(path)
else:
raise exc_info[1]
35 changes: 29 additions & 6 deletions codeflash/languages/python/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import custom_addopts
from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args
from codeflash.languages.registry import get_language_support

# Pattern to extract timing from stdout markers: !######...:<duration_ns>######!
Expand Down Expand Up @@ -92,11 +91,35 @@ def _ensure_runtime_files(project_root: Path, language: str = "javascript") -> N

def execute_test_subprocess(
cmd_list: list[str], cwd: Path, env: dict[str, str] | None, timeout: int = 600
) -> subprocess.CompletedProcess:
) -> subprocess.CompletedProcess[str]:
"""Execute a subprocess with the given command list, working directory, environment variables, and timeout."""
logger.debug(f"executing test run with command: {' '.join(cmd_list)}")
with custom_addopts():
run_args = get_cross_platform_subprocess_run_args(
cwd=cwd, env=env, timeout=timeout, check=False, text=True, capture_output=True
)
return subprocess.run(cmd_list, **run_args) # noqa: PLW1510
return subprocess.run(cmd_list, cwd=cwd, env=env, timeout=timeout, check=False, text=True, capture_output=True)


async def async_execute_test_subprocess(
cmd_list: list[str], cwd: Path, env: dict[str, str] | None, timeout: int = 600
) -> subprocess.CompletedProcess[str]:
"""Execute a test subprocess asynchronously using anyio."""
import os as _os

import anyio

logger.debug(f"async executing test run with command: {' '.join(cmd_list)}")

merged_env = _os.environ.copy()
if env:
merged_env.update(env)

with custom_addopts():
try:
with anyio.fail_after(timeout):
result = await anyio.run_process(cmd_list, cwd=cwd, env=merged_env, check=False)
except TimeoutError as e:
raise subprocess.TimeoutExpired(cmd_list, timeout) from e

stdout = result.stdout.decode("utf-8", errors="replace") if result.stdout else ""
stderr = result.stderr.decode("utf-8", errors="replace") if result.stderr else ""

return subprocess.CompletedProcess(args=cmd_list, returncode=result.returncode, stdout=stdout, stderr=stderr)
9 changes: 2 additions & 7 deletions codeflash/models/function_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,9 @@
from pydantic import Field
from pydantic.dataclasses import dataclass

from codeflash.models.shared_types import FunctionParent

@dataclass(frozen=True)
class FunctionParent:
name: str
type: str

def __str__(self) -> str:
return f"{self.type}:{self.name}"
__all__ = ["FunctionParent", "FunctionToOptimize"]


@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
Expand Down
52 changes: 31 additions & 21 deletions codeflash/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, ValidationError, model_validator
from pydantic.dataclasses import dataclass

from codeflash.models.shared_types import OptimizedCandidateSource
from codeflash.models.test_type import TestType

if TYPE_CHECKING:
Expand Down Expand Up @@ -50,6 +51,23 @@ class AIServiceRefinerRequest:
additional_context_files: dict[str, str] | None = None # {filepath: content} for imported modules


@dataclass(frozen=True)
class AIServiceBatchRefinerCandidate:
optimization_id: str
optimized_source_code: str
optimized_explanation: str
optimized_code_runtime: int
original_code_runtime: int
speedup: str
optimized_line_profiler_results: str


@dataclass(frozen=True)
class AIServiceBatchRefinerRequest:
shared_context: dict[str, Any]
candidates: list[dict[str, Any]]


# this should be possible to auto serialize
@dataclass(frozen=True)
class AdaptiveOptimizedCandidate:
Expand Down Expand Up @@ -298,11 +316,11 @@ def flat(self) -> str:

"""
if self._cache.get("flat") is not None:
return self._cache["flat"]
return cast("str", self._cache["flat"])
self._cache["flat"] = "\n".join(
get_code_block_splitter(block.file_path) + "\n" + block.code for block in self.code_strings
)
return self._cache["flat"]
return cast("str", self._cache["flat"])

@property
def markdown(self) -> str:
Expand Down Expand Up @@ -332,7 +350,7 @@ def file_to_path(self) -> dict[str, str]:

"""
try:
return self._cache["file_to_path"]
return cast("dict[str, str]", self._cache["file_to_path"])
except KeyError:
mapping = {str(code_string.file_path): code_string.code for code_string in self.code_strings}
self._cache["file_to_path"] = mapping
Expand Down Expand Up @@ -494,7 +512,7 @@ def _normalize_path_for_comparison(path: Path) -> str:
# Only lowercase on Windows where filesystem is case-insensitive
return resolved.lower() if sys.platform == "win32" else resolved

def __iter__(self) -> Iterator[TestFile]:
def __iter__(self) -> Iterator[TestFile]: # type: ignore[override]
return iter(self.test_files)

def __len__(self) -> int:
Expand All @@ -514,9 +532,9 @@ class CandidateEvaluationContext:
optimized_runtimes: dict[str, float | None] = Field(default_factory=dict)
is_correct: dict[str, bool] = Field(default_factory=dict)
optimized_line_profiler_results: dict[str, str] = Field(default_factory=dict)
ast_code_to_id: dict = Field(default_factory=dict)
ast_code_to_id: dict[str, Any] = Field(default_factory=dict)
optimizations_post: dict[str, str] = Field(default_factory=dict)
valid_optimizations: list = Field(default_factory=list)
valid_optimizations: list[Any] = Field(default_factory=list)

def record_failed_candidate(self, optimization_id: str) -> None:
"""Record results for a failed candidate."""
Expand All @@ -543,7 +561,7 @@ def handle_duplicate_candidate(
# Copy results from the previous evaluation (use .get() in case past_opt_id was registered
# but never benchmarked due to an unhandled exception in process_single_candidate)
self.speedup_ratios[candidate.optimization_id] = self.speedup_ratios.get(past_opt_id)
self.is_correct[candidate.optimization_id] = self.is_correct.get(past_opt_id)
self.is_correct[candidate.optimization_id] = self.is_correct.get(past_opt_id, False)
self.optimized_runtimes[candidate.optimization_id] = self.optimized_runtimes.get(past_opt_id)

# Line profiler results only available for successful runs
Expand Down Expand Up @@ -592,15 +610,6 @@ class TestsInFile:
test_type: TestType


class OptimizedCandidateSource(str, Enum):
OPTIMIZE = "OPTIMIZE"
OPTIMIZE_LP = "OPTIMIZE_LP"
REFINE = "REFINE"
REPAIR = "REPAIR"
ADAPTIVE = "ADAPTIVE"
JIT_REWRITE = "JIT_REWRITE"


@dataclass(frozen=True)
class OptimizedCandidate:
source_code: CodeStringsMarkdown
Expand Down Expand Up @@ -631,7 +640,7 @@ class OriginalCodeBaseline(BaseModel):
behavior_test_results: TestResults
benchmarking_test_results: TestResults
replay_benchmarking_test_results: Optional[dict[BenchmarkKey, TestResults]] = None
line_profile_results: dict
line_profile_results: dict[str, Any]
runtime: int
coverage_results: Optional[CoverageData]
async_throughput: Optional[int] = None
Expand Down Expand Up @@ -794,6 +803,7 @@ def get_src_code(self, test_path: Path) -> Optional[str]:
)

if self.test_class_name:
assert self.test_function_name is not None
for stmt in module_node.body:
if isinstance(stmt, cst.ClassDef) and stmt.name.value == self.test_class_name:
func_node = self.find_func_in_class(stmt, self.test_function_name)
Expand Down Expand Up @@ -884,7 +894,7 @@ def group_by_benchmarks(
"""Group TestResults by benchmark for calculating improvements for each benchmark."""
from codeflash.code_utils.code_utils import module_name_from_file_path

test_results_by_benchmark = defaultdict(TestResults)
test_results_by_benchmark: defaultdict[BenchmarkKey, TestResults] = defaultdict(TestResults)
benchmark_module_path = {}
for benchmark_key in benchmark_keys:
benchmark_module_path[benchmark_key] = module_name_from_file_path(
Expand Down Expand Up @@ -1015,7 +1025,7 @@ def effective_loop_count(self) -> int:
return max(loop_indices) if loop_indices else 0

def file_to_no_of_tests(self, test_functions_to_remove: list[str]) -> Counter[Path]:
map_gen_test_file_to_no_of_tests = Counter()
map_gen_test_file_to_no_of_tests: Counter[Path] = Counter()
for gen_test_result in self.test_results:
if (
gen_test_result.test_type == TestType.GENERATED_REGRESSION
Expand All @@ -1024,7 +1034,7 @@ def file_to_no_of_tests(self, test_functions_to_remove: list[str]) -> Counter[Pa
map_gen_test_file_to_no_of_tests[gen_test_result.file_name] += 1
return map_gen_test_file_to_no_of_tests

def __iter__(self) -> Iterator[FunctionTestInvocation]:
def __iter__(self) -> Iterator[FunctionTestInvocation]: # type: ignore[override]
return iter(self.test_results)

def __len__(self) -> int:
Expand All @@ -1051,7 +1061,7 @@ def __eq__(self, other: object) -> bool:
if len(self) != len(other):
return False
original_recursion_limit = sys.getrecursionlimit()
cast("TestResults", other)
assert isinstance(other, TestResults)
for test_result in self:
other_test_result = other.get_by_unique_invocation_loop_id(test_result.unique_invocation_loop_id)
if other_test_result is None:
Expand Down
Loading
Loading