Skip to content

Commit c5cdefe

Browse files
committed
refactor: extract PythonFunctionOptimizer subclass from FunctionOptimizer
Move 6 Python-specific methods into PythonFunctionOptimizer in languages/python/function_optimizer.py. Base class gets no-op defaults; Optimizer.create_function_optimizer dispatches to the subclass when is_python().
1 parent c7f225d commit c5cdefe

3 files changed

Lines changed: 179 additions & 104 deletions

File tree

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
from collections import defaultdict
5+
from pathlib import Path
6+
from typing import TYPE_CHECKING
7+
8+
from codeflash.cli_cmds.console import console, logger
9+
from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME_EFFECTIVE
10+
from codeflash.languages.python.context.unused_definition_remover import (
11+
detect_unused_helper_functions,
12+
revert_unused_helper_functions,
13+
)
14+
from codeflash.languages.python.optimizer import resolve_python_function_ast
15+
from codeflash.languages.python.static_analysis.code_extractor import get_opt_review_metrics, is_numerical_code
16+
from codeflash.languages.python.static_analysis.code_replacer import (
17+
add_custom_marker_to_all_tests,
18+
modify_autouse_fixture,
19+
replace_function_definitions_in_module,
20+
)
21+
from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports, contains_jit_decorator
22+
from codeflash.models.models import TestingMode, TestResults
23+
from codeflash.optimization.function_optimizer import FunctionOptimizer
24+
25+
if TYPE_CHECKING:
26+
from codeflash.languages.base import Language
27+
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown
28+
29+
30+
class PythonFunctionOptimizer(FunctionOptimizer):
31+
def _resolve_function_ast(
32+
self, source_code: str, function_name: str, parents: list
33+
) -> ast.FunctionDef | ast.AsyncFunctionDef | None:
34+
original_module_ast = ast.parse(source_code)
35+
return resolve_python_function_ast(function_name, parents, original_module_ast)
36+
37+
def analyze_code_characteristics(self, code_context: CodeOptimizationContext) -> None:
38+
self.is_numerical_code = is_numerical_code(code_string=code_context.read_writable_code.flat)
39+
40+
def get_optimization_review_metrics(
41+
self,
42+
source_code: str,
43+
file_path: Path,
44+
qualified_name: str,
45+
project_root: Path,
46+
tests_root: Path,
47+
language: Language,
48+
) -> str:
49+
return get_opt_review_metrics(source_code, file_path, qualified_name, project_root, tests_root, language)
50+
51+
def instrument_test_fixtures(self, test_paths: list[Path]) -> dict[Path, list[str]] | None:
52+
logger.info("Disabling all autouse fixtures associated with the generated test files")
53+
original_conftest_content = modify_autouse_fixture(test_paths)
54+
logger.info("Add custom marker to generated test files")
55+
add_custom_marker_to_all_tests(test_paths)
56+
return original_conftest_content
57+
58+
def replace_function_and_helpers_with_optimized_code(
59+
self,
60+
code_context: CodeOptimizationContext,
61+
optimized_code: CodeStringsMarkdown,
62+
original_helper_code: dict[Path, str],
63+
) -> bool:
64+
did_update = False
65+
read_writable_functions_by_file_path = defaultdict(set)
66+
read_writable_functions_by_file_path[self.function_to_optimize.file_path].add(
67+
self.function_to_optimize.qualified_name
68+
)
69+
for helper_function in code_context.helper_functions:
70+
if helper_function.definition_type != "class":
71+
read_writable_functions_by_file_path[helper_function.file_path].add(helper_function.qualified_name)
72+
for module_abspath, qualified_names in read_writable_functions_by_file_path.items():
73+
did_update |= replace_function_definitions_in_module(
74+
function_names=list(qualified_names),
75+
optimized_code=optimized_code,
76+
module_abspath=module_abspath,
77+
preexisting_objects=code_context.preexisting_objects,
78+
project_root_path=self.project_root,
79+
)
80+
unused_helpers = detect_unused_helper_functions(self.function_to_optimize, code_context, optimized_code)
81+
82+
if unused_helpers:
83+
revert_unused_helper_functions(self.project_root, unused_helpers, original_helper_code)
84+
85+
return did_update
86+
87+
def _line_profiler_step_python(
88+
self, code_context: CodeOptimizationContext, original_helper_code: dict[Path, str], candidate_index: int
89+
) -> dict:
90+
candidate_fto_code = Path(self.function_to_optimize.file_path).read_text("utf-8")
91+
if contains_jit_decorator(candidate_fto_code):
92+
logger.info(
93+
f"Skipping line profiler for {self.function_to_optimize.function_name} - code contains JIT decorator"
94+
)
95+
return {"timings": {}, "unit": 0, "str_out": ""}
96+
97+
for module_abspath in original_helper_code:
98+
candidate_helper_code = Path(module_abspath).read_text("utf-8")
99+
if contains_jit_decorator(candidate_helper_code):
100+
logger.info(
101+
f"Skipping line profiler for {self.function_to_optimize.function_name} - helper code contains JIT decorator"
102+
)
103+
return {"timings": {}, "unit": 0, "str_out": ""}
104+
105+
try:
106+
console.rule()
107+
108+
test_env = self.get_test_env(
109+
codeflash_loop_index=0, codeflash_test_iteration=candidate_index, codeflash_tracer_disable=1
110+
)
111+
line_profiler_output_file = add_decorator_imports(self.function_to_optimize, code_context)
112+
line_profile_results, _ = self.run_and_parse_tests(
113+
testing_type=TestingMode.LINE_PROFILE,
114+
test_env=test_env,
115+
test_files=self.test_files,
116+
optimization_iteration=0,
117+
testing_time=TOTAL_LOOPING_TIME_EFFECTIVE,
118+
enable_coverage=False,
119+
code_context=code_context,
120+
line_profiler_output_file=line_profiler_output_file,
121+
)
122+
finally:
123+
self.write_code_and_helpers(
124+
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
125+
)
126+
if isinstance(line_profile_results, TestResults) and not line_profile_results.test_results:
127+
logger.warning(
128+
f"Timeout occurred while running line profiler for original function {self.function_to_optimize.function_name}"
129+
)
130+
return {"timings": {}, "unit": 0, "str_out": ""}
131+
if line_profile_results["str_out"] == "":
132+
logger.warning(
133+
f"Couldn't run line profiler for original function {self.function_to_optimize.function_name}"
134+
)
135+
return line_profile_results

