Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
333 changes: 304 additions & 29 deletions codeflash/languages/function_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# mypy: ignore-errors
from __future__ import annotations

import concurrent.futures
Expand Down Expand Up @@ -77,6 +78,7 @@
from codeflash.models.models import (
AdaptiveOptimizedCandidate,
AIServiceAdaptiveOptimizeRequest,
AIServiceBatchRefinerCandidate,
AIServiceCodeRepairRequest,
BestOptimization,
CandidateEvaluationContext,
Expand Down Expand Up @@ -1018,6 +1020,34 @@ def handle_successful_candidate(

return best_optimization, benchmark_tree

def _run_line_profiler_for_winner(
self,
best_optimization: BestOptimization,
code_context: CodeOptimizationContext,
original_helper_code: dict[Path, str],
eval_ctx: CandidateEvaluationContext,
) -> BestOptimization:
"""Run line profiler on the winning candidate from parallel evaluation."""
try:
self.replace_function_and_helpers_with_optimized_code(
code_context=code_context,
optimized_code=best_optimization.candidate.source_code,
original_helper_code=original_helper_code,
)
with progress_bar("Running line-by-line profiling"):
lp_results = self.line_profiler_step(
code_context=code_context, original_helper_code=original_helper_code, candidate_index=0
)
eval_ctx.record_line_profiler_result(best_optimization.candidate.optimization_id, lp_results["str_out"])
best_optimization.line_profiler_test_results = lp_results
except (ValueError, SyntaxError, AttributeError) as e:
logger.warning(f"Line profiler failed for winning candidate: {e}")
finally:
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
return best_optimization

def select_best_optimization(
self,
eval_ctx: CandidateEvaluationContext,
Expand Down Expand Up @@ -1378,37 +1408,52 @@ def determine_best_candidate(
original_flat_code=code_context.read_writable_code.flat,
)
candidate_index = 0
parallel_pool_size = getattr(self.args, "parallel_candidates", 0)

# Process candidates using queue-based approach
while not processor.is_done():
candidate_node = processor.get_next_candidate()
if candidate_node is None:
logger.debug("everything done, exiting")
break
if parallel_pool_size > 1:
self._evaluate_candidates_parallel(
processor=processor,
code_context=code_context,
original_code_baseline=original_code_baseline,
original_helper_code=original_helper_code,
file_path_to_helper_classes=file_path_to_helper_classes,
eval_ctx=eval_ctx,
exp_type=exp_type,
function_references=function_references,
normalized_original=normalized_original,
pool_size=parallel_pool_size,
)
else:
# Process candidates using queue-based approach (sequential)
while not processor.is_done():
candidate_node = processor.get_next_candidate()
if candidate_node is None:
logger.debug("everything done, exiting")
break

try:
candidate_index += 1
self.process_single_candidate(
candidate_node=candidate_node,
candidate_index=candidate_index,
total_candidates=processor.candidate_len,
code_context=code_context,
original_code_baseline=original_code_baseline,
original_helper_code=original_helper_code,
file_path_to_helper_classes=file_path_to_helper_classes,
eval_ctx=eval_ctx,
exp_type=exp_type,
function_references=function_references,
normalized_original=normalized_original,
cached_normalized_code=processor.normalized_cache.get(candidate_node.candidate.optimization_id),
)
except KeyboardInterrupt as e:
logger.exception(f"Optimization interrupted: {e}")
raise
finally:
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
try:
candidate_index += 1
self.process_single_candidate(
candidate_node=candidate_node,
candidate_index=candidate_index,
total_candidates=processor.candidate_len,
code_context=code_context,
original_code_baseline=original_code_baseline,
original_helper_code=original_helper_code,
file_path_to_helper_classes=file_path_to_helper_classes,
eval_ctx=eval_ctx,
exp_type=exp_type,
function_references=function_references,
normalized_original=normalized_original,
cached_normalized_code=processor.normalized_cache.get(candidate_node.candidate.optimization_id),
)
except KeyboardInterrupt as e:
logger.exception(f"Optimization interrupted: {e}")
raise
finally:
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)

# Select and return the best optimization
best_optimization = self.select_best_optimization(
Expand All @@ -1421,6 +1466,11 @@ def determine_best_candidate(
)

if best_optimization:
if parallel_pool_size > 1:
best_optimization = self._run_line_profiler_for_winner(
best_optimization, code_context, original_helper_code, eval_ctx
)

self.log_evaluation_results(
eval_ctx=eval_ctx,
best_optimization=best_optimization,
Expand All @@ -1431,6 +1481,231 @@ def determine_best_candidate(

return best_optimization

def _evaluate_candidates_parallel(
self,
processor: CandidateProcessor,
code_context: CodeOptimizationContext,
original_code_baseline: OriginalCodeBaseline,
original_helper_code: dict[Path, str],
file_path_to_helper_classes: dict[Path, set[str]],
eval_ctx: CandidateEvaluationContext,
exp_type: str,
function_references: str,
normalized_original: str,
pool_size: int,
) -> None:
"""Evaluate candidates in parallel using git worktrees and async subprocess execution."""
from codeflash.optimization.parallel_evaluator import run_parallel_evaluation

ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
assert ai_service_client is not None

candidate_index = 0

while not processor.is_done():
batch: list[tuple[CandidateNode, int, str | None]] = []
while len(batch) < pool_size:
candidate_node = processor.get_next_candidate()
if candidate_node is None:
break
candidate_index += 1
cached = processor.normalized_cache.get(candidate_node.candidate.optimization_id)

normalized_code = cached or self.language_support.normalize_code(
candidate_node.candidate.source_code.flat.strip()
)
if normalized_code == normalized_original:
logger.info(f"h3|Candidate {candidate_index}: Identical to original code, skipping.")
continue
if normalized_code in eval_ctx.ast_code_to_id:
logger.info(f"h3|Candidate {candidate_index}: Duplicate of a previous candidate, skipping.")
eval_ctx.handle_duplicate_candidate(
candidate_node.candidate, normalized_code, code_context.read_writable_code.flat
)
continue

eval_ctx.register_new_candidate(
normalized_code, candidate_node.candidate, code_context.read_writable_code.flat
)
batch.append((candidate_node, candidate_index, cached))

if not batch:
break

logger.info(f"Evaluating batch of {len(batch)} candidates in parallel…")

results, _, _ = run_parallel_evaluation(
optimizer=self,
candidates=batch,
code_context=code_context,
original_code_baseline=original_code_baseline,
original_helper_code=original_helper_code,
file_path_to_helper_classes=file_path_to_helper_classes,
eval_ctx=eval_ctx,
exp_type=exp_type,
pool_size=pool_size,
)

# Process results and dispatch refinement/repair futures immediately
batch_refiner_candidates: list[AIServiceBatchRefinerCandidate] = []
for (candidate_node, _idx, _), (_, run_result) in zip(batch, results):
candidate = candidate_node.candidate

if run_result is None or not is_successful(run_result):
eval_ctx.record_failed_candidate(candidate.optimization_id)
if run_result is not None and isinstance(run_result, Failure):
eval_failure = run_result.failure()
repair_future = self._dispatch_repair_if_possible(
candidate,
eval_ctx,
code_context,
exp_type,
ai_service_client,
test_diffs=eval_failure.diffs,
)
if repair_future is not None:
self.future_all_code_repair.append(repair_future)
continue

candidate_result = run_result.unwrap()
perf_gain = performance_gain(
original_runtime_ns=original_code_baseline.runtime,
optimized_runtime_ns=candidate_result.best_test_runtime,
)
eval_ctx.record_successful_candidate(
candidate.optimization_id, candidate_result.best_test_runtime, perf_gain
)

is_successful_opt = speedup_critic(
candidate_result,
original_code_baseline.runtime,
best_runtime_until_now=None,
original_async_throughput=original_code_baseline.async_throughput,
best_throughput_until_now=None,
original_concurrency_metrics=original_code_baseline.concurrency_metrics,
best_concurrency_ratio_until_now=None,
) and quantity_of_tests_critic(candidate_result)

if is_successful_opt:
empty_lp = {"timings": {}, "unit": 0, "str_out": ""}
best_optimization = BestOptimization(
candidate=candidate,
helper_functions=code_context.helper_functions,
code_context=code_context,
runtime=candidate_result.best_test_runtime,
line_profiler_test_results=empty_lp,
winning_behavior_test_results=candidate_result.behavior_test_results,
winning_benchmarking_test_results=candidate_result.benchmarking_test_results,
winning_replay_benchmarking_test_results=None,
async_throughput=candidate_result.async_throughput,
concurrency_metrics=candidate_result.concurrency_metrics,
)
eval_ctx.valid_optimizations.append(best_optimization)

batch_refiner_candidates.append(
AIServiceBatchRefinerCandidate(
optimization_id=candidate.optimization_id,
optimized_source_code=candidate.source_code.markdown,
optimized_explanation=candidate.explanation,
optimized_code_runtime=candidate_result.best_test_runtime,
original_code_runtime=original_code_baseline.runtime,
speedup=f"{int(perf_gain * 100)}%",
optimized_line_profiler_results="",
)
)

# Dispatch refinement immediately so CandidateProcessor sees it
if batch_refiner_candidates:
self._dispatch_refinement(
batch_refiner_candidates,
code_context,
original_code_baseline,
exp_type,
function_references,
ai_service_client,
)

def _dispatch_refinement(
self,
batch_refiner_candidates: list[AIServiceBatchRefinerCandidate],
code_context: CodeOptimizationContext,
original_code_baseline: OriginalCodeBaseline,
exp_type: str,
function_references: str,
ai_service_client: AiServiceClient,
) -> None:
"""Submit refinement request to thread pool so CandidateProcessor can consume results."""
if len(batch_refiner_candidates) > 1:
future = self.executor.submit(
ai_service_client.optimize_code_refinement_batch,
original_source_code=code_context.read_writable_code.markdown,
read_only_dependency_code=code_context.read_only_context_code,
original_line_profiler_results=original_code_baseline.line_profile_results["str_out"],
trace_id=self.get_trace_id(exp_type),
language=self.function_to_optimize.language,
language_version=self.language_support.language_version,
function_references=function_references,
candidates=batch_refiner_candidates,
rerun_trace_id=self.rerun_trace_id,
)
else:
c = batch_refiner_candidates[0]
future = self.executor.submit(
ai_service_client.optimize_code_refinement,
request=[
AIServiceRefinerRequest(
optimization_id=c.optimization_id,
original_source_code=code_context.read_writable_code.markdown,
read_only_dependency_code=code_context.read_only_context_code,
original_code_runtime=c.original_code_runtime,
optimized_source_code=c.optimized_source_code,
optimized_explanation=c.optimized_explanation,
optimized_code_runtime=c.optimized_code_runtime,
speedup=c.speedup,
trace_id=self.get_trace_id(exp_type),
original_line_profiler_results=original_code_baseline.line_profile_results["str_out"],
optimized_line_profiler_results=c.optimized_line_profiler_results,
function_references=function_references,
language=self.function_to_optimize.language,
language_version=self.language_support.language_version,
)
],
rerun_trace_id=self.rerun_trace_id,
)
self.future_all_refinements.append(future)

def _dispatch_repair_if_possible(
self,
candidate: OptimizedCandidate,
eval_ctx: CandidateEvaluationContext,
code_context: CodeOptimizationContext,
exp_type: str,
ai_service_client: AiServiceClient,
test_diffs: list[TestDiff] | None = None,
) -> concurrent.futures.Future | None:
"""Submit a code repair request if the candidate is eligible."""
max_repairs = get_effort_value(EffortKeys.MAX_CODE_REPAIRS_PER_TRACE, self.effort)
if self.repair_counter >= max_repairs:
return None

successful_candidates_count = sum(1 for is_correct in eval_ctx.is_correct.values() if is_correct)
if successful_candidates_count >= MIN_CORRECT_CANDIDATES:
return None

if candidate.source not in (OptimizedCandidateSource.OPTIMIZE, OptimizedCandidateSource.OPTIMIZE_LP):
return None

self.repair_counter += 1
request = AIServiceCodeRepairRequest(
optimization_id=candidate.optimization_id,
original_source_code=code_context.read_writable_code.markdown,
modified_source_code=candidate.source_code.markdown,
test_diffs=test_diffs or [],
trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
language=self.function_to_optimize.language,
)
return self.executor.submit(ai_service_client.code_repair, request=request, rerun_trace_id=self.rerun_trace_id)

def call_adaptive_optimize(
self,
trace_id: str,
Expand Down