|
48 | 48 | INDIVIDUAL_TESTCASE_TIMEOUT, |
49 | 49 | MAX_ADAPTIVE_OPTIMIZATIONS_PER_TRACE, |
50 | 50 | MAX_REPAIRS_PER_TRACE, |
| 51 | + MIN_CORRECT_CANDIDATES, |
51 | 52 | N_TESTS_TO_GENERATE_EFFECTIVE, |
52 | 53 | REFINE_ALL_THRESHOLD, |
53 | 54 | REFINED_CANDIDATE_RANKING_WEIGHTS, |
@@ -887,6 +888,7 @@ def process_single_candidate( |
887 | 888 | baseline_results=original_code_baseline, |
888 | 889 | original_helper_code=original_helper_code, |
889 | 890 | file_path_to_helper_classes=file_path_to_helper_classes, |
| 891 | + eval_ctx=eval_ctx, |
890 | 892 | code_context=code_context, |
891 | 893 | candidate=candidate, |
892 | 894 | exp_type=exp_type, |
@@ -2045,13 +2047,20 @@ def repair_if_possible( |
2045 | 2047 | self, |
2046 | 2048 | candidate: OptimizedCandidate, |
2047 | 2049 | diffs: list[TestDiff], |
| 2050 | + eval_ctx: CandidateEvaluationContext, |
2048 | 2051 | code_context: CodeOptimizationContext, |
2049 | 2052 | test_results_count: int, |
2050 | 2053 | exp_type: str, |
2051 | 2054 | ) -> None: |
2052 | 2055 | if self.repair_counter >= MAX_REPAIRS_PER_TRACE: |
2053 | 2056 | logger.debug(f"Repair counter reached {MAX_REPAIRS_PER_TRACE}, skipping repair") |
2054 | 2057 | return |
| 2058 | + |
| 2059 | + successful_candidates_count = sum(1 for is_correct in eval_ctx.is_correct.values() if is_correct) |
| 2060 | + if successful_candidates_count >= MIN_CORRECT_CANDIDATES: |
| 2061 | + logger.debug(f"{successful_candidates_count} of the candidates were correct, no need to repair") |
| 2062 | + return |
| 2063 | + |
2055 | 2064 | if candidate.source not in (OptimizedCandidateSource.OPTIMIZE, OptimizedCandidateSource.OPTIMIZE_LP): |
2056 | 2065 | # only repair the first pass of the candidates for now |
2057 | 2066 | logger.debug(f"Candidate is a result of {candidate.source.value}, skipping repair") |
@@ -2089,6 +2098,7 @@ def run_optimized_candidate( |
2089 | 2098 | baseline_results: OriginalCodeBaseline, |
2090 | 2099 | original_helper_code: dict[Path, str], |
2091 | 2100 | file_path_to_helper_classes: dict[Path, set[str]], |
| 2101 | + eval_ctx: CandidateEvaluationContext, |
2092 | 2102 | code_context: CodeOptimizationContext, |
2093 | 2103 | candidate: OptimizedCandidate, |
2094 | 2104 | exp_type: str, |
@@ -2144,7 +2154,9 @@ def run_optimized_candidate( |
2144 | 2154 | logger.info("h3|Test results matched ✅") |
2145 | 2155 | console.rule() |
2146 | 2156 | else: |
2147 | | - self.repair_if_possible(candidate, diffs, code_context, len(candidate_behavior_results), exp_type) |
| 2157 | + self.repair_if_possible( |
| 2158 | + candidate, diffs, eval_ctx, code_context, len(candidate_behavior_results), exp_type |
| 2159 | + ) |
2148 | 2160 | return self.get_results_not_matched_error() |
2149 | 2161 |
|
2150 | 2162 | logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...") |
|
0 commit comments