codeflash/optimization/function_optimizer.py

Lines changed: 35 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import ast
43
import concurrent.futures
54
import dataclasses
65
import logging
@@ -73,18 +72,6 @@
7372
from codeflash.languages.current import current_language_support
7473
from codeflash.languages.javascript.test_runner import clear_created_config_files, get_created_config_files
7574
from codeflash.languages.python.context import code_context_extractor
76-
from codeflash.languages.python.context.unused_definition_remover import (
77-
detect_unused_helper_functions,
78-
revert_unused_helper_functions,
79-
)
80-
from codeflash.languages.python.optimizer import resolve_python_function_ast
81-
from codeflash.languages.python.static_analysis.code_extractor import get_opt_review_metrics, is_numerical_code
82-
from codeflash.languages.python.static_analysis.code_replacer import (
83-
add_custom_marker_to_all_tests,
84-
modify_autouse_fixture,
85-
replace_function_definitions_in_module,
86-
)
87-
from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports, contains_jit_decorator
8875
from codeflash.lsp.helpers import is_LSP_enabled, is_subagent_mode, report_to_markdown_table, tree_to_markdown
8976
from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage, LSPMessageId
9077
from codeflash.models.ExperimentMetadata import ExperimentMetadata
@@ -135,6 +122,7 @@
135122
from codeflash.verification.verifier import generate_tests
136123

137124
if TYPE_CHECKING:
125+
import ast
138126
from argparse import Namespace
139127

