Skip to content

Commit cbc66e2

Browse files
authored
Merge pull request #1689 from codeflash-ai/consolidate-python-discovery
refactor: consolidate Python function discovery to CST path only
2 parents 50a2538 + 2b2caf4 commit cbc66e2

5 files changed

Lines changed: 91 additions & 211 deletions

File tree

codeflash/discovery/functions_to_optimize.py

Lines changed: 31 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import os
55
import random
66
import warnings
7-
from _ast import AsyncFunctionDef, ClassDef, FunctionDef
87
from collections import defaultdict
98
from functools import cache
109
from pathlib import Path
@@ -16,7 +15,7 @@
1615
from rich.tree import Tree
1716

1817
from codeflash.api.cfapi import get_blocklisted_functions, is_function_being_optimized_again
19-
from codeflash.cli_cmds.console import DEBUG_MODE, console, logger
18+
from codeflash.cli_cmds.console import console, logger
2019
from codeflash.code_utils.code_utils import (
2120
exit_with_message,
2221
is_class_defined_in_file,
@@ -47,10 +46,6 @@
4746

4847
from rich.text import Text
4948

50-
_property_id = "property"
51-
52-
_ast_name = ast.Name
53-
5449

5550
@dataclass(frozen=True)
5651
class FunctionProperties:
@@ -73,9 +68,9 @@ def visit_Return(self, node: cst.Return) -> None:
7368
class FunctionVisitor(cst.CSTVisitor):
7469
METADATA_DEPENDENCIES = (cst.metadata.PositionProvider, cst.metadata.ParentNodeProvider)
7570

76-
def __init__(self, file_path: str) -> None:
71+
def __init__(self, file_path: Path) -> None:
7772
super().__init__()
78-
self.file_path: str = file_path
73+
self.file_path: Path = file_path
7974
self.functions: list[FunctionToOptimize] = []
8075

8176
@staticmethod
@@ -91,15 +86,26 @@ def is_pytest_fixture(node: cst.FunctionDef) -> bool:
9186
return True
9287
return False
9388

89+
@staticmethod
90+
def is_property(node: cst.FunctionDef) -> bool:
91+
for decorator in node.decorators:
92+
dec = decorator.decorator
93+
if isinstance(dec, cst.Name) and dec.value in ("property", "cached_property"):
94+
return True
95+
return False
96+
9497
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
9598
return_visitor: ReturnStatementVisitor = ReturnStatementVisitor()
9699
node.visit(return_visitor)
97-
if return_visitor.has_return_statement and not self.is_pytest_fixture(node):
100+
if return_visitor.has_return_statement and not self.is_pytest_fixture(node) and not self.is_property(node):
98101
pos: CodeRange = self.get_metadata(cst.metadata.PositionProvider, node)
99102
parents: CSTNode | None = self.get_metadata(cst.metadata.ParentNodeProvider, node)
100103
ast_parents: list[FunctionParent] = []
101104
while parents is not None:
102-
if isinstance(parents, (cst.FunctionDef, cst.ClassDef)):
105+
if isinstance(parents, cst.FunctionDef):
106+
# Skip nested functions — only discover top-level and class-level functions
107+
return
108+
if isinstance(parents, cst.ClassDef):
103109
ast_parents.append(FunctionParent(parents.name.value, parents.__class__.__name__))
104110
parents = self.get_metadata(cst.metadata.ParentNodeProvider, parents, default=None)
105111
self.functions.append(
@@ -114,32 +120,6 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
114120
)
115121

116122

117-
def find_functions_with_return_statement(ast_module: ast.Module, file_path: Path) -> list[FunctionToOptimize]:
118-
results: list[FunctionToOptimize] = []
119-
# (node, parent_path) — iterative DFS avoids RecursionError on deeply nested ASTs
120-
stack: list[tuple[ast.AST, list[FunctionParent]]] = [(ast_module, [])]
121-
while stack:
122-
node, ast_path = stack.pop()
123-
if isinstance(node, (FunctionDef, AsyncFunctionDef)):
124-
if function_has_return_statement(node) and not function_is_a_property(node):
125-
results.append(
126-
FunctionToOptimize(
127-
function_name=node.name,
128-
file_path=file_path,
129-
parents=ast_path[:],
130-
is_async=isinstance(node, AsyncFunctionDef),
131-
)
132-
)
133-
# Don't recurse into function bodies (matches original visitor behaviour)
134-
continue
135-
child_path = (
136-
[*ast_path, FunctionParent(node.name, node.__class__.__name__)] if isinstance(node, ClassDef) else ast_path
137-
)
138-
for child in reversed(list(ast.iter_child_nodes(node))):
139-
stack.append((child, child_path))
140-
return results
141-
142-
143123
# =============================================================================
144124
# Multi-language support helpers
145125
# =============================================================================
@@ -250,23 +230,6 @@ def _is_js_ts_function_exported(file_path: Path, function_name: str) -> tuple[bo
250230
return True, None
251231

252232

253-
def _find_all_functions_in_python_file(file_path: Path) -> dict[Path, list[FunctionToOptimize]]:
254-
"""Find all optimizable functions in a Python file using AST parsing.
255-
256-
This is the original Python implementation preserved for backward compatibility.
257-
"""
258-
functions: dict[Path, list[FunctionToOptimize]] = {}
259-
with file_path.open(encoding="utf8") as f:
260-
try:
261-
ast_module = ast.parse(f.read())
262-
except Exception as e:
263-
if DEBUG_MODE:
264-
logger.exception(e)
265-
return functions
266-
functions[file_path] = find_functions_with_return_statement(ast_module, file_path)
267-
return functions
268-
269-
270233
def _find_all_functions_via_language_support(file_path: Path) -> dict[Path, list[FunctionToOptimize]]:
271234
"""Find all optimizable functions using the language support abstraction.
272235
@@ -280,7 +243,6 @@ def _find_all_functions_via_language_support(file_path: Path) -> dict[Path, list
280243
try:
281244
lang_support = get_language_support(file_path)
282245
criteria = FunctionFilterCriteria(require_return=True)
283-
# discover_functions already returns FunctionToOptimize objects
284246
functions[file_path] = lang_support.discover_functions(file_path, criteria)
285247
except Exception as e:
286248
logger.debug(f"Failed to discover functions in {file_path}: {e}")
@@ -302,7 +264,7 @@ def get_functions_to_optimize(
302264
assert sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1, (
303265
"Only one of optimize_all, replay_test, or file should be provided"
304266
)
305-
functions: dict[str, list[FunctionToOptimize]]
267+
functions: dict[Path, list[FunctionToOptimize]]
306268
trace_file_path: Path | None = None
307269
is_lsp = is_LSP_enabled()
308270
with warnings.catch_warnings():
@@ -319,7 +281,7 @@ def get_functions_to_optimize(
319281
logger.info("!lsp|Finding all functions in the file '%s'…", file)
320282
console.rule()
321283
file = Path(file) if isinstance(file, str) else file
322-
functions: dict[Path, list[FunctionToOptimize]] = find_all_functions_in_file(file)
284+
functions = find_all_functions_in_file(file)
323285
if only_get_this_function is not None:
324286
split_function = only_get_this_function.split(".")
325287
if len(split_function) > 2:
@@ -354,6 +316,7 @@ def get_functions_to_optimize(
354316
f"Function {only_get_this_function} not found in file {file}\nor the function does not have a 'return' statement or is a property"
355317
)
356318

319+
assert found_function is not None
357320
# For JavaScript/TypeScript, verify that the function (or its parent class) is exported
358321
# Non-exported functions cannot be imported by tests
359322
if found_function.language in ("javascript", "typescript"):
@@ -397,7 +360,7 @@ def get_functions_to_optimize(
397360
return filtered_modified_functions, functions_count, trace_file_path
398361

399362

400-
def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[str, list[FunctionToOptimize]]:
363+
def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[Path, list[FunctionToOptimize]]:
401364
modified_lines: dict[str, list[int]] = get_git_diff(uncommitted_changes=uncommitted_changes)
402365
return get_functions_within_lines(modified_lines)
403366

@@ -438,7 +401,7 @@ def closest_matching_file_function_name(
438401
closest_match = function
439402
closest_file = file_path
440403

441-
if closest_match is not None:
404+
if closest_match is not None and closest_file is not None:
442405
return closest_file, closest_match
443406
return None
444407

@@ -472,13 +435,13 @@ def levenshtein_distance(s1: str, s2: str) -> int:
472435
return previous[len1]
473436

474437

475-
def get_functions_inside_a_commit(commit_hash: str) -> dict[str, list[FunctionToOptimize]]:
438+
def get_functions_inside_a_commit(commit_hash: str) -> dict[Path, list[FunctionToOptimize]]:
476439
modified_lines: dict[str, list[int]] = get_git_diff(only_this_commit=commit_hash)
477440
return get_functions_within_lines(modified_lines)
478441

479442

480-
def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[str, list[FunctionToOptimize]]:
481-
functions: dict[str, list[FunctionToOptimize]] = {}
443+
def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[Path, list[FunctionToOptimize]]:
444+
functions: dict[Path, list[FunctionToOptimize]] = {}
482445
for path_str, lines_in_file in modified_lines.items():
483446
path = Path(path_str)
484447
if not path.exists():
@@ -490,9 +453,9 @@ def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[str
490453
except Exception as e:
491454
logger.exception(e)
492455
continue
493-
function_lines = FunctionVisitor(file_path=str(path))
456+
function_lines = FunctionVisitor(file_path=path)
494457
wrapper.visit(function_lines)
495-
functions[str(path)] = [
458+
functions[path] = [
496459
function_to_optimize
497460
for function_to_optimize in function_lines.functions
498461
if (start_line := function_to_optimize.starting_line) is not None
@@ -504,7 +467,7 @@ def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[str
504467

505468
def get_all_files_and_functions(
506469
module_root_path: Path, ignore_paths: list[Path], language: Language | None = None
507-
) -> dict[str, list[FunctionToOptimize]]:
470+
) -> dict[Path, list[FunctionToOptimize]]:
508471
"""Get all optimizable functions from files in the module root.
509472
510473
Args:
@@ -516,9 +479,8 @@ def get_all_files_and_functions(
516479
Dictionary mapping file paths to lists of FunctionToOptimize.
517480
518481
"""
519-
functions: dict[str, list[FunctionToOptimize]] = {}
482+
functions: dict[Path, list[FunctionToOptimize]] = {}
520483
for file_path in get_files_for_language(module_root_path, ignore_paths, language):
521-
# Find all the functions in the file
522484
functions.update(find_all_functions_in_file(file_path).items())
523485
# Randomize the order of the files to optimize to avoid optimizing the same file in the same order every time.
524486
# Helpful if an optimize-all run is stuck and we restart it.
@@ -545,16 +507,6 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt
545507
if not is_language_supported(file_path):
546508
return {}
547509

548-
try:
549-
lang_support = get_language_support(file_path)
550-
except Exception:
551-
return {}
552-
553-
# Route to Python-specific implementation for backward compatibility
554-
if lang_support.language == Language.PYTHON:
555-
return _find_all_functions_in_python_file(file_path)
556-
557-
# Use language support abstraction for other languages
558510
return _find_all_functions_via_language_support(file_path)
559511

560512

@@ -833,7 +785,7 @@ def filter_functions(
833785
disable_logs: bool = False,
834786
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
835787
resolved_project_root = project_root.resolve()
836-
filtered_modified_functions: dict[str, list[FunctionToOptimize]] = {}
788+
filtered_modified_functions: dict[Path, list[FunctionToOptimize]] = {}
837789
blocklist_funcs = get_blocklisted_functions()
838790
logger.debug(f"Blocklisted functions: {blocklist_funcs}")
839791
# Remove any function that we don't want to optimize
@@ -940,7 +892,7 @@ def is_test_file(file_path_normalized: str) -> bool:
940892
functions_tmp.append(function)
941893
_functions = functions_tmp
942894

943-
filtered_modified_functions[file_path] = _functions
895+
filtered_modified_functions[file_path_path] = _functions
944896
functions_count += len(_functions)
945897

946898
if not disable_logs:
@@ -961,7 +913,7 @@ def is_test_file(file_path_normalized: str) -> bool:
961913
if len(tree.children) > 0:
962914
console.print(tree)
963915
console.rule()
964-
return {Path(k): v for k, v in filtered_modified_functions.items() if v}, functions_count
916+
return {k: v for k, v in filtered_modified_functions.items() if v}, functions_count
965917

966918

967919
def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list[Path], module_root: Path) -> bool:
@@ -984,31 +936,3 @@ def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list
984936
file_path in submodule_paths
985937
or any(file_path.is_relative_to(submodule_path) for submodule_path in submodule_paths)
986938
)
987-
988-
989-
def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef) -> bool:
990-
# Custom DFS, return True as soon as a Return node is found
991-
stack: list[ast.AST] = list(function_node.body)
992-
while stack:
993-
node = stack.pop()
994-
if isinstance(node, ast.Return):
995-
return True
996-
# Only push child nodes that are statements; Return nodes are statements,
997-
# so this preserves correctness while avoiding unnecessary traversal into expr/Name/etc.
998-
for field in getattr(node, "_fields", ()):
999-
child = getattr(node, field, None)
1000-
if isinstance(child, list):
1001-
for item in child:
1002-
if isinstance(item, ast.stmt):
1003-
stack.append(item)
1004-
elif isinstance(child, ast.stmt):
1005-
stack.append(child)
1006-
return False
1007-
1008-
1009-
def function_is_a_property(function_node: FunctionDef | AsyncFunctionDef) -> bool:
1010-
for node in function_node.decorator_list: # noqa: SIM110
1011-
# Use isinstance rather than type(...) is ... for better performance with single inheritance trees like ast
1012-
if isinstance(node, _ast_name) and node.id == _property_id:
1013-
return True
1014-
return False

codeflash/languages/python/support.py

Lines changed: 32 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -130,56 +130,39 @@ def discover_functions(
130130

131131
criteria = filter_criteria or FunctionFilterCriteria()
132132

133-
try:
134-
# Read and parse the file using libcst with metadata
135-
source = file_path.read_text(encoding="utf-8")
136-
try:
137-
tree = cst.parse_module(source)
138-
except Exception:
139-
return []
140-
141-
# Use the libcst-based FunctionVisitor for accurate line numbers
142-
wrapper = cst.metadata.MetadataWrapper(tree)
143-
function_visitor = FunctionVisitor(file_path=str(file_path))
144-
wrapper.visit(function_visitor)
145-
146-
functions: list[FunctionToOptimize] = []
147-
for func in function_visitor.functions:
148-
if not isinstance(func, FunctionToOptimize):
149-
continue
150-
151-
# Apply filter criteria
152-
if not criteria.include_async and func.is_async:
153-
continue
154-
155-
if not criteria.include_methods and func.parents:
156-
continue
157-
158-
# Check for return statement requirement (FunctionVisitor already filters this)
159-
# but we double-check here for consistency
160-
if criteria.require_return and func.starting_line is None:
161-
continue
162-
163-
# Add is_method field based on parents
164-
func_with_is_method = FunctionToOptimize(
165-
function_name=func.function_name,
166-
file_path=file_path,
167-
parents=func.parents,
168-
starting_line=func.starting_line,
169-
ending_line=func.ending_line,
170-
starting_col=func.starting_col,
171-
ending_col=func.ending_col,
172-
is_async=func.is_async,
173-
is_method=len(func.parents) > 0 and any(p.type == "ClassDef" for p in func.parents),
174-
language="python",
175-
)
176-
functions.append(func_with_is_method)
177-
178-
return functions
133+
source = file_path.read_text(encoding="utf-8")
134+
tree = cst.parse_module(source)
135+
136+
wrapper = cst.metadata.MetadataWrapper(tree)
137+
function_visitor = FunctionVisitor(file_path=file_path)
138+
wrapper.visit(function_visitor)
139+
140+
functions: list[FunctionToOptimize] = []
141+
for func in function_visitor.functions:
142+
if not criteria.include_async and func.is_async:
143+
continue
144+
145+
if not criteria.include_methods and func.parents:
146+
continue
147+
148+
if criteria.require_return and func.starting_line is None:
149+
continue
150+
151+
func_with_is_method = FunctionToOptimize(
152+
function_name=func.function_name,
153+
file_path=file_path,
154+
parents=func.parents,
155+
starting_line=func.starting_line,
156+
ending_line=func.ending_line,
157+
starting_col=func.starting_col,
158+
ending_col=func.ending_col,
159+
is_async=func.is_async,
160+
is_method=len(func.parents) > 0 and any(p.type == "ClassDef" for p in func.parents),
161+
language="python",
162+
)
163+
functions.append(func_with_is_method)
179164

180-
except Exception as e:
181-
logger.warning("Failed to discover functions in %s: %s", file_path, e)
182-
return []
165+
return functions
183166

184167
def discover_tests(
185168
self, test_root: Path, source_functions: Sequence[FunctionToOptimize]

0 commit comments

Comments
 (0)