Skip to content

Commit 3e20a37

Browse files
more effort values
1 parent 444aab2 commit 3e20a37

2 files changed

Lines changed: 62 additions & 51 deletions

File tree

Lines changed: 42 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from enum import Enum
1+
from enum import StrEnum, auto
22

33
MAX_TEST_RUN_ITERATIONS = 5
44
INDIVIDUAL_TESTCASE_TIMEOUT = 15
@@ -13,18 +13,11 @@
1313
REPEAT_OPTIMIZATION_PROBABILITY = 0.1
1414
DEFAULT_IMPORTANCE_THRESHOLD = 0.001
1515

16-
# Refinement
17-
REFINE_ALL_THRESHOLD = 2 # when valid optimizations count is 2 or less, refine all optimizations
1816
REFINED_CANDIDATE_RANKING_WEIGHTS = (2, 1) # (runtime, diff), runtime is more important than diff by a factor of 2
19-
TOP_N_REFINEMENTS = 0.45 # top 45% of valid optimizations (based on the weighted score) are refined
2017

2118
# LSP-specific
2219
TOTAL_LOOPING_TIME_LSP = 10.0 # Kept same timing for LSP mode to avoid in increase in performance reporting
2320

24-
# Code repair
25-
REPAIR_UNMATCHED_PERCENTAGE_LIMIT = 0.4 # if the percentage of unmatched tests is greater than this, we won't fix it (lowering this value makes the repair more stricted)
26-
MAX_REPAIRS_PER_TRACE = 4 # maximum number of repairs we will do for each function
27-
2821
try:
2922
from codeflash.lsp.helpers import is_LSP_enabled
3023

@@ -37,36 +30,49 @@
3730
MAX_CONTEXT_LEN_REVIEW = 1000
3831

3932

40-
class EffortLevel(str, Enum):
41-
LOW = "low"
42-
MEDIUM = "medium"
43-
HIGH = "high"
33+
class EffortLevel(StrEnum):
34+
LOW = auto()
35+
MEDIUM = auto()
36+
HIGH = auto()
4437

4538

46-
class Effort:
47-
@staticmethod
48-
def get_number_of_optimizer_candidates(effort: str) -> int:
49-
if effort == EffortLevel.LOW.value:
50-
return 3
51-
if effort == EffortLevel.MEDIUM.value:
52-
return 4
53-
if effort == EffortLevel.HIGH.value:
54-
return 5
55-
msg = f"Invalid effort level: {effort}"
56-
raise ValueError(msg)
39+
class EffortKeys(StrEnum):
40+
N_OPTIMIZER_CANDIDATES = auto()
41+
N_OPTIMIZER_LP_CANDIDATES = auto()
42+
N_GENERATED_TESTS = auto()
43+
MAX_CODE_REPAIRS_PER_TRACE = auto()
44+
REPAIR_UNMATCHED_PERCENTAGE_LIMIT = auto()
45+
REFINE_ALL_THRESHOLD = auto()
46+
TOP_VALID_CANDIDATES_FOR_REFINEMENT = auto()
47+
5748

