Skip to content

Commit a47cfad

Browse files
committed
feat: integrate parallel evaluator into function optimizer
Wires the parallel evaluation path into _evaluate_candidates: - Checks --parallel-candidates flag to branch between sequential/parallel - Batches candidates with dedup/normalization gating - Dispatches repair and refinement futures from evaluation results - Calls _run_line_profiler_for_winner after selection New methods: _evaluate_candidates_parallel, _dispatch_refinement, _dispatch_repair_if_possible.
1 parent f69f20f commit a47cfad

1 file changed

Lines changed: 275 additions & 29 deletions

File tree

codeflash/languages/function_optimizer.py

Lines changed: 275 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
from codeflash.models.models import (
7979
AdaptiveOptimizedCandidate,
8080
AIServiceAdaptiveOptimizeRequest,
81+
AIServiceBatchRefinerCandidate,
8182
AIServiceCodeRepairRequest,
8283
BestOptimization,
8384
CandidateEvaluationContext,
@@ -1407,37 +1408,52 @@ def determine_best_candidate(
14071408
original_flat_code=code_context.read_writable_code.flat,
14081409
)
14091410
candidate_index = 0
1411+
parallel_pool_size = getattr(self.args, "parallel_candidates", 0)
14101412

1411-
# Process candidates using queue-based approach
1412-
while not processor.is_done():
1413-
candidate_node = processor.get_next_candidate()
1414-
if candidate_node is None:
1415-
logger.debug("everything done, exiting")
1416-
break
1413+
if parallel_pool_size > 1:
1414+
self._evaluate_candidates_parallel(
1415+
processor=processor,
1416+
code_context=code_context,
1417+
original_code_baseline=original_code_baseline,
1418+
original_helper_code=original_helper_code,
1419+
file_path_to_helper_classes=file_path_to_helper_classes,
1420+
eval_ctx=eval_ctx,
1421+
exp_type=exp_type,
1422+
function_references=function_references,
1423+
normalized_original=normalized_original,
1424+
pool_size=parallel_pool_size,
1425+
)
1426+
else:
1427+
# Process candidates using queue-based approach (sequential)
1428+
while not processor.is_done():
1429+
candidate_node = processor.get_next_candidate()
1430+
if candidate_node is None:
1431+
logger.debug("everything done, exiting")
1432+
break
14171433

1418-
try:
1419-
candidate_index += 1
1420-
self.process_single_candidate(
1421-
candidate_node=candidate_node,
1422-
candidate_index=candidate_index,
1423-
total_candidates=processor.candidate_len,
1424-
code_context=code_context,
1425-
original_code_baseline=original_code_baseline,
1426-
original_helper_code=original_helper_code,
1427-
file_path_to_helper_classes=file_path_to_helper_classes,
1428-
eval_ctx=eval_ctx,
1429-
exp_type=exp_type,
1430-
function_references=function_references,
1431-
normalized_original=normalized_original,
1432-
cached_normalized_code=processor.normalized_cache.get(candidate_node.candidate.optimization_id),
1433-
)
1434-
except KeyboardInterrupt as e:
1435-
logger.exception(f"Optimization interrupted: {e}")
1436-
raise
1437-
finally:
1438-
self.write_code_and_helpers(
1439-
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
1440-
)
1434+
try:
1435+
candidate_index += 1
1436+
self.process_single_candidate(
1437+
candidate_node=candidate_node,
1438+
candidate_index=candidate_index,
1439+
total_candidates=processor.candidate_len,
1440+
code_context=code_context,
1441+
original_code_baseline=original_code_baseline,
1442+
original_helper_code=original_helper_code,
1443+
file_path_to_helper_classes=file_path_to_helper_classes,
1444+
eval_ctx=eval_ctx,
1445+
exp_type=exp_type,
1446+
function_references=function_references,
1447+
normalized_original=normalized_original,
1448+
cached_normalized_code=processor.normalized_cache.get(candidate_node.candidate.optimization_id),
1449+
)
1450+
except KeyboardInterrupt as e:
1451+
logger.exception(f"Optimization interrupted: {e}")
1452+
raise
1453+
finally:
1454+
self.write_code_and_helpers(
1455+
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
1456+
)
14411457

14421458
# Select and return the best optimization
14431459
best_optimization = self.select_best_optimization(
@@ -1450,6 +1466,11 @@ def determine_best_candidate(
14501466
)
14511467

14521468
if best_optimization:
1469+
if parallel_pool_size > 1:
1470+
best_optimization = self._run_line_profiler_for_winner(
1471+
best_optimization, code_context, original_helper_code, eval_ctx
1472+
)
1473+
14531474
self.log_evaluation_results(
14541475
eval_ctx=eval_ctx,
14551476
best_optimization=best_optimization,
@@ -1460,6 +1481,231 @@ def determine_best_candidate(
14601481

14611482
return best_optimization
14621483

1484+
def _evaluate_candidates_parallel(
1485+
self,
1486+
processor: CandidateProcessor,
1487+
code_context: CodeOptimizationContext,
1488+
original_code_baseline: OriginalCodeBaseline,
1489+
original_helper_code: dict[Path, str],
1490+
file_path_to_helper_classes: dict[Path, set[str]],
1491+
eval_ctx: CandidateEvaluationContext,
1492+
exp_type: str,
1493+
function_references: str,
1494+
normalized_original: str,
1495+
pool_size: int,
1496+
) -> None:
1497+
"""Evaluate candidates in parallel using git worktrees and async subprocess execution."""
1498+
from codeflash.optimization.parallel_evaluator import run_parallel_evaluation
1499+
1500+
ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
1501+
assert ai_service_client is not None
1502+
1503+
candidate_index = 0
1504+
1505+
while not processor.is_done():
1506+
batch: list[tuple[CandidateNode, int, str | None]] = []
1507+
while len(batch) < pool_size:
1508+
candidate_node = processor.get_next_candidate()
1509+
if candidate_node is None:
1510+
break
1511+
candidate_index += 1
1512+
cached = processor.normalized_cache.get(candidate_node.candidate.optimization_id)
1513+
1514+
normalized_code = cached or self.language_support.normalize_code(
1515+
candidate_node.candidate.source_code.flat.strip()
1516+
)
1517+
if normalized_code == normalized_original:
1518+
logger.info(f"h3|Candidate {candidate_index}: Identical to original code, skipping.")
1519+
continue
1520+
if normalized_code in eval_ctx.ast_code_to_id:
1521+
logger.info(f"h3|Candidate {candidate_index}: Duplicate of a previous candidate, skipping.")
1522+
eval_ctx.handle_duplicate_candidate(
1523+
candidate_node.candidate, normalized_code, code_context.read_writable_code.flat
1524+
)
1525+
continue
1526+
1527+
eval_ctx.register_new_candidate(
1528+
normalized_code, candidate_node.candidate, code_context.read_writable_code.flat
1529+
)
1530+
batch.append((candidate_node, candidate_index, cached))
1531+
1532+
if not batch:
1533+
break
1534+
1535+
logger.info(f"Evaluating batch of {len(batch)} candidates in parallel…")
1536+
1537+
results, _, _ = run_parallel_evaluation(
1538+
optimizer=self,
1539+
candidates=batch,
1540+
code_context=code_context,
1541+
original_code_baseline=original_code_baseline,
1542+
original_helper_code=original_helper_code,
1543+
file_path_to_helper_classes=file_path_to_helper_classes,
1544+
eval_ctx=eval_ctx,
1545+
exp_type=exp_type,
1546+
pool_size=pool_size,
1547+
)
1548+
1549+
# Process results and dispatch refinement/repair futures immediately
1550+
batch_refiner_candidates: list[AIServiceBatchRefinerCandidate] = []
1551+
for (candidate_node, _idx, _), (_, run_result) in zip(batch, results):
1552+
candidate = candidate_node.candidate
1553+
1554+
if run_result is None or not is_successful(run_result):
1555+
eval_ctx.record_failed_candidate(candidate.optimization_id)
1556+
if run_result is not None and isinstance(run_result, Failure):
1557+
eval_failure = run_result.failure()
1558+
repair_future = self._dispatch_repair_if_possible(
1559+
candidate,
1560+
eval_ctx,
1561+
code_context,
1562+
exp_type,
1563+
ai_service_client,
1564+
test_diffs=eval_failure.diffs,
1565+
)
1566+
if repair_future is not None:
1567+
self.future_all_code_repair.append(repair_future)
1568+
continue
1569+
1570+
candidate_result = run_result.unwrap()
1571+
perf_gain = performance_gain(
1572+
original_runtime_ns=original_code_baseline.runtime,
1573+
optimized_runtime_ns=candidate_result.best_test_runtime,
1574+
)
1575+
eval_ctx.record_successful_candidate(
1576+
candidate.optimization_id, candidate_result.best_test_runtime, perf_gain
1577+
)
1578+
1579+
is_successful_opt = speedup_critic(
1580+
candidate_result,
1581+
original_code_baseline.runtime,
1582+
best_runtime_until_now=None,
1583+
original_async_throughput=original_code_baseline.async_throughput,
1584+
best_throughput_until_now=None,
1585+
original_concurrency_metrics=original_code_baseline.concurrency_metrics,
1586+
best_concurrency_ratio_until_now=None,
1587+
) and quantity_of_tests_critic(candidate_result)
1588+
1589+
if is_successful_opt:
1590+
empty_lp = {"timings": {}, "unit": 0, "str_out": ""}
1591+
best_optimization = BestOptimization(
1592+
candidate=candidate,
1593+
helper_functions=code_context.helper_functions,
1594+
code_context=code_context,
1595+
runtime=candidate_result.best_test_runtime,
1596+
line_profiler_test_results=empty_lp,
1597+
winning_behavior_test_results=candidate_result.behavior_test_results,
1598+
winning_benchmarking_test_results=candidate_result.benchmarking_test_results,
1599+
winning_replay_benchmarking_test_results=None,
1600+
async_throughput=candidate_result.async_throughput,
1601+
concurrency_metrics=candidate_result.concurrency_metrics,
1602+
)
1603+
eval_ctx.valid_optimizations.append(best_optimization)
1604+
1605+
batch_refiner_candidates.append(
1606+
AIServiceBatchRefinerCandidate(
1607+
optimization_id=candidate.optimization_id,
1608+
optimized_source_code=candidate.source_code.markdown,
1609+
optimized_explanation=candidate.explanation,
1610+
optimized_code_runtime=candidate_result.best_test_runtime,
1611+
original_code_runtime=original_code_baseline.runtime,
1612+
speedup=f"{int(perf_gain * 100)}%",
1613+
optimized_line_profiler_results="",
1614+
)
1615+
)
1616+
1617+
# Dispatch refinement immediately so CandidateProcessor sees it
1618+
if batch_refiner_candidates:
1619+
self._dispatch_refinement(
1620+
batch_refiner_candidates,
1621+
code_context,
1622+
original_code_baseline,
1623+
exp_type,
1624+
function_references,
1625+
ai_service_client,
1626+
)
1627+
1628+
def _dispatch_refinement(
1629+
self,
1630+
batch_refiner_candidates: list[AIServiceBatchRefinerCandidate],
1631+
code_context: CodeOptimizationContext,
1632+
original_code_baseline: OriginalCodeBaseline,
1633+
exp_type: str,
1634+
function_references: str,
1635+
ai_service_client: AiServiceClient,
1636+
) -> None:
1637+
"""Submit refinement request to thread pool so CandidateProcessor can consume results."""
1638+
if len(batch_refiner_candidates) > 1:
1639+
future = self.executor.submit(
1640+
ai_service_client.optimize_code_refinement_batch,
1641+
original_source_code=code_context.read_writable_code.markdown,
1642+
read_only_dependency_code=code_context.read_only_context_code,
1643+
original_line_profiler_results=original_code_baseline.line_profile_results["str_out"],
1644+
trace_id=self.get_trace_id(exp_type),
1645+
language=self.function_to_optimize.language,
1646+
language_version=self.language_support.language_version,
1647+
function_references=function_references,
1648+
candidates=batch_refiner_candidates,
1649+
rerun_trace_id=self.rerun_trace_id,
1650+
)
1651+
else:
1652+
c = batch_refiner_candidates[0]
1653+
future = self.executor.submit(
1654+
ai_service_client.optimize_code_refinement,
1655+
request=[
1656+
AIServiceRefinerRequest(
1657+
optimization_id=c.optimization_id,
1658+
original_source_code=code_context.read_writable_code.markdown,
1659+
read_only_dependency_code=code_context.read_only_context_code,
1660+
original_code_runtime=c.original_code_runtime,
1661+
optimized_source_code=c.optimized_source_code,
1662+
optimized_explanation=c.optimized_explanation,
1663+
optimized_code_runtime=c.optimized_code_runtime,
1664+
speedup=c.speedup,
1665+
trace_id=self.get_trace_id(exp_type),
1666+
original_line_profiler_results=original_code_baseline.line_profile_results["str_out"],
1667+
optimized_line_profiler_results=c.optimized_line_profiler_results,
1668+
function_references=function_references,
1669+
language=self.function_to_optimize.language,
1670+
language_version=self.language_support.language_version,
1671+
)
1672+
],
1673+
rerun_trace_id=self.rerun_trace_id,
1674+
)
1675+
self.future_all_refinements.append(future)
1676+
1677+
def _dispatch_repair_if_possible(
1678+
self,
1679+
candidate: OptimizedCandidate,
1680+
eval_ctx: CandidateEvaluationContext,
1681+
code_context: CodeOptimizationContext,
1682+
exp_type: str,
1683+
ai_service_client: AiServiceClient,
1684+
test_diffs: list[TestDiff] | None = None,
1685+
) -> concurrent.futures.Future | None:
1686+
"""Submit a code repair request if the candidate is eligible."""
1687+
max_repairs = get_effort_value(EffortKeys.MAX_CODE_REPAIRS_PER_TRACE, self.effort)
1688+
if self.repair_counter >= max_repairs:
1689+
return None
1690+
1691+
successful_candidates_count = sum(1 for is_correct in eval_ctx.is_correct.values() if is_correct)
1692+
if successful_candidates_count >= MIN_CORRECT_CANDIDATES:
1693+
return None
1694+
1695+
if candidate.source not in (OptimizedCandidateSource.OPTIMIZE, OptimizedCandidateSource.OPTIMIZE_LP):
1696+
return None
1697+
1698+
self.repair_counter += 1
1699+
request = AIServiceCodeRepairRequest(
1700+
optimization_id=candidate.optimization_id,
1701+
original_source_code=code_context.read_writable_code.markdown,
1702+
modified_source_code=candidate.source_code.markdown,
1703+
test_diffs=test_diffs or [],
1704+
trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
1705+
language=self.function_to_optimize.language,
1706+
)
1707+
return self.executor.submit(ai_service_client.code_repair, request=request, rerun_trace_id=self.rerun_trace_id)
1708+
14631709
def call_adaptive_optimize(
14641710
self,
14651711
trace_id: str,

0 commit comments

Comments
 (0)