|
9 | 9 | import uuid |
10 | 10 | from collections import defaultdict |
11 | 11 | from pathlib import Path |
12 | | -from typing import TYPE_CHECKING |
| 12 | +from typing import TYPE_CHECKING, Callable |
13 | 13 |
|
14 | 14 | import libcst as cst |
15 | 15 | from rich.console import Group |
@@ -220,32 +220,65 @@ def get_next_candidate(self) -> CandidateNode | None: |
220 | 220 |
|
221 | 221 | def _handle_empty_queue(self) -> CandidateNode | None: |
222 | 222 | """Handle empty queue by checking for pending async results.""" |
223 | | - # TODO: Many duplicates here for processing functions, create a single function and set the priority of each optimization source |
224 | 223 | if not self.line_profiler_done: |
225 | | - return self._process_line_profiler_results() |
| 224 | + return self._process_candidates( |
| 225 | + [self.future_line_profile_results], |
| 226 | + "all candidates processed, await candidates from line profiler", |
| 227 | + "Added results from line profiler to candidates, total candidates now: {1}", |
| 228 | + lambda: setattr(self, "line_profiler_done", True), |
| 229 | + ) |
226 | 230 | if len(self.future_all_code_repair) > 0: |
227 | | - return self._process_code_repair() |
| 231 | + return self._process_candidates( |
| 232 | + self.future_all_code_repair, |
| 233 | + "Repairing {0} candidates", |
| 234 | + "Added {0} candidates from repair, total candidates now: {1}", |
| 235 | + lambda: self.future_all_code_repair.clear(), |
| 236 | + ) |
228 | 237 | if self.line_profiler_done and not self.refinement_done: |
229 | 238 | return self._process_refinement_results() |
230 | 239 | if len(self.future_adaptive_optimizations) > 0: |
231 | | - return self._process_adaptive_optimizations() |
| 240 | + return self._process_candidates( |
| 241 | + self.future_adaptive_optimizations, |
| 242 | + "Applying adaptive optimizations to {0} candidates", |
| 243 | + "Added {0} candidates from adaptive optimization, total candidates now: {1}", |
| 244 | + lambda: self.future_adaptive_optimizations.clear(), |
| 245 | + ) |
232 | 246 | return None # All done |
233 | 247 |
|
234 | | - def _process_line_profiler_results(self) -> CandidateNode | None: |
235 | | - """Process line profiler results and add to queue.""" |
236 | | - logger.debug("all candidates processed, await candidates from line profiler") |
237 | | - concurrent.futures.wait([self.future_line_profile_results]) |
238 | | - line_profile_results = self.future_line_profile_results.result() |
| 248 | + def _process_candidates( |
| 249 | + self, |
| 250 | + future_candidates: list[concurrent.futures.Future], |
| 251 | + loading_msg: str, |
| 252 | + success_msg: str, |
| 253 | + callback: Callable[[], None], |
| 254 | + ) -> CandidateNode | None: |
| 255 | + if len(future_candidates) == 0: |
| 256 | + return None |
| 257 | + with progress_bar( |
| 258 | + loading_msg.format(len(future_candidates)), transient=True, revert_to_print=bool(get_pr_number()) |
| 259 | + ): |
| 260 | + concurrent.futures.wait(future_candidates) |
| 261 | + candidates: list[OptimizedCandidate] = [] |
| 262 | + for future_c in future_candidates: |
| 263 | + candidate_result = future_c.result() |
| 264 | + if not candidate_result: |
| 265 | + continue |
239 | 266 |
|
240 | | - for candidate in line_profile_results: |
241 | | - self.forest.add(candidate) |
242 | | - self.candidate_queue.put(candidate) |
| 267 | + if isinstance(candidate_result, list): |
| 268 | + candidates.extend(candidate_result) |
| 269 | + else: |
| 270 | + candidates.append(candidate_result) |
243 | 271 |
|
244 | | - self.candidate_len += len(line_profile_results) |
245 | | - logger.info(f"Added results from line profiler to candidates, total candidates now: {self.candidate_len}") |
246 | | - self.line_profiler_done = True |
| 272 | + for candidate in candidates: |
| 273 | + self.forest.add(candidate) |
| 274 | + self.candidate_queue.put(candidate) |
| 275 | + self.candidate_len += 1 |
247 | 276 |
|
248 | | - return self.get_next_candidate() |
| 277 | + if len(candidates) > 0: |
| 278 | + logger.info(success_msg.format(len(candidates), self.candidate_len)) |
| 279 | + |
| 280 | + callback() |
| 281 | + return self.get_next_candidate() |
249 | 282 |
|
250 | 283 | def refine_optimizations(self, request: list[AIServiceRefinerRequest]) -> concurrent.futures.Future: |
251 | 284 | return self.executor.submit(self.ai_service_client.optimize_python_code_refinement, request=request) |
@@ -284,70 +317,12 @@ def _process_refinement_results(self) -> CandidateNode | None: |
284 | 317 | # Track total refinement calls made |
285 | 318 | self.refinement_calls_count = refinement_call_index |
286 | 319 |
|
287 | | - if future_refinements: |
288 | | - logger.info("loading|Refining generated code for improved quality and performance...") |
289 | | - |
290 | | - concurrent.futures.wait(future_refinements) |
291 | | - refinement_response = [] |
292 | | - |
293 | | - for f in future_refinements: |
294 | | - possible_refinement = f.result() |
295 | | - if len(possible_refinement) > 0: |
296 | | - refinement_response.append(possible_refinement[0]) |
297 | | - |
298 | | - for candidate in refinement_response: |
299 | | - self.forest.add(candidate) |
300 | | - self.candidate_queue.put(candidate) |
301 | | - |
302 | | - self.candidate_len += len(refinement_response) |
303 | | - if len(refinement_response) > 0: |
304 | | - logger.info( |
305 | | - f"Added {len(refinement_response)} candidates from refinement, total candidates now: {self.candidate_len}" |
306 | | - ) |
307 | | - console.rule() |
308 | | - self.refinement_done = True |
309 | | - |
310 | | - return self.get_next_candidate() |
311 | | - |
312 | | - def _process_code_repair(self) -> CandidateNode | None: |
313 | | - logger.info(f"loading|Repairing {len(self.future_all_code_repair)} candidates") |
314 | | - concurrent.futures.wait(self.future_all_code_repair) |
315 | | - candidates_added = 0 |
316 | | - for future_code_repair in self.future_all_code_repair: |
317 | | - possible_code_repair = future_code_repair.result() |
318 | | - if possible_code_repair: |
319 | | - self.forest.add(possible_code_repair) |
320 | | - self.candidate_queue.put(possible_code_repair) |
321 | | - self.candidate_len += 1 |
322 | | - candidates_added += 1 |
323 | | - |
324 | | - if candidates_added > 0: |
325 | | - logger.info( |
326 | | - f"Added {candidates_added} candidates from code repair, total candidates now: {self.candidate_len}" |
327 | | - ) |
328 | | - self.future_all_code_repair.clear() |
329 | | - |
330 | | - return self.get_next_candidate() |
331 | | - |
332 | | - def _process_adaptive_optimizations(self) -> CandidateNode | None: |
333 | | - logger.info(f"loading|Applying adaptive optimizations to {len(self.future_adaptive_optimizations)} candidates") |
334 | | - concurrent.futures.wait(self.future_adaptive_optimizations) |
335 | | - candidates_added = 0 |
336 | | - for future_adaptive_optimization in self.future_adaptive_optimizations: |
337 | | - possible_adaptive_optimization = future_adaptive_optimization.result() |
338 | | - if possible_adaptive_optimization: |
339 | | - self.forest.add(possible_adaptive_optimization) |
340 | | - self.candidate_queue.put(possible_adaptive_optimization) |
341 | | - self.candidate_len += 1 |
342 | | - candidates_added += 1 |
343 | | - |
344 | | - if candidates_added > 0: |
345 | | - logger.info( |
346 | | - f"Added {candidates_added} candidates from adaptive optimizations, total candidates now: {self.candidate_len}" |
347 | | - ) |
348 | | - self.future_adaptive_optimizations.clear() |
349 | | - |
350 | | - return self.get_next_candidate() |
| 320 | + return self._process_candidates( |
| 321 | + future_refinements, |
| 322 | + "Refining generated code for improved quality and performance...", |
| 323 | + "Added {0} candidates from refinement, total candidates now: {1}", |
| 324 | + lambda: setattr(self, "refinement_done", True), |
| 325 | + ) |
351 | 326 |
|
352 | 327 | def is_done(self) -> bool: |
353 | 328 | """Check if processing is complete.""" |
@@ -868,6 +843,7 @@ def process_single_candidate( |
868 | 843 |
|
869 | 844 | logger.info(f"h3|Optimization candidate {candidate_index}/{total_candidates}:") |
870 | 845 | candidate = candidate_node.candidate |
| 846 | + print(f" {' -> '.join([c.source for c in candidate_node.path_to_root()])}") |
871 | 847 | code_print( |
872 | 848 | candidate.source_code.flat, |
873 | 849 | file_name=f"candidate_{candidate_index}.py", |
|
0 commit comments