Skip to content

Commit 8a75fe1

Browse files
authored
Merge pull request #1699 from codeflash-ai/extract-python-optimizer
refactor: extract PythonFunctionOptimizer subclass
2 parents 86202d4 + 4ccbffe commit 8a75fe1

11 files changed

Lines changed: 543 additions & 514 deletions

File tree

.claude/rules/architecture.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ codeflash/
1515
├── code_utils/ # Code parsing, git utilities
1616
├── models/ # Pydantic models and types
1717
├── languages/ # Multi-language support (Python, JavaScript/TypeScript)
18+
│ └── python/
19+
│ ├── function_optimizer.py # PythonFunctionOptimizer (Python-specific hooks)
20+
│ └── optimizer.py # Python module preparation & AST resolution
1821
├── setup/ # Config schema, auto-detection, first-run experience
1922
├── picklepatch/ # Serialization/deserialization utilities
2023
├── tracing/ # Function call tracing
@@ -32,7 +35,7 @@ codeflash/
3235
|------|------------|
3336
| CLI arguments & commands | `cli_cmds/cli.py` |
3437
| Optimization orchestration | `optimization/optimizer.py``run()` |
35-
| Per-function optimization | `optimization/function_optimizer.py` |
38+
| Per-function optimization | `optimization/function_optimizer.py` (base), `languages/python/function_optimizer.py` (Python subclass) |
3639
| Function discovery | `discovery/functions_to_optimize.py` |
3740
| Context extraction | `languages/<lang>/context/code_context_extractor.py` |
3841
| Test execution | `verification/test_runner.py`, `verification/pytest_plugin.py` |
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
from pathlib import Path
5+
from typing import TYPE_CHECKING
6+
7+
from codeflash.cli_cmds.console import console, logger
8+
from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME_EFFECTIVE
9+
from codeflash.languages.python.context.unused_definition_remover import (
10+
detect_unused_helper_functions,
11+
revert_unused_helper_functions,
12+
)
13+
from codeflash.languages.python.optimizer import resolve_python_function_ast
14+
from codeflash.languages.python.static_analysis.code_extractor import get_opt_review_metrics, is_numerical_code
15+
from codeflash.languages.python.static_analysis.code_replacer import (
16+
add_custom_marker_to_all_tests,
17+
modify_autouse_fixture,
18+
)
19+
from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports, contains_jit_decorator
20+
from codeflash.models.models import TestingMode, TestResults
21+
from codeflash.optimization.function_optimizer import FunctionOptimizer
22+
from codeflash.verification.parse_test_output import calculate_function_throughput_from_test_results
23+
24+
if TYPE_CHECKING:
25+
from typing import Any
26+
27+
from codeflash.languages.base import Language
28+
from codeflash.models.function_types import FunctionParent
29+
from codeflash.models.models import (
30+
CodeOptimizationContext,
31+
CodeStringsMarkdown,
32+
ConcurrencyMetrics,
33+
CoverageData,
34+
OriginalCodeBaseline,
35+
TestDiff,
36+
)
37+
38+
39+
class PythonFunctionOptimizer(FunctionOptimizer):
40+
def _resolve_function_ast(
41+
self, source_code: str, function_name: str, parents: list[FunctionParent]
42+
) -> ast.FunctionDef | ast.AsyncFunctionDef | None:
43+
original_module_ast = ast.parse(source_code)
44+
return resolve_python_function_ast(function_name, parents, original_module_ast)
45+
46+
def analyze_code_characteristics(self, code_context: CodeOptimizationContext) -> None:
47+
self.is_numerical_code = is_numerical_code(code_string=code_context.read_writable_code.flat)
48+
49+
def get_optimization_review_metrics(
50+
self,
51+
source_code: str,
52+
file_path: Path,
53+
qualified_name: str,
54+
project_root: Path,
55+
tests_root: Path,
56+
language: Language,
57+
) -> str:
58+
return get_opt_review_metrics(source_code, file_path, qualified_name, project_root, tests_root, language)
59+
60+
def instrument_test_fixtures(self, test_paths: list[Path]) -> dict[Path, list[str]] | None:
61+
logger.info("Disabling all autouse fixtures associated with the generated test files")
62+
original_conftest_content = modify_autouse_fixture(test_paths)
63+
logger.info("Add custom marker to generated test files")
64+
add_custom_marker_to_all_tests(test_paths)
65+
return original_conftest_content
66+
67+
def instrument_capture(self, file_path_to_helper_classes: dict[Path, set[str]]) -> None:
68+
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
69+
70+
instrument_codeflash_capture(self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root)
71+
72+
def should_check_coverage(self) -> bool:
73+
return True
74+
75+
def collect_async_metrics(
76+
self,
77+
benchmarking_results: TestResults,
78+
code_context: CodeOptimizationContext,
79+
helper_code: dict[Path, str],
80+
test_env: dict[str, str],
81+
) -> tuple[int | None, ConcurrencyMetrics | None]:
82+
if not self.function_to_optimize.is_async:
83+
return None, None
84+
85+
async_throughput = calculate_function_throughput_from_test_results(
86+
benchmarking_results, self.function_to_optimize.function_name
87+
)
88+
logger.debug(f"Async function throughput: {async_throughput} calls/second")
89+
90+
concurrency_metrics = self.run_concurrency_benchmark(
91+
code_context=code_context, original_helper_code=helper_code, test_env=test_env
92+
)
93+
if concurrency_metrics:
94+
logger.debug(
95+
f"Concurrency metrics: ratio={concurrency_metrics.concurrency_ratio:.2f}, "
96+
f"seq={concurrency_metrics.sequential_time_ns}ns, conc={concurrency_metrics.concurrent_time_ns}ns"
97+
)
98+
return async_throughput, concurrency_metrics
99+
100+
def instrument_async_for_mode(self, mode: TestingMode) -> None:
101+
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
102+
103+
add_async_decorator_to_function(
104+
self.function_to_optimize.file_path, self.function_to_optimize, mode, project_root=self.project_root
105+
)
106+
107+
def should_skip_sqlite_cleanup(self, testing_type: TestingMode, optimization_iteration: int) -> bool:
108+
return False
109+
110+
def parse_line_profile_test_results(
111+
self, line_profiler_output_file: Path | None
112+
) -> tuple[TestResults | dict, CoverageData | None]:
113+
from codeflash.verification.parse_line_profile_test_output import parse_line_profile_results
114+
115+
return parse_line_profile_results(line_profiler_output_file=line_profiler_output_file)
116+
117+
def compare_candidate_results(
118+
self,
119+
baseline_results: OriginalCodeBaseline,
120+
candidate_behavior_results: TestResults,
121+
optimization_candidate_index: int,
122+
) -> tuple[bool, list[TestDiff]]:
123+
from codeflash.verification.equivalence import compare_test_results
124+
125+
return compare_test_results(baseline_results.behavior_test_results, candidate_behavior_results)
126+
127+
def replace_function_and_helpers_with_optimized_code(
128+
self,
129+
code_context: CodeOptimizationContext,
130+
optimized_code: CodeStringsMarkdown,
131+
original_helper_code: dict[Path, str],
132+
) -> bool:
133+
did_update = super().replace_function_and_helpers_with_optimized_code(
134+
code_context, optimized_code, original_helper_code
135+
)
136+
unused_helpers = detect_unused_helper_functions(self.function_to_optimize, code_context, optimized_code)
137+
if unused_helpers:
138+
revert_unused_helper_functions(self.project_root, unused_helpers, original_helper_code)
139+
return did_update
140+
141+
def line_profiler_step(
142+
self, code_context: CodeOptimizationContext, original_helper_code: dict[Path, str], candidate_index: int
143+
) -> dict[str, Any]:
144+
candidate_fto_code = Path(self.function_to_optimize.file_path).read_text("utf-8")
145+
if contains_jit_decorator(candidate_fto_code):
146+
logger.info(
147+
f"Skipping line profiler for {self.function_to_optimize.function_name} - code contains JIT decorator"
148+
)
149+
return {"timings": {}, "unit": 0, "str_out": ""}
150+
151+
for module_abspath in original_helper_code:
152+
candidate_helper_code = Path(module_abspath).read_text("utf-8")
153+
if contains_jit_decorator(candidate_helper_code):
154+
logger.info(
155+
f"Skipping line profiler for {self.function_to_optimize.function_name} - helper code contains JIT decorator"
156+
)
157+
return {"timings": {}, "unit": 0, "str_out": ""}
158+
159+
try:
160+
console.rule()
161+
162+
test_env = self.get_test_env(
163+
codeflash_loop_index=0, codeflash_test_iteration=candidate_index, codeflash_tracer_disable=1
164+
)
165+
line_profiler_output_file = add_decorator_imports(self.function_to_optimize, code_context)
166+
line_profile_results, _ = self.run_and_parse_tests(
167+
testing_type=TestingMode.LINE_PROFILE,
168+
test_env=test_env,
169+
test_files=self.test_files,
170+
optimization_iteration=0,
171+
testing_time=TOTAL_LOOPING_TIME_EFFECTIVE,
172+
enable_coverage=False,
173+
code_context=code_context,
174+
line_profiler_output_file=line_profiler_output_file,
175+
)
176+
finally:
177+
self.write_code_and_helpers(
178+
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
179+
)
180+
if isinstance(line_profile_results, TestResults) and not line_profile_results.test_results:
181+
logger.warning(
182+
f"Timeout occurred while running line profiler for original function {self.function_to_optimize.function_name}"
183+
)
184+
return {"timings": {}, "unit": 0, "str_out": ""}
185+
if line_profile_results["str_out"] == "":
186+
logger.warning(
187+
f"Couldn't run line profiler for original function {self.function_to_optimize.function_name}"
188+
)
189+
return line_profile_results
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
from typing import TYPE_CHECKING
5+
6+
from codeflash.cli_cmds.console import logger
7+
from codeflash.models.models import ValidCode
8+
9+
if TYPE_CHECKING:
10+
from pathlib import Path
11+
12+
from codeflash.models.function_types import FunctionParent
13+
14+
15+
def prepare_python_module(
16+
original_module_code: str, original_module_path: Path, project_root: Path
17+
) -> tuple[dict[Path, ValidCode], ast.Module] | None:
18+
"""Parse a Python module, normalize its code, and validate imported callee modules.
19+
20+
Returns a mapping of file paths to ValidCode (for the module and its imported callees)
21+
plus the parsed AST, or None on syntax error.
22+
"""
23+
from codeflash.languages.python.static_analysis.code_replacer import normalize_code, normalize_node
24+
from codeflash.languages.python.static_analysis.static_analysis import analyze_imported_modules
25+
26+
try:
27+
original_module_ast = ast.parse(original_module_code)
28+
except SyntaxError as e:
29+
logger.warning(f"Syntax error parsing code in {original_module_path}: {e}")
30+
logger.info("Skipping optimization due to file error.")
31+
return None
32+
33+
normalized_original_module_code = ast.unparse(normalize_node(original_module_ast))
34+
validated_original_code: dict[Path, ValidCode] = {
35+
original_module_path: ValidCode(
36+
source_code=original_module_code, normalized_code=normalized_original_module_code
37+
)
38+
}
39+
40+
imported_module_analyses = analyze_imported_modules(original_module_code, original_module_path, project_root)
41+
42+
for analysis in imported_module_analyses:
43+
callee_original_code = analysis.file_path.read_text(encoding="utf8")
44+
try:
45+
normalized_callee_original_code = normalize_code(callee_original_code)
46+
except SyntaxError as e:
47+
logger.warning(f"Syntax error parsing code in callee module {analysis.file_path}: {e}")
48+
logger.info("Skipping optimization due to helper file error.")
49+
return None
50+
validated_original_code[analysis.file_path] = ValidCode(
51+
source_code=callee_original_code, normalized_code=normalized_callee_original_code
52+
)
53+
54+
return validated_original_code, original_module_ast
55+
56+
57+
def resolve_python_function_ast(
58+
function_name: str, parents: list[FunctionParent], module_ast: ast.Module
59+
) -> ast.FunctionDef | ast.AsyncFunctionDef | None:
60+
"""Look up a function/method AST node in a parsed Python module."""
61+
from codeflash.languages.python.static_analysis.static_analysis import get_first_top_level_function_or_method_ast
62+
63+
return get_first_top_level_function_or_method_ast(function_name, parents, module_ast)

codeflash/lsp/beta.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -463,14 +463,10 @@ def _initialize_current_function_optimizer() -> Union[dict[str, str], WrappedIni
463463
"message": "Failed to prepare module for optimization",
464464
}
465465

466-
validated_original_code, original_module_ast = module_prep_result
466+
validated_original_code, _original_module_ast = module_prep_result
467467

468468
function_optimizer = server.optimizer.create_function_optimizer(
469-
fto,
470-
function_to_optimize_source_code=validated_original_code[fto.file_path].source_code,
471-
original_module_ast=original_module_ast,
472-
original_module_path=fto.file_path,
473-
function_to_tests={},
469+
fto, function_to_optimize_source_code=validated_original_code[fto.file_path].source_code, function_to_tests={}
474470
)
475471

476472
server.optimizer.current_function_optimizer = function_optimizer

0 commit comments

Comments
 (0)