Skip to content

Commit a3d81ba

Browse files
committed
refactor: remove redundant AST pre-resolution from create_function_optimizer
The factory method was resolving the function AST from original_module_ast before constructing PythonFunctionOptimizer, duplicating the _resolve_function_ast hook. Now the hook handles resolution and the factory checks the result post-construction.
1 parent a2e3a0a commit a3d81ba

2 files changed

Lines changed: 13 additions & 24 deletions

File tree

codeflash/lsp/beta.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -463,14 +463,10 @@ def _initialize_current_function_optimizer() -> Union[dict[str, str], WrappedIni
463463
"message": "Failed to prepare module for optimization",
464464
}
465465

466-
validated_original_code, original_module_ast = module_prep_result
466+
validated_original_code, _original_module_ast = module_prep_result
467467

468468
function_optimizer = server.optimizer.create_function_optimizer(
469-
fto,
470-
function_to_optimize_source_code=validated_original_code[fto.file_path].source_code,
471-
original_module_ast=original_module_ast,
472-
original_module_path=fto.file_path,
473-
function_to_tests={},
469+
fto, function_to_optimize_source_code=validated_original_code[fto.file_path].source_code, function_to_tests={}
474470
)
475471

476472
server.optimizer.current_function_optimizer = function_optimizer

codeflash/optimization/optimizer.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -247,24 +247,10 @@ def create_function_optimizer(
247247
function_to_optimize_source_code: str | None = "",
248248
function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] | None = None,
249249
total_benchmark_timings: dict[BenchmarkKey, float] | None = None,
250-
original_module_ast: ast.Module | None = None,
251-
original_module_path: Path | None = None,
252250
call_graph: DependencyResolver | None = None,
253251
) -> FunctionOptimizer | None:
254-
from codeflash.languages.python.optimizer import resolve_python_function_ast
255252
from codeflash.optimization.function_optimizer import FunctionOptimizer
256253

257-
if function_to_optimize_ast is None and original_module_ast is not None:
258-
function_to_optimize_ast = resolve_python_function_ast(
259-
function_to_optimize.function_name, function_to_optimize.parents, original_module_ast
260-
)
261-
if function_to_optimize_ast is None:
262-
logger.info(
263-
f"Function {function_to_optimize.qualified_name} not found in {original_module_path}.\n"
264-
f"Skipping optimization."
265-
)
266-
return None
267-
268254
qualified_name_w_module = function_to_optimize.qualified_name_with_modules_from_root(self.args.project_root)
269255

270256
function_specific_timings = None
@@ -284,7 +270,9 @@ def create_function_optimizer(
284270
else:
285271
cls = FunctionOptimizer
286272

287-
return cls(
273+
# TODO: _resolve_function_ast re-parses source via ast.parse() per function, even when the caller already
274+
# has a parsed module AST. Consider passing the pre-parsed AST through to avoid redundant parsing.
275+
function_optimizer = cls(
288276
function_to_optimize=function_to_optimize,
289277
test_cfg=self.test_cfg,
290278
function_to_optimize_source_code=function_to_optimize_source_code,
@@ -297,6 +285,13 @@ def create_function_optimizer(
297285
replay_tests_dir=self.replay_tests_dir,
298286
call_graph=call_graph,
299287
)
288+
if function_optimizer.function_to_optimize_ast is None:
289+
logger.info(
290+
f"Function {function_to_optimize.qualified_name} not found in "
291+
f"{function_to_optimize.file_path}.\nSkipping optimization."
292+
)
293+
return None
294+
return function_optimizer
300295

301296
def prepare_module_for_optimization(
302297
self, original_module_path: Path
@@ -593,7 +588,7 @@ def run(self) -> None:
593588
continue
594589
prepared_modules[original_module_path] = module_prep_result
595590

596-
validated_original_code, original_module_ast = prepared_modules[original_module_path]
591+
validated_original_code, _original_module_ast = prepared_modules[original_module_path]
597592

598593
function_iterator_count = i + 1
599594
logger.info(
@@ -609,8 +604,6 @@ def run(self) -> None:
609604
function_to_optimize_source_code=validated_original_code[original_module_path].source_code,
610605
function_benchmark_timings=function_benchmark_timings,
611606
total_benchmark_timings=total_benchmark_timings,
612-
original_module_ast=original_module_ast,
613-
original_module_path=original_module_path,
614607
call_graph=resolver,
615608
)
616609
if function_optimizer is None:

0 commit comments

Comments
 (0)