Skip to content

Commit 7613d76

Browse files
committed
feat: per-function test quality review and repair
Add review + repair step between test generation and baseline: - Run behavioral tests to identify failing test functions - Send to /ai/testgen_review with failures pre-flagged, AI reviews passing functions for unrealistic patterns (cache warm-up, internal state manipulation, identical inputs) - Repair flagged functions via /ai/testgen_repair - Loop up to MAX_TEST_REPAIR_CYCLES (default 1) - Full baseline (behavioral + benchmarking) runs once on final tests
1 parent aed7161 commit 7613d76

4 files changed

Lines changed: 247 additions & 0 deletions

File tree

codeflash/api/aiservice.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020
from codeflash.models.models import (
2121
AIServiceRefinerRequest,
2222
CodeStringsMarkdown,
23+
FunctionRepairInfo,
2324
OptimizationReviewResult,
2425
OptimizedCandidate,
2526
OptimizedCandidateSource,
27+
TestFileReview,
2628
)
2729
from codeflash.telemetry.posthog_cf import ph
2830
from codeflash.version import __version__ as codeflash_version
@@ -803,6 +805,98 @@ def generate_regression_tests(
803805
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": response.text})
804806
return None
805807

808+
def review_generated_tests(
809+
self,
810+
tests: list[dict],
811+
function_source_code: str,
812+
function_name: str,
813+
trace_id: str,
814+
language: str = "python",
815+
) -> list[TestFileReview]:
816+
payload = {
817+
"tests": tests,
818+
"function_source_code": function_source_code,
819+
"function_name": function_name,
820+
"trace_id": trace_id,
821+
"language": language,
822+
"codeflash_version": codeflash_version,
823+
"call_sequence": self.get_next_sequence(),
824+
}
825+
try:
826+
response = self.make_ai_service_request("/testgen_review", payload=payload, timeout=self.timeout)
827+
except requests.exceptions.RequestException as e:
828+
logger.exception(f"Error reviewing generated tests: {e}")
829+
ph("cli-testgen-review-error-caught", {"error": str(e)})
830+
return []
831+
832+
if response.status_code == 200:
833+
data = response.json()
834+
return [
835+
TestFileReview(
836+
test_index=r["test_index"],
837+
functions_to_repair=[
838+
FunctionRepairInfo(function_name=f["function_name"], reason=f.get("reason", ""))
839+
for f in r.get("functions", [])
840+
],
841+
)
842+
for r in data.get("reviews", [])
843+
]
844+
try:
845+
error = response.json()["error"]
846+
except Exception:
847+
error = response.text
848+
logger.error(f"Error reviewing generated tests: {response.status_code} - {error}")
849+
ph("cli-testgen-review-error-response", {"response_status_code": response.status_code, "error": error})
850+
return []
851+
852+
def repair_generated_tests(
853+
self,
854+
test_source: str,
855+
functions_to_repair: list[FunctionRepairInfo],
856+
function_source_code: str,
857+
function_to_optimize: FunctionToOptimize,
858+
helper_function_names: list[str],
859+
module_path: Path,
860+
test_module_path: Path,
861+
test_framework: str,
862+
test_timeout: int,
863+
trace_id: str,
864+
language: str = "python",
865+
) -> tuple[str, str, str] | None:
866+
payload: dict[str, Any] = {
867+
"test_source": test_source,
868+
"functions_to_repair": [{"function_name": f.function_name, "reason": f.reason} for f in functions_to_repair],
869+
"function_source_code": function_source_code,
870+
"function_to_optimize": function_to_optimize,
871+
"helper_function_names": helper_function_names,
872+
"module_path": module_path,
873+
"test_module_path": test_module_path,
874+
"test_framework": test_framework,
875+
"test_timeout": test_timeout,
876+
"trace_id": trace_id,
877+
"language": language,
878+
"python_version": platform.python_version(),
879+
"codeflash_version": codeflash_version,
880+
"call_sequence": self.get_next_sequence(),
881+
}
882+
try:
883+
response = self.make_ai_service_request("/testgen_repair", payload=payload, timeout=self.timeout)
884+
except requests.exceptions.RequestException as e:
885+
logger.exception(f"Error repairing generated tests: {e}")
886+
ph("cli-testgen-repair-error-caught", {"error": str(e)})
887+
return None
888+
889+
if response.status_code == 200:
890+
data = response.json()
891+
return (data["generated_tests"], data["instrumented_behavior_tests"], data["instrumented_perf_tests"])
892+
try:
893+
error = response.json()["error"]
894+
except Exception:
895+
error = response.text
896+
logger.error(f"Error repairing generated tests: {response.status_code} - {error}")
897+
ph("cli-testgen-repair-error-response", {"response_status_code": response.status_code, "error": error})
898+
return None
899+
806900
def get_optimization_review(
807901
self,
808902
original_code: dict[Path, str],

codeflash/code_utils/config_consts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
COVERAGE_THRESHOLD = 60.0
2222
MIN_TESTCASE_PASSED_THRESHOLD = 6
2323
REPEAT_OPTIMIZATION_PROBABILITY = 0.1
24+
MAX_TEST_REPAIR_CYCLES = 1
2425
DEFAULT_IMPORTANCE_THRESHOLD = 0.001
2526

2627
# pytest loop stability

codeflash/models/models.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,16 @@ class OptimizationReviewResult(NamedTuple):
115115
explanation: str
116116

117117

118+
class FunctionRepairInfo(NamedTuple):
119+
function_name: str
120+
reason: str
121+
122+
123+
class TestFileReview(NamedTuple):
124+
test_index: int
125+
functions_to_repair: list[FunctionRepairInfo]
126+
127+
118128
# If the method spam is in the class Ham, which is at the top level of the module eggs in the package foo, the fully
119129
# qualified name of the method is foo.eggs.Ham.spam, its qualified name is Ham.spam, and its name is spam. The full name
120130
# of the module is foo.eggs.

codeflash/optimization/function_optimizer.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,15 @@
4242
extract_unique_errors,
4343
file_name_from_test_module_name,
4444
get_run_tmp_file,
45+
module_name_from_file_path,
4546
normalize_by_max,
4647
restore_conftest,
4748
unified_diff_strings,
4849
)
4950
from codeflash.code_utils.config_consts import (
5051
COVERAGE_THRESHOLD,
5152
INDIVIDUAL_TESTCASE_TIMEOUT,
53+
MAX_TEST_REPAIR_CYCLES,
5254
MIN_CORRECT_CANDIDATES,
5355
OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
5456
REFINED_CANDIDATE_RANKING_WEIGHTS,
@@ -763,6 +765,17 @@ def optimize_function(self) -> Result[BestOptimization, str]:
763765

764766
optimizations_set, function_references = optimization_result.unwrap()
765767

768+
review_result = self.review_and_repair_tests(
769+
generated_tests=generated_tests,
770+
code_context=code_context,
771+
original_helper_code=original_helper_code,
772+
)
773+
if not is_successful(review_result):
774+
return Failure(review_result.failure())
775+
776+
generated_tests = review_result.unwrap()
777+
778+
# Full baseline (behavioral + benchmarking) runs once on the final approved tests
766779
baseline_setup_result = self.setup_and_establish_baseline(
767780
code_context=code_context,
768781
original_helper_code=original_helper_code,
@@ -1885,6 +1898,135 @@ def setup_and_establish_baseline(
18851898
)
18861899
)
18871900

1901+
def run_behavioral_validation(
1902+
self,
1903+
code_context: CodeOptimizationContext,
1904+
original_helper_code: dict[Path, str],
1905+
) -> TestResults | None:
1906+
"""Run behavioral tests only. Returns results or None if no tests ran."""
1907+
file_path_to_helper_classes: dict[Path, set[str]] = defaultdict(set)
1908+
for function_source in code_context.helper_functions:
1909+
if (
1910+
function_source.qualified_name != self.function_to_optimize.qualified_name
1911+
and "." in function_source.qualified_name
1912+
):
1913+
file_path_to_helper_classes[function_source.file_path].add(
1914+
function_source.qualified_name.split(".")[0]
1915+
)
1916+
1917+
test_env = self.get_test_env(codeflash_loop_index=0, codeflash_test_iteration=0, codeflash_tracer_disable=1)
1918+
if self.function_to_optimize.is_async:
1919+
self.instrument_async_for_mode(TestingMode.BEHAVIOR)
1920+
try:
1921+
self.instrument_capture(file_path_to_helper_classes)
1922+
behavioral_results, _ = self.run_and_parse_tests(
1923+
testing_type=TestingMode.BEHAVIOR,
1924+
test_env=test_env,
1925+
test_files=self.test_files,
1926+
optimization_iteration=0,
1927+
testing_time=TOTAL_LOOPING_TIME_EFFECTIVE,
1928+
enable_coverage=False,
1929+
code_context=code_context,
1930+
)
1931+
finally:
1932+
self.write_code_and_helpers(
1933+
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
1934+
)
1935+
return behavioral_results if behavioral_results else None
1936+
1937+
def review_and_repair_tests(
1938+
self,
1939+
generated_tests: GeneratedTestsList,
1940+
code_context: CodeOptimizationContext,
1941+
original_helper_code: dict[Path, str],
1942+
) -> Result[GeneratedTestsList, str]:
1943+
"""Run behavioral tests, review quality per-function, repair flagged functions.
1944+
1945+
Flow (up to MAX_TEST_REPAIR_CYCLES):
1946+
behavioral → collect failures → AI review passing functions → repair flagged → loop
1947+
No benchmarking runs here — only behavioral validation.
1948+
"""
1949+
for cycle in range(MAX_TEST_REPAIR_CYCLES):
1950+
# 1. Run behavioral tests
1951+
behavioral_results = self.run_behavioral_validation(code_context, original_helper_code)
1952+
if behavioral_results is None:
1953+
return Failure("Generated tests failed behavioral validation.")
1954+
1955+
# 2. Collect per-function failures grouped by behavior file path
1956+
failed_by_file: dict[Path, list[str]] = defaultdict(list)
1957+
for result in behavioral_results.test_results:
1958+
if result.test_type == TestType.GENERATED_REGRESSION and not result.did_pass:
1959+
failed_by_file[result.file_name].append(result.id.test_function_name)
1960+
1961+
# 3. Build review request with failed functions pre-flagged
1962+
tests_for_review = []
1963+
for i, gt in enumerate(generated_tests.generated_tests):
1964+
failed_fns = failed_by_file.get(gt.behavior_file_path, [])
1965+
tests_for_review.append({
1966+
"test_source": gt.generated_original_test_source,
1967+
"test_index": i,
1968+
"failed_test_functions": failed_fns,
1969+
})
1970+
1971+
review_results = self.aiservice_client.review_generated_tests(
1972+
tests=tests_for_review,
1973+
function_source_code=self.function_to_optimize_source_code,
1974+
function_name=self.function_to_optimize.function_name,
1975+
trace_id=self.function_trace_id,
1976+
language=self.function_to_optimize.language,
1977+
)
1978+
1979+
# 4. Repair test files that have flagged functions
1980+
any_repaired = False
1981+
for review in review_results:
1982+
if not review.functions_to_repair:
1983+
continue
1984+
1985+
gt = generated_tests.generated_tests[review.test_index]
1986+
fn_names = ", ".join(f.function_name for f in review.functions_to_repair)
1987+
logger.info(f"Repairing test functions in test {review.test_index} (cycle {cycle + 1}): {fn_names}")
1988+
ph("cli-testgen-repair", {
1989+
"test_index": review.test_index,
1990+
"cycle": cycle + 1,
1991+
"functions": [f.function_name for f in review.functions_to_repair],
1992+
})
1993+
1994+
test_module_path = Path(
1995+
module_name_from_file_path(gt.behavior_file_path, self.test_cfg.tests_project_rootdir)
1996+
)
1997+
repair_result = self.aiservice_client.repair_generated_tests(
1998+
test_source=gt.generated_original_test_source,
1999+
functions_to_repair=review.functions_to_repair,
2000+
function_source_code=self.function_to_optimize_source_code,
2001+
function_to_optimize=self.function_to_optimize,
2002+
helper_function_names=[],
2003+
module_path=Path(self.original_module_path),
2004+
test_module_path=test_module_path,
2005+
test_framework=self.test_cfg.test_framework,
2006+
test_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
2007+
trace_id=self.function_trace_id,
2008+
language=self.function_to_optimize.language,
2009+
)
2010+
2011+
if repair_result is None:
2012+
logger.warning(f"Repair failed for test {review.test_index}, keeping original")
2013+
continue
2014+
2015+
repaired_source, behavior_source, perf_source = repair_result
2016+
gt.generated_original_test_source = repaired_source
2017+
gt.instrumented_behavior_test_source = behavior_source
2018+
gt.instrumented_perf_test_source = perf_source
2019+
2020+
gt.behavior_file_path.write_text(behavior_source, encoding="utf8")
2021+
gt.perf_file_path.write_text(perf_source, encoding="utf8")
2022+
any_repaired = True
2023+
2024+
# Nothing needed repair — tests are good
2025+
if not any_repaired:
2026+
break
2027+
2028+
return Success(generated_tests)
2029+
18882030
def find_and_process_best_optimization(
18892031
self,
18902032
optimizations_set: OptimizationSet,

0 commit comments

Comments
 (0)