Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -429,3 +429,4 @@ code_to_optimize/**/package-lock.json

# Other tools
.codeflash/
.codeflash_eval_worktrees/
94 changes: 93 additions & 1 deletion codeflash/api/aiservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
FunctionRepairInfo,
OptimizationReviewResult,
OptimizedCandidate,
OptimizedCandidateSource,
TestFileReview,
)
from codeflash.models.shared_types import OptimizedCandidateSource
from codeflash.telemetry.posthog_cf import ph
from codeflash.version import __version__ as codeflash_version

Expand All @@ -35,6 +35,7 @@
from codeflash.models.ExperimentMetadata import ExperimentMetadata
from codeflash.models.models import (
AIServiceAdaptiveOptimizeRequest,
AIServiceBatchRefinerCandidate,
AIServiceCodeRepairRequest,
AIServiceRefinerRequest,
)
Expand Down Expand Up @@ -384,6 +385,97 @@ def optimize_code_refinement(
console.rule()
return []

def optimize_code_refinement_batch(
self,
*,
original_source_code: str,
read_only_dependency_code: str,
original_line_profiler_results: str,
trace_id: str,
language: str,
language_version: str | None,
function_references: str | None,
candidates: list[AIServiceBatchRefinerCandidate],
rerun_trace_id: str | None = None,
) -> list[OptimizedCandidate]:
shared_context: dict[str, Any] = {
"original_source_code": original_source_code,
"read_only_dependency_code": read_only_dependency_code,
"original_line_profiler_results": original_line_profiler_results,
"trace_id": trace_id,
"language": language,
"function_references": function_references,
"rerun_trace_id": rerun_trace_id,
}
self.add_language_metadata(shared_context, language_version)

candidate_payloads: list[dict[str, Any]] = []
for c in candidates:
candidate_payloads.append(
{
"optimization_id": c.optimization_id,
"optimized_source_code": c.optimized_source_code,
"optimized_explanation": c.optimized_explanation,
"optimized_code_runtime": humanize_runtime(c.optimized_code_runtime),
"original_code_runtime": humanize_runtime(c.original_code_runtime),
"speedup": c.speedup,
"optimized_line_profiler_results": c.optimized_line_profiler_results,
"call_sequence": self.get_next_sequence(),
}
)

payload: dict[str, Any] = {"shared_context": shared_context, "candidates": candidate_payloads}

try:
response = self.make_ai_service_request("/batch_refinement", payload=payload, timeout=self.timeout)
except requests.exceptions.RequestException as e:
logger.exception(f"Error generating batch optimization refinements: {e}")
ph("cli-optimize-error-caught", {"error": str(e)})
return []

if response.status_code == 404:
return self._fallback_to_sequential_refinement(
shared_context=shared_context, candidates=candidates, rerun_trace_id=rerun_trace_id
)

if response.status_code == 200:
refined_optimizations = response.json()["refinements"]
return self._get_valid_candidates(refined_optimizations, OptimizedCandidateSource.REFINE, language=language)

self.log_error_response(response, "generating batch optimized candidates", "cli-optimize-error-response")
console.rule()
return []

def _fallback_to_sequential_refinement(
self,
*,
shared_context: dict[str, Any],
candidates: list[AIServiceBatchRefinerCandidate],
rerun_trace_id: str | None,
) -> list[OptimizedCandidate]:
from codeflash.models.models import AIServiceRefinerRequest

requests_list = [
AIServiceRefinerRequest(
optimization_id=c.optimization_id,
original_source_code=shared_context["original_source_code"],
read_only_dependency_code=shared_context["read_only_dependency_code"],
original_code_runtime=c.original_code_runtime,
optimized_source_code=c.optimized_source_code,
optimized_explanation=c.optimized_explanation,
optimized_code_runtime=c.optimized_code_runtime,
speedup=c.speedup,
trace_id=shared_context["trace_id"],
original_line_profiler_results=shared_context["original_line_profiler_results"],
optimized_line_profiler_results=c.optimized_line_profiler_results,
function_references=shared_context.get("function_references"),
language=shared_context["language"],
language_version=shared_context.get("language_version"),
)
for c in candidates
]
return self.optimize_code_refinement(requests_list, rerun_trace_id=rerun_trace_id)

def code_repair(
self, request: AIServiceCodeRepairRequest, rerun_trace_id: str | None = None
) -> OptimizedCandidate | None:
Expand Down
7 changes: 7 additions & 0 deletions codeflash/cli_cmds/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,13 @@ def _build_parser() -> ArgumentParser:
)
parser.add_argument("--no-draft", default=False, action="store_true", help="Skip optimization for draft PRs")
parser.add_argument("--worktree", default=False, action="store_true", help="Use worktree for optimization")
parser.add_argument(
"--parallel-candidates",
type=int,
default=0,
metavar="N",
help="Evaluate up to N optimization candidates in parallel using git worktrees (0 = sequential)",
)
parser.add_argument(
"--testgen-review", default=False, action="store_true", help="Enable AI review and repair of generated tests"
)
Expand Down
158 changes: 158 additions & 0 deletions codeflash/code_utils/worktree_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from __future__ import annotations

