Skip to content

Commit ace7524

Browse files
refactor the candidate processor
1 parent def5caf commit ace7524

2 files changed

Lines changed: 58 additions & 81 deletions

File tree

codeflash/models/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class AIServiceRefinerRequest:
5454
class AdaptiveOptimizedCandidate:
5555
optimization_id: str
5656
source_code: str
57+
# TODO: introduce repair explanation for code repair candidates to help the llm understand the full process
5758
explanation: str
5859
source: OptimizedCandidateSource
5960
speedup: str

codeflash/optimization/function_optimizer.py

Lines changed: 57 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import uuid
1010
from collections import defaultdict
1111
from pathlib import Path
12-
from typing import TYPE_CHECKING
12+
from typing import TYPE_CHECKING, Callable
1313

1414
import libcst as cst
1515
from rich.console import Group
@@ -220,32 +220,65 @@ def get_next_candidate(self) -> CandidateNode | None:
220220

221221
def _handle_empty_queue(self) -> CandidateNode | None:
222222
"""Handle empty queue by checking for pending async results."""
223-
# TODO: Many duplicates here for processing functions, create a single function and set the priority of each optimization source
224223
if not self.line_profiler_done:
225-
return self._process_line_profiler_results()
224+
return self._process_candidates(
225+
[self.future_line_profile_results],
226+
"all candidates processed, await candidates from line profiler",
227+
"Added results from line profiler to candidates, total candidates now: {1}",
228+
lambda: setattr(self, "line_profiler_done", True),
229+
)
226230
if len(self.future_all_code_repair) > 0:
227-
return self._process_code_repair()
231+
return self._process_candidates(
232+
self.future_all_code_repair,
233+
"Repairing {0} candidates",
234+
"Added {0} candidates from repair, total candidates now: {1}",
235+
lambda: self.future_all_code_repair.clear(),
236+
)
228237
if self.line_profiler_done and not self.refinement_done:
229238
return self._process_refinement_results()
230239
if len(self.future_adaptive_optimizations) > 0:
231-
return self._process_adaptive_optimizations()
240+
return self._process_candidates(
241+
self.future_adaptive_optimizations,
242+
"Applying adaptive optimizations to {0} candidates",
243+
"Added {0} candidates from adaptive optimization, total candidates now: {1}",
244+
lambda: self.future_adaptive_optimizations.clear(),
245+
)
232246
return None # All done
233247

234-
def _process_line_profiler_results(self) -> CandidateNode | None:
235-
"""Process line profiler results and add to queue."""
236-
logger.debug("all candidates processed, await candidates from line profiler")
237-
concurrent.futures.wait([self.future_line_profile_results])
238-
line_profile_results = self.future_line_profile_results.result()
248+
def _process_candidates(
249+
self,
250+
future_candidates: list[concurrent.futures.Future],
251+
loading_msg: str,
252+
success_msg: str,
253+
callback: Callable[[], None],
254+
) -> CandidateNode | None:
255+
if len(future_candidates) == 0:
256+
return None
257+
with progress_bar(
258+
loading_msg.format(len(future_candidates)), transient=True, revert_to_print=bool(get_pr_number())
259+
):
260+
concurrent.futures.wait(future_candidates)
261+
candidates: list[OptimizedCandidate] = []
262+
for future_c in future_candidates:
263+
candidate_result = future_c.result()
264+
if not candidate_result:
265+
continue
239266

240-
for candidate in line_profile_results:
241-
self.forest.add(candidate)
242-
self.candidate_queue.put(candidate)
267+
if isinstance(candidate_result, list):
268+
candidates.extend(candidate_result)
269+
else:
270+
candidates.append(candidate_result)
243271

244-
self.candidate_len += len(line_profile_results)
245-
logger.info(f"Added results from line profiler to candidates, total candidates now: {self.candidate_len}")
246-
self.line_profiler_done = True
272+
for candidate in candidates:
273+
self.forest.add(candidate)
274+
self.candidate_queue.put(candidate)
275+
self.candidate_len += 1
247276

