|
22 | 22 | FunctionRepairInfo, |
23 | 23 | OptimizationReviewResult, |
24 | 24 | OptimizedCandidate, |
25 | | - OptimizedCandidateSource, |
26 | 25 | TestFileReview, |
27 | 26 | ) |
| 27 | +from codeflash.models.shared_types import OptimizedCandidateSource |
28 | 28 | from codeflash.telemetry.posthog_cf import ph |
29 | 29 | from codeflash.version import __version__ as codeflash_version |
30 | 30 |
|
|
35 | 35 | from codeflash.models.ExperimentMetadata import ExperimentMetadata |
36 | 36 | from codeflash.models.models import ( |
37 | 37 | AIServiceAdaptiveOptimizeRequest, |
| 38 | + AIServiceBatchRefinerCandidate, |
38 | 39 | AIServiceCodeRepairRequest, |
39 | 40 | AIServiceRefinerRequest, |
40 | 41 | ) |
@@ -384,6 +385,97 @@ def optimize_code_refinement( |
384 | 385 | console.rule() |
385 | 386 | return [] |
386 | 387 |
|
| 388 | + def optimize_code_refinement_batch( |
| 389 | + self, |
| 390 | + *, |
| 391 | + original_source_code: str, |
| 392 | + read_only_dependency_code: str, |
| 393 | + original_line_profiler_results: str, |
| 394 | + trace_id: str, |
| 395 | + language: str, |
| 396 | + language_version: str | None, |
| 397 | + function_references: str | None, |
| 398 | + candidates: list[AIServiceBatchRefinerCandidate], |
| 399 | + rerun_trace_id: str | None = None, |
| 400 | + ) -> list[OptimizedCandidate]: |
| 401 | + shared_context: dict[str, Any] = { |
| 402 | + "original_source_code": original_source_code, |
| 403 | + "read_only_dependency_code": read_only_dependency_code, |
| 404 | + "original_line_profiler_results": original_line_profiler_results, |
| 405 | + "trace_id": trace_id, |
| 406 | + "language": language, |
| 407 | + "function_references": function_references, |
| 408 | + "rerun_trace_id": rerun_trace_id, |
| 409 | + } |
| 410 | + self.add_language_metadata(shared_context, language_version) |
| 411 | + |
| 412 | + candidate_payloads: list[dict[str, Any]] = [] |
| 413 | + for c in candidates: |
| 414 | + candidate_payloads.append( |
| 415 | + { |
| 416 | + "optimization_id": c.optimization_id, |
| 417 | + "optimized_source_code": c.optimized_source_code, |
| 418 | + "optimized_explanation": c.optimized_explanation, |
| 419 | + "optimized_code_runtime": humanize_runtime(c.optimized_code_runtime), |
| 420 | + "original_code_runtime": humanize_runtime(c.original_code_runtime), |
| 421 | + "speedup": c.speedup, |
| 422 | + "optimized_line_profiler_results": c.optimized_line_profiler_results, |
| 423 | + "call_sequence": self.get_next_sequence(), |
| 424 | + } |
| 425 | + ) |
| 426 | + |
| 427 | + payload: dict[str, Any] = {"shared_context": shared_context, "candidates": candidate_payloads} |
| 428 | + |
| 429 | + try: |
| 430 | + response = self.make_ai_service_request("/batch_refinement", payload=payload, timeout=self.timeout) |
| 431 | + except requests.exceptions.RequestException as e: |
| 432 | + logger.exception(f"Error generating batch optimization refinements: {e}") |
| 433 | + ph("cli-optimize-error-caught", {"error": str(e)}) |
| 434 | + return [] |
| 435 | + |
| 436 | + if response.status_code == 404: |
| 437 | + return self._fallback_to_sequential_refinement( |
| 438 | + shared_context=shared_context, candidates=candidates, rerun_trace_id=rerun_trace_id |
| 439 | + ) |
| 440 | + |
| 441 | + if response.status_code == 200: |
| 442 | + refined_optimizations = response.json()["refinements"] |
| 443 | + return self._get_valid_candidates(refined_optimizations, OptimizedCandidateSource.REFINE) |
| 444 | + |
| 445 | + self.log_error_response(response, "generating batch optimized candidates", "cli-optimize-error-response") |
| 446 | + console.rule() |
| 447 | + return [] |
| 448 | + |
| 449 | + def _fallback_to_sequential_refinement( |
| 450 | + self, |
| 451 | + *, |
| 452 | + shared_context: dict[str, Any], |
| 453 | + candidates: list[AIServiceBatchRefinerCandidate], |
| 454 | + rerun_trace_id: str | None, |
| 455 | + ) -> list[OptimizedCandidate]: |
| 456 | + from codeflash.models.models import AIServiceRefinerRequest |
| 457 | + |
| 458 | + requests_list = [ |
| 459 | + AIServiceRefinerRequest( |
| 460 | + optimization_id=c.optimization_id, |
| 461 | + original_source_code=shared_context["original_source_code"], |
| 462 | + read_only_dependency_code=shared_context["read_only_dependency_code"], |
| 463 | + original_code_runtime=c.original_code_runtime, |
| 464 | + optimized_source_code=c.optimized_source_code, |
| 465 | + optimized_explanation=c.optimized_explanation, |
| 466 | + optimized_code_runtime=c.optimized_code_runtime, |
| 467 | + speedup=c.speedup, |
| 468 | + trace_id=shared_context["trace_id"], |
| 469 | + original_line_profiler_results=shared_context["original_line_profiler_results"], |
| 470 | + optimized_line_profiler_results=c.optimized_line_profiler_results, |
| 471 | + function_references=shared_context.get("function_references"), |
| 472 | + language=shared_context["language"], |
| 473 | + language_version=shared_context.get("language_version"), |
| 474 | + ) |
| 475 | + for c in candidates |
| 476 | + ] |
| 477 | + return self.optimize_code_refinement(requests_list, rerun_trace_id=rerun_trace_id) |
| 478 | + |
387 | 479 | def code_repair( |
388 | 480 | self, request: AIServiceCodeRepairRequest, rerun_trace_id: str | None = None |
389 | 481 | ) -> OptimizedCandidate | None: |
|
0 commit comments