import contextlib
import functools
import shutil
import stat
import sys
from pathlib import Path
from typing import TYPE_CHECKING, Any

import anyio

if TYPE_CHECKING:
from collections.abc import Callable
from typing import Self

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.git_utils import git_root_dir, mirror_path

_USE_ONEXC = sys.version_info >= (3, 12)


class WorktreeSlot:
__slots__ = ("_git_root", "index", "path")

def __init__(self, path: Path, index: int, git_root: Path) -> None:
self.path = path
self.index = index
self._git_root = git_root

def mirror(self, original_path: Path) -> Path:
return mirror_path(original_path, self._git_root, self.path)

async def write_candidate(self, file_path: Path, code: str) -> None:
mirrored = anyio.Path(self.mirror(file_path))
await mirrored.parent.mkdir(parents=True, exist_ok=True)
await mirrored.write_text(code, encoding="utf-8")


class WorktreePool:
def __init__(self, pool_size: int = 4, base_dir: Path | None = None) -> None:
self._pool_size = pool_size
self._git_root = git_root_dir()
self._base_dir = base_dir or (self._git_root / ".codeflash_eval_worktrees")
self._slots: list[WorktreeSlot] = []
self._send: anyio.abc.ObjectSendStream[WorktreeSlot] | None = None
self._receive: anyio.abc.ObjectReceiveStream[WorktreeSlot] | None = None
self._initialized = False
self._closed = False

async def initialize(self) -> None:
if self._initialized:
return
await anyio.Path(self._base_dir).mkdir(parents=True, exist_ok=True)

results: list[WorktreeSlot | None] = [None] * self._pool_size
async with anyio.create_task_group() as tg:
for i in range(self._pool_size):
tg.start_soon(self._create_slot_task, i, results)

self._slots = [s for s in results if s is not None]
if not self._slots:
msg = "Failed to create any worktree slots"
raise RuntimeError(msg)

self._send, self._receive = anyio.create_memory_object_stream[WorktreeSlot](len(self._slots))
for slot in self._slots:
await self._send.send(slot)
self._initialized = True
logger.debug(f"WorktreePool initialized with {len(self._slots)} slots at {self._base_dir}")

async def _create_slot_task(self, index: int, results: list[WorktreeSlot | None]) -> None:
try:
results[index] = await self._create_slot(index)
except Exception as exc:
logger.warning(f"Failed to create worktree slot {index}: {exc}")

async def _create_slot(self, index: int) -> WorktreeSlot:
slot_dir = self._base_dir / f"slot-{index}"
if await anyio.Path(slot_dir).exists():
await anyio.to_thread.run_sync(functools.partial(_rmtree_safe, slot_dir))

result = await anyio.run_process(
["git", "-C", str(self._git_root), "worktree", "add", "--detach", str(slot_dir), "HEAD"], check=False
)
if result.returncode != 0:
msg = f"git worktree add failed for slot {index}: {result.stderr.decode()}"
raise RuntimeError(msg)

return WorktreeSlot(slot_dir, index, self._git_root)

async def acquire(self) -> WorktreeSlot:
assert self._receive is not None
return await self._receive.receive()

async def release(self, slot: WorktreeSlot) -> None:
if self._closed:
return
assert self._send is not None
with contextlib.suppress(anyio.ClosedResourceError):
await self._send.send(slot)

async def cleanup(self) -> None:
self._closed = True

if self._send is not None:
await self._send.aclose()
if self._receive is not None:
await self._receive.aclose()

for slot in self._slots:
try:
await self._remove_slot_async(slot)
except Exception as exc:
logger.warning(f"Failed to remove worktree slot {slot.index}: {exc}")

self._slots.clear()
self._initialized = False

if await anyio.Path(self._base_dir).exists():
with contextlib.suppress(Exception):
await anyio.run_process(["git", "-C", str(self._git_root), "worktree", "prune"], check=False)
with contextlib.suppress(OSError):
await anyio.Path(self._base_dir).rmdir()

async def _remove_slot_async(self, slot: WorktreeSlot) -> None:
if await anyio.Path(slot.path).exists():
await anyio.to_thread.run_sync(functools.partial(_rmtree_safe, slot.path))

async def __aenter__(self) -> Self:
await self.initialize()
return self

async def __aexit__(self, *exc: object) -> None:
await self.cleanup()


def _rmtree_safe(path: Path) -> None:
if _USE_ONEXC:
shutil.rmtree(path, onexc=_handle_remove_readonly_onexc)
else:
shutil.rmtree(path, onerror=_handle_remove_readonly_onerror)


def _handle_remove_readonly_onexc(func: Callable[..., Any], path: str, exc: BaseException) -> None:
if isinstance(exc, PermissionError):
Path(path).chmod(stat.S_IWUSR | stat.S_IRUSR | stat.S_IXUSR)
func(path)
else:
raise exc


def _handle_remove_readonly_onerror(func: Callable[..., Any], path: str, exc_info: tuple[Any, ...]) -> None:
if isinstance(exc_info[1], PermissionError):
Path(path).chmod(stat.S_IWUSR | stat.S_IRUSR | stat.S_IXUSR)
func(path)
else:
raise exc_info[1]
Loading
Loading