Skip to content

Commit b88f946

Browse files
committed
refactor: move CST-based display_repaired_functions to PythonFunctionOptimizer
The libcst-based diff display is Python-specific. Move it to the Python subclass and keep a simple log-only fallback in the base class so JS/TS can override with their own parser later.
1 parent d95a4e1 commit b88f946

2 files changed

Lines changed: 57 additions & 44 deletions

File tree

codeflash/languages/python/function_optimizer.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
from pathlib import Path
55
from typing import TYPE_CHECKING
66

7-
from codeflash.cli_cmds.console import console, logger
7+
from rich.syntax import Syntax
8+
9+
from codeflash.cli_cmds.console import code_print, console, logger
10+
from codeflash.code_utils.code_utils import unified_diff_strings
811
from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME_EFFECTIVE
912
from codeflash.either import Failure, Success
1013
from codeflash.languages.python.context.unused_definition_remover import (
@@ -33,8 +36,10 @@
3336
CodeStringsMarkdown,
3437
ConcurrencyMetrics,
3538
CoverageData,
39+
GeneratedTestsList,
3640
OriginalCodeBaseline,
3741
TestDiff,
42+
TestFileReview,
3843
)
3944

4045

@@ -86,6 +91,54 @@ def instrument_capture(self, file_path_to_helper_classes: dict[Path, set[str]])
8691

8792
instrument_codeflash_capture(self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root)
8893

94+
def display_repaired_functions(
95+
self, generated_tests: GeneratedTestsList, reviews: list[TestFileReview], original_sources: dict[int, str]
96+
) -> None:
97+
"""Display per-function diffs of repaired tests using libcst."""
98+
import libcst as cst
99+
100+
def extract_functions(source: str, names: set[str]) -> dict[str, str]:
101+
"""Extract functions by name from top-level and class bodies."""
102+
try:
103+
tree = cst.parse_module(source)
104+
except cst.ParserSyntaxError:
105+
return {}
106+
result: dict[str, str] = {}
107+
for node in tree.body:
108+
if isinstance(node, cst.FunctionDef) and node.name.value in names:
109+
result[node.name.value] = tree.code_for_node(node)
110+
elif isinstance(node, cst.ClassDef):
111+
for child in node.body.body:
112+
if isinstance(child, cst.FunctionDef) and child.name.value in names:
113+
result[child.name.value] = tree.code_for_node(child)
114+
return result
115+
116+
for review in reviews:
117+
gt = generated_tests.generated_tests[review.test_index]
118+
repaired_names = {f.function_name for f in review.functions_to_repair}
119+
new_source = gt.generated_original_test_source
120+
old_source = original_sources.get(review.test_index, "")
121+
122+
old_funcs = extract_functions(old_source, repaired_names)
123+
new_funcs = extract_functions(new_source, repaired_names)
124+
125+
for name in repaired_names:
126+
old_func = old_funcs.get(name, "")
127+
new_func = new_funcs.get(name, "")
128+
if not new_func:
129+
continue
130+
console.rule()
131+
if old_func and old_func != new_func:
132+
diff = unified_diff_strings(
133+
old_func, new_func, fromfile=f"{name} (before)", tofile=f"{name} (after)"
134+
)
135+
if diff:
136+
logger.info(f"Repaired: {name}")
137+
console.print(Syntax(diff, "diff", theme="monokai"))
138+
continue
139+
logger.info(f"Repaired: {name}")
140+
code_print(new_func, language=self.function_to_optimize.language)
141+
89142
def should_check_coverage(self) -> bool:
90143
return True
91144

codeflash/optimization/function_optimizer.py

Lines changed: 3 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1899,50 +1899,10 @@ def setup_and_establish_baseline(
18991899
def display_repaired_functions(
19001900
self, generated_tests: GeneratedTestsList, reviews: list[TestFileReview], original_sources: dict[int, str]
19011901
) -> None:
1902-
"""Display diffs of repaired functions showing what changed."""
1903-
import libcst as cst
1904-
1905-
def extract_functions(source: str, names: set[str]) -> dict[str, str]:
1906-
"""Extract functions by name from top-level and class bodies."""
1907-
try:
1908-
tree = cst.parse_module(source)
1909-
except cst.ParserSyntaxError:
1910-
return {}
1911-
result: dict[str, str] = {}
1912-
for node in tree.body:
1913-
if isinstance(node, cst.FunctionDef) and node.name.value in names:
1914-
result[node.name.value] = tree.code_for_node(node)
1915-
elif isinstance(node, cst.ClassDef):
1916-
for child in node.body.body:
1917-
if isinstance(child, cst.FunctionDef) and child.name.value in names:
1918-
result[child.name.value] = tree.code_for_node(child)
1919-
return result
1920-
1902+
"""Display repaired functions. Override in language subclasses for richer diff output."""
19211903
for review in reviews:
1922-
gt = generated_tests.generated_tests[review.test_index]
1923-
repaired_names = {f.function_name for f in review.functions_to_repair}
1924-
new_source = gt.generated_original_test_source
1925-
old_source = original_sources.get(review.test_index, "")
1926-
1927-
old_funcs = extract_functions(old_source, repaired_names)
1928-
new_funcs = extract_functions(new_source, repaired_names)
1929-
1930-
for name in repaired_names:
1931-
old_func = old_funcs.get(name, "")
1932-
new_func = new_funcs.get(name, "")
1933-
if not new_func:
1934-
continue
1935-
console.rule()
1936-
if old_func and old_func != new_func:
1937-
diff = unified_diff_strings(
1938-
old_func, new_func, fromfile=f"{name} (before)", tofile=f"{name} (after)"
1939-
)
1940-
if diff:
1941-
logger.info(f"Repaired: {name}")
1942-
console.print(Syntax(diff, "diff", theme="monokai"))
1943-
continue
1944-
logger.info(f"Repaired: {name}")
1945-
code_print(new_func, language=self.function_to_optimize.language)
1904+
for f in review.functions_to_repair:
1905+
logger.info(f"Repaired: {f.function_name}")
19461906

19471907
def build_helper_classes_map(self, code_context: CodeOptimizationContext) -> dict[Path, set[str]]:
19481908
"""Build a mapping of file paths to helper class names from code context."""

0 commit comments

Comments
 (0)