Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion codeflash/api/aiservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def optimize_code_refinement_batch(

if response.status_code == 200:
refined_optimizations = response.json()["refinements"]
return self._get_valid_candidates(refined_optimizations, OptimizedCandidateSource.REFINE)
return self._get_valid_candidates(refined_optimizations, OptimizedCandidateSource.REFINE, language=language)

self.log_error_response(response, "generating batch optimized candidates", "cli-optimize-error-response")
console.rule()
Expand Down
77 changes: 54 additions & 23 deletions codeflash/code_utils/worktree_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import contextlib
import functools
import os
import shutil
import stat
import sys
from pathlib import Path
from typing import TYPE_CHECKING, Any

Expand All @@ -17,6 +17,8 @@
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.git_utils import git_root_dir, mirror_path

_USE_ONEXC = sys.version_info >= (3, 12)


class WorktreeSlot:
__slots__ = ("_git_root", "index", "path")
Expand All @@ -34,10 +36,6 @@ async def write_candidate(self, file_path: Path, code: str) -> None:
await mirrored.parent.mkdir(parents=True, exist_ok=True)
await mirrored.write_text(code, encoding="utf-8")

async def restore_file(self, file_path: Path, original_code: str) -> None:
mirrored = anyio.Path(self.mirror(file_path))
await mirrored.write_text(original_code, encoding="utf-8")


class WorktreePool:
def __init__(self, pool_size: int = 4, base_dir: Path | None = None) -> None:
Expand All @@ -48,68 +46,86 @@ def __init__(self, pool_size: int = 4, base_dir: Path | None = None) -> None:
self._send: anyio.abc.ObjectSendStream[WorktreeSlot] | None = None
self._receive: anyio.abc.ObjectReceiveStream[WorktreeSlot] | None = None
self._initialized = False
self._closed = False

async def initialize(self) -> None:
if self._initialized:
return
await anyio.Path(self._base_dir).mkdir(parents=True, exist_ok=True)

results: list[WorktreeSlot | None] = [None] * self._pool_size
async with anyio.create_task_group() as tg:
results: list[WorktreeSlot | None] = [None] * self._pool_size
for i in range(self._pool_size):
tg.start_soon(self._create_slot_task, i, results)

self._slots = [s for s in results if s is not None]
self._send, self._receive = anyio.create_memory_object_stream[WorktreeSlot](self._pool_size)
if not self._slots:
msg = "Failed to create any worktree slots"
raise RuntimeError(msg)

self._send, self._receive = anyio.create_memory_object_stream[WorktreeSlot](len(self._slots))
for slot in self._slots:
await self._send.send(slot)
self._initialized = True
logger.debug(f"WorktreePool initialized with {len(self._slots)} slots at {self._base_dir}")

async def _create_slot_task(self, index: int, results: list[WorktreeSlot | None]) -> None:
results[index] = await self._create_slot(index)
try:
results[index] = await self._create_slot(index)
except Exception as exc:
logger.warning(f"Failed to create worktree slot {index}: {exc}")

async def _create_slot(self, index: int) -> WorktreeSlot:
slot_dir = self._base_dir / f"slot-{index}"
if slot_dir.exists():
await anyio.to_thread.run_sync(functools.partial(shutil.rmtree, slot_dir, onerror=_handle_remove_readonly))
if await anyio.Path(slot_dir).exists():
await anyio.to_thread.run_sync(functools.partial(_rmtree_safe, slot_dir))

result = await anyio.run_process(
["git", "-C", str(self._git_root), "worktree", "add", "--detach", str(slot_dir), "HEAD"], check=False
)
if result.returncode != 0:
msg = f"Failed to create worktree slot {index}: {result.stderr.decode()}"
msg = f"git worktree add failed for slot {index}: {result.stderr.decode()}"
raise RuntimeError(msg)

pid_file = anyio.Path(slot_dir / ".codeflash_pool.pid")
await pid_file.write_text(str(os.getpid()), encoding="utf-8")

return WorktreeSlot(slot_dir, index, self._git_root)

async def acquire(self) -> WorktreeSlot:
assert self._receive is not None
return await self._receive.receive()

async def release(self, slot: WorktreeSlot) -> None:
if self._closed:
return
assert self._send is not None
await self._send.send(slot)
with contextlib.suppress(anyio.ClosedResourceError):
await self._send.send(slot)

async def cleanup(self) -> None:
async with anyio.create_task_group() as tg:
for slot in self._slots:
tg.start_soon(self._remove_slot_async, slot)
self._closed = True

if self._send is not None:
await self._send.aclose()
if self._receive is not None:
await self._receive.aclose()

for slot in self._slots:
try:
await self._remove_slot_async(slot)
except Exception as exc:
logger.warning(f"Failed to remove worktree slot {slot.index}: {exc}")

self._slots.clear()
self._initialized = False

if self._base_dir.exists():
if await anyio.Path(self._base_dir).exists():
with contextlib.suppress(Exception):
await anyio.run_process(["git", "-C", str(self._git_root), "worktree", "prune"], check=False)
with contextlib.suppress(OSError):
self._base_dir.rmdir()
await anyio.Path(self._base_dir).rmdir()

async def _remove_slot_async(self, slot: WorktreeSlot) -> None:
if slot.path.exists():
await anyio.to_thread.run_sync(functools.partial(shutil.rmtree, slot.path, onerror=_handle_remove_readonly))
if await anyio.Path(slot.path).exists():
await anyio.to_thread.run_sync(functools.partial(_rmtree_safe, slot.path))

async def __aenter__(self) -> Self:
await self.initialize()
Expand All @@ -119,7 +135,22 @@ async def __aexit__(self, *exc: object) -> None:
await self.cleanup()


def _handle_remove_readonly(func: Callable[..., Any], path: str, exc_info: tuple[Any, ...]) -> None:
def _rmtree_safe(path: Path) -> None:
if _USE_ONEXC:
shutil.rmtree(path, onexc=_handle_remove_readonly_onexc)
else:
shutil.rmtree(path, onerror=_handle_remove_readonly_onerror)


def _handle_remove_readonly_onexc(func: Callable[..., Any], path: str, exc: BaseException) -> None:
if isinstance(exc, PermissionError):
Path(path).chmod(stat.S_IWUSR | stat.S_IRUSR | stat.S_IXUSR)
func(path)
else:
raise exc


def _handle_remove_readonly_onerror(func: Callable[..., Any], path: str, exc_info: tuple[Any, ...]) -> None:
if isinstance(exc_info[1], PermissionError):
Path(path).chmod(stat.S_IWUSR | stat.S_IRUSR | stat.S_IXUSR)
func(path)
Expand Down
43 changes: 30 additions & 13 deletions codeflash/languages/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,7 @@ def _run_line_profiler_for_winner(
)
eval_ctx.record_line_profiler_result(best_optimization.candidate.optimization_id, lp_results["str_out"])
best_optimization.line_profiler_test_results = lp_results
except (ValueError, SyntaxError, AttributeError) as e:
except (ValueError, SyntaxError, AttributeError, Exception) as e:
logger.warning(f"Line profiler failed for winning candidate: {e}")
finally:
self.write_code_and_helpers(
Expand Down Expand Up @@ -1541,8 +1541,6 @@ def _evaluate_candidates_parallel(
original_code_baseline=original_code_baseline,
original_helper_code=original_helper_code,
file_path_to_helper_classes=file_path_to_helper_classes,
eval_ctx=eval_ctx,
exp_type=exp_type,
pool_size=pool_size,
)

Expand Down Expand Up @@ -1602,18 +1600,34 @@ def _evaluate_candidates_parallel(
)
eval_ctx.valid_optimizations.append(best_optimization)

batch_refiner_candidates.append(
AIServiceBatchRefinerCandidate(
optimization_id=candidate.optimization_id,
optimized_source_code=candidate.source_code.markdown,
optimized_explanation=candidate.explanation,
optimized_code_runtime=candidate_result.best_test_runtime,
original_code_runtime=original_code_baseline.runtime,
speedup=f"{int(perf_gain * 100)}%",
optimized_line_profiler_results="",
)
current_tree_candidates = candidate_node.path_to_root()
is_candidate_refined_before = any(
c.source == OptimizedCandidateSource.REFINE for c in current_tree_candidates
)

if is_candidate_refined_before:
future_adaptive = self.call_adaptive_optimize(
trace_id=self.get_trace_id(exp_type),
original_source_code=code_context.read_writable_code.markdown,
prev_candidates=current_tree_candidates,
eval_ctx=eval_ctx,
ai_service_client=ai_service_client,
)
if future_adaptive:
self.future_adaptive_optimizations.append(future_adaptive)
else:
batch_refiner_candidates.append(
AIServiceBatchRefinerCandidate(
optimization_id=candidate.optimization_id,
optimized_source_code=candidate.source_code.markdown,
optimized_explanation=candidate.explanation,
optimized_code_runtime=candidate_result.best_test_runtime,
original_code_runtime=original_code_baseline.runtime,
speedup=f"{int(perf_gain * 100)}%",
optimized_line_profiler_results="",
)
)

# Dispatch refinement immediately so CandidateProcessor sees it
if batch_refiner_candidates:
self._dispatch_refinement(
Expand Down Expand Up @@ -1684,6 +1698,9 @@ def _dispatch_repair_if_possible(
test_diffs: list[TestDiff] | None = None,
) -> concurrent.futures.Future | None:
"""Submit a code repair request if the candidate is eligible."""
if not test_diffs:
return None

max_repairs = get_effort_value(EffortKeys.MAX_CODE_REPAIRS_PER_TRACE, self.effort)
if self.repair_counter >= max_repairs:
return None
Expand Down
6 changes: 0 additions & 6 deletions codeflash/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,6 @@ class AIServiceBatchRefinerCandidate:
optimized_line_profiler_results: str


@dataclass(frozen=True)
class AIServiceBatchRefinerRequest:
shared_context: dict[str, Any]
candidates: list[dict[str, Any]]


# this should be possible to auto serialize
@dataclass(frozen=True)
class AdaptiveOptimizedCandidate:
Expand Down
Loading
Loading