44import os
55import random
66import warnings
7- from _ast import AsyncFunctionDef , ClassDef , FunctionDef
87from collections import defaultdict
98from functools import cache
109from pathlib import Path
1615from rich .tree import Tree
1716
1817from codeflash .api .cfapi import get_blocklisted_functions , is_function_being_optimized_again
19- from codeflash .cli_cmds .console import DEBUG_MODE , console , logger
18+ from codeflash .cli_cmds .console import console , logger
2019from codeflash .code_utils .code_utils import (
2120 exit_with_message ,
2221 is_class_defined_in_file ,
4746
4847from rich .text import Text
4948
50- _property_id = "property"
51-
52- _ast_name = ast .Name
53-
5449
5550@dataclass (frozen = True )
5651class FunctionProperties :
@@ -73,9 +68,9 @@ def visit_Return(self, node: cst.Return) -> None:
7368class FunctionVisitor (cst .CSTVisitor ):
7469 METADATA_DEPENDENCIES = (cst .metadata .PositionProvider , cst .metadata .ParentNodeProvider )
7570
76- def __init__ (self , file_path : str ) -> None :
71+ def __init__ (self , file_path : Path ) -> None :
7772 super ().__init__ ()
78- self .file_path : str = file_path
73+ self .file_path : Path = file_path
7974 self .functions : list [FunctionToOptimize ] = []
8075
8176 @staticmethod
@@ -91,15 +86,26 @@ def is_pytest_fixture(node: cst.FunctionDef) -> bool:
9186 return True
9287 return False
9388
89+ @staticmethod
90+ def is_property (node : cst .FunctionDef ) -> bool :
91+ for decorator in node .decorators :
92+ dec = decorator .decorator
93+ if isinstance (dec , cst .Name ) and dec .value in ("property" , "cached_property" ):
94+ return True
95+ return False
96+
9497 def visit_FunctionDef (self , node : cst .FunctionDef ) -> None :
9598 return_visitor : ReturnStatementVisitor = ReturnStatementVisitor ()
9699 node .visit (return_visitor )
97- if return_visitor .has_return_statement and not self .is_pytest_fixture (node ):
100+ if return_visitor .has_return_statement and not self .is_pytest_fixture (node ) and not self . is_property ( node ) :
98101 pos : CodeRange = self .get_metadata (cst .metadata .PositionProvider , node )
99102 parents : CSTNode | None = self .get_metadata (cst .metadata .ParentNodeProvider , node )
100103 ast_parents : list [FunctionParent ] = []
101104 while parents is not None :
102- if isinstance (parents , (cst .FunctionDef , cst .ClassDef )):
105+ if isinstance (parents , cst .FunctionDef ):
106+ # Skip nested functions — only discover top-level and class-level functions
107+ return
108+ if isinstance (parents , cst .ClassDef ):
103109 ast_parents .append (FunctionParent (parents .name .value , parents .__class__ .__name__ ))
104110 parents = self .get_metadata (cst .metadata .ParentNodeProvider , parents , default = None )
105111 self .functions .append (
@@ -114,32 +120,6 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
114120 )
115121
116122
117- def find_functions_with_return_statement (ast_module : ast .Module , file_path : Path ) -> list [FunctionToOptimize ]:
118- results : list [FunctionToOptimize ] = []
119- # (node, parent_path) — iterative DFS avoids RecursionError on deeply nested ASTs
120- stack : list [tuple [ast .AST , list [FunctionParent ]]] = [(ast_module , [])]
121- while stack :
122- node , ast_path = stack .pop ()
123- if isinstance (node , (FunctionDef , AsyncFunctionDef )):
124- if function_has_return_statement (node ) and not function_is_a_property (node ):
125- results .append (
126- FunctionToOptimize (
127- function_name = node .name ,
128- file_path = file_path ,
129- parents = ast_path [:],
130- is_async = isinstance (node , AsyncFunctionDef ),
131- )
132- )
133- # Don't recurse into function bodies (matches original visitor behaviour)
134- continue
135- child_path = (
136- [* ast_path , FunctionParent (node .name , node .__class__ .__name__ )] if isinstance (node , ClassDef ) else ast_path
137- )
138- for child in reversed (list (ast .iter_child_nodes (node ))):
139- stack .append ((child , child_path ))
140- return results
141-
142-
143123# =============================================================================
144124# Multi-language support helpers
145125# =============================================================================
@@ -250,23 +230,6 @@ def _is_js_ts_function_exported(file_path: Path, function_name: str) -> tuple[bo
250230 return True , None
251231
252232
253- def _find_all_functions_in_python_file (file_path : Path ) -> dict [Path , list [FunctionToOptimize ]]:
254- """Find all optimizable functions in a Python file using AST parsing.
255-
256- This is the original Python implementation preserved for backward compatibility.
257- """
258- functions : dict [Path , list [FunctionToOptimize ]] = {}
259- with file_path .open (encoding = "utf8" ) as f :
260- try :
261- ast_module = ast .parse (f .read ())
262- except Exception as e :
263- if DEBUG_MODE :
264- logger .exception (e )
265- return functions
266- functions [file_path ] = find_functions_with_return_statement (ast_module , file_path )
267- return functions
268-
269-
270233def _find_all_functions_via_language_support (file_path : Path ) -> dict [Path , list [FunctionToOptimize ]]:
271234 """Find all optimizable functions using the language support abstraction.
272235
@@ -280,7 +243,6 @@ def _find_all_functions_via_language_support(file_path: Path) -> dict[Path, list
280243 try :
281244 lang_support = get_language_support (file_path )
282245 criteria = FunctionFilterCriteria (require_return = True )
283- # discover_functions already returns FunctionToOptimize objects
284246 functions [file_path ] = lang_support .discover_functions (file_path , criteria )
285247 except Exception as e :
286248 logger .debug (f"Failed to discover functions in { file_path } : { e } " )
@@ -302,7 +264,7 @@ def get_functions_to_optimize(
302264 assert sum ([bool (optimize_all ), bool (replay_test ), bool (file )]) <= 1 , (
303265 "Only one of optimize_all, replay_test, or file should be provided"
304266 )
305- functions : dict [str , list [FunctionToOptimize ]]
267+ functions : dict [Path , list [FunctionToOptimize ]]
306268 trace_file_path : Path | None = None
307269 is_lsp = is_LSP_enabled ()
308270 with warnings .catch_warnings ():
@@ -319,7 +281,7 @@ def get_functions_to_optimize(
319281 logger .info ("!lsp|Finding all functions in the file '%s'…" , file )
320282 console .rule ()
321283 file = Path (file ) if isinstance (file , str ) else file
322- functions : dict [ Path , list [ FunctionToOptimize ]] = find_all_functions_in_file (file )
284+ functions = find_all_functions_in_file (file )
323285 if only_get_this_function is not None :
324286 split_function = only_get_this_function .split ("." )
325287 if len (split_function ) > 2 :
@@ -354,6 +316,7 @@ def get_functions_to_optimize(
354316 f"Function { only_get_this_function } not found in file { file } \n or the function does not have a 'return' statement or is a property"
355317 )
356318
319+ assert found_function is not None
357320 # For JavaScript/TypeScript, verify that the function (or its parent class) is exported
358321 # Non-exported functions cannot be imported by tests
359322 if found_function .language in ("javascript" , "typescript" ):
@@ -397,7 +360,7 @@ def get_functions_to_optimize(
397360 return filtered_modified_functions , functions_count , trace_file_path
398361
399362
400- def get_functions_within_git_diff (uncommitted_changes : bool ) -> dict [str , list [FunctionToOptimize ]]:
363+ def get_functions_within_git_diff (uncommitted_changes : bool ) -> dict [Path , list [FunctionToOptimize ]]:
401364 modified_lines : dict [str , list [int ]] = get_git_diff (uncommitted_changes = uncommitted_changes )
402365 return get_functions_within_lines (modified_lines )
403366
@@ -438,7 +401,7 @@ def closest_matching_file_function_name(
438401 closest_match = function
439402 closest_file = file_path
440403
441- if closest_match is not None :
404+ if closest_match is not None and closest_file is not None :
442405 return closest_file , closest_match
443406 return None
444407
@@ -472,13 +435,13 @@ def levenshtein_distance(s1: str, s2: str) -> int:
472435 return previous [len1 ]
473436
474437
475- def get_functions_inside_a_commit (commit_hash : str ) -> dict [str , list [FunctionToOptimize ]]:
438+ def get_functions_inside_a_commit (commit_hash : str ) -> dict [Path , list [FunctionToOptimize ]]:
476439 modified_lines : dict [str , list [int ]] = get_git_diff (only_this_commit = commit_hash )
477440 return get_functions_within_lines (modified_lines )
478441
479442
480- def get_functions_within_lines (modified_lines : dict [str , list [int ]]) -> dict [str , list [FunctionToOptimize ]]:
481- functions : dict [str , list [FunctionToOptimize ]] = {}
443+ def get_functions_within_lines (modified_lines : dict [str , list [int ]]) -> dict [Path , list [FunctionToOptimize ]]:
444+ functions : dict [Path , list [FunctionToOptimize ]] = {}
482445 for path_str , lines_in_file in modified_lines .items ():
483446 path = Path (path_str )
484447 if not path .exists ():
@@ -490,9 +453,9 @@ def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[str
490453 except Exception as e :
491454 logger .exception (e )
492455 continue
493- function_lines = FunctionVisitor (file_path = str ( path ) )
456+ function_lines = FunctionVisitor (file_path = path )
494457 wrapper .visit (function_lines )
495- functions [str ( path ) ] = [
458+ functions [path ] = [
496459 function_to_optimize
497460 for function_to_optimize in function_lines .functions
498461 if (start_line := function_to_optimize .starting_line ) is not None
@@ -504,7 +467,7 @@ def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[str
504467
505468def get_all_files_and_functions (
506469 module_root_path : Path , ignore_paths : list [Path ], language : Language | None = None
507- ) -> dict [str , list [FunctionToOptimize ]]:
470+ ) -> dict [Path , list [FunctionToOptimize ]]:
508471 """Get all optimizable functions from files in the module root.
509472
510473 Args:
@@ -516,9 +479,8 @@ def get_all_files_and_functions(
516479 Dictionary mapping file paths to lists of FunctionToOptimize.
517480
518481 """
519- functions : dict [str , list [FunctionToOptimize ]] = {}
482+ functions : dict [Path , list [FunctionToOptimize ]] = {}
520483 for file_path in get_files_for_language (module_root_path , ignore_paths , language ):
521- # Find all the functions in the file
522484 functions .update (find_all_functions_in_file (file_path ).items ())
523485 # Randomize the order of the files to optimize to avoid optimizing the same file in the same order every time.
524486 # Helpful if an optimize-all run is stuck and we restart it.
@@ -545,16 +507,6 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt
545507 if not is_language_supported (file_path ):
546508 return {}
547509
548- try :
549- lang_support = get_language_support (file_path )
550- except Exception :
551- return {}
552-
553- # Route to Python-specific implementation for backward compatibility
554- if lang_support .language == Language .PYTHON :
555- return _find_all_functions_in_python_file (file_path )
556-
557- # Use language support abstraction for other languages
558510 return _find_all_functions_via_language_support (file_path )
559511
560512
@@ -833,7 +785,7 @@ def filter_functions(
833785 disable_logs : bool = False ,
834786) -> tuple [dict [Path , list [FunctionToOptimize ]], int ]:
835787 resolved_project_root = project_root .resolve ()
836- filtered_modified_functions : dict [str , list [FunctionToOptimize ]] = {}
788+ filtered_modified_functions : dict [Path , list [FunctionToOptimize ]] = {}
837789 blocklist_funcs = get_blocklisted_functions ()
838790 logger .debug (f"Blocklisted functions: { blocklist_funcs } " )
839791 # Remove any function that we don't want to optimize
@@ -940,7 +892,7 @@ def is_test_file(file_path_normalized: str) -> bool:
940892 functions_tmp .append (function )
941893 _functions = functions_tmp
942894
943- filtered_modified_functions [file_path ] = _functions
895+ filtered_modified_functions [file_path_path ] = _functions
944896 functions_count += len (_functions )
945897
946898 if not disable_logs :
@@ -961,7 +913,7 @@ def is_test_file(file_path_normalized: str) -> bool:
961913 if len (tree .children ) > 0 :
962914 console .print (tree )
963915 console .rule ()
964- return {Path ( k ) : v for k , v in filtered_modified_functions .items () if v }, functions_count
916+ return {k : v for k , v in filtered_modified_functions .items () if v }, functions_count
965917
966918
967919def filter_files_optimized (file_path : Path , tests_root : Path , ignore_paths : list [Path ], module_root : Path ) -> bool :
@@ -984,31 +936,3 @@ def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list
984936 file_path in submodule_paths
985937 or any (file_path .is_relative_to (submodule_path ) for submodule_path in submodule_paths )
986938 )
987-
988-
989- def function_has_return_statement (function_node : FunctionDef | AsyncFunctionDef ) -> bool :
990- # Custom DFS, return True as soon as a Return node is found
991- stack : list [ast .AST ] = list (function_node .body )
992- while stack :
993- node = stack .pop ()
994- if isinstance (node , ast .Return ):
995- return True
996- # Only push child nodes that are statements; Return nodes are statements,
997- # so this preserves correctness while avoiding unnecessary traversal into expr/Name/etc.
998- for field in getattr (node , "_fields" , ()):
999- child = getattr (node , field , None )
1000- if isinstance (child , list ):
1001- for item in child :
1002- if isinstance (item , ast .stmt ):
1003- stack .append (item )
1004- elif isinstance (child , ast .stmt ):
1005- stack .append (child )
1006- return False
1007-
1008-
1009- def function_is_a_property (function_node : FunctionDef | AsyncFunctionDef ) -> bool :
1010- for node in function_node .decorator_list : # noqa: SIM110
1011- # Use isinstance rather than type(...) is ... for better performance with single inheritance trees like ast
1012- if isinstance (node , _ast_name ) and node .id == _property_id :
1013- return True
1014- return False
0 commit comments