Skip to content

Commit 0d4b854

Browse files
Merge pull request #1003 from codeflash-ai/exp/adaptive-optimization
[FEAT][EXP] Adaptive optimizations (CF-831)
2 parents a301c22 + 250be17 commit 0d4b854

4 files changed

Lines changed: 250 additions & 68 deletions

File tree

codeflash/api/aiservice.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,11 @@
3232

3333
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
3434
from codeflash.models.ExperimentMetadata import ExperimentMetadata
35-
from codeflash.models.models import AIServiceCodeRepairRequest, AIServiceRefinerRequest
35+
from codeflash.models.models import (
36+
AIServiceAdaptiveOptimizeRequest,
37+
AIServiceCodeRepairRequest,
38+
AIServiceRefinerRequest,
39+
)
3640
from codeflash.result.explanation import Explanation
3741

3842

@@ -249,6 +253,38 @@ def optimize_python_code_line_profiler( # noqa: D417
249253
console.rule()
250254
return []
251255

256+
def adaptive_optimize(self, request: AIServiceAdaptiveOptimizeRequest) -> OptimizedCandidate | None:
257+
try:
258+
payload = {
259+
"trace_id": request.trace_id,
260+
"original_source_code": request.original_source_code,
261+
"candidates": request.candidates,
262+
}
263+
response = self.make_ai_service_request("/adaptive_optimize", payload=payload, timeout=120)
264+
except (requests.exceptions.RequestException, TypeError) as e:
265+
logger.exception(f"Error generating adaptive optimized candidates: {e}")
266+
ph("cli-optimize-error-caught", {"error": str(e)})
267+
return None
268+
269+
if response.status_code == 200:
270+
fixed_optimization = response.json()
271+
console.rule()
272+
273+
valid_candidates = self._get_valid_candidates([fixed_optimization], OptimizedCandidateSource.ADAPTIVE)
274+
if not valid_candidates:
275+
logger.error("Adaptive optimization failed to generate a valid candidate.")
276+
return None
277+
278+
return valid_candidates[0]
279+
280+
try:
281+
error = response.json()["error"]
282+
except Exception:
283+
error = response.text
284+
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
285+
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
286+
return None
287+
252288
def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]) -> list[OptimizedCandidate]:
253289
"""Optimize the given python code for performance by making a request to the Django endpoint.
254290

codeflash/code_utils/config_consts.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@
3535
REPAIR_UNMATCHED_PERCENTAGE_LIMIT = 0.4 # if the percentage of unmatched tests is greater than this, we won't fix it (lowering this value makes the repair more stricted)
3636
MAX_REPAIRS_PER_TRACE = 4 # maximum number of repairs we will do for each function
3737

38+
# Adaptive optimization
39+
# TODO (ali): make this configurable with effort arg once the PR is merged
40+
ADAPTIVE_OPTIMIZATION_THRESHOLD = 2 # Max adaptive optimizations per single candidate tree (for example : optimize -> refine -> adaptive -> another adaptive).
41+
MAX_ADAPTIVE_OPTIMIZATIONS_PER_TRACE = 4 # maximum number of adaptive optimizations we will do for each function (this can be 2 adaptive optimizations for 2 candidates for example)
42+
3843
MAX_N_CANDIDATES = 5
3944
MAX_N_CANDIDATES_LP = 6
4045

codeflash/models/models.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,24 @@ class AIServiceRefinerRequest:
4949
call_sequence: int | None = None
5050

5151

52+
# this should be possible to auto serialize
53+
@dataclass(frozen=True)
54+
class AdaptiveOptimizedCandidate:
55+
optimization_id: str
56+
source_code: str
57+
# TODO: introduce repair explanation for code repair candidates to help the llm understand the full process
58+
explanation: str
59+
source: OptimizedCandidateSource
60+
speedup: str
61+
62+
63+
@dataclass(frozen=True)
64+
class AIServiceAdaptiveOptimizeRequest:
65+
trace_id: str
66+
original_source_code: str
67+
candidates: list[AdaptiveOptimizedCandidate]
68+
69+
5270
class TestDiffScope(str, Enum):
5371
RETURN_VALUE = "return_value"
5472
STDOUT = "stdout"
@@ -442,6 +460,9 @@ def register_new_candidate(
442460
"diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat),
443461
}
444462

463+
def get_speedup_ratio(self, optimization_id: str) -> float | None:
464+
return self.speedup_ratios.get(optimization_id)
465+
445466

446467
@dataclass(frozen=True)
447468
class TestsInFile:
@@ -456,6 +477,7 @@ class OptimizedCandidateSource(str, Enum):
456477
OPTIMIZE_LP = "OPTIMIZE_LP"
457478
REFINE = "REFINE"
458479
REPAIR = "REPAIR"
480+
ADAPTIVE = "ADAPTIVE"
459481

460482

461483
@dataclass(frozen=True)

0 commit comments

Comments
 (0)