140128
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
@@ -459,13 +447,9 @@ def __init__(
459447
)
460448
self.language_support = current_language_support()
461449
if not function_to_optimize_ast:
462-
if not is_python():
463-
self.function_to_optimize_ast = None
464-
else:
465-
original_module_ast = ast.parse(self.function_to_optimize_source_code)
466-
self.function_to_optimize_ast = resolve_python_function_ast(
467-
function_to_optimize.function_name, function_to_optimize.parents, original_module_ast
468-
)
450+
self.function_to_optimize_ast = self._resolve_function_ast(
451+
self.function_to_optimize_source_code, function_to_optimize.function_name, function_to_optimize.parents
452+
)
469453
else:
470454
self.function_to_optimize_ast = function_to_optimize_ast
471455
self.function_to_tests = function_to_tests if function_to_tests else {}
@@ -502,6 +486,32 @@ def __init__(
502486
self.is_numerical_code: bool | None = None
503487
self.code_already_exists: bool = False
504488

489+
# --- Hooks for language-specific subclasses ---
490+
491+
def _resolve_function_ast(
492+
self, source_code: str, function_name: str, parents: list
493+
) -> ast.FunctionDef | ast.AsyncFunctionDef | None:
494+
return None
495+
496+
def analyze_code_characteristics(self, code_context: CodeOptimizationContext) -> None:
497+
pass
498+
499+
def get_optimization_review_metrics(
500+
self,
501+
source_code: str,
502+
file_path: Path,
503+
qualified_name: str,
504+
project_root: Path,
505+
tests_root: Path,
506+
language: Language,
507+
) -> str:
508+
return ""
509+
510+
def instrument_test_fixtures(self, test_paths: list[Path]) -> dict[Path, list[str]] | None:
511+
return None
512+
513+
# --- End hooks ---
514+
505515
def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]:
506516
should_run_experiment = self.experiment_id is not None
507517
logger.info(f"!lsp|Function Trace ID: {self.function_trace_id}")
@@ -655,10 +665,7 @@ def generate_and_instrument_tests(
655665

656666
original_conftest_content = None
657667
if self.args.override_fixtures:
658-
logger.info("Disabling all autouse fixtures associated with the generated test files")
659-
original_conftest_content = modify_autouse_fixture(generated_test_paths + generated_perf_test_paths)
660-
logger.info("Add custom marker to generated test files")
661-
add_custom_marker_to_all_tests(generated_test_paths + generated_perf_test_paths)
668+
original_conftest_content = self.instrument_test_fixtures(generated_test_paths + generated_perf_test_paths)
662669

663670
return Success(
664671
(
@@ -678,7 +685,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
678685
if not is_successful(initialization_result):
679686
return Failure(initialization_result.failure())
680687
should_run_experiment, code_context, original_helper_code = initialization_result.unwrap()
681-
self.is_numerical_code = is_numerical_code(code_string=code_context.read_writable_code.flat)
688+
self.analyze_code_characteristics(code_context)
682689
code_print(
683690
code_context.read_writable_code.flat,
684691
file_name=self.function_to_optimize.file_path,
@@ -1498,30 +1505,7 @@ def replace_function_and_helpers_with_optimized_code(
14981505
optimized_code: CodeStringsMarkdown,
14991506
original_helper_code: dict[Path, str],
15001507
) -> bool:
1501-
did_update = False
1502-
read_writable_functions_by_file_path = defaultdict(set)
1503-
read_writable_functions_by_file_path[self.function_to_optimize.file_path].add(
1504-
self.function_to_optimize.qualified_name
1505-
)
1506-
for helper_function in code_context.helper_functions:
1507-
# Skip class definitions (definition_type may be None for non-Python languages)
1508-
if helper_function.definition_type != "class":
1509-
read_writable_functions_by_file_path[helper_function.file_path].add(helper_function.qualified_name)
1510-
for module_abspath, qualified_names in read_writable_functions_by_file_path.items():
1511-
did_update |= replace_function_definitions_in_module(
1512-
function_names=list(qualified_names),
1513-
optimized_code=optimized_code,
1514-
module_abspath=module_abspath,
1515-
preexisting_objects=code_context.preexisting_objects,
1516-
project_root_path=self.project_root,
1517-
)
1518-
unused_helpers = detect_unused_helper_functions(self.function_to_optimize, code_context, optimized_code)
1519-
1520-
# Revert unused helper functions to their original definitions
1521-
if unused_helpers:
1522-
revert_unused_helper_functions(self.project_root, unused_helpers, original_helper_code)
1523-
1524-
return did_update
1508+
return False
15251509

15261510
def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
15271511
try:
@@ -1841,7 +1825,7 @@ def generate_optimizations(
18411825
)
18421826

18431827
future_references = self.executor.submit(
1844-
get_opt_review_metrics,
1828+
self.get_optimization_review_metrics,
18451829
self.function_to_optimize_source_code,
18461830
self.function_to_optimize.file_path,
18471831
self.function_to_optimize.qualified_name,
@@ -2960,58 +2944,7 @@ def line_profiler_step(
29602944
def _line_profiler_step_python(
29612945
self, code_context: CodeOptimizationContext, original_helper_code: dict[Path, str], candidate_index: int
29622946
) -> dict:
2963-
"""Python-specific line profiler using decorator imports."""
2964-
# Check if candidate code contains JIT decorators - line profiler doesn't work with JIT compiled code
2965-
candidate_fto_code = Path(self.function_to_optimize.file_path).read_text("utf-8")
2966-
if contains_jit_decorator(candidate_fto_code):
2967-
logger.info(
2968-
f"Skipping line profiler for {self.function_to_optimize.function_name} - code contains JIT decorator"
2969-
)
2970-
return {"timings": {}, "unit": 0, "str_out": ""}
2971-
2972-
# Check helper code for JIT decorators
2973-
for module_abspath in original_helper_code:
2974-
candidate_helper_code = Path(module_abspath).read_text("utf-8")
2975-
if contains_jit_decorator(candidate_helper_code):
2976-
logger.info(
2977-
f"Skipping line profiler for {self.function_to_optimize.function_name} - helper code contains JIT decorator"
2978-
)
2979-
return {"timings": {}, "unit": 0, "str_out": ""}
2980-
2981-
try:
2982-
console.rule()
2983-
2984-
test_env = self.get_test_env(
2985-
codeflash_loop_index=0, codeflash_test_iteration=candidate_index, codeflash_tracer_disable=1
2986-
)
2987-
line_profiler_output_file = add_decorator_imports(self.function_to_optimize, code_context)
2988-
line_profile_results, _ = self.run_and_parse_tests(
2989-
testing_type=TestingMode.LINE_PROFILE,
2990-
test_env=test_env,
2991-
test_files=self.test_files,
2992-
optimization_iteration=0,
2993-
testing_time=TOTAL_LOOPING_TIME_EFFECTIVE,
2994-
enable_coverage=False,
2995-
code_context=code_context,
2996-
line_profiler_output_file=line_profiler_output_file,
2997-
)
2998-
finally:
2999-
# Remove codeflash capture
3000-
self.write_code_and_helpers(
3001-
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
3002-
)
3003-
# this will happen when a timeoutexpired exception happens
3004-
if isinstance(line_profile_results, TestResults) and not line_profile_results.test_results:
3005-
logger.warning(
3006-
f"Timeout occurred while running line profiler for original function {self.function_to_optimize.function_name}"
3007-
)
3008-
# set default value for line profiler results
3009-
return {"timings": {}, "unit": 0, "str_out": ""}
3010-
if line_profile_results["str_out"] == "":
3011-
logger.warning(
3012-
f"Couldn't run line profiler for original function {self.function_to_optimize.function_name}"
3013-
)
3014-
return line_profile_results
2947+
return {"timings": {}, "unit": 0, "str_out": ""}
30152948

30162949
def run_concurrency_benchmark(
30172950
self, code_context: CodeOptimizationContext, original_helper_code: dict[Path, str], test_env: dict[str, str]

codeflash/optimization/optimizer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
)
3131
from codeflash.code_utils.time_utils import humanize_runtime
3232
from codeflash.either import is_successful
33-
from codeflash.languages import current_language_support, is_javascript, set_current_language
33+
from codeflash.languages import current_language_support, is_javascript, is_python, set_current_language
3434
from codeflash.lsp.helpers import is_subagent_mode
3535
from codeflash.models.models import ValidCode
3636
from codeflash.telemetry.posthog_cf import ph
@@ -277,7 +277,14 @@ def create_function_optimizer(
277277
):
278278
function_specific_timings = function_benchmark_timings[qualified_name_w_module]
279279

280-
return FunctionOptimizer(
280+
if is_python():
281+
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
282+
283+
cls = PythonFunctionOptimizer
284+
else:
285+
cls = FunctionOptimizer
286+
287+
return cls(
281288
function_to_optimize=function_to_optimize,
282289
test_cfg=self.test_cfg,
283290
function_to_optimize_source_code=function_to_optimize_source_code,

0 commit comments

Comments
 (0)