Skip to content

Commit c09d100

Browse files
committed
feat: integrate parallel evaluator into function optimizer
Wire up the parallel evaluator in _evaluate_candidates_parallel(): - Dispatch refinement/repair immediately via ThreadPoolExecutor after each batch completes (no lazy carry-over pattern) - Pass test diffs from EvalFailure to repair requests - Add _run_line_profiler_for_winner() for post-selection line profiling - Add batch_refine endpoint to AiServiceClient
1 parent 619fd20 commit c09d100

2 files changed

Lines changed: 397 additions & 30 deletions

File tree

codeflash/api/aiservice.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
FunctionRepairInfo,
2323
OptimizationReviewResult,
2424
OptimizedCandidate,
25-
OptimizedCandidateSource,
2625
TestFileReview,
2726
)
27+
from codeflash.models.shared_types import OptimizedCandidateSource
2828
from codeflash.telemetry.posthog_cf import ph
2929
from codeflash.version import __version__ as codeflash_version
3030

@@ -35,6 +35,7 @@
3535
from codeflash.models.ExperimentMetadata import ExperimentMetadata
3636
from codeflash.models.models import (
3737
AIServiceAdaptiveOptimizeRequest,
38+
AIServiceBatchRefinerCandidate,
3839
AIServiceCodeRepairRequest,
3940
AIServiceRefinerRequest,
4041
)
@@ -384,6 +385,97 @@ def optimize_code_refinement(
384385
console.rule()
385386
return []
386387

388+
def optimize_code_refinement_batch(
389+
self,
390+
*,
391+
original_source_code: str,
392+
read_only_dependency_code: str,
393+
original_line_profiler_results: str,
394+
trace_id: str,
395+
language: str,
396+
language_version: str | None,
397+
function_references: str | None,
398+
candidates: list[AIServiceBatchRefinerCandidate],
399+
rerun_trace_id: str | None = None,
400+
) -> list[OptimizedCandidate]:
401+
shared_context: dict[str, Any] = {
402+
"original_source_code": original_source_code,
403+
"read_only_dependency_code": read_only_dependency_code,
404+
"original_line_profiler_results": original_line_profiler_results,
405+
"trace_id": trace_id,
406+
"language": language,
407+
"function_references": function_references,
408+
"rerun_trace_id": rerun_trace_id,
409+
}
410+
self.add_language_metadata(shared_context, language_version)
411+
412+
candidate_payloads: list[dict[str, Any]] = []
413+
for c in candidates:
414+
candidate_payloads.append(
415+
{
416+
"optimization_id": c.optimization_id,
417+
"optimized_source_code": c.optimized_source_code,
418+
"optimized_explanation": c.optimized_explanation,
419+
"optimized_code_runtime": humanize_runtime(c.optimized_code_runtime),
420+
"original_code_runtime": humanize_runtime(c.original_code_runtime),
421+
"speedup": c.speedup,
422+
"optimized_line_profiler_results": c.optimized_line_profiler_results,
423+
"call_sequence": self.get_next_sequence(),
424+
}
425+
)
426+
427+
payload: dict[str, Any] = {"shared_context": shared_context, "candidates": candidate_payloads}
428+
429+
try:
430+
response = self.make_ai_service_request("/batch_refinement", payload=payload, timeout=self.timeout)
431+
except requests.exceptions.RequestException as e:
432+
logger.exception(f"Error generating batch optimization refinements: {e}")
433+
ph("cli-optimize-error-caught", {"error": str(e)})
434+
return []
435+
436+
if response.status_code == 404:
437+
return self._fallback_to_sequential_refinement(
438+
shared_context=shared_context, candidates=candidates, rerun_trace_id=rerun_trace_id
439+
)
440+
441+
if response.status_code == 200:
442+
refined_optimizations = response.json()["refinements"]
443+
return self._get_valid_candidates(refined_optimizations, OptimizedCandidateSource.REFINE)
444+
445+
self.log_error_response(response, "generating batch optimized candidates", "cli-optimize-error-response")
446+
console.rule()
447+
return []
448+
449+
def _fallback_to_sequential_refinement(
450+
self,
451+
*,
452+
shared_context: dict[str, Any],
453+
candidates: list[AIServiceBatchRefinerCandidate],
454+
rerun_trace_id: str | None,
455+
) -> list[OptimizedCandidate]:
456+
from codeflash.models.models import AIServiceRefinerRequest
457+
458+
requests_list = [
459+
AIServiceRefinerRequest(
460+
optimization_id=c.optimization_id,
461+
original_source_code=shared_context["original_source_code"],
462+
read_only_dependency_code=shared_context["read_only_dependency_code"],
463+
original_code_runtime=c.original_code_runtime,
464+
optimized_source_code=c.optimized_source_code,
465+
optimized_explanation=c.optimized_explanation,
466+
optimized_code_runtime=c.optimized_code_runtime,
467+
speedup=c.speedup,
468+
trace_id=shared_context["trace_id"],
469+
original_line_profiler_results=shared_context["original_line_profiler_results"],
470+
optimized_line_profiler_results=c.optimized_line_profiler_results,
471+
function_references=shared_context.get("function_references"),
472+
language=shared_context["language"],
473+
language_version=shared_context.get("language_version"),
474+
)
475+
for c in candidates
476+
]
477+
return self.optimize_code_refinement(requests_list, rerun_trace_id=rerun_trace_id)
478+
387479
def code_repair(
388480
self, request: AIServiceCodeRepairRequest, rerun_trace_id: str | None = None
389481
) -> OptimizedCandidate | None:

0 commit comments

Comments
 (0)