58-
@staticmethod
59-
def get_number_of_optimizer_lp_candidates(effort: str) -> int:
60-
if effort == EffortLevel.LOW.value:
61-
return 3
62-
if effort == EffortLevel.MEDIUM.value:
63-
return 5
64-
if effort == EffortLevel.HIGH.value:
65-
return 6
49+
EFFORT_VALUES: dict[str, dict[EffortLevel, any]] = {
50+
EffortKeys.N_OPTIMIZER_CANDIDATES.value: {EffortLevel.LOW: 3, EffortLevel.MEDIUM: 4, EffortLevel.HIGH: 5},
51+
EffortKeys.N_OPTIMIZER_LP_CANDIDATES.value: {EffortLevel.LOW: 3, EffortLevel.MEDIUM: 5, EffortLevel.HIGH: 6},
52+
# we don't use effort with generated tests for now
53+
EffortKeys.N_GENERATED_TESTS.value: {EffortLevel.LOW: 2, EffortLevel.MEDIUM: 2, EffortLevel.HIGH: 2},
54+
# maximum number of repairs we will do for each function
55+
EffortKeys.MAX_CODE_REPAIRS_PER_TRACE.value: {EffortLevel.LOW: 2, EffortLevel.MEDIUM: 4, EffortLevel.HIGH: 5},
56+
# if the percentage of unmatched tests is greater than this, we won't fix it (lowering this value makes the repair more stricted)
57+
# on the low effort we lower the limit to 20% to be more strict (less repairs)
58+
EffortKeys.REPAIR_UNMATCHED_PERCENTAGE_LIMIT.value: {
59+
EffortLevel.LOW: 0.2,
60+
EffortLevel.MEDIUM: 0.4,
61+
EffortLevel.HIGH: 0.5,
62+
},
63+
# when valid optimizations count is N or less, refine all optimizations
64+
EffortKeys.REFINE_ALL_THRESHOLD.value: {EffortLevel.LOW: 2, EffortLevel.MEDIUM: 3, EffortLevel.HIGH: 4},
65+
# Top valid candidates percentage for refinements
66+
EffortKeys.TOP_VALID_CANDIDATES_FOR_REFINEMENT: {EffortLevel.LOW: 2, EffortLevel.MEDIUM: 3, EffortLevel.HIGH: 4},
67+
}
68+
69+
70+
def get_effort_value(key: EffortKeys, effort: EffortLevel) -> any:
71+
key_str = key.value
72+
if key_str in EFFORT_VALUES:
73+
if effort in EFFORT_VALUES[key_str]:
74+
return EFFORT_VALUES[key_str][effort]
6675
msg = f"Invalid effort level: {effort}"
6776
raise ValueError(msg)
68-
69-
@staticmethod
70-
def get_number_of_generated_tests(effort: str) -> int: # noqa: ARG004
71-
# we don't use effort with generated tests for now
72-
return 2
77+
msg = f"Invalid key: {key_str}"
78+
raise ValueError(msg)

