Skip to content

Commit d95a4e1

Browse files
committed
fix: clear stale raw_generated_test_source after repair and fix repair count
- Clear raw_generated_test_source after repair so next review cycle uses repaired source instead of pre-repair LLM output - Fix repaired_count to track per-repair-call successes instead of overcounting all flagged functions - Clean up verifier tuple unpacking to match declared return type
1 parent c428587 commit d95a4e1

2 files changed

Lines changed: 20 additions & 10 deletions

File tree

codeflash/optimization/function_optimizer.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1903,15 +1903,20 @@ def display_repaired_functions(
19031903
import libcst as cst
19041904

19051905
def extract_functions(source: str, names: set[str]) -> dict[str, str]:
1906+
"""Extract functions by name from top-level and class bodies."""
19061907
try:
19071908
tree = cst.parse_module(source)
1908-
return {
1909-
node.name.value: tree.code_for_node(node)
1910-
for node in tree.body
1911-
if isinstance(node, cst.FunctionDef) and node.name.value in names
1912-
}
19131909
except cst.ParserSyntaxError:
19141910
return {}
1911+
result: dict[str, str] = {}
1912+
for node in tree.body:
1913+
if isinstance(node, cst.FunctionDef) and node.name.value in names:
1914+
result[node.name.value] = tree.code_for_node(node)
1915+
elif isinstance(node, cst.ClassDef):
1916+
for child in node.body.body:
1917+
if isinstance(child, cst.FunctionDef) and child.name.value in names:
1918+
result[child.name.value] = tree.code_for_node(child)
1919+
return result
19151920

19161921
for review in reviews:
19171922
gt = generated_tests.generated_tests[review.test_index]
@@ -2062,7 +2067,7 @@ def review_and_repair_tests(
20622067
console.print(Panel(issues_tree, title=f"Test Review (cycle {cycle + 1})", border_style="yellow"))
20632068

20642069
any_repaired = False
2065-
repaired_count = 0
2070+
repaired_count = 0 # tracks individual repair API successes (one per review, not per function)
20662071
# Snapshot original sources before repair so we can show diffs
20672072
original_sources: dict[int, str] = {
20682073
r.test_index: generated_tests.generated_tests[r.test_index].generated_original_test_source
@@ -2117,11 +2122,13 @@ def review_and_repair_tests(
21172122
gt.generated_original_test_source = repaired_source
21182123
gt.instrumented_behavior_test_source = behavior_source
21192124
gt.instrumented_perf_test_source = perf_source
2125+
# Clear stale LLM output so the next review cycle sends repaired source
2126+
gt.raw_generated_test_source = None
21202127

21212128
gt.behavior_file_path.write_text(behavior_source, encoding="utf8")
21222129
gt.perf_file_path.write_text(perf_source, encoding="utf8")
21232130
any_repaired = True
2124-
repaired_count += len(review.functions_to_repair)
2131+
repaired_count += 1
21252132

21262133
if any_repaired:
21272134
generated_tests = self.language_support.postprocess_generated_tests(
@@ -2130,7 +2137,7 @@ def review_and_repair_tests(
21302137
project_root=self.project_root,
21312138
source_file_path=self.function_to_optimize.file_path,
21322139
)
2133-
console.print(f" [green]Repaired {repaired_count} test function(s)[/green]")
2140+
console.print(f" [green]Repaired {repaired_count} test file(s)[/green]")
21342141
self.display_repaired_functions(generated_tests, all_to_repair, original_sources)
21352142
with progress_bar("Re-validating repaired tests..."):
21362143
validation = self.run_behavioral_validation(
@@ -2141,6 +2148,8 @@ def review_and_repair_tests(
21412148
behavioral_results, coverage_results = validation
21422149

21432150
console.rule()
2151+
# When all repair API calls failed (any_repaired=False), behavioral_results are from before
2152+
# the repair attempts. This is correct since no test code actually changed.
21442153
return Success((generated_tests, behavioral_results, coverage_results))
21452154

21462155
def find_and_process_best_optimization(

codeflash/verification/verifier.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,9 @@ def generate_tests(
7474
is_numerical_code=is_numerical_code,
7575
)
7676
if response and isinstance(response, tuple) and len(response) >= 3:
77-
generated_test_source, instrumented_behavior_test_source, instrumented_perf_test_source, *rest = response
78-
raw_generated_tests = rest[0] if rest else None
77+
generated_test_source, instrumented_behavior_test_source, instrumented_perf_test_source, raw_generated_tests = (
78+
response
79+
)
7980

8081
generated_test_source, instrumented_behavior_test_source, instrumented_perf_test_source = (
8182
lang_support.process_generated_test_strings(

0 commit comments

Comments
 (0)