248-
return self.get_next_candidate()
277+
if len(candidates) > 0:
278+
logger.info(success_msg.format(len(candidates), self.candidate_len))
279+
280+
callback()
281+
return self.get_next_candidate()
249282

250283
def refine_optimizations(self, request: list[AIServiceRefinerRequest]) -> concurrent.futures.Future:
251284
return self.executor.submit(self.ai_service_client.optimize_python_code_refinement, request=request)
@@ -284,70 +317,12 @@ def _process_refinement_results(self) -> CandidateNode | None:
284317
# Track total refinement calls made
285318
self.refinement_calls_count = refinement_call_index
286319

287-
if future_refinements:
288-
logger.info("loading|Refining generated code for improved quality and performance...")
289-
290-
concurrent.futures.wait(future_refinements)
291-
refinement_response = []
292-
293-
for f in future_refinements:
294-
possible_refinement = f.result()
295-
if len(possible_refinement) > 0:
296-
refinement_response.append(possible_refinement[0])
297-
298-
for candidate in refinement_response:
299-
self.forest.add(candidate)
300-
self.candidate_queue.put(candidate)
301-
302-
self.candidate_len += len(refinement_response)
303-
if len(refinement_response) > 0:
304-
logger.info(
305-
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {self.candidate_len}"
306-
)
307-
console.rule()
308-
self.refinement_done = True
309-
310-
return self.get_next_candidate()
311-
312-
def _process_code_repair(self) -> CandidateNode | None:
313-
logger.info(f"loading|Repairing {len(self.future_all_code_repair)} candidates")
314-
concurrent.futures.wait(self.future_all_code_repair)
315-
candidates_added = 0
316-
for future_code_repair in self.future_all_code_repair:
317-
possible_code_repair = future_code_repair.result()
318-
if possible_code_repair:
319-
self.forest.add(possible_code_repair)
320-
self.candidate_queue.put(possible_code_repair)
321-
self.candidate_len += 1
322-
candidates_added += 1
323-
324-
if candidates_added > 0:
325-
logger.info(
326-
f"Added {candidates_added} candidates from code repair, total candidates now: {self.candidate_len}"
327-
)
328-
self.future_all_code_repair.clear()
329-
330-
return self.get_next_candidate()
331-
332-
def _process_adaptive_optimizations(self) -> CandidateNode | None:
333-
logger.info(f"loading|Applying adaptive optimizations to {len(self.future_adaptive_optimizations)} candidates")
334-
concurrent.futures.wait(self.future_adaptive_optimizations)
335-
candidates_added = 0
336-
for future_adaptive_optimization in self.future_adaptive_optimizations:
337-
possible_adaptive_optimization = future_adaptive_optimization.result()
338-
if possible_adaptive_optimization:
339-
self.forest.add(possible_adaptive_optimization)
340-
self.candidate_queue.put(possible_adaptive_optimization)
341-
self.candidate_len += 1
342-
candidates_added += 1
343-
344-
if candidates_added > 0:
345-
logger.info(
346-
f"Added {candidates_added} candidates from adaptive optimizations, total candidates now: {self.candidate_len}"
347-
)
348-
self.future_adaptive_optimizations.clear()
349-
350-
return self.get_next_candidate()
320+
return self._process_candidates(
321+
future_refinements,
322+
"Refining generated code for improved quality and performance...",
323+
"Added {0} candidates from refinement, total candidates now: {1}",
324+
lambda: setattr(self, "refinement_done", True),
325+
)
351326

352327
def is_done(self) -> bool:
353328
"""Check if processing is complete."""
@@ -868,6 +843,7 @@ def process_single_candidate(
868843

869844
logger.info(f"h3|Optimization candidate {candidate_index}/{total_candidates}:")
870845
candidate = candidate_node.candidate
846+
print(f" {' -> '.join([c.source for c in candidate_node.path_to_root()])}")
871847
code_print(
872848
candidate.source_code.flat,
873849
file_name=f"candidate_{candidate_index}.py",

0 commit comments

Comments
 (0)