codeflash/optimization/function_optimizer.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,11 @@
4545
from codeflash.code_utils.config_consts import (
4646
COVERAGE_THRESHOLD,
4747
INDIVIDUAL_TESTCASE_TIMEOUT,
48-
MAX_REPAIRS_PER_TRACE,
49-
REFINE_ALL_THRESHOLD,
5048
REFINED_CANDIDATE_RANKING_WEIGHTS,
51-
REPAIR_UNMATCHED_PERCENTAGE_LIMIT,
5249
REPEAT_OPTIMIZATION_PROBABILITY,
53-
TOP_N_REFINEMENTS,
5450
TOTAL_LOOPING_TIME_EFFECTIVE,
55-
Effort,
51+
EffortKeys,
52+
get_effort_value,
5653
)
5754
from codeflash.code_utils.deduplicate_code import normalize_code
5855
from codeflash.code_utils.edit_generated_tests import (
@@ -191,8 +188,16 @@ def refine_optimizations(self, request: list[AIServiceRefinerRequest]) -> concur
191188
def _process_refinement_results(self) -> OptimizedCandidate | None:
192189
"""Process refinement results and add to queue. We generate a weighted ranking based on the runtime and diff lines and select the best (round of 45%) of valid optimizations to be refined."""
193190
future_refinements: list[concurrent.futures.Future] = []
191+
top_n_candidates = int(
192+
min(
193+
get_effort_value(EffortKeys.TOP_VALID_CANDIDATES_FOR_REFINEMENT, self.args.effort),
194+
len(self.all_refinements_data),
195+
)
196+
)
194197

195-
if len(self.all_refinements_data) <= REFINE_ALL_THRESHOLD:
198+
if top_n_candidates == len(self.all_refinements_data) or len(self.all_refinements_data) <= get_effort_value(
199+
EffortKeys.REFINE_ALL_THRESHOLD, self.args.effort
200+
):
196201
for data in self.all_refinements_data:
197202
future_refinements.append(self.refine_optimizations([data])) # noqa: PERF401
198203
else:
@@ -209,7 +214,6 @@ def _process_refinement_results(self) -> OptimizedCandidate | None:
209214
diffs_norm = normalize_by_max(diff_lens_list)
210215
# the lower the better
211216
score_dict = create_score_dictionary_from_metrics(weights, runtime_norm, diffs_norm)
212-
top_n_candidates = int((TOP_N_REFINEMENTS * len(runtimes_list)) + 0.5)
213217
top_indecies = sorted(score_dict, key=score_dict.get)[:top_n_candidates]
214218

215219
for idx in top_indecies:
@@ -310,7 +314,7 @@ def __init__(
310314
self.function_benchmark_timings = function_benchmark_timings if function_benchmark_timings else {}
311315
self.total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else {}
312316
self.replay_tests_dir = replay_tests_dir if replay_tests_dir else None
313-
n_tests = Effort.get_number_of_generated_tests(args.effort)
317+
n_tests = get_effort_value(EffortKeys.N_GENERATED_TESTS, args.effort)
314318
self.executor = concurrent.futures.ThreadPoolExecutor(
315319
max_workers=n_tests + 3 if self.experiment_id is None else n_tests + 4
316320
)
@@ -360,7 +364,7 @@ def generate_and_instrument_tests(
360364
str,
361365
]:
362366
"""Generate and instrument tests for the function."""
363-
n_tests = Effort.get_number_of_generated_tests(self.args.effort)
367+
n_tests = get_effort_value(EffortKeys.N_GENERATED_TESTS, self.args.effort)
364368
generated_test_paths = [
365369
get_test_file_path(
366370
self.test_cfg.tests_root, self.function_to_optimize.function_name, test_index, test_type="unit"
@@ -925,7 +929,7 @@ def determine_best_candidate(
925929
dependency_code=code_context.read_only_context_code,
926930
trace_id=self.get_trace_id(exp_type),
927931
line_profiler_results=original_code_baseline.line_profile_results["str_out"],
928-
num_candidates=Effort.get_number_of_optimizer_lp_candidates(self.args.effort),
932+
num_candidates=get_effort_value(EffortKeys.N_OPTIMIZER_LP_CANDIDATES, self.args.effort),
929933
experiment_metadata=ExperimentMetadata(
930934
id=self.experiment_id, group="control" if exp_type == "EXP0" else "experiment"
931935
)
@@ -1290,7 +1294,7 @@ def generate_tests(
12901294
generated_perf_test_paths: list[Path],
12911295
) -> Result[tuple[int, GeneratedTestsList, dict[str, set[FunctionCalledInTest]], str], str]:
12921296
"""Generate unit tests and concolic tests for the function."""
1293-
n_tests = Effort.get_number_of_generated_tests(self.args.effort)
1297+
n_tests = get_effort_value(EffortKeys.N_GENERATED_TESTS, self.args.effort)
12941298
assert len(generated_test_paths) == n_tests
12951299

12961300
# Submit test generation tasks
@@ -1352,7 +1356,7 @@ def generate_optimizations(
13521356
run_experiment: bool = False, # noqa: FBT001, FBT002
13531357
) -> Result[tuple[OptimizationSet, str], str]:
13541358
"""Generate optimization candidates for the function."""
1355-
n_candidates = Effort.get_number_of_optimizer_candidates(self.args.effort)
1359+
n_candidates = get_effort_value(EffortKeys.N_OPTIMIZER_CANDIDATES, self.args.effort)
13561360

13571361
future_optimization_candidates = self.executor.submit(
13581362
self.aiservice_client.optimize_python_code,
@@ -1919,8 +1923,9 @@ def repair_if_possible(
19191923
test_results_count: int,
19201924
exp_type: str,
19211925
) -> None:
1922-
if self.repair_counter >= MAX_REPAIRS_PER_TRACE:
1923-
logger.debug(f"Repair counter reached {MAX_REPAIRS_PER_TRACE}, skipping repair")
1926+
max_repairs = get_effort_value(EffortKeys.MAX_CODE_REPAIRS_PER_TRACE, self.args.effort)
1927+
if self.repair_counter >= max_repairs:
1928+
logger.debug(f"Repair counter reached {max_repairs}, skipping repair")
19241929
return
19251930
if candidate.source not in (OptimizedCandidateSource.OPTIMIZE, OptimizedCandidateSource.OPTIMIZE_LP):
19261931
# only repair the first pass of the candidates for now
@@ -1930,7 +1935,7 @@ def repair_if_possible(
19301935
logger.debug("No diffs found, skipping repair")
19311936
return
19321937
result_unmatched_perc = len(diffs) / test_results_count
1933-
if result_unmatched_perc > REPAIR_UNMATCHED_PERCENTAGE_LIMIT:
1938+
if result_unmatched_perc > get_effort_value(EffortKeys.REPAIR_UNMATCHED_PERCENTAGE_LIMIT, self.args.effort):
19341939
logger.debug(f"Result unmatched percentage is {result_unmatched_perc * 100}%, skipping repair")
19351940
return
19361941

0 commit comments

Comments
 (0)