From 3006d624ae4dd2379a1fb0e3ff1cc5c05498c6d3 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 25 Jun 2025 16:16:27 -0700 Subject: [PATCH 01/11] add lSP module --- codeflash/lsp/__init__.py | 0 codeflash/lsp/beta.py | 231 ++++++++++++++++++++++++++++++++++++++ codeflash/lsp/server.py | 56 +++++++++ 3 files changed, 287 insertions(+) create mode 100644 codeflash/lsp/__init__.py create mode 100644 codeflash/lsp/beta.py create mode 100644 codeflash/lsp/server.py diff --git a/codeflash/lsp/__init__.py b/codeflash/lsp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/codeflash/lsp/beta.py b/codeflash/lsp/beta.py new file mode 100644 index 000000000..797e33115 --- /dev/null +++ b/codeflash/lsp/beta.py @@ -0,0 +1,231 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING + +from pygls import uris + +from codeflash.either import is_successful +from codeflash.lsp.server import CodeflashLanguageServer, CodeflashLanguageServerProtocol + +if TYPE_CHECKING: + from lsprotocol import types + + from codeflash.models.models import GeneratedTestsList, OptimizationSet + + +@dataclass +class OptimizableFunctionsParams: + textDocument: types.TextDocumentIdentifier # noqa: N815 + + +@dataclass +class FunctionOptimizationParams: + textDocument: types.TextDocumentIdentifier # noqa: N815 + functionName: str # noqa: N815 + + +server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol) + + +@server.feature("getOptimizableFunctions") +def get_optimizable_functions( + server: CodeflashLanguageServer, params: OptimizableFunctionsParams +) -> dict[str, list[str]]: + file_path = Path(uris.to_fs_path(params.textDocument.uri)) + server.optimizer.args.file = file_path + server.optimizer.args.previous_checkpoint_functions = False + optimizable_funcs, _ = server.optimizer.get_optimizable_functions() + path_to_qualified_names = {} + for path, functions in optimizable_funcs.items(): + path_to_qualified_names[path.as_posix()] = [func.qualified_name for func in functions] + return path_to_qualified_names + + +@server.feature("initializeFunctionOptimization") +def initialize_function_optimization( + server: CodeflashLanguageServer, params: FunctionOptimizationParams +) -> dict[str, str]: + file_path = Path(uris.to_fs_path(params.textDocument.uri)) + server.optimizer.args.function = params.functionName + server.optimizer.args.file = file_path + optimizable_funcs, _ = server.optimizer.get_optimizable_functions() + if not optimizable_funcs: + return {"functionName": params.functionName, "status": "not found", "args": None} + fto = optimizable_funcs.popitem()[1][0] + server.optimizer.current_function_being_optimized = fto + return {"functionName": params.functionName, "status": "success", "info": fto.server_info} + + +@server.feature("discoverFunctionTests") +def discover_function_tests(server: CodeflashLanguageServer, params: FunctionOptimizationParams) -> dict[str, str]: + current_function = server.optimizer.current_function_being_optimized + + optimizable_funcs = {current_function.file_path: [current_function]} + + function_to_tests, num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs) + # mocking in order to get things going + return {"functionName": params.functionName, "status": "success", "generated_tests": str(num_discovered_tests)} + + +@server.feature("prepareOptimization") +def prepare_optimization(server: CodeflashLanguageServer, params: FunctionOptimizationParams) -> dict[str, str]: + current_function = server.optimizer.current_function_being_optimized + + module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path) + validated_original_code, original_module_ast = module_prep_result + + function_optimizer = server.optimizer.create_function_optimizer( + current_function, + function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code, + original_module_ast=original_module_ast, + original_module_path=current_function.file_path, + ) + + server.optimizer.current_function_optimizer = function_optimizer + if not function_optimizer: + return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"} + + initialization_result = function_optimizer.can_be_optimized() + if not is_successful(initialization_result): + return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()} + + return {"functionName": params.functionName, "status": "success", "message": "Optimization preparation completed"} + + +@server.feature("generateTests") +def generate_tests(server: CodeflashLanguageServer, params: FunctionOptimizationParams) -> dict[str, str]: + function_optimizer = server.optimizer.current_function_optimizer + if not function_optimizer: + return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"} + + initialization_result = function_optimizer.can_be_optimized() + if not is_successful(initialization_result): + return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()} + + should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() + + test_setup_result = function_optimizer.generate_and_instrument_tests( + code_context, should_run_experiment=should_run_experiment + ) + if not is_successful(test_setup_result): + return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()} + generated_tests_list: GeneratedTestsList + optimizations_set: OptimizationSet + generated_tests_list, _, concolic__test_str, optimizations_set = test_setup_result.unwrap() + + generated_tests: list[str] = [ + generated_test.generated_original_test_source for generated_test in generated_tests_list.generated_tests + ] + optimizations_dict = { + candidate.optimization_id: {"source_code": candidate.source_code, "explanation": candidate.explanation} + for candidate in optimizations_set.control + optimizations_set.experiment + } + + return { + "functionName": params.functionName, + "status": "success", + "message": {"generated_tests": generated_tests, "optimizations": optimizations_dict}, + } + + +@server.feature("performFunctionOptimization") +def perform_function_optimization( + server: CodeflashLanguageServer, params: FunctionOptimizationParams +) -> dict[str, str]: + current_function = server.optimizer.current_function_being_optimized + + module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path) + + validated_original_code, original_module_ast = module_prep_result + + function_optimizer = server.optimizer.create_function_optimizer( + current_function, + function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code, + original_module_ast=original_module_ast, + original_module_path=current_function.file_path, + ) + + server.optimizer.current_function_optimizer = function_optimizer + if not function_optimizer: + return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"} + + initialization_result = function_optimizer.can_be_optimized() + if not is_successful(initialization_result): + return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()} + + should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() + + test_setup_result = function_optimizer.generate_and_instrument_tests( + code_context, should_run_experiment=should_run_experiment + ) + if not is_successful(test_setup_result): + return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()} + ( + generated_tests, + function_to_concolic_tests, + concolic_test_str, + optimizations_set, + generated_test_paths, + generated_perf_test_paths, + instrumented_unittests_created_for_function, + original_conftest_content, + ) = test_setup_result.unwrap() + + baseline_setup_result = function_optimizer.setup_and_establish_baseline( + code_context=code_context, + original_helper_code=original_helper_code, + function_to_concolic_tests=function_to_concolic_tests, + generated_test_paths=generated_test_paths, + generated_perf_test_paths=generated_perf_test_paths, + instrumented_unittests_created_for_function=instrumented_unittests_created_for_function, + original_conftest_content=original_conftest_content, + ) + + if not is_successful(baseline_setup_result): + return {"functionName": params.functionName, "status": "error", "message": baseline_setup_result.failure()} + + ( + function_to_optimize_qualified_name, + function_to_all_tests, + original_code_baseline, + test_functions_to_remove, + file_path_to_helper_classes, + ) = baseline_setup_result.unwrap() + + best_optimization = function_optimizer.find_and_process_best_optimization( + optimizations_set=optimizations_set, + code_context=code_context, + original_code_baseline=original_code_baseline, + original_helper_code=original_helper_code, + file_path_to_helper_classes=file_path_to_helper_classes, + function_to_optimize_qualified_name=function_to_optimize_qualified_name, + function_to_all_tests=function_to_all_tests, + generated_tests=generated_tests, + test_functions_to_remove=test_functions_to_remove, + concolic_test_str=concolic_test_str, + ) + + if not best_optimization: + return { + "functionName": params.functionName, + "status": "error", + "message": f"No best optimizations found for function {function_to_optimize_qualified_name}", + } + + optimized_source = best_optimization.candidate.source_code # noqa: F841 + + return { + "functionName": params.functionName, + "status": "success", + "message": "Optimization completed successfully", + "extra": f"Speedup: {original_code_baseline.runtime / best_optimization.runtime:.2f}x faster", + } + + +if __name__ == "__main__": + from codeflash.cli_cmds.console import console + + console.quiet = True + server.start_io() diff --git a/codeflash/lsp/server.py b/codeflash/lsp/server.py new file mode 100644 index 000000000..222a1318c --- /dev/null +++ b/codeflash/lsp/server.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from lsprotocol.types import INITIALIZE +from pygls import uris +from pygls.protocol import LanguageServerProtocol, lsp_method +from pygls.server import LanguageServer + +if TYPE_CHECKING: + from lsprotocol.types import InitializeParams, InitializeResult + + +class CodeflashLanguageServerProtocol(LanguageServerProtocol): + _server: CodeflashLanguageServer + + @lsp_method(INITIALIZE) + def lsp_initialize(self, params: InitializeParams) -> InitializeResult: + server = self._server + initialize_result: InitializeResult = super().lsp_initialize(params) + + workspace_uri = params.root_uri + if workspace_uri: + workspace_path = uris.to_fs_path(workspace_uri) + pyproject_toml_path = self._find_pyproject_toml(workspace_path) + if pyproject_toml_path: + server.initialize_optimizer(pyproject_toml_path) + server.show_message(f"Found pyproject.toml at: {pyproject_toml_path}") + else: + server.show_message("No pyproject.toml found in workspace.") + else: + server.show_message("No workspace URI provided.") + + return initialize_result + + def _find_pyproject_toml(self, workspace_path: str) -> Path | None: + workspace_path_obj = Path(workspace_path) + for file_path in workspace_path_obj.rglob("pyproject.toml"): + return file_path.resolve() + return None + + +class CodeflashLanguageServer(LanguageServer): + def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401 + super().__init__(*args, **kwargs) + self.optimizer = None + + def initialize_optimizer(self, config_file: Path) -> None: + from codeflash.cli_cmds.cli import parse_args, process_pyproject_config + from codeflash.optimization.optimizer import Optimizer + + args = parse_args() + args.config_file = config_file + args = process_pyproject_config(args) + self.optimizer = Optimizer(args) From 1311198c369621f40a7252ba25f7d667a496ff5a Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 25 Jun 2025 16:23:26 -0700 Subject: [PATCH 02/11] extract benchmark runs --- codeflash/optimization/optimizer.py | 115 +++++++++++++++------------- pyproject.toml | 3 +- 2 files changed, 65 insertions(+), 53 deletions(-) diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 23b5594f8..055ee502a 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -47,6 +47,66 @@ def __init__(self, args: Namespace) -> None: self.functions_checkpoint: CodeflashRunCheckpoint | None = None self.current_function_optimizer: FunctionOptimizer | None = None + def run_benchmarks( + self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], num_optimizable_functions: int + ) -> tuple[dict[str, dict[BenchmarkKey, float]], dict[BenchmarkKey, float]]: + """Run benchmarks for the functions to optimize and collect timing information.""" + function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] = {} + total_benchmark_timings: dict[BenchmarkKey, float] = {} + + if not (hasattr(self.args, "benchmark") and self.args.benchmark and num_optimizable_functions > 0): + return function_benchmark_timings, total_benchmark_timings + + from codeflash.benchmarking.instrument_codeflash_trace import instrument_codeflash_trace_decorator + from codeflash.benchmarking.plugin.plugin import CodeFlashBenchmarkPlugin + from codeflash.benchmarking.replay_test import generate_replay_test + from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest + from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table + from codeflash.code_utils.env_utils import get_pr_number + + with progress_bar( + f"Running benchmarks in {self.args.benchmarks_root}", transient=True, revert_to_print=bool(get_pr_number()) + ): + # Insert decorator + file_path_to_source_code = defaultdict(str) + for file in file_to_funcs_to_optimize: + with file.open("r", encoding="utf8") as f: + file_path_to_source_code[file] = f.read() + try: + instrument_codeflash_trace_decorator(file_to_funcs_to_optimize) + trace_file = Path(self.args.benchmarks_root) / "benchmarks.trace" + if trace_file.exists(): + trace_file.unlink() + + self.replay_tests_dir = Path( + tempfile.mkdtemp(prefix="codeflash_replay_tests_", dir=self.args.benchmarks_root) + ) + trace_benchmarks_pytest( + self.args.benchmarks_root, self.args.tests_root, self.args.project_root, trace_file + ) # Run all tests that use pytest-benchmark + replay_count = generate_replay_test(trace_file, self.replay_tests_dir) + if replay_count == 0: + logger.info( + f"No valid benchmarks found in {self.args.benchmarks_root} for functions to optimize, continuing optimization" + ) + else: + function_benchmark_timings = CodeFlashBenchmarkPlugin.get_function_benchmark_timings(trace_file) + total_benchmark_timings = CodeFlashBenchmarkPlugin.get_benchmark_timings(trace_file) + function_to_results = validate_and_format_benchmark_table( + function_benchmark_timings, total_benchmark_timings + ) + print_benchmark_table(function_to_results) + except Exception as e: + logger.info(f"Error while tracing existing benchmarks: {e}") + logger.info("Information on existing benchmarks will not be available for this run.") + finally: + # Restore original source code + for file in file_path_to_source_code: + with file.open("w", encoding="utf8") as f: + f.write(file_path_to_source_code[file]) + + return function_benchmark_timings, total_benchmark_timings + def create_function_optimizer( self, function_to_optimize: FunctionToOptimize, @@ -108,58 +168,9 @@ def run(self) -> None: module_root=self.args.module_root, previous_checkpoint_functions=self.args.previous_checkpoint_functions, ) - function_benchmark_timings: dict[str, dict[BenchmarkKey, int]] = {} - total_benchmark_timings: dict[BenchmarkKey, int] = {} - if self.args.benchmark and num_optimizable_functions > 0: - from codeflash.benchmarking.instrument_codeflash_trace import instrument_codeflash_trace_decorator - from codeflash.benchmarking.plugin.plugin import CodeFlashBenchmarkPlugin - from codeflash.benchmarking.replay_test import generate_replay_test - from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest - from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table - - console.rule() - with progress_bar( - f"Running benchmarks in {self.args.benchmarks_root}", - transient=True, - revert_to_print=bool(get_pr_number()), - ): - # Insert decorator - file_path_to_source_code = defaultdict(str) - for file in file_to_funcs_to_optimize: - with file.open("r", encoding="utf8") as f: - file_path_to_source_code[file] = f.read() - try: - instrument_codeflash_trace_decorator(file_to_funcs_to_optimize) - trace_file = Path(self.args.benchmarks_root) / "benchmarks.trace" - if trace_file.exists(): - trace_file.unlink() - - self.replay_tests_dir = Path( - tempfile.mkdtemp(prefix="codeflash_replay_tests_", dir=self.args.tests_root) - ) - trace_benchmarks_pytest( - self.args.benchmarks_root, self.args.tests_root, self.args.project_root, trace_file - ) # Run all tests that use pytest-benchmark - replay_count = generate_replay_test(trace_file, self.replay_tests_dir) - if replay_count == 0: - logger.info( - f"No valid benchmarks found in {self.args.benchmarks_root} for functions to optimize, continuing optimization" - ) - else: - function_benchmark_timings = CodeFlashBenchmarkPlugin.get_function_benchmark_timings(trace_file) - total_benchmark_timings = CodeFlashBenchmarkPlugin.get_benchmark_timings(trace_file) - function_to_results = validate_and_format_benchmark_table( - function_benchmark_timings, total_benchmark_timings - ) - print_benchmark_table(function_to_results) - except Exception as e: - logger.info(f"Error while tracing existing benchmarks: {e}") - logger.info("Information on existing benchmarks will not be available for this run.") - finally: - # Restore original source code - for file in file_path_to_source_code: - with file.open("w", encoding="utf8") as f: - f.write(file_path_to_source_code[file]) + function_benchmark_timings, total_benchmark_timings = self.run_benchmarks( + file_to_funcs_to_optimize, num_optimizable_functions + ) optimizations_found: int = 0 function_iterator_count: int = 0 if self.args.test_framework == "pytest": diff --git a/pyproject.toml b/pyproject.toml index 3a990197c..e056740f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -237,7 +237,8 @@ ignore = [ "S301", "D104", "PERF203", - "LOG015" + "LOG015", + "PLC0415" ] [tool.ruff.lint.flake8-type-checking] From 30e51b8db9fda489d860cc06380fa0d375e1550a Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 25 Jun 2025 16:24:46 -0700 Subject: [PATCH 03/11] extract get_functions_to_optimize --- codeflash/optimization/optimizer.py | 32 +++++++++++++++-------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 055ee502a..289aeeaaf 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -107,6 +107,22 @@ def run_benchmarks( return function_benchmark_timings, total_benchmark_timings + def get_optimizable_functions(self) -> tuple[dict[Path, list[FunctionToOptimize]], int]: + """Discover functions to optimize.""" + from codeflash.discovery.functions_to_optimize import get_functions_to_optimize + + return get_functions_to_optimize( + optimize_all=self.args.all, + replay_test=self.args.replay_test, + file=self.args.file, + only_get_this_function=self.args.function, + test_cfg=self.test_cfg, + ignore_paths=self.args.ignore_paths, + project_root=self.args.project_root, + module_root=self.args.module_root, + previous_checkpoint_functions=self.args.previous_checkpoint_functions, + ) + def create_function_optimizer( self, function_to_optimize: FunctionToOptimize, @@ -139,7 +155,6 @@ def run(self) -> None: get_first_top_level_function_or_method_ast, ) from codeflash.discovery.discover_unit_tests import discover_unit_tests - from codeflash.discovery.functions_to_optimize import get_functions_to_optimize ph("cli-optimize-run-start") logger.info("Running optimizer.") @@ -154,20 +169,7 @@ def run(self) -> None: return function_optimizer = None - file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] - num_optimizable_functions: int - # discover functions - (file_to_funcs_to_optimize, num_optimizable_functions) = get_functions_to_optimize( - optimize_all=self.args.all, - replay_test=self.args.replay_test, - file=self.args.file, - only_get_this_function=self.args.function, - test_cfg=self.test_cfg, - ignore_paths=self.args.ignore_paths, - project_root=self.args.project_root, - module_root=self.args.module_root, - previous_checkpoint_functions=self.args.previous_checkpoint_functions, - ) + file_to_funcs_to_optimize, num_optimizable_functions = self.get_optimizable_functions() function_benchmark_timings, total_benchmark_timings = self.run_benchmarks( file_to_funcs_to_optimize, num_optimizable_functions ) From 2ff091b280118f414cde8d2e2fd7288760c3598f Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 25 Jun 2025 16:30:03 -0700 Subject: [PATCH 04/11] extract create_function_optimizer --- codeflash/optimization/optimizer.py | 80 ++++++++++++++--------------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 289aeeaaf..df288603d 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -131,9 +131,35 @@ def create_function_optimizer( function_to_optimize_source_code: str | None = "", function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] | None = None, total_benchmark_timings: dict[BenchmarkKey, float] | None = None, - ) -> FunctionOptimizer: + original_module_ast: ast.Module | None = None, + original_module_path: Path | None = None, + ) -> FunctionOptimizer | None: + from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast from codeflash.optimization.function_optimizer import FunctionOptimizer + if function_to_optimize_ast is None and original_module_ast is not None: + function_to_optimize_ast = get_first_top_level_function_or_method_ast( + function_to_optimize.function_name, function_to_optimize.parents, original_module_ast + ) + if function_to_optimize_ast is None: + logger.info( + f"Function {function_to_optimize.qualified_name} not found in {original_module_path}.\n" + f"Skipping optimization." + ) + return None + + qualified_name_w_module = function_to_optimize.qualified_name_with_modules_from_root(self.args.project_root) + + function_specific_timings = None + if ( + hasattr(self.args, "benchmark") + and self.args.benchmark + and function_benchmark_timings + and qualified_name_w_module in function_benchmark_timings + and total_benchmark_timings + ): + function_specific_timings = function_benchmark_timings[qualified_name_w_module] + return FunctionOptimizer( function_to_optimize=function_to_optimize, test_cfg=self.test_cfg, @@ -142,18 +168,15 @@ def create_function_optimizer( function_to_optimize_ast=function_to_optimize_ast, aiservice_client=self.aiservice_client, args=self.args, - function_benchmark_timings=function_benchmark_timings if function_benchmark_timings else None, - total_benchmark_timings=total_benchmark_timings if total_benchmark_timings else None, + function_benchmark_timings=function_specific_timings, + total_benchmark_timings=total_benchmark_timings if function_specific_timings else None, replay_tests_dir=self.replay_tests_dir, ) def run(self) -> None: from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint from codeflash.code_utils.code_replacer import normalize_code, normalize_node - from codeflash.code_utils.static_analysis import ( - analyze_imported_modules, - get_first_top_level_function_or_method_ast, - ) + from codeflash.code_utils.static_analysis import analyze_imported_modules from codeflash.discovery.discover_unit_tests import discover_unit_tests ph("cli-optimize-run-start") @@ -245,40 +268,17 @@ def run(self) -> None: f"{function_to_optimize.qualified_name}" ) console.rule() - if not ( - function_to_optimize_ast := get_first_top_level_function_or_method_ast( - function_to_optimize.function_name, function_to_optimize.parents, original_module_ast - ) - ): - logger.info( - f"Function {function_to_optimize.qualified_name} not found in {original_module_path}.\n" - f"Skipping optimization." - ) - continue - qualified_name_w_module = function_to_optimize.qualified_name_with_modules_from_root( - self.args.project_root + + function_optimizer = self.create_function_optimizer( + function_to_optimize, + function_to_tests=function_to_tests, + function_to_optimize_source_code=validated_original_code[original_module_path].source_code, + function_benchmark_timings=function_benchmark_timings, + total_benchmark_timings=total_benchmark_timings, + original_module_ast=original_module_ast, + original_module_path=original_module_path, ) - if ( - self.args.benchmark - and function_benchmark_timings - and qualified_name_w_module in function_benchmark_timings - and total_benchmark_timings - ): - function_optimizer = self.create_function_optimizer( - function_to_optimize, - function_to_optimize_ast, - function_to_tests, - validated_original_code[original_module_path].source_code, - function_benchmark_timings[qualified_name_w_module], - total_benchmark_timings, - ) - else: - function_optimizer = self.create_function_optimizer( - function_to_optimize, - function_to_optimize_ast, - function_to_tests, - validated_original_code[original_module_path].source_code, - ) + self.current_function_optimizer = ( function_optimizer # needed to clean up from the outside of this function ) From e48e13fd509f42f91a56dadcc4cd815f316e8c96 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 25 Jun 2025 16:35:54 -0700 Subject: [PATCH 05/11] extract module prep --- codeflash/optimization/optimizer.py | 86 ++++++++++++++++------------- 1 file changed, 49 insertions(+), 37 deletions(-) diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index df288603d..66e600ebc 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -173,10 +173,54 @@ def create_function_optimizer( replay_tests_dir=self.replay_tests_dir, ) - def run(self) -> None: - from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint + def prepare_module_for_optimization( + self, original_module_path: Path + ) -> tuple[dict[Path, ValidCode], ast.Module] | None: from codeflash.code_utils.code_replacer import normalize_code, normalize_node from codeflash.code_utils.static_analysis import analyze_imported_modules + + logger.info(f"Examining file {original_module_path!s}…") + console.rule() + + original_module_code: str = original_module_path.read_text(encoding="utf8") + try: + original_module_ast = ast.parse(original_module_code) + except SyntaxError as e: + logger.warning(f"Syntax error parsing code in {original_module_path}: {e}") + logger.info("Skipping optimization due to file error.") + return None + normalized_original_module_code = ast.unparse(normalize_node(original_module_ast)) + validated_original_code: dict[Path, ValidCode] = { + original_module_path: ValidCode( + source_code=original_module_code, normalized_code=normalized_original_module_code + ) + } + + imported_module_analyses = analyze_imported_modules( + original_module_code, original_module_path, self.args.project_root + ) + + has_syntax_error = False + for analysis in imported_module_analyses: + callee_original_code = analysis.file_path.read_text(encoding="utf8") + try: + normalized_callee_original_code = normalize_code(callee_original_code) + except SyntaxError as e: + logger.warning(f"Syntax error parsing code in callee module {analysis.file_path}: {e}") + logger.info("Skipping optimization due to helper file error.") + has_syntax_error = True + break + validated_original_code[analysis.file_path] = ValidCode( + source_code=callee_original_code, normalized_code=normalized_callee_original_code + ) + + if has_syntax_error: + return None + + return validated_original_code, original_module_ast + + def run(self) -> None: + from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint from codeflash.discovery.discover_unit_tests import discover_unit_tests ph("cli-optimize-run-start") @@ -223,43 +267,11 @@ def run(self) -> None: self.functions_checkpoint = CodeflashRunCheckpoint(self.args.module_root) for original_module_path in file_to_funcs_to_optimize: - logger.info(f"Examining file {original_module_path!s}…") - console.rule() - - original_module_code: str = original_module_path.read_text(encoding="utf8") - try: - original_module_ast = ast.parse(original_module_code) - except SyntaxError as e: - logger.warning(f"Syntax error parsing code in {original_module_path}: {e}") - logger.info("Skipping optimization due to file error.") + module_prep_result = self.prepare_module_for_optimization(original_module_path) + if module_prep_result is None: continue - normalized_original_module_code = ast.unparse(normalize_node(original_module_ast)) - validated_original_code: dict[Path, ValidCode] = { - original_module_path: ValidCode( - source_code=original_module_code, normalized_code=normalized_original_module_code - ) - } - imported_module_analyses = analyze_imported_modules( - original_module_code, original_module_path, self.args.project_root - ) - - has_syntax_error = False - for analysis in imported_module_analyses: - callee_original_code = analysis.file_path.read_text(encoding="utf8") - try: - normalized_callee_original_code = normalize_code(callee_original_code) - except SyntaxError as e: - logger.warning(f"Syntax error parsing code in callee module {analysis.file_path}: {e}") - logger.info("Skipping optimization due to helper file error.") - has_syntax_error = True - break - validated_original_code[analysis.file_path] = ValidCode( - source_code=callee_original_code, normalized_code=normalized_callee_original_code - ) - - if has_syntax_error: - continue + validated_original_code, original_module_ast = module_prep_result for function_to_optimize in file_to_funcs_to_optimize[original_module_path]: function_iterator_count += 1 From e91838c9822d7400640d6bdc602ef1b601ab134a Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 25 Jun 2025 16:39:00 -0700 Subject: [PATCH 06/11] extract test discovery --- codeflash/optimization/optimizer.py | 31 ++++++++++++++++++----------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 66e600ebc..5cacb480f 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -219,9 +219,26 @@ def prepare_module_for_optimization( return validated_original_code, original_module_ast + def discover_tests( + self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] + ) -> tuple[dict[str, set[FunctionCalledInTest]], int]: + from codeflash.discovery.discover_unit_tests import discover_unit_tests + + console.rule() + start_time = time.time() + function_to_tests, num_discovered_tests = discover_unit_tests( + self.test_cfg, file_to_funcs_to_optimize=file_to_funcs_to_optimize + ) + console.rule() + logger.info( + f"Discovered {num_discovered_tests} existing unit tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}" + ) + console.rule() + ph("cli-optimize-discovered-tests", {"num_tests": num_discovered_tests}) + return function_to_tests, num_discovered_tests + def run(self) -> None: from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint - from codeflash.discovery.discover_unit_tests import discover_unit_tests ph("cli-optimize-run-start") logger.info("Running optimizer.") @@ -252,17 +269,7 @@ def run(self) -> None: logger.info("No functions found to optimize. Exiting…") return - console.rule() - start_time = time.time() - function_to_tests, num_discovered_tests = discover_unit_tests( - self.test_cfg, file_to_funcs_to_optimize=file_to_funcs_to_optimize - ) - console.rule() - logger.info( - f"Discovered {num_discovered_tests} existing unit tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}" - ) - console.rule() - ph("cli-optimize-discovered-tests", {"num_tests": num_discovered_tests}) + function_to_tests, _ = self.discover_tests(file_to_funcs_to_optimize) if self.args.all: self.functions_checkpoint = CodeflashRunCheckpoint(self.args.module_root) From f52ada5eab4ef8f33a542ec3c31e978af306ed24 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 25 Jun 2025 16:41:18 -0700 Subject: [PATCH 07/11] add pygls as a direct dependency, not transitive --- codeflash/optimization/optimizer.py | 1 + pyproject.toml | 1 + 2 files changed, 2 insertions(+) diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 5cacb480f..32a825b76 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -45,6 +45,7 @@ def __init__(self, args: Namespace) -> None: self.local_aiservice_client = LocalAiServiceClient() if self.experiment_id else None self.replay_tests_dir = None self.functions_checkpoint: CodeflashRunCheckpoint | None = None + self.current_function_being_optimized: FunctionToOptimize | None = None # current only for the LSP self.current_function_optimizer: FunctionOptimizer | None = None def run_benchmarks( diff --git a/pyproject.toml b/pyproject.toml index e056740f8..8ae24fbe4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "coverage>=7.6.4", "line_profiler>=4.2.0", "platformdirs>=4.3.7", + "pygls>=1.3.1", ] [project.urls] From ebe4f2555fe009e21c70adf5f56a34cb6f8c0e4f Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 25 Jun 2025 16:53:39 -0700 Subject: [PATCH 08/11] implement can be optimized --- codeflash/optimization/function_optimizer.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 96f0caf64..e34e6aebd 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -144,7 +144,7 @@ def __init__( self.total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else {} self.replay_tests_dir = replay_tests_dir if replay_tests_dir else None - def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 + def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]: should_run_experiment = self.experiment_id is not None logger.debug(f"Function Trace ID: {self.function_trace_id}") ph("cli-optimize-function-start", {"function_trace_id": self.function_trace_id}) @@ -171,6 +171,15 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 ): return Failure("Function optimization previously attempted, skipping.") + return Success((should_run_experiment, code_context, original_helper_code)) + + def optimize_function(self) -> Result[BestOptimization, str]: + initialization_result = self.can_be_optimized() + if not is_successful(initialization_result): + return Failure(initialization_result.failure()) + + should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() + code_print(code_context.read_writable_code) generated_test_paths = [ get_test_file_path( From 0bc582fc227148bc634325a8cd0879d9c5d75b51 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 25 Jun 2025 16:58:53 -0700 Subject: [PATCH 09/11] modularize optimize_function --- codeflash/optimization/function_optimizer.py | 455 ++++++++++++------- 1 file changed, 289 insertions(+), 166 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index e34e6aebd..c30710079 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -143,6 +143,9 @@ def __init__( self.function_benchmark_timings = function_benchmark_timings if function_benchmark_timings else {} self.total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else {} self.replay_tests_dir = replay_tests_dir if replay_tests_dir else None + self.generate_and_instrument_tests_results: ( + tuple[GeneratedTestsList, dict[str, set[FunctionCalledInTest]], OptimizationSet] | None + ) = None def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]: should_run_experiment = self.experiment_id is not None @@ -173,14 +176,21 @@ def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[P return Success((should_run_experiment, code_context, original_helper_code)) - def optimize_function(self) -> Result[BestOptimization, str]: - initialization_result = self.can_be_optimized() - if not is_successful(initialization_result): - return Failure(initialization_result.failure()) - - should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() - - code_print(code_context.read_writable_code) + def generate_and_instrument_tests( + self, code_context: CodeOptimizationContext, *, should_run_experiment: bool + ) -> Result[ + tuple[ + GeneratedTestsList, + dict[str, set[FunctionCalledInTest]], + str, + OptimizationSet, + list[Path], + list[Path], + set[Path], + dict | None, + ] + ]: + """Generate and instrument tests, returning all necessary data for optimization.""" generated_test_paths = [ get_test_file_path( self.test_cfg.tests_root, self.function_to_optimize.function_name, test_index, test_type="unit" @@ -211,6 +221,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: if not is_successful(generated_results): return Failure(generated_results.failure()) + generated_tests: GeneratedTestsList optimizations_set: OptimizationSet generated_tests, function_to_concolic_tests, concolic_test_str, optimizations_set = generated_results.unwrap() @@ -239,177 +250,91 @@ def optimize_function(self) -> Result[BestOptimization, str]: logger.info(f"Generated test {count_tests}/{count_tests}:") code_print(concolic_test_str) - function_to_optimize_qualified_name = self.function_to_optimize.qualified_name function_to_all_tests = { key: self.function_to_tests.get(key, set()) | function_to_concolic_tests.get(key, set()) for key in set(self.function_to_tests) | set(function_to_concolic_tests) } instrumented_unittests_created_for_function = self.instrument_existing_tests(function_to_all_tests) + + original_conftest_content = None if self.args.override_fixtures: logger.info("Disabling all autouse fixtures associated with the generated test files") original_conftest_content = modify_autouse_fixture(generated_test_paths + generated_perf_test_paths) logger.info("Add custom marker to generated test files") add_custom_marker_to_all_tests(generated_test_paths + generated_perf_test_paths) - # Get a dict of file_path_to_classes of fto and helpers_of_fto - file_path_to_helper_classes = defaultdict(set) - for function_source in code_context.helper_functions: - if ( - function_source.qualified_name != self.function_to_optimize.qualified_name - and "." in function_source.qualified_name - ): - file_path_to_helper_classes[function_source.file_path].add(function_source.qualified_name.split(".")[0]) - - baseline_result = self.establish_original_code_baseline( # this needs better typing - code_context=code_context, - original_helper_code=original_helper_code, - file_path_to_helper_classes=file_path_to_helper_classes, - ) - - console.rule() - paths_to_cleanup = ( - generated_test_paths + generated_perf_test_paths + list(instrumented_unittests_created_for_function) + return Success( + ( + generated_tests, + function_to_concolic_tests, + concolic_test_str, + optimizations_set, + generated_test_paths, + generated_perf_test_paths, + instrumented_unittests_created_for_function, + original_conftest_content, + ) ) - if not is_successful(baseline_result): - if self.args.override_fixtures: - restore_conftest(original_conftest_content) - cleanup_paths(paths_to_cleanup) - return Failure(baseline_result.failure()) - - original_code_baseline, test_functions_to_remove = baseline_result.unwrap() - if isinstance(original_code_baseline, OriginalCodeBaseline) and not coverage_critic( - original_code_baseline.coverage_results, self.args.test_framework - ): - if self.args.override_fixtures: - restore_conftest(original_conftest_content) - cleanup_paths(paths_to_cleanup) - return Failure("The threshold for test coverage was not met.") - # request for new optimizations but don't block execution, check for completion later - # adding to control and experiment set but with same traceid - best_optimization = None - for _u, (candidates, exp_type) in enumerate( - zip([optimizations_set.control, optimizations_set.experiment], ["EXP0", "EXP1"]) - ): - if candidates is None: - continue + def optimize_function(self) -> Result[BestOptimization, str]: + initialization_result = self.can_be_optimized() + if not is_successful(initialization_result): + return Failure(initialization_result.failure()) - best_optimization = self.determine_best_candidate( - candidates=candidates, - code_context=code_context, - original_code_baseline=original_code_baseline, - original_helper_code=original_helper_code, - file_path_to_helper_classes=file_path_to_helper_classes, - exp_type=exp_type, - ) - ph( - "cli-optimize-function-finished", - { - "function_trace_id": self.function_trace_id[:-4] + exp_type - if self.experiment_id - else self.function_trace_id - }, - ) + should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() - if best_optimization: - logger.info("Best candidate:") - code_print(best_optimization.candidate.source_code) - console.print( - Panel( - best_optimization.candidate.explanation, title="Best Candidate Explanation", border_style="blue" - ) - ) - processed_benchmark_info = None - if self.args.benchmark: - processed_benchmark_info = process_benchmark_data( - replay_performance_gain=best_optimization.replay_performance_gain, - fto_benchmark_timings=self.function_benchmark_timings, - total_benchmark_timings=self.total_benchmark_timings, - ) - explanation = Explanation( - raw_explanation_message=best_optimization.candidate.explanation, - winning_behavioral_test_results=best_optimization.winning_behavioral_test_results, - winning_benchmarking_test_results=best_optimization.winning_benchmarking_test_results, - original_runtime_ns=original_code_baseline.runtime, - best_runtime_ns=best_optimization.runtime, - function_name=function_to_optimize_qualified_name, - file_path=self.function_to_optimize.file_path, - benchmark_details=processed_benchmark_info.benchmark_details if processed_benchmark_info else None, - ) + code_print(code_context.read_writable_code) - self.replace_function_and_helpers_with_optimized_code( - code_context=code_context, - optimized_code=best_optimization.candidate.source_code, - original_helper_code=original_helper_code, - ) + test_setup_result = self.generate_and_instrument_tests( # also generates optimizations + code_context, should_run_experiment=should_run_experiment + ) + if not is_successful(test_setup_result): + return Failure(test_setup_result.failure()) + + ( + generated_tests, + function_to_concolic_tests, + concolic_test_str, + optimizations_set, + generated_test_paths, + generated_perf_test_paths, + instrumented_unittests_created_for_function, + original_conftest_content, + ) = test_setup_result.unwrap() + + baseline_setup_result = self.setup_and_establish_baseline( + code_context=code_context, + original_helper_code=original_helper_code, + function_to_concolic_tests=function_to_concolic_tests, + generated_test_paths=generated_test_paths, + generated_perf_test_paths=generated_perf_test_paths, + instrumented_unittests_created_for_function=instrumented_unittests_created_for_function, + original_conftest_content=original_conftest_content, + ) - new_code, new_helper_code = self.reformat_code_and_helpers( - code_context.helper_functions, - explanation.file_path, - self.function_to_optimize_source_code, - optimized_function=best_optimization.candidate.source_code, - ) + if not is_successful(baseline_setup_result): + return Failure(baseline_setup_result.failure()) - original_code_combined = original_helper_code.copy() - original_code_combined[explanation.file_path] = self.function_to_optimize_source_code - new_code_combined = new_helper_code.copy() - new_code_combined[explanation.file_path] = new_code - if not self.args.no_pr: - coverage_message = ( - original_code_baseline.coverage_results.build_message() - if original_code_baseline.coverage_results - else "Coverage data not available" - ) - generated_tests = remove_functions_from_generated_tests( - generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove - ) - original_runtime_by_test = ( - original_code_baseline.benchmarking_test_results.usable_runtime_data_by_test_case() - ) - optimized_runtime_by_test = ( - best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case() - ) - # Add runtime comments to generated tests before creating the PR - generated_tests = add_runtime_comments_to_generated_tests( - self.test_cfg, generated_tests, original_runtime_by_test, optimized_runtime_by_test - ) - generated_tests_str = "\n\n".join( - [test.generated_original_test_source for test in generated_tests.generated_tests] - ) - existing_tests = existing_tests_source_for( - self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root), - function_to_all_tests, - test_cfg=self.test_cfg, - original_runtimes_all=original_runtime_by_test, - optimized_runtimes_all=optimized_runtime_by_test, - ) - if concolic_test_str: - generated_tests_str += "\n\n" + concolic_test_str + ( + function_to_optimize_qualified_name, + function_to_all_tests, + original_code_baseline, + test_functions_to_remove, + file_path_to_helper_classes, + ) = baseline_setup_result.unwrap() - check_create_pr( - original_code=original_code_combined, - new_code=new_code_combined, - explanation=explanation, - existing_tests_source=existing_tests, - generated_original_test_source=generated_tests_str, - function_trace_id=self.function_trace_id[:-4] + exp_type - if self.experiment_id - else self.function_trace_id, - coverage_message=coverage_message, - git_remote=self.args.git_remote, - ) - if self.args.all or env_utils.get_pr_number() or (self.args.file and not self.args.function): - self.write_code_and_helpers( - self.function_to_optimize_source_code, - original_helper_code, - self.function_to_optimize.file_path, - ) - else: - # Mark optimization success since no PR will be created - mark_optimization_success( - trace_id=self.function_trace_id, is_optimization_found=best_optimization is not None - ) - self.log_successful_optimization(explanation, generated_tests, exp_type) + best_optimization = self.find_and_process_best_optimization( + optimizations_set=optimizations_set, + code_context=code_context, + original_code_baseline=original_code_baseline, + original_helper_code=original_helper_code, + file_path_to_helper_classes=file_path_to_helper_classes, + function_to_optimize_qualified_name=function_to_optimize_qualified_name, + function_to_all_tests=function_to_all_tests, + generated_tests=generated_tests, + test_functions_to_remove=test_functions_to_remove, + concolic_test_str=concolic_test_str, + ) # Add function to code context hash if in gh actions @@ -843,7 +768,7 @@ def generate_tests_and_optimizations( console.rule() with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit the test generation task as future - future_tests = self.generate_and_instrument_tests( + future_tests = self.submit_test_generation_tasks( executor, testgen_context_code, [definition.fully_qualified_name for definition in helper_functions], @@ -918,16 +843,214 @@ def generate_tests_and_optimizations( logger.info(f"Generated {len(tests)} tests for {self.function_to_optimize.function_name}") console.rule() generated_tests = GeneratedTestsList(generated_tests=tests) + result = ( + generated_tests, + function_to_concolic_tests, + concolic_test_str, + OptimizationSet(control=candidates, experiment=candidates_experiment), + ) + self.generate_and_instrument_tests_results = result + return Success(result) + + def setup_and_establish_baseline( + self, + code_context: CodeOptimizationContext, + original_helper_code: dict[Path, str], + function_to_concolic_tests: dict[str, set[FunctionCalledInTest]], + generated_test_paths: list[Path], + generated_perf_test_paths: list[Path], + instrumented_unittests_created_for_function: set[Path], + original_conftest_content: str | None, + ) -> Result[ + tuple[str, dict[str, set[FunctionCalledInTest]], OriginalCodeBaseline, list[str], dict[Path, set[str]]], str + ]: + """Set up baseline context and establish original code baseline.""" + function_to_optimize_qualified_name = self.function_to_optimize.qualified_name + function_to_all_tests = { + key: self.function_to_tests.get(key, set()) | function_to_concolic_tests.get(key, set()) + for key in set(self.function_to_tests) | set(function_to_concolic_tests) + } + + # Get a dict of file_path_to_classes of fto and helpers_of_fto + file_path_to_helper_classes = defaultdict(set) + for function_source in code_context.helper_functions: + if ( + function_source.qualified_name != self.function_to_optimize.qualified_name + and "." in function_source.qualified_name + ): + file_path_to_helper_classes[function_source.file_path].add(function_source.qualified_name.split(".")[0]) + + baseline_result = self.establish_original_code_baseline( + code_context=code_context, + original_helper_code=original_helper_code, + file_path_to_helper_classes=file_path_to_helper_classes, + ) + + console.rule() + paths_to_cleanup = ( + generated_test_paths + generated_perf_test_paths + list(instrumented_unittests_created_for_function) + ) + + if not is_successful(baseline_result): + if self.args.override_fixtures: + restore_conftest(original_conftest_content) + cleanup_paths(paths_to_cleanup) + return Failure(baseline_result.failure()) + + original_code_baseline, test_functions_to_remove = baseline_result.unwrap() + if isinstance(original_code_baseline, OriginalCodeBaseline) and not coverage_critic( + original_code_baseline.coverage_results, self.args.test_framework + ): + if self.args.override_fixtures: + restore_conftest(original_conftest_content) + cleanup_paths(paths_to_cleanup) + return Failure("The threshold for test coverage was not met.") return Success( ( - generated_tests, - function_to_concolic_tests, - concolic_test_str, - OptimizationSet(control=candidates, experiment=candidates_experiment), + function_to_optimize_qualified_name, + function_to_all_tests, + original_code_baseline, + test_functions_to_remove, + file_path_to_helper_classes, ) ) + def find_and_process_best_optimization( + self, + optimizations_set: OptimizationSet, + code_context: CodeOptimizationContext, + original_code_baseline: OriginalCodeBaseline, + original_helper_code: dict[Path, str], + file_path_to_helper_classes: dict[Path, set[str]], + function_to_optimize_qualified_name: str, + function_to_all_tests: dict[str, set[FunctionCalledInTest]], + generated_tests: GeneratedTestsList, + test_functions_to_remove: list[str], + concolic_test_str: str | None, + ) -> BestOptimization | None: + """Find the best optimization candidate and process it with all required steps.""" + best_optimization = None + for _u, (candidates, exp_type) in enumerate( + zip([optimizations_set.control, optimizations_set.experiment], ["EXP0", "EXP1"]) + ): + if candidates is None: + continue + + best_optimization = self.determine_best_candidate( + candidates=candidates, + code_context=code_context, + original_code_baseline=original_code_baseline, + original_helper_code=original_helper_code, + file_path_to_helper_classes=file_path_to_helper_classes, + exp_type=exp_type, + ) + ph( + "cli-optimize-function-finished", + { + "function_trace_id": self.function_trace_id[:-4] + exp_type + if self.experiment_id + else self.function_trace_id + }, + ) + + if best_optimization: + logger.info("Best candidate:") + code_print(best_optimization.candidate.source_code) + console.print( + Panel( + best_optimization.candidate.explanation, title="Best Candidate Explanation", border_style="blue" + ) + ) + processed_benchmark_info = None + if self.args.benchmark: + processed_benchmark_info = process_benchmark_data( + replay_performance_gain=best_optimization.replay_performance_gain, + fto_benchmark_timings=self.function_benchmark_timings, + total_benchmark_timings=self.total_benchmark_timings, + ) + explanation = Explanation( + raw_explanation_message=best_optimization.candidate.explanation, + winning_behavioral_test_results=best_optimization.winning_behavioral_test_results, + winning_benchmarking_test_results=best_optimization.winning_benchmarking_test_results, + original_runtime_ns=original_code_baseline.runtime, + best_runtime_ns=best_optimization.runtime, + function_name=function_to_optimize_qualified_name, + file_path=self.function_to_optimize.file_path, + benchmark_details=processed_benchmark_info.benchmark_details if processed_benchmark_info else None, + ) + + self.replace_function_and_helpers_with_optimized_code( + code_context=code_context, + optimized_code=best_optimization.candidate.source_code, + original_helper_code=original_helper_code, + ) + + new_code, new_helper_code = self.reformat_code_and_helpers( + code_context.helper_functions, + explanation.file_path, + self.function_to_optimize_source_code, + optimized_function=best_optimization.candidate.source_code, + ) + + existing_tests = existing_tests_source_for( + self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root), + function_to_all_tests, + tests_root=self.test_cfg.tests_root, + ) + + original_code_combined = original_helper_code.copy() + original_code_combined[explanation.file_path] = self.function_to_optimize_source_code + new_code_combined = new_helper_code.copy() + new_code_combined[explanation.file_path] = new_code + if not self.args.no_pr: + coverage_message = ( + original_code_baseline.coverage_results.build_message() + if original_code_baseline.coverage_results + else "Coverage data not available" + ) + generated_tests = remove_functions_from_generated_tests( + generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove + ) + # Add runtime comments to generated tests before creating the PR + generated_tests = add_runtime_comments_to_generated_tests( + generated_tests, + original_code_baseline.benchmarking_test_results, + best_optimization.winning_benchmarking_test_results, + ) + generated_tests_str = "\n\n".join( + [test.generated_original_test_source for test in generated_tests.generated_tests] + ) + if concolic_test_str: + generated_tests_str += "\n\n" + concolic_test_str + + check_create_pr( + original_code=original_code_combined, + new_code=new_code_combined, + explanation=explanation, + existing_tests_source=existing_tests, + generated_original_test_source=generated_tests_str, + function_trace_id=self.function_trace_id[:-4] + exp_type + if self.experiment_id + else self.function_trace_id, + coverage_message=coverage_message, + git_remote=self.args.git_remote, + ) + if self.args.all or env_utils.get_pr_number() or (self.args.file and not self.args.function): + self.write_code_and_helpers( + self.function_to_optimize_source_code, + original_helper_code, + self.function_to_optimize.file_path, + ) + else: + # Mark optimization success since no PR will be created + mark_optimization_success( + trace_id=self.function_trace_id, is_optimization_found=best_optimization is not None + ) + self.log_successful_optimization(explanation, generated_tests, exp_type) + + return best_optimization + def establish_original_code_baseline( self, code_context: CodeOptimizationContext, @@ -1301,7 +1424,7 @@ def run_and_parse_tests( results, coverage_results = parse_line_profile_results(line_profiler_output_file=line_profiler_output_file) return results, coverage_results - def generate_and_instrument_tests( + def submit_test_generation_tasks( self, executor: concurrent.futures.ThreadPoolExecutor, source_code_being_tested: str, From 95fb1f08476974fafda07e5e8c0374798a44a0c6 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 25 Jun 2025 17:08:22 -0700 Subject: [PATCH 10/11] fix regression --- codeflash/optimization/function_optimizer.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index c30710079..54e644d66 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -1012,15 +1012,26 @@ def find_and_process_best_optimization( generated_tests = remove_functions_from_generated_tests( generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove ) + original_runtime_by_test = ( + original_code_baseline.benchmarking_test_results.usable_runtime_data_by_test_case() + ) + optimized_runtime_by_test = ( + best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case() + ) # Add runtime comments to generated tests before creating the PR generated_tests = add_runtime_comments_to_generated_tests( - generated_tests, - original_code_baseline.benchmarking_test_results, - best_optimization.winning_benchmarking_test_results, + self.test_cfg, generated_tests, original_runtime_by_test, optimized_runtime_by_test ) generated_tests_str = "\n\n".join( [test.generated_original_test_source for test in generated_tests.generated_tests] ) + existing_tests = existing_tests_source_for( + self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root), + function_to_all_tests, + test_cfg=self.test_cfg, + original_runtimes_all=original_runtime_by_test, + optimized_runtimes_all=optimized_runtime_by_test, + ) if concolic_test_str: generated_tests_str += "\n\n" + concolic_test_str From fa0734203c98efddeef836305b78a489d61d9a63 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 25 Jun 2025 17:15:20 -0700 Subject: [PATCH 11/11] regression fix --- codeflash/optimization/function_optimizer.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 54e644d66..acae4b9fb 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -993,12 +993,6 @@ def find_and_process_best_optimization( optimized_function=best_optimization.candidate.source_code, ) - existing_tests = existing_tests_source_for( - self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root), - function_to_all_tests, - tests_root=self.test_cfg.tests_root, - ) - original_code_combined = original_helper_code.copy() original_code_combined[explanation.file_path] = self.function_to_optimize_source_code new_code_combined = new_helper_code.copy() @@ -1059,7 +1053,6 @@ def find_and_process_best_optimization( trace_id=self.function_trace_id, is_optimization_found=best_optimization is not None ) self.log_successful_optimization(explanation, generated_tests, exp_type) - return best_optimization def establish_original_code_baseline(