|
4 | 4 | from pathlib import Path |
5 | 5 | from typing import TYPE_CHECKING |
6 | 6 |
|
7 | | -from codeflash.cli_cmds.console import console, logger |
| 7 | +from rich.syntax import Syntax |
| 8 | + |
| 9 | +from codeflash.cli_cmds.console import code_print, console, logger |
| 10 | +from codeflash.code_utils.code_utils import unified_diff_strings |
8 | 11 | from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME_EFFECTIVE |
9 | 12 | from codeflash.either import Failure, Success |
10 | 13 | from codeflash.languages.python.context.unused_definition_remover import ( |
|
33 | 36 | CodeStringsMarkdown, |
34 | 37 | ConcurrencyMetrics, |
35 | 38 | CoverageData, |
| 39 | + GeneratedTestsList, |
36 | 40 | OriginalCodeBaseline, |
37 | 41 | TestDiff, |
| 42 | + TestFileReview, |
38 | 43 | ) |
39 | 44 |
|
40 | 45 |
|
@@ -86,6 +91,54 @@ def instrument_capture(self, file_path_to_helper_classes: dict[Path, set[str]]) |
86 | 91 |
|
87 | 92 | instrument_codeflash_capture(self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root) |
88 | 93 |
|
| 94 | + def display_repaired_functions( |
| 95 | + self, generated_tests: GeneratedTestsList, reviews: list[TestFileReview], original_sources: dict[int, str] |
| 96 | + ) -> None: |
| 97 | + """Display per-function diffs of repaired tests using libcst.""" |
| 98 | + import libcst as cst |
| 99 | + |
| 100 | + def extract_functions(source: str, names: set[str]) -> dict[str, str]: |
| 101 | + """Extract functions by name from top-level and class bodies.""" |
| 102 | + try: |
| 103 | + tree = cst.parse_module(source) |
| 104 | + except cst.ParserSyntaxError: |
| 105 | + return {} |
| 106 | + result: dict[str, str] = {} |
| 107 | + for node in tree.body: |
| 108 | + if isinstance(node, cst.FunctionDef) and node.name.value in names: |
| 109 | + result[node.name.value] = tree.code_for_node(node) |
| 110 | + elif isinstance(node, cst.ClassDef): |
| 111 | + for child in node.body.body: |
| 112 | + if isinstance(child, cst.FunctionDef) and child.name.value in names: |
| 113 | + result[child.name.value] = tree.code_for_node(child) |
| 114 | + return result |
| 115 | + |
| 116 | + for review in reviews: |
| 117 | + gt = generated_tests.generated_tests[review.test_index] |
| 118 | + repaired_names = {f.function_name for f in review.functions_to_repair} |
| 119 | + new_source = gt.generated_original_test_source |
| 120 | + old_source = original_sources.get(review.test_index, "") |
| 121 | + |
| 122 | + old_funcs = extract_functions(old_source, repaired_names) |
| 123 | + new_funcs = extract_functions(new_source, repaired_names) |
| 124 | + |
| 125 | + for name in repaired_names: |
| 126 | + old_func = old_funcs.get(name, "") |
| 127 | + new_func = new_funcs.get(name, "") |
| 128 | + if not new_func: |
| 129 | + continue |
| 130 | + console.rule() |
| 131 | + if old_func and old_func != new_func: |
| 132 | + diff = unified_diff_strings( |
| 133 | + old_func, new_func, fromfile=f"{name} (before)", tofile=f"{name} (after)" |
| 134 | + ) |
| 135 | + if diff: |
| 136 | + logger.info(f"Repaired: {name}") |
| 137 | + console.print(Syntax(diff, "diff", theme="monokai")) |
| 138 | + continue |
| 139 | + logger.info(f"Repaired: {name}") |
| 140 | + code_print(new_func, language=self.function_to_optimize.language) |
| 141 | + |
89 | 142 | def should_check_coverage(self) -> bool: |
90 | 143 | return True |
91 | 144 |
|
|
0 commit comments